123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- import torch
- from KoopmanESNModel import KoopmanESNConfig, KoopmanESN
- from BaseModel.ESNModel import ESN
- from DataProcess import DataLoad
- import numpy as np
- import matplotlib.pyplot as plt
- import time
- KoopmanModel = torch.load('./ModelLib/DeepKoopmanModel.pt', map_location="cpu")
- encoder = KoopmanModel.encoder
- decoder = KoopmanModel.decoder
- A = KoopmanModel.A.detach().numpy()
- state_dim = KoopmanModel.state_dim
- latent_dim = KoopmanModel.latent_dim
- MonteNum = 100
- MinStart = 0
- MaxStart = 25000
- StartList = np.random.randint(MinStart, MaxStart, size=MonteNum)
- ESNMSEList = []
- KoopMSEList = []
- KoopTime = 0
- ESNTime = 0
- count = 0
- for train_start in StartList:
- count += 1
- print(count)
- config = KoopmanESNConfig(
- units=latent_dim,
- lr=0.5,
- sr=0.9,
- sp=0.1,
- ridge=1e-7,
- train_start=train_start,
- train_len=500,
- train_warmup=10,
- predict_warmup=100,
- predict_len=20,
- state_dim=latent_dim
- )
- KoopESNModel = KoopmanESN(
- config=config,
- A=A
- )
- ESNModel = ESN(
- config=config
- )
- _, DataTrain, DataWarm, DataVal = DataLoad('./DataLib/LatentNor.csv', config)
- # _, DataTrain, DataWarm = DataLoad('./DataLib/DataNor.csv', config)
- # KoopmanESN
- time_start = time.time()
- KoopESNModel.koopmanesn_train(DataTrain)
- DataPreKoop = KoopESNModel.predict(DataWarm)
- time_end = time.time()
- KoopTime += time_end-time_start
- # ESN
- time_start = time.time()
- ESNModel.esn_train(DataTrain)
- DataPreESN = ESNModel.predict(DataWarm)
- time_end = time.time()
- ESNTime += time_end - time_start
- DataPreKoop_Ten = torch.from_numpy(DataPreKoop).float()
- DataPreESN_Ten = torch.from_numpy(DataPreESN).float()
- DataVal_Ten = torch.from_numpy(DataVal).float()
- DataWarm_Ten = torch.from_numpy(DataWarm).float()
- StatePreKoop_Ten = decoder(DataPreKoop_Ten)
- StatePreESN_Ten = decoder(DataPreESN_Ten)
- StateVal_Ten = decoder(DataVal_Ten)
- StateWarm_Ten = decoder(DataWarm_Ten)
- StatePreKoop = StatePreKoop_Ten.detach().numpy()
- StatePreESN = StatePreESN_Ten.detach().numpy()
- StateVal = StateVal_Ten.detach().numpy()
- StateWarm = StateWarm_Ten.detach().numpy()
- KoopMSE = np.linalg.norm(StatePreKoop - StateVal, ord='fro') ** 2 / np.prod(StateVal.shape)
- ESNMSE = np.linalg.norm(StatePreESN - StateVal, ord='fro') ** 2 / np.prod(StateVal.shape)
- KoopMSEList.append(KoopMSE)
- ESNMSEList.append(ESNMSE)
- KoopTime /= MonteNum
- ESNTime /= MonteNum
- KoopMSEList = np.log10(KoopMSEList)
- ESNMSEList = np.log10(ESNMSEList)
- print(f"KoopESN平均单次训练+预测耗时:{KoopTime:.4f}s")
- print(f"ESN平均单次训练+预测耗时:{ESNTime:.4f}s")
- plt.figure(figsize=(8, 5))
- # 绘制箱线图(先使用统一的boxprops)
- box = plt.boxplot(
- [KoopMSEList, ESNMSEList],
- labels=['KoopmanESN', 'ESN'],
- patch_artist=True,
- boxprops={'facecolor': 'lightblue', 'edgecolor': 'navy'}, # 初始统一设置
- medianprops={'color': 'red', 'linewidth': 2},
- whiskerprops={'color': 'gray', 'linestyle': '--'},
- flierprops={'marker': 'o', 'markersize': 5}
- )
- # 绘制后单独设置每个箱体的颜色
- colors = ['lightblue', 'lightgreen']
- edge_colors = ['navy', 'darkgreen']
- for box_element, color, edge_color in zip(box['boxes'], colors, edge_colors):
- box_element.set_facecolor(color)
- box_element.set_edgecolor(edge_color)
- # 添加图表元素
- plt.title('Model Training Loss Comparison (Boxplot)', fontsize=12)
- plt.ylabel('log10 Loss Value', fontsize=10)
- plt.legend([box["boxes"][0], box["boxes"][1]], ['KoopmanESN', 'ESN'])
- plt.grid(True, linestyle='--', alpha=0.3)
- plt.tight_layout()
- plt.show()
|