飞飞会灰 2021-09-07 14:26 采纳率: 0%
浏览 112

Pytorch从零开始实现线性回归,为什么结果不收敛?


#定义数据集:
x = torch.randn((200,1))*10
y = x*50+12
y_ = y + torch.normal(10,30,size= (200,1)) #加入噪音


#定义损失函数:
def MSE_loss(y_,y):
    MSE = torch.sum(torch.square(y-y_))/y.size()[0]
    loss_ = MSE
    return loss_

#定义网络及迭代方式
def Linear_Net(x,y,lr = 0.02,epochs = 10):
    #导入必备库:
    import torch
    import torch.nn as nn
    #提取size:
    input_size = x.size()[0]
    input_features = x.size()[1]
    #初始化参数:
    w = torch.rand((1,1),requires_grad=True)
    b = torch.rand(1,requires_grad=True)
    loss_ = []
    loss_internal = nn.MSELoss()
    print("init_w =",w)
    print("init_b =",b)
    #计算及迭代:
    for i in range(epochs):
        y_hat = torch.mm(x,w)+b
        loss = MSE_loss(y_hat,y) #使用自带损失函数
        # loss = loss_internal(y,y_hat) #使用系统损失函数
                
        loss.backward()
        #for param in [w,b]:
            #param.data -= -lr* param.grad
        
        w.data = w.data-lr* w.grad
        b.data = w.data-lr* b.grad
        #梯度清零,否则梯度累加
        w.grad.data.zero_()
        b.grad.data.zero_()
        
              
        loss_.append(loss.item())
        
    
        #print("w============>>>>>>",w)
        #print("b============>>>>>>",b)
        print("loss=========>>>>>>",loss.item())
        
    #导出结果:
       
    结果如下:

    
    



    


img

  • 写回答

1条回答 默认 最新

  • python收藏家 2021-09-07 15:01
    关注

    lr 设置10是不是太大了 就用默认的试试

    评论

报告相同问题?

问题事件

  • 创建了问题 9月7日

悬赏问题

  • ¥15 seatunnel-web使用SQL组件时候后台报错,无法找到表格
  • ¥15 fpga自动售货机数码管(相关搜索:数字时钟)
  • ¥15 用前端向数据库插入数据,通过debug发现数据能走到后端,但是放行之后就会提示错误
  • ¥30 3天&7天&&15天&销量如何统计同一行
  • ¥30 帮我写一段可以读取LD2450数据并计算距离的Arduino代码
  • ¥15 飞机曲面部件如机翼,壁板等具体的孔位模型
  • ¥15 vs2019中数据导出问题
  • ¥20 云服务Linux系统TCP-MSS值修改?
  • ¥20 关于#单片机#的问题:项目:使用模拟iic与ov2640通讯环境:F407问题:读取的ID号总是0xff,自己调了调发现在读从机数据时,SDA线上并未有信号变化(语言-c语言)
  • ¥20 怎么在stm32门禁成品上增加查询记录功能