xmu_yjx 2024-01-16 22:45 采纳率: 50%
浏览 5

数据科学、深度学习入门

这段回归代码在设置x=torch.randn(20,1)*10时会梯度爆炸,很明显也是直线为什么拟合不了


import torch
import random
import matplotlib.pyplot as plt
torch.manual_seed(1)

x=torch.rand(20,1)*10
y=2*x+(5+torch.randn(20,1))
epoch=1000
lr=0.05
w=torch.randn((1),requires_grad=True)
b=torch.randn((1),requires_grad=True)
for i in range(epoch):
    wx=torch.mul(w,x)
    print(wx)
    y_pred=torch.add(wx,b)
    loss=(0.5*(y-y_pred)**2).mean()
    print(loss)
    loss.backward()     
    b.data.sub_(lr*b.grad)
    w.data.sub_(lr*w.grad)
    
    w.grad.zero_()
    b.grad.zero_()
    if i%20==0:
        plt.scatter(x.data.numpy(),y.data.numpy())
        plt.plot(x.data.numpy(),y_pred.data.numpy(),'r-',lw=5)
        plt.text(2, 20, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color':  'red'})
        plt.xlim(1.5, 10)
        plt.ylim(8, 28)
        plt.title("Iteration: {}\nw: {} b: {}".format(i, w.data.numpy(), b.data.numpy()))
        plt.pause(0.5)
        if loss.data.numpy()<0.1:
            break
  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2024-01-17 08:27
    关注

    【以下回答由 GPT 生成】

    优化建议:

    为了解决梯度爆炸的问题并改善拟合结果,您可以尝试以下几个优化方案:

    1. 数据标准化: 由于 x=torch.randn(20,1)*10 生成的数据具有较大的方差,这可能导致梯度爆炸。您可以尝试对 x 数据进行标准化,即将数据缩放到均值为0,方差为1的范围内。可以使用以下代码来实现数据标准化:
    x = (x - x.mean()) / x.std()
    
    1. 调整学习率: 当遇到梯度爆炸的问题时,减小学习率通常是一个有效的解决方法。您可以尝试将学习率 lr 的值减小为原来的一半或更小的值,以减缓梯度的更新速度,从而避免梯度爆炸的问题。

    2. 增加正则化项: 过拟合可能是导致拟合结果不佳的原因之一。您可以尝试在损失函数中添加正则化项,以控制模型的复杂度。例如,可以将 loss 修改为以下形式:

    loss = (0.5 * (y - y_pred)**2).mean() + lambda_reg * torch.norm(w) + lambda_reg * torch.norm(b)
    

    其中 lambda_reg 是正则化系数,用于控制正则化项的权重。通过增加正则化项,您可以降低模型的过拟合程度,从而改善拟合结果。

    请根据您的需求选择其中一个或多个优化方案,并在代码中进行实验。对于每个优化方案,可以适当修改代码并观察结果的变化。



    【相关推荐】



    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
    评论

报告相同问题?

问题事件

  • 创建了问题 1月16日

悬赏问题

  • ¥200 基于同花顺supermind的量化策略脚本编辑
  • ¥20 Html备忘录页面制作
  • ¥15 黄永刚的晶体塑性子程序中输入的材料参数里的晶体取向参数是什么形式的?
  • ¥20 数学建模来解决我这个问题
  • ¥15 计算机网络ip分片偏移量计算头部是-20还是-40呀
  • ¥15 stc15f2k60s2单片机关于流水灯,时钟,定时器,矩阵键盘等方面的综合问题
  • ¥15 YOLOv8已有一个初步的检测模型,想利用这个模型对新的图片进行自动标注,生成labellmg可以识别的数据,再手动修改。如何操作?
  • ¥30 NIRfast软件使用指导
  • ¥20 matlab仿真问题,求功率谱密度
  • ¥15 求micropython modbus-RTU 从机的代码或库?