DeepKoopmanModel.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import torch
  2. import torch.nn as nn
  3. from dataclasses import dataclass
  4. import torch.optim as optim
  5. import numpy as np
  6. @dataclass
  7. class DeepKoopManConfig:
  8. state_dim: int
  9. latent_dim: int
  10. hidden_dim: int
  11. # seq_len: int
  12. class DeepKoopMan(nn.Module):
  13. def __init__(self, config: DeepKoopManConfig, encoder, decoder):
  14. super().__init__()
  15. self.config = config
  16. state_dim = config.state_dim
  17. latent_dim = config.latent_dim
  18. hidden_dim = config.hidden_dim
  19. # seq_len = config.seq_len
  20. self.state_dim = state_dim
  21. self.latent_dim = latent_dim
  22. # self.seq_len = seq_len
  23. self.encoder = encoder
  24. self.decoder = decoder
  25. self.lambda1 = 0.6
  26. # for param in self.encoder.parameters():
  27. # param.requires_grad = False
  28. # for param in self.decoder.parameters():
  29. # param.requires_grad = False
  30. # # 编码器 Encoder
  31. # self.encoder = nn.Sequential(
  32. # nn.Linear(state_dim, hidden_dim),
  33. # nn.ReLU(),
  34. # nn.Linear(hidden_dim, hidden_dim),
  35. # nn.ReLU(),
  36. # nn.Linear(hidden_dim, latent_dim)
  37. # )
  38. #
  39. # # 解码器 Decoder
  40. # self.decoder = nn.Sequential(
  41. # nn.Linear(latent_dim, hidden_dim),
  42. # nn.ReLU(),
  43. # nn.Linear(hidden_dim, hidden_dim),
  44. # nn.ReLU(),
  45. # nn.Linear(hidden_dim, state_dim)
  46. # )
  47. # Koopman算子A
  48. self.A = nn.Parameter(torch.randn(latent_dim, latent_dim))
  49. def forward(self, state):
  50. latent = self.encoder(state)
  51. dlatent_plus_dt = torch.matmul(latent, self.A)
  52. latent_next_pre = latent + dlatent_plus_dt
  53. state_next_pre = self.decoder(latent_next_pre)
  54. state_pre = self.decoder(latent)
  55. return state_pre, state_next_pre
  56. def total_loss_func(self, state, state_pre, state_next, state_next_pre):
  57. mseloss = nn.MSELoss()
  58. ae_loss = mseloss(state, state_pre)
  59. pre_loss = mseloss(state_next, state_next_pre)
  60. total_loss = pre_loss * self.lambda1 + ae_loss
  61. return total_loss
  62. def deepkoopman_train(self, batch_size, epochs, lr, data_train, data_val):
  63. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  64. self.to(device)
  65. data_train_input = data_train[:, :-1, :]
  66. data_train_output = data_train[:, 1:, :]
  67. data_val_input = data_val[:, :-1, :]
  68. data_val_output = data_val[:, 1:, :]
  69. data_train_input = data_train_input.to(device)
  70. data_train_output = data_train_output.to(device)
  71. data_val_input = data_val_input.to(device)
  72. data_val_output = data_val_output.to(device)
  73. train_dataset = torch.utils.data.TensorDataset(data_train_input, data_train_output)
  74. val_dataset = torch.utils.data.TensorDataset(data_val_input, data_val_output)
  75. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  76. val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
  77. op = optim.Adam(self.parameters(), lr=lr)
  78. TrainLoss = []
  79. ValLoss = []
  80. for epoch in range(epochs):
  81. train_loss = 0
  82. for X_batch, Y_batch in train_loader:
  83. X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
  84. X_State_batch = X_batch[:, :, :self.state_dim]
  85. Y_State_batch = Y_batch[:, :, :self.state_dim]
  86. X_Pre_State_batch, Y_Pre_State_batch = self.forward(X_State_batch)
  87. loss = self.total_loss_func(X_State_batch, X_Pre_State_batch, Y_State_batch, Y_Pre_State_batch)
  88. train_loss += loss.item()
  89. op.zero_grad()
  90. loss.backward()
  91. op.step()
  92. val_loss = 0
  93. with torch.no_grad():
  94. for X_batch, Y_batch in val_loader:
  95. X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
  96. X_State_batch = X_batch[:, :, :self.state_dim]
  97. Y_State_batch = Y_batch[:, :, :self.state_dim]
  98. X_Pre_State_batch, Y_Pre_State_batch = self.forward(X_State_batch)
  99. loss = self.total_loss_func(X_State_batch, X_Pre_State_batch, Y_State_batch, Y_Pre_State_batch)
  100. val_loss += loss.item()
  101. TrainLoss.append(train_loss / len(train_loader))
  102. ValLoss.append(val_loss / len(val_loader))
  103. train_log_loss = np.log10(train_loss / len(train_loader))
  104. val_log_loss = np.log10(val_loss / len(val_loader))
  105. print(
  106. f"LatentDim[{self.latent_dim}],"
  107. f"LR[{lr}], Epoch [{epoch + 1}/{epochs}],"
  108. f"Train Loss: {train_log_loss:.2f},"
  109. f"Validation Loss: {val_log_loss:.2f}")
  110. return TrainLoss, ValLoss