import torch from KoopmanESNModel import KoopmanESNConfig, KoopmanESN from BaseModel.ESNModel import ESN from DataProcess import DataLoad import numpy as np import matplotlib.pyplot as plt KoopmanModel = torch.load('./ModelLib/DeepKoopmanModel.pt', map_location="cpu") encoder = KoopmanModel.encoder decoder = KoopmanModel.decoder A = KoopmanModel.A.detach().numpy() state_dim = KoopmanModel.state_dim latent_dim = KoopmanModel.latent_dim lr = 0.5 sr = 0.9 sp = 0.1 ridge = 1e-7 train_start = 15000 train_len = 5000 train_warmup = 100 predict_warmup = 100 predict_len = 1000 ESNconfig = KoopmanESNConfig( units=100, lr=lr, sr=sr, sp=sp, ridge=ridge, train_start=train_start, train_len=train_len, train_warmup=train_warmup, predict_warmup=predict_warmup, predict_len=predict_len, state_dim=state_dim ) ESNModel = ESN( config=ESNconfig ) _, DataTrain, DataWarm, DataVal = DataLoad('./DataLib/DataNor.csv', ESNconfig) # _, DataTrain, DataWarm = DataLoad('./DataLib/DataNor.csv', config) # ESN ESNModel.esn_train(DataTrain) DataPreESN = ESNModel.predict(DataWarm) DataPreESN_Ten = torch.from_numpy(DataPreESN).float() DataVal_Ten = torch.from_numpy(DataVal).float() DataWarm_Ten = torch.from_numpy(DataWarm).float() StatePreESN_Ten = DataPreESN_Ten StateVal_Ten = DataVal_Ten StateWarm_Ten = DataWarm_Ten StatePreESN = StatePreESN_Ten.detach().numpy() StateVal = StateVal_Ten.detach().numpy() StateWarm = StateWarm_Ten.detach().numpy() ESNMSE = np.linalg.norm(StatePreESN - StateVal, ord='fro') ** 2 / np.prod(StateVal.shape) print(ESNMSE) t = np.arange(ESNconfig.predict_warmup + ESNconfig.predict_len) for fea in np.arange(state_dim): plt.figure(fea) plt.plot(t[:ESNconfig.predict_warmup], StateWarm[:, fea], linestyle="-", color='black', label='StateWarm') plt.plot(t[-ESNconfig.predict_len:], StatePreESN[:, fea], linestyle="--", color='green', label='ESNPre') plt.plot(t[-ESNconfig.predict_len:], StateVal[:, fea], linestyle="-", color='blue', label='StateReal') plt.legend() plt.show()