在跑模型的时候,出现了以下cuda问题:
相关位置代码为
ground_truth = torch.unsqueeze(ground_truth, 1).to(device)
model.zero_grad()
output = model(sinogram)
loss = nn.functional.mse_loss(output, ground_truth)
loss.backward()
其中,尝试print(loss.item())没有问题,但到了下一步loss.backward()就报了该错误。
把代码改为在cpu上运行,没有报错。
请问大家这可能是哪里的问题,如何修改,多谢大家的讨论和回答。