KoopmanESNMain.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import torch
  2. from KoopmanESNModel import KoopmanESNConfig, KoopmanESN
  3. from BaseModel.ESNModel import ESN
  4. from DataProcess import DataLoad
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. KoopmanModel = torch.load('./ModelLib/DeepKoopmanModel.pt', map_location="cpu")
  8. encoder = KoopmanModel.encoder
  9. decoder = KoopmanModel.decoder
  10. A = KoopmanModel.A.detach().numpy()
  11. state_dim = KoopmanModel.state_dim
  12. latent_dim = KoopmanModel.latent_dim
  13. lr = 0.5
  14. sr = 0.9
  15. sp = 0.1
  16. ridge = 1e-7
  17. train_start = 15000
  18. train_len = 5000
  19. train_warmup = 100
  20. predict_warmup = 100
  21. predict_len = 5000
  22. Koopmanconfig = KoopmanESNConfig(
  23. units=latent_dim,
  24. lr=lr,
  25. sr=sr,
  26. sp=sp,
  27. ridge=ridge,
  28. train_start=train_start,
  29. train_len=train_len,
  30. train_warmup=train_warmup,
  31. predict_warmup=predict_warmup,
  32. predict_len=predict_len,
  33. state_dim=latent_dim
  34. )
  35. ESNconfig = KoopmanESNConfig(
  36. units=500,
  37. lr=lr,
  38. sr=sr,
  39. sp=sp,
  40. ridge=ridge,
  41. train_start=train_start,
  42. train_len=train_len,
  43. train_warmup=train_warmup,
  44. predict_warmup=predict_warmup,
  45. predict_len=predict_len,
  46. state_dim=state_dim
  47. )
  48. KoopESNModel = KoopmanESN(
  49. config=Koopmanconfig,
  50. A=A
  51. )
  52. ESNModel = ESN(
  53. config=ESNconfig
  54. )
  55. _, DataTrainKoop, DataWarmKoop, DataValKoop = DataLoad('./DataLib/LatentNor.csv', Koopmanconfig)
  56. # KoopmanESN
  57. KoopESNModel.koopmanesn_train(DataTrainKoop)
  58. DataPreKoop = KoopESNModel.predict(DataWarmKoop)
  59. DataPreKoop_Ten = torch.from_numpy(DataPreKoop).float()
  60. DataValKoop_Ten = torch.from_numpy(DataValKoop).float()
  61. DataWarmKoop_Ten = torch.from_numpy(DataWarmKoop).float()
  62. StatePreKoop_Ten = decoder(DataPreKoop_Ten)
  63. StateValKoop_Ten = decoder(DataValKoop_Ten)
  64. StateWarmKoop_Ten = decoder(DataWarmKoop_Ten)
  65. StatePreKoop = StatePreKoop_Ten.detach().numpy()
  66. StateValKoop = StateValKoop_Ten.detach().numpy()
  67. StateWarmKoop = StateWarmKoop_Ten.detach().numpy()
  68. KoopMSE = np.linalg.norm(StatePreKoop - StateValKoop, ord='fro') ** 2 / np.prod(StateValKoop.shape)
  69. print(KoopMSE)
  70. t = np.arange(Koopmanconfig.predict_warmup + Koopmanconfig.predict_len)
  71. DataMaxMin = np.loadtxt('./DataLib/DataMaxMin.csv', delimiter=',')
  72. DataMax = DataMaxMin[0, :]
  73. DataMin = DataMaxMin[1, :]
  74. StateWarmKoop = (StateWarmKoop + 1) * (DataMax-DataMin) + DataMin
  75. StatePreKoop = (StatePreKoop + 1) * (DataMax-DataMin) + DataMin
  76. StateValKoop = (StateValKoop + 1) * (DataMax-DataMin) + DataMin
  77. for fea in np.arange(state_dim):
  78. plt.figure(fea)
  79. plt.plot(t[:Koopmanconfig.predict_warmup], StateWarmKoop[:, fea])
  80. plt.plot(t[-Koopmanconfig.predict_len:], StatePreKoop[:, fea],
  81. linestyle="--", color='red', label='KoopmanESNPre')
  82. plt.plot(t[-Koopmanconfig.predict_len:], StateValKoop[:, fea],
  83. linestyle="-", color='blue', label='StateReal')
  84. plt.legend()
  85. plt.show()