TypeError: linear(): argument 'input' (position 1) must be Tensor, not tuple
def train(verbose = False):
net.train()
loss_list = []
for i,data in enumerate(train_dataloader):
inputs = data['inputs']
groundtruths = data['groundtruths']
if USE_GPU:
inputs = Variable(inputs).cuda()
groundtruths = Variable(groundtruths).cuda()
else:
inputs = Variable(inputs)
groundtruths = Variable(groundtruths)
#将参数的grad值初始化为0
optimizer.zero_grad()
#获得网络输出结果
out = net(inputs)
#根据真值计算损失函数的值
loss = loss_criterion(out,groundtruths)
#通过优化器优化网络
loss.backward()
optimizer.step()
loss_list.append(loss.item())
return loss_list
def test():
error = 0.0
predictions = []
test_groundtruths = []
# 告诉网络进行测试,不再是训练模式
net.eval()
for i,data in enumerate(test_dataloader):
inputs = data['inputs']
groundtruths = data['groundtruths']
if USE_GPU:
inputs = Variable(inputs).cuda()
groundtruths = Variable(groundtruths).cuda()
else:
inputs = Variable(inputs)
groundtruths = Variable(groundtruths)
out = net(inputs)
error += (error_criterion(out,groundtruths).item()*groundtruths.size(0))
if USE_GPU:
predictions.extend(out.cpu().data.numpy().tolist())
test_groundtruths.extend(groundtruths.cpu().data.numpy().tolist())
else:
predictions.extend(out.data.numpy().tolist())
test_groundtruths.extend(groundtruths.data.numpy().tolist())
average_error = np.sqrt(error/len(test_data_trans))
return np.array(predictions).reshape((len(predictions))),np.array(test_groundtruths).reshape((len(test_groundtruths))),average_error
def main():
#记录程序开始的时间
train_start = time.time()
loss_recorder = []
print('starting training... ')
for epoch in range(EPOCHES):
# adjust learning rate
adjust_lr.step()
loss_list = train(verbose= True)
loss_recorder.append(np.mean(loss_list))
print('epoch = %d,loss = %.5f'%(epoch+1,np.mean(loss_list)))
print ('training time = {}s'.format(int((time.time() - train_start))))
# 记录测试开始的时间
test_start = time.time()
predictions, test_groundtruth, average_error = test()
print(predictions.shape)
print(test_groundtruth.shape)
print('test time = {}s'.format(int((time.time() - test_start)+1.0)))
print('average error = ', average_error)
result = pd.DataFrame(data = {'Q(t+1)':predictions,'Q(t+1)truth':test_groundtruth})
result.to_csv('D:/python目录/pythonProject/STA-LSTM-main/data/output/out_t+1.csv')
torch.save(net,'D:/python目录/pythonProject/STA-LSTM-main/models/sta_lstm_t+1.pth')
if name == 'main':
main()