m0_74709400 2022-11-16 16:44 采纳率: 50%
浏览 14
已结题

pytorch函数拟合出现的问题

问题遇到的现象和发生背景

最近在自学pytorch,到了函数拟合的部分,在网站https://www.cnblogs.com/St-Lovaer/p/13696295.html
找到了一个很符合自己学习需求的代码,具体如下

用代码块功能插入代码,请勿粘贴截图
import torch
from torch import nn,optim
import torch.nn.functional as F
from matplotlib import pyplot as plt

class unLinear(nn.Module):
    def __init__(self,input_feature,num_hidden,output_size):
        super(unLinear,self).__init__()
        self.hidden=nn.Linear(input_feature,num_hidden)#一个层就是一个函数
        self.out=nn.Linear(num_hidden,output_size)#可以把层理解成函数的右值引用

    def forward(self,x):
        # x=F.relu(self.hidden(x))
        # x = torch.sigmoid(self.hidden(x))
        x=torch.tanh(self.hidden(x))
        x=self.out(x)
        return x

    def train(self,inputs,target,criterion,optimizer,epoches):
        print(inputs.size())
        print(target.size())
        loss=0
        for epoch in range(epoches):
            output = model.forward(inputs)
            # if epoch%1000==0:
            #     plt.scatter(inputs.detach().numpy(), output.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test")
            #     plt.show()
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        return self, loss

model = unLinear(input_feature=1,num_hidden=20,output_size=1)
x=torch.torch.arange(-2,2,0.1)
y=x.pow(3)+0.1*torch.rand(x.size())
# print(x)
# print(y)
plt.scatter(x.detach().numpy(), y.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test")
plt.show()

inputs=torch.unsqueeze(x,dim=1)
target=torch.unsqueeze(y,dim=1)
criterion=nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)

new_model=model.train(inputs=inputs,target=target,criterion=criterion,optimizer=optimizer,epoches=10000)

# plt.scatter(x.numpy(),y.numpy(),c='#00CED1',s=10,alpha=0.8,label="test")
# plt.show()

x_predict=torch.unsqueeze(torch.arange(-2,2,0.05),dim=1)
y_predict=model.forward(x_predict)
# y_predict=model.forward(inputs)
# print(inputs.size())
# print(x_predict.size())
# print(y_predict.detach().numpy())
x_predict=torch.squeeze(x_predict)
y_predict=torch.squeeze(y_predict)
x_predict=x_predict.detach().numpy()
y_predict=y_predict.detach().numpy()
# print(y_predict)
plt.scatter(x_predict,y_predict,s=10,alpha=0.8,label="test")
plt.show()

运行结果及报错内容

结果发现这个程序在line47处出现了报错:"TypeError:train() got an unexpected keyword argument 'inputs'

我的解答思路和尝试过的方法

我认为是数据类型的错误,但是自己也不是很会修改。还希望各位能够帮助一下,修改一下程序能让他最起码跑起来以便后续的学习。最好能告诉我一下错误的原因和之后如何避免。感谢各位了!(请勿水贴,希望悬赏能得到很好的答案!)

我想要达到的结果

最起码让这个程序跑起来,不要报错。

  • 写回答

4条回答

      报告相同问题?

      相关推荐 更多相似问题

      问题事件

      • 系统已结题 11月25日
      • 已采纳回答 11月17日
      • 创建了问题 11月16日

      悬赏问题

      • ¥15 51单片机自关机代码实现
      • ¥15 TensorFlow Object Detection API
      • ¥15 粘贴替换字符串的时候,右边引号会自动换行导致报错
      • ¥15 用verilog HDL语法仿真
      • ¥15 用超表面产生涡旋光束,怎么用matlab代码算得到的涡旋光束的模式纯度
      • ¥40 返乡没拿电脑航班取消被困在机场了,C语言实验ddl要到了
      • ¥15 find 命令优化语句问题
      • ¥15 js 使用contenteditable属性模拟富文本框 实现具体关键字高亮
      • ¥15 QT QList<QLIst<int>> 遍历问题
      • ¥15 关于#C++#2048游戏问题