2401_83649141 2024-10-12 12:39 采纳率: 0%
浏览 9

动手学深度学习第三章线性回归的简介实现,为什么我的损失是这样的

动手学深度学习第三章线性回归的简介实现:


import torch
from torch import nn
from d2l import torch as d2l
from torch.utils import data

#               生成数据集
true_w = torch.tensor([2,-3.4])
true_b = 4.2
features,labels = d2l.synthetic_data(true_w,true_b,1000)



#               读取数据集
def load_array(data_array,batch_size,is_train=True):
    dataset = data.TensorDataset(*data_array)
    return data.DataLoader(dataset,batch_size,shuffle=True)

batch_size = 10
data_iter = load_array((features,labels),batch_size)
# print(next(iter(data_iter)))



#               定义模型
net = nn.Sequential(nn.Linear(2,1))
net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)



#               定义损失函数
loss = nn.MSELoss()



#               定义优化算法
trainer = torch.optim.SGD(net.parameters(),lr = 0.03)



#               训练
num_epochs = 3
for epoch in range(num_epochs):
    for X,y in data_iter:
        l = loss(net(X),y)
        trainer.zero_grad
        l.backward()
        trainer.step()
    l = loss(net(features),labels)
    print(f'epoch {epoch + 1},loss {l:f}')

出来的结果:
epoch 1,loss 36.045746
epoch 2,loss 106.395393
epoch 3,loss 152.102753

为什么我的损失越来越大了,而且我照着动手学深度学习书上代码一样敲得,为什么别人第一个epoch就是0.000291,我的这么大呢?

  • 写回答

2条回答 默认 最新

  • 香格里辣 2024-10-12 12:40
    关注

    以下回复参考:皆我百晓生券券喵儿等免费微信小程序作答:

    你的代码和书中的代码大致相同,但存在一些差异可能导致你的损失函数值较大。以下是可能的原因和解决方案:

    1. 数据生成方式:你使用的是d2l.synthetic_data生成的数据集,这意味着每次运行代码时生成的数据是不同的。因此,即使你按照书中的代码完全复制,由于数据的随机性,结果可能会有所不同。此外,书中可能使用了固定的随机数种子来确保每次生成的数据相同,而你没有设置随机数种子。为了更准确地比较结果,你可以尝试设置随机数种子。
    2. 初始化权重和偏置:你设置了权重的初始化方式为正态分布,均值为0,标准差为0.01。而书中可能使用了不同的初始化方式。尽管这可能会导致不同的起点损失值,但一般不会导致损失值逐渐增大。
    3. 学习率过高:你设置的学习率为0.03。如果学习率过高,可能会导致模型在优化过程中不稳定,从而使损失值增大。你可以尝试降低学习率,例如设置为0.01或更低。
    4. 批量大小(Batch Size):你设置的批量大小为10。较小的批量大小可能导致模型在每次迭代时更新的步长较大,从而导致损失值的波动。你可以尝试增加批量大小。

    为了解决这个问题,你可以尝试以下修改:

    1. 设置随机数种子以确保每次生成的数据相同。
    2. 尝试降低学习率。
    3. 增加批量大小。
    4. 检查并确认你的数据和书中的数据是否完全一致。

    此外,即使进行了上述修改,由于随机性和初始化的差异,你的结果仍然可能与书中的结果有所不同。只要你的损失值随着训练的进行而逐渐下降,并且模型在训练集上的表现良好,那么你的实现就是正确的。

    评论

报告相同问题?

问题事件

  • 创建了问题 10月12日