KoopmanESNTest.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import torch
  2. from KoopmanESNModel import KoopmanESNConfig, KoopmanESN
  3. from BaseModel.ESNModel import ESN
  4. from DataProcess import DataLoad
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import time
  8. KoopmanModel = torch.load('./ModelLib/DeepKoopmanModel.pt', map_location="cpu")
  9. encoder = KoopmanModel.encoder
  10. decoder = KoopmanModel.decoder
  11. A = KoopmanModel.A.detach().numpy()
  12. state_dim = KoopmanModel.state_dim
  13. latent_dim = KoopmanModel.latent_dim
  14. MonteNum = 100
  15. MinStart = 0
  16. MaxStart = 20000
  17. StartList = np.random.randint(MinStart, MaxStart, size=MonteNum)
  18. ESNMSEList = []
  19. KoopMSEList = []
  20. KoopTime = 0
  21. ESNTime = 0
  22. count = 0
  23. lr = 0.5
  24. sr = 0.9
  25. sp = 0.1
  26. ridge = 1e-7
  27. # 短时预测
  28. # train_len = 500
  29. # train_warmup = 10
  30. # predict_warmup = 100
  31. # predict_len = 20
  32. # 长期预测
  33. train_len = 5000
  34. train_warmup = 100
  35. predict_warmup = 100
  36. predict_len = 1000
  37. for train_start in StartList:
  38. count += 1
  39. print(count)
  40. Koopmanconfig = KoopmanESNConfig(
  41. units=latent_dim,
  42. lr=lr,
  43. sr=sr,
  44. sp=sp,
  45. ridge=ridge,
  46. train_start=train_start,
  47. train_len=train_len,
  48. train_warmup=train_warmup,
  49. predict_warmup=predict_warmup,
  50. predict_len=predict_len,
  51. state_dim=latent_dim
  52. )
  53. ESNconfig = KoopmanESNConfig(
  54. units=1000,
  55. lr=lr,
  56. sr=sr,
  57. sp=sp,
  58. ridge=ridge,
  59. train_start=train_start,
  60. train_len=train_len,
  61. train_warmup=train_warmup,
  62. predict_warmup=predict_warmup,
  63. predict_len=predict_len,
  64. state_dim=latent_dim
  65. )
  66. KoopESNModel = KoopmanESN(
  67. config=Koopmanconfig,
  68. A=A
  69. )
  70. ESNModel = ESN(
  71. config=ESNconfig
  72. )
  73. _, DataTrain, DataWarm, DataVal = DataLoad('./DataLib/LatentNor.csv', Koopmanconfig)
  74. # _, DataTrain, DataWarm = DataLoad('./DataLib/DataNor.csv', config)
  75. # KoopmanESN
  76. time_start = time.time()
  77. KoopESNModel.koopmanesn_train(DataTrain)
  78. DataPreKoop = KoopESNModel.predict(DataWarm)
  79. time_end = time.time()
  80. KoopTime += time_end-time_start
  81. # ESN
  82. time_start = time.time()
  83. ESNModel.esn_train(DataTrain)
  84. DataPreESN = ESNModel.predict(DataWarm)
  85. time_end = time.time()
  86. ESNTime += time_end - time_start
  87. DataPreKoop_Ten = torch.from_numpy(DataPreKoop).float()
  88. DataPreESN_Ten = torch.from_numpy(DataPreESN).float()
  89. DataVal_Ten = torch.from_numpy(DataVal).float()
  90. DataWarm_Ten = torch.from_numpy(DataWarm).float()
  91. StatePreKoop_Ten = decoder(DataPreKoop_Ten)
  92. StatePreESN_Ten = decoder(DataPreESN_Ten)
  93. StateVal_Ten = decoder(DataVal_Ten)
  94. StateWarm_Ten = decoder(DataWarm_Ten)
  95. StatePreKoop = StatePreKoop_Ten.detach().numpy()
  96. StatePreESN = StatePreESN_Ten.detach().numpy()
  97. StateVal = StateVal_Ten.detach().numpy()
  98. StateWarm = StateWarm_Ten.detach().numpy()
  99. KoopMSE = np.linalg.norm(StatePreKoop - StateVal, ord='fro') ** 2 / np.prod(StateVal.shape)
  100. ESNMSE = np.linalg.norm(StatePreESN - StateVal, ord='fro') ** 2 / np.prod(StateVal.shape)
  101. KoopMSEList.append(KoopMSE)
  102. ESNMSEList.append(ESNMSE)
  103. KoopTime /= MonteNum
  104. ESNTime /= MonteNum
  105. KoopMSEList = np.log10(KoopMSEList)
  106. ESNMSEList = np.log10(ESNMSEList)
  107. print(f"KoopESN平均单次训练+预测耗时:{KoopTime:.4f}s")
  108. print(f"ESN平均单次训练+预测耗时:{ESNTime:.4f}s")
  109. plt.figure(figsize=(8, 5))
  110. # 绘制箱线图(先使用统一的boxprops)
  111. box = plt.boxplot(
  112. [KoopMSEList, ESNMSEList],
  113. labels=['KoopmanESN', 'ESN'],
  114. patch_artist=True,
  115. boxprops={'facecolor': 'lightblue', 'edgecolor': 'navy'}, # 初始统一设置
  116. medianprops={'color': 'red', 'linewidth': 2},
  117. whiskerprops={'color': 'gray', 'linestyle': '--'},
  118. flierprops={'marker': 'o', 'markersize': 5}
  119. )
  120. # 绘制后单独设置每个箱体的颜色
  121. colors = ['lightblue', 'lightgreen']
  122. edge_colors = ['navy', 'darkgreen']
  123. for box_element, color, edge_color in zip(box['boxes'], colors, edge_colors):
  124. box_element.set_facecolor(color)
  125. box_element.set_edgecolor(edge_color)
  126. # 添加图表元素
  127. plt.title('Model Training Loss Comparison (Boxplot)', fontsize=12)
  128. plt.ylabel('log10 Loss Value', fontsize=10)
  129. plt.legend([box["boxes"][0], box["boxes"][1]], ['KoopmanESN', 'ESN'])
  130. plt.grid(True, linestyle='--', alpha=0.3)
  131. plt.tight_layout()
  132. plt.show()