m0_74709400 2022-11-16 16:44 采纳率: 50%

# pytorch函数拟合出现的问题

###### 用代码块功能插入代码，请勿粘贴截图
``````import torch
from torch import nn,optim
import torch.nn.functional as F
from matplotlib import pyplot as plt

class unLinear(nn.Module):
def __init__(self,input_feature,num_hidden,output_size):
super(unLinear,self).__init__()
self.hidden=nn.Linear(input_feature,num_hidden)#一个层就是一个函数
self.out=nn.Linear(num_hidden,output_size)#可以把层理解成函数的右值引用

def forward(self,x):
# x=F.relu(self.hidden(x))
# x = torch.sigmoid(self.hidden(x))
x=torch.tanh(self.hidden(x))
x=self.out(x)
return x

def train(self,inputs,target,criterion,optimizer,epoches):
print(inputs.size())
print(target.size())
loss=0
for epoch in range(epoches):
output = model.forward(inputs)
# if epoch%1000==0:
#     plt.scatter(inputs.detach().numpy(), output.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test")
#     plt.show()
loss = criterion(output, target)
loss.backward()
optimizer.step()
return self, loss

model = unLinear(input_feature=1,num_hidden=20,output_size=1)
x=torch.torch.arange(-2,2,0.1)
y=x.pow(3)+0.1*torch.rand(x.size())
# print(x)
# print(y)
plt.scatter(x.detach().numpy(), y.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test")
plt.show()

inputs=torch.unsqueeze(x,dim=1)
target=torch.unsqueeze(y,dim=1)
criterion=nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)

new_model=model.train(inputs=inputs,target=target,criterion=criterion,optimizer=optimizer,epoches=10000)

# plt.scatter(x.numpy(),y.numpy(),c='#00CED1',s=10,alpha=0.8,label="test")
# plt.show()

x_predict=torch.unsqueeze(torch.arange(-2,2,0.05),dim=1)
y_predict=model.forward(x_predict)
# y_predict=model.forward(inputs)
# print(inputs.size())
# print(x_predict.size())
# print(y_predict.detach().numpy())
x_predict=torch.squeeze(x_predict)
y_predict=torch.squeeze(y_predict)
x_predict=x_predict.detach().numpy()
y_predict=y_predict.detach().numpy()
# print(y_predict)
plt.scatter(x_predict,y_predict,s=10,alpha=0.8,label="test")
plt.show()

``````

• 写回答

#### 4条回答默认 最新

• 爱晚乏客游 2022-11-16 17:59
关注

问题出在你不应该将函数名命名为train，因为你本身继承的是nn.Module()这个类，如果你去看源码说明的话，你会发现这个类本身有个函数脚train，model.train()的意思是将模型转成训练模式，你在执行47行这里的model.train（xxxx）并不会像你想的那样调用的是你重载的unLinear.train（），而是它的父类，也就是nn.Module中的trian函数，所以就会报错说你的参数错误。实际上，你只要改个函数名，只要不是父类中定义好的函数名字冲突都可以，例如train_net(),下面调用的47行这里一起改就可以了。

nn.Module这个父类中的成员函数train()

简单修改下

本回答被题主选为最佳回答 , 对您是否有帮助呢?
评论 编辑记录

• 系统已结题 11月25日
• 已采纳回答 11月17日
• 创建了问题 11月16日

#### 悬赏问题

• ¥15 WPF动态创建页面内容
• ¥15 如何对TBSS的结果进行统计学的分析已完成置换检验，如何在最终的TBSS输出结果提取除具体值及如何做进一步相关性分析
• ¥15 SQL数据库操作问题
• ¥100 关于lm339比较电路出现的问题
• ¥15 Matlab安装yalmip和cplex功能安装失败
• ¥15 加装宝马安卓中控改变开机画面
• ¥15 STK安装问题问问大家，这种情况应该怎么办
• ¥15 关于罗技鼠标宏lua文件的问题
• ¥15 halcon ocr mlp 识别问题
• ¥15 已知曲线满足正余弦函数，根据其峰值，还原出整条曲线