KoopmanESNTest.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import torch
  2. from KoopmanESNModel import KoopmanESNConfig, KoopmanESN
  3. from BaseModel.ESNModel import ESN
  4. from DataProcess import DataLoad
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import time
  8. KoopmanModel = torch.load('./ModelLib/DeepKoopmanModel.pt', map_location="cpu")
  9. encoder = KoopmanModel.encoder
  10. decoder = KoopmanModel.decoder
  11. A = KoopmanModel.A.detach().numpy()
  12. state_dim = KoopmanModel.state_dim
  13. latent_dim = KoopmanModel.latent_dim
  14. MonteNum = 100
  15. MinStart = 0
  16. MaxStart = 25000
  17. StartList = np.random.randint(MinStart, MaxStart, size=MonteNum)
  18. ESNMSEList = []
  19. KoopMSEList = []
  20. KoopTime = 0
  21. ESNTime = 0
  22. count = 0
  23. for train_start in StartList:
  24. count += 1
  25. print(count)
  26. config = KoopmanESNConfig(
  27. units=latent_dim,
  28. lr=0.5,
  29. sr=0.9,
  30. sp=0.1,
  31. ridge=1e-7,
  32. train_start=train_start,
  33. train_len=500,
  34. train_warmup=10,
  35. predict_warmup=100,
  36. predict_len=20,
  37. state_dim=latent_dim
  38. )
  39. KoopESNModel = KoopmanESN(
  40. config=config,
  41. A=A
  42. )
  43. ESNModel = ESN(
  44. config=config
  45. )
  46. _, DataTrain, DataWarm, DataVal = DataLoad('./DataLib/LatentNor.csv', config)
  47. # _, DataTrain, DataWarm = DataLoad('./DataLib/DataNor.csv', config)
  48. # KoopmanESN
  49. time_start = time.time()
  50. KoopESNModel.koopmanesn_train(DataTrain)
  51. DataPreKoop = KoopESNModel.predict(DataWarm)
  52. time_end = time.time()
  53. KoopTime += time_end-time_start
  54. # ESN
  55. time_start = time.time()
  56. ESNModel.esn_train(DataTrain)
  57. DataPreESN = ESNModel.predict(DataWarm)
  58. time_end = time.time()
  59. ESNTime += time_end - time_start
  60. DataPreKoop_Ten = torch.from_numpy(DataPreKoop).float()
  61. DataPreESN_Ten = torch.from_numpy(DataPreESN).float()
  62. DataVal_Ten = torch.from_numpy(DataVal).float()
  63. DataWarm_Ten = torch.from_numpy(DataWarm).float()
  64. StatePreKoop_Ten = decoder(DataPreKoop_Ten)
  65. StatePreESN_Ten = decoder(DataPreESN_Ten)
  66. StateVal_Ten = decoder(DataVal_Ten)
  67. StateWarm_Ten = decoder(DataWarm_Ten)
  68. StatePreKoop = StatePreKoop_Ten.detach().numpy()
  69. StatePreESN = StatePreESN_Ten.detach().numpy()
  70. StateVal = StateVal_Ten.detach().numpy()
  71. StateWarm = StateWarm_Ten.detach().numpy()
  72. KoopMSE = np.linalg.norm(StatePreKoop - StateVal, ord='fro') ** 2 / np.prod(StateVal.shape)
  73. ESNMSE = np.linalg.norm(StatePreESN - StateVal, ord='fro') ** 2 / np.prod(StateVal.shape)
  74. KoopMSEList.append(KoopMSE)
  75. ESNMSEList.append(ESNMSE)
  76. KoopTime /= MonteNum
  77. ESNTime /= MonteNum
  78. KoopMSEList = np.log10(KoopMSEList)
  79. ESNMSEList = np.log10(ESNMSEList)
  80. print(f"KoopESN平均单次训练+预测耗时:{KoopTime:.4f}s")
  81. print(f"ESN平均单次训练+预测耗时:{ESNTime:.4f}s")
  82. plt.figure(figsize=(8, 5))
  83. # 绘制箱线图(先使用统一的boxprops)
  84. box = plt.boxplot(
  85. [KoopMSEList, ESNMSEList],
  86. labels=['KoopmanESN', 'ESN'],
  87. patch_artist=True,
  88. boxprops={'facecolor': 'lightblue', 'edgecolor': 'navy'}, # 初始统一设置
  89. medianprops={'color': 'red', 'linewidth': 2},
  90. whiskerprops={'color': 'gray', 'linestyle': '--'},
  91. flierprops={'marker': 'o', 'markersize': 5}
  92. )
  93. # 绘制后单独设置每个箱体的颜色
  94. colors = ['lightblue', 'lightgreen']
  95. edge_colors = ['navy', 'darkgreen']
  96. for box_element, color, edge_color in zip(box['boxes'], colors, edge_colors):
  97. box_element.set_facecolor(color)
  98. box_element.set_edgecolor(edge_color)
  99. # 添加图表元素
  100. plt.title('Model Training Loss Comparison (Boxplot)', fontsize=12)
  101. plt.ylabel('log10 Loss Value', fontsize=10)
  102. plt.legend([box["boxes"][0], box["boxes"][1]], ['KoopmanESN', 'ESN'])
  103. plt.grid(True, linestyle='--', alpha=0.3)
  104. plt.tight_layout()
  105. plt.show()