AutoEncoderModel.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import numpy as np
  5. class AutoEncoder(nn.Module):
  6. def __init__(self, state_dim, hidden_dim, latent_dim):
  7. super().__init__()
  8. self.state_dim = state_dim
  9. self.hidden_dim = hidden_dim
  10. self.latent_dim = latent_dim
  11. # 编码器 Encoder
  12. self.encoder = nn.Sequential(
  13. nn.Linear(state_dim, hidden_dim),
  14. nn.Tanh(),
  15. nn.Linear(hidden_dim, hidden_dim),
  16. nn.Tanh(),
  17. nn.Linear(hidden_dim, latent_dim)
  18. )
  19. # 解码器 Decoder
  20. self.decoder = nn.Sequential(
  21. nn.Linear(latent_dim, hidden_dim),
  22. nn.Tanh(),
  23. nn.Linear(hidden_dim, hidden_dim),
  24. nn.Tanh(),
  25. nn.Linear(hidden_dim, state_dim)
  26. )
  27. def forward(self, state):
  28. latent = self.encoder(state)
  29. state_pre = self.decoder(latent)
  30. return state_pre
  31. def autoencoder_train(self, batch_size, epochs, lr, datax_train, datay_train,
  32. datax_val, datay_val):
  33. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  34. self.to(device)
  35. datax_train, datay_train = datax_train.to(device), datay_train.to(device)
  36. datax_val, datay_val = datax_val.to(device), datay_val.to(device)
  37. train_dataset = torch.utils.data.TensorDataset(datax_train, datay_train)
  38. val_dataset = torch.utils.data.TensorDataset(datax_val, datay_val)
  39. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  40. val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
  41. lossfunc = nn.MSELoss()
  42. op = optim.Adam(self.parameters(), lr=lr)
  43. TrainLoss = []
  44. ValLoss = []
  45. for epoch in range(epochs):
  46. train_loss = 0
  47. for X_batch, Y_batch in train_loader:
  48. X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
  49. outputs = self.forward(X_batch)
  50. loss = lossfunc(outputs, Y_batch)
  51. train_loss += loss.item()
  52. op.zero_grad()
  53. loss.backward()
  54. op.step()
  55. val_loss = 0
  56. with torch.no_grad():
  57. for X_batch, Y_batch in val_loader:
  58. X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
  59. outputs = self.forward(X_batch)
  60. loss = lossfunc(outputs, Y_batch)
  61. val_loss += loss.item()
  62. TrainLoss.append(train_loss / len(train_loader))
  63. ValLoss.append(val_loss / len(val_loader))
  64. train_log_loss = np.log10(train_loss / len(train_loader))
  65. val_log_loss = np.log10(val_loss / len(val_loader))
  66. print(
  67. f"LatentDim[{self.latent_dim}], HiddenDim[{self.hidden_dim}], LR[{lr}], Epoch [{epoch + 1}/{epochs}], "
  68. f"Train Loss: {train_log_loss}, Validation Loss: {val_log_loss}")
  69. return TrainLoss, ValLoss