AutoEncoderTest.py 603 B

123456789101112131415161718192021222324
  1. import csv
  2. from AutoEncoderModel import AutoEncoder
  3. import torch
  4. model = torch.load('./ModelLib/models/model_nlat_100_nhid_128_lr_0.001.pt', map_location="cpu")
  5. Data = torch.load('./DataLib/Data.pt')
  6. input = Data
  7. input_pre = model(input)
  8. latent = model.encoder(input)
  9. input_pre = input_pre.detach().numpy()
  10. latent = latent.detach().numpy()
  11. with open('./DataLib/DataPre.csv', mode='w', newline='') as file:
  12. writer = csv.writer(file)
  13. writer.writerows(input_pre)
  14. with open('./DataLib/LatentPre.csv', mode='w', newline='') as file:
  15. writer = csv.writer(file)
  16. writer.writerows(latent)