import torch from KoopmanESNModel import KoopmanESNConfig, KoopmanESN from BaseModel.ESNModel import ESN from DataProcess import DataLoad import numpy as np import matplotlib.pyplot as plt import time KoopmanModel = torch.load('./ModelLib/DeepKoopmanModel.pt', map_location="cpu") encoder = KoopmanModel.encoder decoder = KoopmanModel.decoder A = KoopmanModel.A.detach().numpy() state_dim = KoopmanModel.state_dim latent_dim = KoopmanModel.latent_dim MonteNum = 100 MinStart = 0 MaxStart = 25000 StartList = np.random.randint(MinStart, MaxStart, size=MonteNum) ESNMSEList = [] KoopMSEList = [] KoopTime = 0 ESNTime = 0 count = 0 for train_start in StartList: count += 1 print(count) config = KoopmanESNConfig( units=latent_dim, lr=0.5, sr=0.9, sp=0.1, ridge=1e-7, train_start=train_start, train_len=500, train_warmup=10, predict_warmup=100, predict_len=20, state_dim=latent_dim ) KoopESNModel = KoopmanESN( config=config, A=A ) ESNModel = ESN( config=config ) _, DataTrain, DataWarm, DataVal = DataLoad('./DataLib/LatentNor.csv', config) # _, DataTrain, DataWarm = DataLoad('./DataLib/DataNor.csv', config) # KoopmanESN time_start = time.time() KoopESNModel.koopmanesn_train(DataTrain) DataPreKoop = KoopESNModel.predict(DataWarm) time_end = time.time() KoopTime += time_end-time_start # ESN time_start = time.time() ESNModel.esn_train(DataTrain) DataPreESN = ESNModel.predict(DataWarm) time_end = time.time() ESNTime += time_end - time_start DataPreKoop_Ten = torch.from_numpy(DataPreKoop).float() DataPreESN_Ten = torch.from_numpy(DataPreESN).float() DataVal_Ten = torch.from_numpy(DataVal).float() DataWarm_Ten = torch.from_numpy(DataWarm).float() StatePreKoop_Ten = decoder(DataPreKoop_Ten) StatePreESN_Ten = decoder(DataPreESN_Ten) StateVal_Ten = decoder(DataVal_Ten) StateWarm_Ten = decoder(DataWarm_Ten) StatePreKoop = StatePreKoop_Ten.detach().numpy() StatePreESN = StatePreESN_Ten.detach().numpy() StateVal = StateVal_Ten.detach().numpy() StateWarm = StateWarm_Ten.detach().numpy() KoopMSE = np.linalg.norm(StatePreKoop - StateVal, ord='fro') ** 2 / np.prod(StateVal.shape) ESNMSE = np.linalg.norm(StatePreESN - StateVal, ord='fro') ** 2 / np.prod(StateVal.shape) KoopMSEList.append(KoopMSE) ESNMSEList.append(ESNMSE) KoopTime /= MonteNum ESNTime /= MonteNum KoopMSEList = np.log10(KoopMSEList) ESNMSEList = np.log10(ESNMSEList) print(f"KoopESN平均单次训练+预测耗时:{KoopTime:.4f}s") print(f"ESN平均单次训练+预测耗时:{ESNTime:.4f}s") plt.figure(figsize=(8, 5)) # 绘制箱线图(先使用统一的boxprops) box = plt.boxplot( [KoopMSEList, ESNMSEList], labels=['KoopmanESN', 'ESN'], patch_artist=True, boxprops={'facecolor': 'lightblue', 'edgecolor': 'navy'}, # 初始统一设置 medianprops={'color': 'red', 'linewidth': 2}, whiskerprops={'color': 'gray', 'linestyle': '--'}, flierprops={'marker': 'o', 'markersize': 5} ) # 绘制后单独设置每个箱体的颜色 colors = ['lightblue', 'lightgreen'] edge_colors = ['navy', 'darkgreen'] for box_element, color, edge_color in zip(box['boxes'], colors, edge_colors): box_element.set_facecolor(color) box_element.set_edgecolor(edge_color) # 添加图表元素 plt.title('Model Training Loss Comparison (Boxplot)', fontsize=12) plt.ylabel('log10 Loss Value', fontsize=10) plt.legend([box["boxes"][0], box["boxes"][1]], ['KoopmanESN', 'ESN']) plt.grid(True, linestyle='--', alpha=0.3) plt.tight_layout() plt.show()