KoopmanESNMain.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import torch
  2. from KoopmanESNModel import KoopmanESNConfig, KoopmanESN
  3. from BaseModel.ESNModel import ESN
  4. from DataProcess import DataLoad
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. KoopmanModel = torch.load('./ModelLib/DeepKoopmanModel.pt', map_location="cpu")
  8. encoder = KoopmanModel.encoder
  9. decoder = KoopmanModel.decoder
  10. A = KoopmanModel.A.detach().numpy()
  11. state_dim = KoopmanModel.state_dim
  12. latent_dim = KoopmanModel.latent_dim
  13. config = KoopmanESNConfig(
  14. units=latent_dim,
  15. lr=0.5,
  16. sr=0.9,
  17. sp=0.1,
  18. ridge=1e-7,
  19. train_start=19500,
  20. train_len=500,
  21. train_warmup=10,
  22. predict_warmup=100,
  23. predict_len=20,
  24. state_dim=latent_dim
  25. )
  26. # config = KoopmanESNConfig(
  27. # units=latent_dim,
  28. # lr=0.5,
  29. # sr=0.9,
  30. # sp=0.1,
  31. # ridge=1e-7,
  32. # train_start=10000,
  33. # train_len=500,
  34. # train_warmup=10,
  35. # predict_warmup=100,
  36. # predict_len=20,
  37. # state_dim=latent_dim
  38. # )
  39. KoopESNModel = KoopmanESN(
  40. config=config,
  41. A=A
  42. )
  43. ESNModel = ESN(
  44. config=config
  45. )
  46. _, DataTrain, DataWarm, DataVal = DataLoad('./DataLib/LatentNor.csv', config)
  47. # _, DataTrain, DataWarm = DataLoad('./DataLib/DataNor.csv', config)
  48. # KoopmanESN
  49. KoopESNModel.koopmanesn_train(DataTrain)
  50. DataPreKoop = KoopESNModel.predict(DataWarm)
  51. # ESN
  52. ESNModel.esn_train(DataTrain)
  53. DataPreESN = ESNModel.predict(DataWarm)
  54. DataPreKoop_Ten = torch.from_numpy(DataPreKoop).float()
  55. DataPreESN_Ten = torch.from_numpy(DataPreESN).float()
  56. DataVal_Ten = torch.from_numpy(DataVal).float()
  57. DataWarm_Ten = torch.from_numpy(DataWarm).float()
  58. StatePreKoop_Ten = decoder(DataPreKoop_Ten)
  59. StatePreESN_Ten = decoder(DataPreESN_Ten)
  60. StateVal_Ten = decoder(DataVal_Ten)
  61. StateWarm_Ten = decoder(DataWarm_Ten)
  62. StatePreKoop = StatePreKoop_Ten.detach().numpy()
  63. StatePreESN = StatePreESN_Ten.detach().numpy()
  64. StateVal = StateVal_Ten.detach().numpy()
  65. StateWarm = StateWarm_Ten.detach().numpy()
  66. KoopMSE = np.linalg.norm(StatePreKoop-StateVal, ord='fro')**2/np.prod(StateVal.shape)
  67. ESNMSE = np.linalg.norm(StatePreESN-StateVal, ord='fro')**2/np.prod(StateVal.shape)
  68. print(KoopMSE)
  69. print(ESNMSE)
  70. t = np.arange(config.predict_warmup + config.predict_len)
  71. for fea in np.arange(state_dim):
  72. plt.figure(fea)
  73. plt.plot(t[:config.predict_warmup], StateWarm[:, fea],
  74. linestyle="-", color='black', label='StateWarm')
  75. plt.plot(t[-config.predict_len:], StatePreKoop[:, fea],
  76. linestyle="--", color='red', label='KoopmanESNPre')
  77. plt.plot(t[-config.predict_len:], StatePreESN[:, fea],
  78. linestyle="--", color='green', label='ESNPre')
  79. plt.plot(t[-config.predict_len:], StateVal[:, fea],
  80. linestyle="-", color='blue', label='StateReal')
  81. plt.legend()
  82. plt.show()