ESNTest.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import torch
  2. from KoopmanESNModel import KoopmanESNConfig, KoopmanESN
  3. from BaseModel.ESNModel import ESN
  4. from DataProcess import DataLoad
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. KoopmanModel = torch.load('./ModelLib/DeepKoopmanModel.pt', map_location="cpu")
  8. encoder = KoopmanModel.encoder
  9. decoder = KoopmanModel.decoder
  10. A = KoopmanModel.A.detach().numpy()
  11. state_dim = KoopmanModel.state_dim
  12. latent_dim = KoopmanModel.latent_dim
  13. lr = 0.5
  14. sr = 0.9
  15. sp = 0.1
  16. ridge = 1e-7
  17. train_start = 15000
  18. train_len = 5000
  19. train_warmup = 100
  20. predict_warmup = 100
  21. predict_len = 1000
  22. ESNconfig = KoopmanESNConfig(
  23. units=100,
  24. lr=lr,
  25. sr=sr,
  26. sp=sp,
  27. ridge=ridge,
  28. train_start=train_start,
  29. train_len=train_len,
  30. train_warmup=train_warmup,
  31. predict_warmup=predict_warmup,
  32. predict_len=predict_len,
  33. state_dim=state_dim
  34. )
  35. ESNModel = ESN(
  36. config=ESNconfig
  37. )
  38. _, DataTrain, DataWarm, DataVal = DataLoad('./DataLib/DataNor.csv', ESNconfig)
  39. # _, DataTrain, DataWarm = DataLoad('./DataLib/DataNor.csv', config)
  40. # ESN
  41. ESNModel.esn_train(DataTrain)
  42. DataPreESN = ESNModel.predict(DataWarm)
  43. DataPreESN_Ten = torch.from_numpy(DataPreESN).float()
  44. DataVal_Ten = torch.from_numpy(DataVal).float()
  45. DataWarm_Ten = torch.from_numpy(DataWarm).float()
  46. StatePreESN_Ten = DataPreESN_Ten
  47. StateVal_Ten = DataVal_Ten
  48. StateWarm_Ten = DataWarm_Ten
  49. StatePreESN = StatePreESN_Ten.detach().numpy()
  50. StateVal = StateVal_Ten.detach().numpy()
  51. StateWarm = StateWarm_Ten.detach().numpy()
  52. ESNMSE = np.linalg.norm(StatePreESN - StateVal, ord='fro') ** 2 / np.prod(StateVal.shape)
  53. print(ESNMSE)
  54. t = np.arange(ESNconfig.predict_warmup + ESNconfig.predict_len)
  55. for fea in np.arange(state_dim):
  56. plt.figure(fea)
  57. plt.plot(t[:ESNconfig.predict_warmup], StateWarm[:, fea],
  58. linestyle="-", color='black', label='StateWarm')
  59. plt.plot(t[-ESNconfig.predict_len:], StatePreESN[:, fea],
  60. linestyle="--", color='green', label='ESNPre')
  61. plt.plot(t[-ESNconfig.predict_len:], StateVal[:, fea],
  62. linestyle="-", color='blue', label='StateReal')
  63. plt.legend()
  64. plt.show()