DataProcess.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import numpy as np
  2. import torch
  3. DataFileName = './DataLib/Data.csv'
  4. Data = np.loadtxt(DataFileName, delimiter=',')
  5. NumState = 13
  6. NumControl = 3
  7. Data = Data[:, :NumState]
  8. # 标准化
  9. DataMax = np.max(Data, axis=0)
  10. DataMin = np.min(Data, axis=0)
  11. DataNor = 2*(Data-DataMin)/(DataMax-DataMin)-1
  12. DataMaxMin = np.vstack([DataMax, DataMin])
  13. np.savetxt('./DataLib/DataMaxMin.csv', DataMaxMin, delimiter=',')
  14. # 数据间隔固定为10s
  15. deltat = 100
  16. # 每个sample的长度为51,对应相邻两个时序5000s
  17. seqlen = 51
  18. # 数据特征
  19. NumSample = DataNor.shape[0]
  20. NumFea = DataNor.shape[1]
  21. NumSampleTensor = NumSample - seqlen + 1
  22. print(NumSampleTensor)
  23. DataTensor = torch.zeros([NumSampleTensor, seqlen, NumFea])
  24. # print(DataTensor.shape)
  25. for iterSample in range(NumSampleTensor):
  26. index_start = iterSample
  27. index_end = iterSample + seqlen
  28. DataThis = DataNor[index_start:index_end, :]
  29. DataThisTensor = torch.from_numpy(DataThis)
  30. # # print(index_start)
  31. # # print(index_end)
  32. # # print(DataThis.shape)
  33. DataTensor[iterSample, :, :] = DataThisTensor
  34. print(iterSample)
  35. # 打乱
  36. idx = torch.randperm(DataTensor.size(0))
  37. shuffData = DataTensor[idx]
  38. DataFileNameTensor = './DataLib/DataSeqTensor.pt'
  39. torch.save(shuffData, DataFileNameTensor)