问题遇到的现象和发生背景
最近在自学pytorch,到了函数拟合的部分,在网站https://www.cnblogs.com/St-Lovaer/p/13696295.html
找到了一个很符合自己学习需求的代码,具体如下
用代码块功能插入代码,请勿粘贴截图
import torch
from torch import nn,optim
import torch.nn.functional as F
from matplotlib import pyplot as plt
class unLinear(nn.Module):
def __init__(self,input_feature,num_hidden,output_size):
super(unLinear,self).__init__()
self.hidden=nn.Linear(input_feature,num_hidden)#一个层就是一个函数
self.out=nn.Linear(num_hidden,output_size)#可以把层理解成函数的右值引用
def forward(self,x):
# x=F.relu(self.hidden(x))
# x = torch.sigmoid(self.hidden(x))
x=torch.tanh(self.hidden(x))
x=self.out(x)
return x
def train(self,inputs,target,criterion,optimizer,epoches):
print(inputs.size())
print(target.size())
loss=0
for epoch in range(epoches):
output = model.forward(inputs)
# if epoch%1000==0:
# plt.scatter(inputs.detach().numpy(), output.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test")
# plt.show()
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return self, loss
model = unLinear(input_feature=1,num_hidden=20,output_size=1)
x=torch.torch.arange(-2,2,0.1)
y=x.pow(3)+0.1*torch.rand(x.size())
# print(x)
# print(y)
plt.scatter(x.detach().numpy(), y.detach().numpy(), c='#00CED1', s=10, alpha=0.8, label="test")
plt.show()
inputs=torch.unsqueeze(x,dim=1)
target=torch.unsqueeze(y,dim=1)
criterion=nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)
new_model=model.train(inputs=inputs,target=target,criterion=criterion,optimizer=optimizer,epoches=10000)
# plt.scatter(x.numpy(),y.numpy(),c='#00CED1',s=10,alpha=0.8,label="test")
# plt.show()
x_predict=torch.unsqueeze(torch.arange(-2,2,0.05),dim=1)
y_predict=model.forward(x_predict)
# y_predict=model.forward(inputs)
# print(inputs.size())
# print(x_predict.size())
# print(y_predict.detach().numpy())
x_predict=torch.squeeze(x_predict)
y_predict=torch.squeeze(y_predict)
x_predict=x_predict.detach().numpy()
y_predict=y_predict.detach().numpy()
# print(y_predict)
plt.scatter(x_predict,y_predict,s=10,alpha=0.8,label="test")
plt.show()
运行结果及报错内容
结果发现这个程序在line47处出现了报错:"TypeError:train() got an unexpected keyword argument 'inputs'
我的解答思路和尝试过的方法
我认为是数据类型的错误,但是自己也不是很会修改。还希望各位能够帮助一下,修改一下程序能让他最起码跑起来以便后续的学习。最好能告诉我一下错误的原因和之后如何避免。感谢各位了!(请勿水贴,希望悬赏能得到很好的答案!)
我想要达到的结果
最起码让这个程序跑起来,不要报错。