AutoEncoderModel.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. class AutoEncoder(nn.Module):
  6. def __init__(self, state_dim, hidden_dim, latent_dim):
  7. super().__init__()
  8. self.state_dim = state_dim
  9. self.hidden_dim = hidden_dim
  10. self.latent_dim = latent_dim
  11. # 编码器 Encoder
  12. self.encoder = nn.Sequential(
  13. nn.Linear(state_dim, hidden_dim),
  14. nn.ReLU(),
  15. nn.Linear(hidden_dim, hidden_dim),
  16. nn.ReLU(),
  17. nn.Linear(hidden_dim, latent_dim),
  18. nn.Sigmoid()
  19. )
  20. # 解码器 Decoder
  21. self.decoder = nn.Sequential(
  22. nn.Linear(latent_dim, hidden_dim),
  23. nn.ReLU(),
  24. nn.Linear(hidden_dim, hidden_dim),
  25. nn.ReLU(),
  26. nn.Linear(hidden_dim, state_dim),
  27. nn.Sigmoid()
  28. )
  29. def forward(self, state):
  30. print(state.shape)
  31. latent = self.encoder(state)
  32. state_pre = self.decoder(latent)
  33. return state_pre
  34. def autoencoder_train(self, batch_size, epochs, lr, datax_train, datay_train,
  35. datax_val, datay_val):
  36. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  37. print(device)
  38. self.to(device)
  39. datax_train, datay_train = datax_train.to(device), datay_train.to(device)
  40. datax_val, datay_val = datax_val.to(device), datay_val.to(device)
  41. train_dataset = torch.utils.data.TensorDataset(datax_train, datay_train)
  42. val_dataset = torch.utils.data.TensorDataset(datax_val, datay_val)
  43. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  44. val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
  45. lossfunc = nn.MSELoss()
  46. op = optim.Adam(self.parameters(), lr=lr)
  47. TrainLoss = []
  48. ValLoss = []
  49. for epoch in range(epochs):
  50. train_loss = 0
  51. for X_batch, Y_batch in train_loader:
  52. X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
  53. outputs = self.forward(X_batch)
  54. loss = lossfunc(outputs, Y_batch)
  55. train_loss += loss.item()
  56. op.zero_grad()
  57. loss.backward()
  58. op.step()
  59. val_loss = 0
  60. with torch.no_grad():
  61. for X_batch, Y_batch in val_loader:
  62. X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
  63. outputs = self.forward(X_batch)
  64. loss = lossfunc(outputs, Y_batch)
  65. val_loss += loss.item()
  66. TrainLoss.append(train_loss / len(train_loader))
  67. ValLoss.append(val_loss / len(val_loader))
  68. train_log_loss = np.log10(train_loss / len(train_loader))
  69. val_log_loss = np.log10(val_loss / len(val_loader))
  70. print(
  71. f"LatentDim[{self.latent_dim}], HiddenDim[{self.hidden_dim}], LR[{lr}], Epoch [{epoch + 1}/{epochs}], Train Loss: {train_log_loss}, Validation Loss: {val_log_loss}")
  72. return TrainLoss, ValLoss