DeepKoopmanModel.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import torch
  2. import torch.nn as nn
  3. from dataclasses import dataclass
  4. import torch.optim as optim
  5. import numpy as np
  6. @dataclass
  7. class DeepKoopManConfig:
  8. state_dim: int
  9. latent_dim: int
  10. hidden_dim: int
  11. class DeepKoopMan(nn.Module):
  12. def __init__(self, config: DeepKoopManConfig, encoder, decoder):
  13. super().__init__()
  14. self.config = config
  15. state_dim = config.state_dim
  16. latent_dim = config.latent_dim
  17. hidden_dim = config.hidden_dim
  18. # seq_len = config.seq_len
  19. self.state_dim = state_dim
  20. self.latent_dim = latent_dim
  21. self.encoder = encoder
  22. self.decoder = decoder
  23. self.lambda1 = 0.6
  24. # Koopman算子A
  25. self.A = nn.Parameter(torch.randn(latent_dim, latent_dim))
  26. def forward(self, state):
  27. latent = self.encoder(state)
  28. latent_next_pre = torch.matmul(latent, self.A)
  29. state_next_pre = self.decoder(latent_next_pre)
  30. state_pre = self.decoder(latent)
  31. return state_pre, state_next_pre
  32. def total_loss_func(self, state, state_pre, state_next, state_next_pre):
  33. mseloss = nn.MSELoss()
  34. ae_loss = mseloss(state, state_pre)
  35. pre_loss = mseloss(state_next, state_next_pre)
  36. total_loss = pre_loss * self.lambda1 + ae_loss
  37. return total_loss
  38. def deepkoopman_train(self, batch_size, epochs, lr, data_train, data_val):
  39. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  40. self.to(device)
  41. data_train_input = data_train[:, :-1, :]
  42. data_train_output = data_train[:, 1:, :]
  43. data_val_input = data_val[:, :-1, :]
  44. data_val_output = data_val[:, 1:, :]
  45. data_train_input = data_train_input.to(device)
  46. data_train_output = data_train_output.to(device)
  47. data_val_input = data_val_input.to(device)
  48. data_val_output = data_val_output.to(device)
  49. train_dataset = torch.utils.data.TensorDataset(data_train_input, data_train_output)
  50. val_dataset = torch.utils.data.TensorDataset(data_val_input, data_val_output)
  51. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  52. val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
  53. op = optim.Adam(self.parameters(), lr=lr)
  54. TrainLoss = []
  55. ValLoss = []
  56. for epoch in range(epochs):
  57. train_loss = 0
  58. for X_batch, Y_batch in train_loader:
  59. X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
  60. X_State_batch = X_batch[:, :, :self.state_dim]
  61. Y_State_batch = Y_batch[:, :, :self.state_dim]
  62. X_Pre_State_batch, Y_Pre_State_batch = self.forward(X_State_batch)
  63. loss = self.total_loss_func(X_State_batch, X_Pre_State_batch, Y_State_batch, Y_Pre_State_batch)
  64. train_loss += loss.item()
  65. op.zero_grad()
  66. loss.backward()
  67. op.step()
  68. val_loss = 0
  69. with torch.no_grad():
  70. for X_batch, Y_batch in val_loader:
  71. X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
  72. X_State_batch = X_batch[:, :, :self.state_dim]
  73. Y_State_batch = Y_batch[:, :, :self.state_dim]
  74. X_Pre_State_batch, Y_Pre_State_batch = self.forward(X_State_batch)
  75. loss = self.total_loss_func(X_State_batch, X_Pre_State_batch, Y_State_batch, Y_Pre_State_batch)
  76. val_loss += loss.item()
  77. TrainLoss.append(train_loss / len(train_loader))
  78. ValLoss.append(val_loss / len(val_loader))
  79. train_log_loss = np.log10(train_loss / len(train_loader))
  80. val_log_loss = np.log10(val_loss / len(val_loader))
  81. print(
  82. f"LatentDim[{self.latent_dim}],"
  83. f"LR[{lr}], Epoch [{epoch + 1}/{epochs}],"
  84. f"Train Loss: {train_log_loss:.2f},"
  85. f"Validation Loss: {val_log_loss:.2f}")
  86. return TrainLoss, ValLoss