KoopmanESNMain.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import torch
  2. from KoopmanESNModel import KoopmanESNConfig, KoopmanESN
  3. from BaseModel.ESNModel import ESN
  4. from DataProcess import DataLoad
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. KoopmanModel = torch.load('./ModelLib/DeepKoopmanModel.pt', map_location="cpu")
  8. encoder = KoopmanModel.encoder
  9. decoder = KoopmanModel.decoder
  10. A = KoopmanModel.A.detach().numpy()
  11. state_dim = KoopmanModel.state_dim
  12. latent_dim = KoopmanModel.latent_dim
  13. lr = 0.5
  14. sr = 0.9
  15. sp = 0.1
  16. ridge = 1e-7
  17. train_start = 15000
  18. train_len = 5000
  19. train_warmup = 100
  20. predict_warmup = 100
  21. predict_len = 500
  22. Koopmanconfig = KoopmanESNConfig(
  23. units=latent_dim,
  24. lr=lr,
  25. sr=sr,
  26. sp=sp,
  27. ridge=ridge,
  28. train_start=train_start,
  29. train_len=train_len,
  30. train_warmup=train_warmup,
  31. predict_warmup=predict_warmup,
  32. predict_len=predict_len,
  33. state_dim=latent_dim
  34. )
  35. ESNconfig = KoopmanESNConfig(
  36. units=1000,
  37. lr=lr,
  38. sr=sr,
  39. sp=sp,
  40. ridge=ridge,
  41. train_start=train_start,
  42. train_len=train_len,
  43. train_warmup=train_warmup,
  44. predict_warmup=predict_warmup,
  45. predict_len=predict_len,
  46. state_dim=latent_dim
  47. )
  48. KoopESNModel = KoopmanESN(
  49. config=Koopmanconfig,
  50. A=A
  51. )
  52. ESNModel = ESN(
  53. config=ESNconfig
  54. )
  55. _, DataTrain, DataWarm, DataVal = DataLoad('./DataLib/LatentNor.csv', Koopmanconfig)
  56. # _, DataTrain, DataWarm = DataLoad('./DataLib/DataNor.csv', config)
  57. # KoopmanESN
  58. KoopESNModel.koopmanesn_train(DataTrain)
  59. DataPreKoop = KoopESNModel.predict(DataWarm)
  60. # ESN
  61. ESNModel.esn_train(DataTrain)
  62. DataPreESN = ESNModel.predict(DataWarm)
  63. DataPreKoop_Ten = torch.from_numpy(DataPreKoop).float()
  64. DataPreESN_Ten = torch.from_numpy(DataPreESN).float()
  65. DataVal_Ten = torch.from_numpy(DataVal).float()
  66. DataWarm_Ten = torch.from_numpy(DataWarm).float()
  67. StatePreKoop_Ten = decoder(DataPreKoop_Ten)
  68. StatePreESN_Ten = decoder(DataPreESN_Ten)
  69. StateVal_Ten = decoder(DataVal_Ten)
  70. StateWarm_Ten = decoder(DataWarm_Ten)
  71. StatePreKoop = StatePreKoop_Ten.detach().numpy()
  72. StatePreESN = StatePreESN_Ten.detach().numpy()
  73. StateVal = StateVal_Ten.detach().numpy()
  74. StateWarm = StateWarm_Ten.detach().numpy()
  75. KoopMSE = np.linalg.norm(StatePreKoop - StateVal, ord='fro') ** 2 / np.prod(StateVal.shape)
  76. ESNMSE = np.linalg.norm(StatePreESN - StateVal, ord='fro') ** 2 / np.prod(StateVal.shape)
  77. print(KoopMSE)
  78. print(ESNMSE)
  79. t = np.arange(Koopmanconfig.predict_warmup + Koopmanconfig.predict_len)
  80. for fea in np.arange(state_dim):
  81. plt.figure(fea)
  82. plt.plot(t[:Koopmanconfig.predict_warmup], StateWarm[:, fea],
  83. linestyle="-", color='black', label='StateWarm')
  84. plt.plot(t[-Koopmanconfig.predict_len:], StatePreKoop[:, fea],
  85. linestyle="--", color='red', label='KoopmanESNPre')
  86. plt.plot(t[-Koopmanconfig.predict_len:], StatePreESN[:, fea],
  87. linestyle="--", color='green', label='ESNPre')
  88. plt.plot(t[-Koopmanconfig.predict_len:], StateVal[:, fea],
  89. linestyle="-", color='blue', label='StateReal')
  90. plt.legend()
  91. plt.show()