import torch from KoopmanESNModel import KoopmanESNConfig, KoopmanESN from BaseModel.ESNModel import ESN from DataProcess import DataLoad import numpy as np import matplotlib.pyplot as plt KoopmanModel = torch.load('./ModelLib/DeepKoopmanModel.pt', map_location="cpu") encoder = KoopmanModel.encoder decoder = KoopmanModel.decoder A = KoopmanModel.A.detach().numpy() state_dim = KoopmanModel.state_dim latent_dim = KoopmanModel.latent_dim config = KoopmanESNConfig( units=latent_dim, lr=0.5, sr=0.9, sp=0.1, ridge=1e-7, train_start=19500, train_len=500, train_warmup=10, predict_warmup=100, predict_len=20, state_dim=latent_dim ) # config = KoopmanESNConfig( # units=latent_dim, # lr=0.5, # sr=0.9, # sp=0.1, # ridge=1e-7, # train_start=10000, # train_len=500, # train_warmup=10, # predict_warmup=100, # predict_len=20, # state_dim=latent_dim # ) KoopESNModel = KoopmanESN( config=config, A=A ) ESNModel = ESN( config=config ) _, DataTrain, DataWarm, DataVal = DataLoad('./DataLib/LatentNor.csv', config) # _, DataTrain, DataWarm = DataLoad('./DataLib/DataNor.csv', config) # KoopmanESN KoopESNModel.koopmanesn_train(DataTrain) DataPreKoop = KoopESNModel.predict(DataWarm) # ESN ESNModel.esn_train(DataTrain) DataPreESN = ESNModel.predict(DataWarm) DataPreKoop_Ten = torch.from_numpy(DataPreKoop).float() DataPreESN_Ten = torch.from_numpy(DataPreESN).float() DataVal_Ten = torch.from_numpy(DataVal).float() DataWarm_Ten = torch.from_numpy(DataWarm).float() StatePreKoop_Ten = decoder(DataPreKoop_Ten) StatePreESN_Ten = decoder(DataPreESN_Ten) StateVal_Ten = decoder(DataVal_Ten) StateWarm_Ten = decoder(DataWarm_Ten) StatePreKoop = StatePreKoop_Ten.detach().numpy() StatePreESN = StatePreESN_Ten.detach().numpy() StateVal = StateVal_Ten.detach().numpy() StateWarm = StateWarm_Ten.detach().numpy() KoopMSE = np.linalg.norm(StatePreKoop-StateVal, ord='fro')**2/np.prod(StateVal.shape) ESNMSE = np.linalg.norm(StatePreESN-StateVal, ord='fro')**2/np.prod(StateVal.shape) print(KoopMSE) print(ESNMSE) t = np.arange(config.predict_warmup + config.predict_len) for fea in np.arange(state_dim): plt.figure(fea) plt.plot(t[:config.predict_warmup], StateWarm[:, fea], linestyle="-", color='black', label='StateWarm') plt.plot(t[-config.predict_len:], StatePreKoop[:, fea], linestyle="--", color='red', label='KoopmanESNPre') plt.plot(t[-config.predict_len:], StatePreESN[:, fea], linestyle="--", color='green', label='ESNPre') plt.plot(t[-config.predict_len:], StateVal[:, fea], linestyle="-", color='blue', label='StateReal') plt.legend() plt.show()