m0_74709400 2022-11-16 16:44 采纳率: 50%
浏览 19
已结题

pytorch函数拟合出现的问题

问题遇到的现象和发生背景

最近在自学pytorch,到了函数拟合的部分,在网站https://www.cnblogs.com/St-Lovaer/p/13696295.html
找到了一个很符合自己学习需求的代码,具体如下

用代码块功能插入代码,请勿粘贴截图
import torch
from torch import nn,optim
import torch.nn.functional as F
from matplotlib import pyplot as plt

class unLinear(nn.Module):
    def __init__(self,input_feature,num_hidden,output_size):
        super(unLinear,self).__init__()
        self.hidden=nn.Linear(input_feature,num_hidden)#一个层就是一个函数
        self.out=nn.Linear(num_hidden,output_size)#可以把层理解成函数的右值引用

    def forward(self,x):
        # x=F.relu(self.hidden(x))
        # x = torch.sigmoid(self.hidden(x))
        x=torch.tanh(self.hidden(x))
        x=self.out(x)
        return x

    def train(self,inputs,target,criterion,optimizer,epoches):
        print(inputs.size())
        print(target.size())
        loss=0
        for epoch in range(epoches):
            output = model.forward(inputs)
            # if epoch%1000==0:
            #     plt.scatter(inputs.detach().numpy(), output.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test")
            #     plt.show()
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        return self, loss

model = unLinear(input_feature=1,num_hidden=20,output_size=1)
x=torch.torch.arange(-2,2,0.1)
y=x.pow(3)+0.1*torch.rand(x.size())
# print(x)
# print(y)
plt.scatter(x.detach().numpy(), y.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test")
plt.show()

inputs=torch.unsqueeze(x,dim=1)
target=torch.unsqueeze(y,dim=1)
criterion=nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)

new_model=model.train(inputs=inputs,target=target,criterion=criterion,optimizer=optimizer,epoches=10000)

# plt.scatter(x.numpy(),y.numpy(),c='#00CED1',s=10,alpha=0.8,label="test")
# plt.show()

x_predict=torch.unsqueeze(torch.arange(-2,2,0.05),dim=1)
y_predict=model.forward(x_predict)
# y_predict=model.forward(inputs)
# print(inputs.size())
# print(x_predict.size())
# print(y_predict.detach().numpy())
x_predict=torch.squeeze(x_predict)
y_predict=torch.squeeze(y_predict)
x_predict=x_predict.detach().numpy()
y_predict=y_predict.detach().numpy()
# print(y_predict)
plt.scatter(x_predict,y_predict,s=10,alpha=0.8,label="test")
plt.show()

运行结果及报错内容

结果发现这个程序在line47处出现了报错:"TypeError:train() got an unexpected keyword argument 'inputs'

我的解答思路和尝试过的方法

我认为是数据类型的错误,但是自己也不是很会修改。还希望各位能够帮助一下,修改一下程序能让他最起码跑起来以便后续的学习。最好能告诉我一下错误的原因和之后如何避免。感谢各位了!(请勿水贴,希望悬赏能得到很好的答案!)

我想要达到的结果

最起码让这个程序跑起来,不要报错。

  • 写回答

4条回答 默认 最新

  • 爱晚乏客游 2022-11-16 17:59
    关注

    问题出在你不应该将函数名命名为train,因为你本身继承的是nn.Module()这个类,如果你去看源码说明的话,你会发现这个类本身有个函数脚train,model.train()的意思是将模型转成训练模式,你在执行47行这里的model.train(xxxx)并不会像你想的那样调用的是你重载的unLinear.train(),而是它的父类,也就是nn.Module中的trian函数,所以就会报错说你的参数错误。实际上,你只要改个函数名,只要不是父类中定义好的函数名字冲突都可以,例如train_net(),下面调用的47行这里一起改就可以了。

    img


    nn.Module这个父类中的成员函数train()

    img


    简单修改下

    img

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录
查看更多回答(3条)

报告相同问题?

问题事件

  • 系统已结题 11月25日
  • 已采纳回答 11月17日
  • 创建了问题 11月16日

悬赏问题

  • ¥15 WPF动态创建页面内容
  • ¥15 如何对TBSS的结果进行统计学的分析已完成置换检验,如何在最终的TBSS输出结果提取除具体值及如何做进一步相关性分析
  • ¥15 SQL数据库操作问题
  • ¥100 关于lm339比较电路出现的问题
  • ¥15 Matlab安装yalmip和cplex功能安装失败
  • ¥15 加装宝马安卓中控改变开机画面
  • ¥15 STK安装问题问问大家,这种情况应该怎么办
  • ¥15 关于罗技鼠标宏lua文件的问题
  • ¥15 halcon ocr mlp 识别问题
  • ¥15 已知曲线满足正余弦函数,根据其峰值,还原出整条曲线