透明人間 2020-04-20 18:13 采纳率: 0%
浏览 7310

关于pytorch里对cuda的报错:RuntimeError: expected device cuda:0 but got device cpu

运行时报错。

这是我的代码:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]], dtype=np.float32)
y_train = np.array([[1.23],[3.24],[2.3],[2.14],[2.93],[3.168],[1.779],[2.182],[2.59],[3.167],[1.042],[3.791],[3.313],[2.997],[1.1]], dtype=np.float32)


x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)

class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1,1)

    def forward(self,x):
        out = self.linear(x)
        return out

if torch.cuda.is_available():
    model = LinearRegression().cuda()
else:
    model = LinearRegression()

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),lr = 1e-3)

num_epoch = 100
for epcoh in range(num_epoch):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        outputs = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        outputs = Variable(y_train)

    out = model(inputs)

    target = y_train

    loss = criterion(out,target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epcoh+1)%20 == 0:
        print('Epoch[{}/{}],loss:{:.6f}'
              .format((epcoh+1,num_epoch,loss.data[0])))
model.eval()
predict = model(Variable(x_train))
predict = predict.data.numpy()
plt.plot(x_train.numpy(),y_train(),'ro',label = 'Original data')
plt.plot(x_train.numpy(),predict,label = 'Fitting Line')
plt.show()

错误截图:
图片说明

各位大佬们,这个错误是怎么回事啊,是cpu的问题吗?以下是cpu的信息截图:
torch.cuda.current_device(): 0
torch.cuda.device(0):
torch.cuda.device_count(): 1
torch.cuda.get_device_name(0): GeForce MX250
torch.cuda.is_available(): True

  • 写回答

2条回答

  • 夏末的初雪 2020-05-29 21:49
    关注

    这个targets也需要.cuda或者是to(device)

    评论

报告相同问题?

悬赏问题

  • ¥15 c程序不知道为什么得不到结果
  • ¥40 复杂的限制性的商函数处理
  • ¥15 程序不包含适用于入口点的静态Main方法
  • ¥15 素材场景中光线烘焙后灯光失效
  • ¥15 请教一下各位,为什么我这个没有实现模拟点击
  • ¥15 执行 virtuoso 命令后,界面没有,cadence 启动不起来
  • ¥50 comfyui下连接animatediff节点生成视频质量非常差的原因
  • ¥20 有关区间dp的问题求解
  • ¥15 多电路系统共用电源的串扰问题
  • ¥15 slam rangenet++配置