透明人間 2020-04-20 18:13 采纳率: 0%
浏览 7305

关于pytorch里对cuda的报错:RuntimeError: expected device cuda:0 but got device cpu

运行时报错。

这是我的代码:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]], dtype=np.float32)
y_train = np.array([[1.23],[3.24],[2.3],[2.14],[2.93],[3.168],[1.779],[2.182],[2.59],[3.167],[1.042],[3.791],[3.313],[2.997],[1.1]], dtype=np.float32)


x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)

class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1,1)

    def forward(self,x):
        out = self.linear(x)
        return out

if torch.cuda.is_available():
    model = LinearRegression().cuda()
else:
    model = LinearRegression()

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),lr = 1e-3)

num_epoch = 100
for epcoh in range(num_epoch):
    if torch.cuda.is_available():
        inputs = Variable(x_train).cuda()
        outputs = Variable(y_train).cuda()
    else:
        inputs = Variable(x_train)
        outputs = Variable(y_train)

    out = model(inputs)

    target = y_train

    loss = criterion(out,target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epcoh+1)%20 == 0:
        print('Epoch[{}/{}],loss:{:.6f}'
              .format((epcoh+1,num_epoch,loss.data[0])))
model.eval()
predict = model(Variable(x_train))
predict = predict.data.numpy()
plt.plot(x_train.numpy(),y_train(),'ro',label = 'Original data')
plt.plot(x_train.numpy(),predict,label = 'Fitting Line')
plt.show()

错误截图:
图片说明

各位大佬们,这个错误是怎么回事啊,是cpu的问题吗?以下是cpu的信息截图:
torch.cuda.current_device(): 0
torch.cuda.device(0):
torch.cuda.device_count(): 1
torch.cuda.get_device_name(0): GeForce MX250
torch.cuda.is_available(): True

  • 写回答

2条回答 默认 最新

  • 夏末的初雪 2020-05-29 21:49
    关注

    这个targets也需要.cuda或者是to(device)

    评论

报告相同问题?

悬赏问题

  • ¥30 这是哪个作者做的宝宝起名网站
  • ¥60 版本过低apk如何修改可以兼容新的安卓系统
  • ¥25 由IPR导致的DRIVER_POWER_STATE_FAILURE蓝屏
  • ¥50 有数据,怎么建立模型求影响全要素生产率的因素
  • ¥50 有数据,怎么用matlab求全要素生产率
  • ¥15 TI的insta-spin例程
  • ¥15 完成下列问题完成下列问题
  • ¥15 C#算法问题, 不知道怎么处理这个数据的转换
  • ¥15 YoloV5 第三方库的版本对照问题
  • ¥15 请完成下列相关问题!