123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- import torch
- from KoopmanESNModel import KoopmanESNConfig, KoopmanESN
- from BaseModel.ESNModel import ESN
- from DataProcess import DataLoad
- import numpy as np
- import matplotlib.pyplot as plt
- KoopmanModel = torch.load('./ModelLib/DeepKoopmanModel.pt', map_location="cpu")
- encoder = KoopmanModel.encoder
- decoder = KoopmanModel.decoder
- A = KoopmanModel.A.detach().numpy()
- state_dim = KoopmanModel.state_dim
- latent_dim = KoopmanModel.latent_dim
- config = KoopmanESNConfig(
- units=latent_dim,
- lr=0.5,
- sr=0.9,
- sp=0.1,
- ridge=1e-7,
- train_start=19500,
- train_len=500,
- train_warmup=10,
- predict_warmup=100,
- predict_len=20,
- state_dim=latent_dim
- )
- # config = KoopmanESNConfig(
- # units=latent_dim,
- # lr=0.5,
- # sr=0.9,
- # sp=0.1,
- # ridge=1e-7,
- # train_start=10000,
- # train_len=500,
- # train_warmup=10,
- # predict_warmup=100,
- # predict_len=20,
- # state_dim=latent_dim
- # )
- KoopESNModel = KoopmanESN(
- config=config,
- A=A
- )
- ESNModel = ESN(
- config=config
- )
- _, DataTrain, DataWarm, DataVal = DataLoad('./DataLib/LatentNor.csv', config)
- # _, DataTrain, DataWarm = DataLoad('./DataLib/DataNor.csv', config)
- # KoopmanESN
- KoopESNModel.koopmanesn_train(DataTrain)
- DataPreKoop = KoopESNModel.predict(DataWarm)
- # ESN
- ESNModel.esn_train(DataTrain)
- DataPreESN = ESNModel.predict(DataWarm)
- DataPreKoop_Ten = torch.from_numpy(DataPreKoop).float()
- DataPreESN_Ten = torch.from_numpy(DataPreESN).float()
- DataVal_Ten = torch.from_numpy(DataVal).float()
- DataWarm_Ten = torch.from_numpy(DataWarm).float()
- StatePreKoop_Ten = decoder(DataPreKoop_Ten)
- StatePreESN_Ten = decoder(DataPreESN_Ten)
- StateVal_Ten = decoder(DataVal_Ten)
- StateWarm_Ten = decoder(DataWarm_Ten)
- StatePreKoop = StatePreKoop_Ten.detach().numpy()
- StatePreESN = StatePreESN_Ten.detach().numpy()
- StateVal = StateVal_Ten.detach().numpy()
- StateWarm = StateWarm_Ten.detach().numpy()
- KoopMSE = np.linalg.norm(StatePreKoop-StateVal, ord='fro')**2/np.prod(StateVal.shape)
- ESNMSE = np.linalg.norm(StatePreESN-StateVal, ord='fro')**2/np.prod(StateVal.shape)
- print(KoopMSE)
- print(ESNMSE)
- t = np.arange(config.predict_warmup + config.predict_len)
- for fea in np.arange(state_dim):
- plt.figure(fea)
- plt.plot(t[:config.predict_warmup], StateWarm[:, fea],
- linestyle="-", color='black', label='StateWarm')
- plt.plot(t[-config.predict_len:], StatePreKoop[:, fea],
- linestyle="--", color='red', label='KoopmanESNPre')
- plt.plot(t[-config.predict_len:], StatePreESN[:, fea],
- linestyle="--", color='green', label='ESNPre')
- plt.plot(t[-config.predict_len:], StateVal[:, fea],
- linestyle="-", color='blue', label='StateReal')
- plt.legend()
- plt.show()
|