DeepKoopmanModel.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import torch
  2. import torch.nn as nn
  3. from dataclasses import dataclass
  4. import torch.optim as optim
  5. import numpy as np
  6. @dataclass
  7. class DeepKoopManConfig:
  8. state_dim: int
  9. latent_dim: int
  10. hidden_dim: int
  11. # seq_len: int
  12. class DeepKoopMan(nn.Module):
  13. def __init__(self, config: DeepKoopManConfig, encoder, decoder):
  14. super().__init__()
  15. self.config = config
  16. state_dim = config.state_dim
  17. latent_dim = config.latent_dim
  18. hidden_dim = config.hidden_dim
  19. # seq_len = config.seq_len
  20. self.state_dim = state_dim
  21. self.latent_dim = latent_dim
  22. # self.seq_len = seq_len
  23. self.encoder = encoder
  24. self.decoder = decoder
  25. self.lambda1 = 0.6
  26. # for param in self.encoder.parameters():
  27. # param.requires_grad = False
  28. # for param in self.decoder.parameters():
  29. # param.requires_grad = False
  30. # # 编码器 Encoder
  31. # self.encoder = nn.Sequential(
  32. # nn.Linear(state_dim, hidden_dim),
  33. # nn.ReLU(),
  34. # nn.Linear(hidden_dim, hidden_dim),
  35. # nn.ReLU(),
  36. # nn.Linear(hidden_dim, latent_dim)
  37. # )
  38. #
  39. # # 解码器 Decoder
  40. # self.decoder = nn.Sequential(
  41. # nn.Linear(latent_dim, hidden_dim),
  42. # nn.ReLU(),
  43. # nn.Linear(hidden_dim, hidden_dim),
  44. # nn.ReLU(),
  45. # nn.Linear(hidden_dim, state_dim)
  46. # )
  47. # Koopman算子A
  48. self.A = nn.Parameter(torch.randn(latent_dim, latent_dim))
  49. def forward(self, state):
  50. latent = self.encoder(state)
  51. latent_next_pre = torch.matmul(latent, self.A)
  52. state_next_pre = self.decoder(latent_next_pre)
  53. state_pre = self.decoder(latent)
  54. return state_pre, state_next_pre
  55. def total_loss_func(self, state, state_pre, state_next, state_next_pre):
  56. mseloss = nn.MSELoss()
  57. ae_loss = mseloss(state, state_pre)
  58. pre_loss = mseloss(state_next, state_next_pre)
  59. total_loss = pre_loss * self.lambda1 + ae_loss
  60. return total_loss
  61. def deepkoopman_train(self, batch_size, epochs, lr, data_train, data_val):
  62. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  63. self.to(device)
  64. data_train_input = data_train[:, :-1, :]
  65. data_train_output = data_train[:, 1:, :]
  66. data_val_input = data_val[:, :-1, :]
  67. data_val_output = data_val[:, 1:, :]
  68. data_train_input = data_train_input.to(device)
  69. data_train_output = data_train_output.to(device)
  70. data_val_input = data_val_input.to(device)
  71. data_val_output = data_val_output.to(device)
  72. train_dataset = torch.utils.data.TensorDataset(data_train_input, data_train_output)
  73. val_dataset = torch.utils.data.TensorDataset(data_val_input, data_val_output)
  74. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  75. val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
  76. op = optim.Adam(self.parameters(), lr=lr)
  77. TrainLoss = []
  78. ValLoss = []
  79. for epoch in range(epochs):
  80. train_loss = 0
  81. for X_batch, Y_batch in train_loader:
  82. X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
  83. X_State_batch = X_batch[:, :, :self.state_dim]
  84. Y_State_batch = Y_batch[:, :, :self.state_dim]
  85. X_Pre_State_batch, Y_Pre_State_batch = self.forward(X_State_batch)
  86. loss = self.total_loss_func(X_State_batch, X_Pre_State_batch, Y_State_batch, Y_Pre_State_batch)
  87. train_loss += loss.item()
  88. op.zero_grad()
  89. loss.backward()
  90. op.step()
  91. val_loss = 0
  92. with torch.no_grad():
  93. for X_batch, Y_batch in val_loader:
  94. X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
  95. X_State_batch = X_batch[:, :, :self.state_dim]
  96. Y_State_batch = Y_batch[:, :, :self.state_dim]
  97. X_Pre_State_batch, Y_Pre_State_batch = self.forward(X_State_batch)
  98. loss = self.total_loss_func(X_State_batch, X_Pre_State_batch, Y_State_batch, Y_Pre_State_batch)
  99. val_loss += loss.item()
  100. TrainLoss.append(train_loss / len(train_loader))
  101. ValLoss.append(val_loss / len(val_loader))
  102. train_log_loss = np.log10(train_loss / len(train_loader))
  103. val_log_loss = np.log10(val_loss / len(val_loader))
  104. print(
  105. f"LatentDim[{self.latent_dim}],"
  106. f"LR[{lr}], Epoch [{epoch + 1}/{epochs}],"
  107. f"Train Loss: {train_log_loss:.2f},"
  108. f"Validation Loss: {val_log_loss:.2f}")
  109. return TrainLoss, ValLoss