这段回归代码在设置x=torch.randn(20,1)*10时会梯度爆炸,很明显也是直线为什么拟合不了
import torch
import random
import matplotlib.pyplot as plt
torch.manual_seed(1)
x=torch.rand(20,1)*10
y=2*x+(5+torch.randn(20,1))
epoch=1000
lr=0.05
w=torch.randn((1),requires_grad=True)
b=torch.randn((1),requires_grad=True)
for i in range(epoch):
wx=torch.mul(w,x)
print(wx)
y_pred=torch.add(wx,b)
loss=(0.5*(y-y_pred)**2).mean()
print(loss)
loss.backward()
b.data.sub_(lr*b.grad)
w.data.sub_(lr*w.grad)
w.grad.zero_()
b.grad.zero_()
if i%20==0:
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),y_pred.data.numpy(),'r-',lw=5)
plt.text(2, 20, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.xlim(1.5, 10)
plt.ylim(8, 28)
plt.title("Iteration: {}\nw: {} b: {}".format(i, w.data.numpy(), b.data.numpy()))
plt.pause(0.5)
if loss.data.numpy()<0.1:
break