lzslzz 2024-03-20 20:55 采纳率: 50%
浏览 19

损失度一直是0但准确率正常波动该是哪里错了

求问!刚学机器学习,数据没问题,但是输出的这个损失度是不是不太对啊,怎么一直是0

img

train_data, train_label, test_data, test_label = load_dataset(dataset)
dataset = torch.utils.data.TensorDataset(torch.tensor(train_data), torch.tensor(train_label))  
dataloader = torch.utils.data.DataLoader(dataset, batch_size=40, shuffle=True)  

# 2. 模型定义 ---
model = nn.Sequential(  # 顺序模型
          nn.Linear(288, 100),  # 线性层
          nn.Sigmoid(),
          nn.Linear(100, 2)
        )

# 损失函数
loss = nn.CrossEntropyLoss()

# 准确率
def accuracy(y_pred, y_true):
    correct_pred = torch.eq(torch.argmax(y_pred, 1), y_true)  
    return torch.mean(correct_pred.float())


# 训练步
def train_step(model, x, y, opt):
    current_loss = loss(model(x), torch.argmax(y, 1))
    current_loss.backward() 
    with torch.no_grad():
        acc = accuracy(model(x), y)  
    opt.step()  
    opt.zero_grad()  
    return current_loss.item(), acc.item()


# 3. 创建模型 ---
opt = torch.optim.SGD(model.parameters(), lr=0.1)  # 标准梯度下降
ls = []  # 记录损失函数值
accs = []  # 记录准确率

# 4. 训练 ---
for _ in range(100):
    for batch_data, batch_label in dataloader:
        l, acc = train_step(model, batch_data, batch_label, opt)
        ls.append(l)
        accs.append(acc)

test_acc = accuracy(model(torch.tensor(test_data)), torch.tensor(test_label))
print("test acc:", test_acc.detach().data)
plt.plot(ls)
plt.plot(accs)
plt.legend(['loss', 'acc'])
plt.show()

  • 写回答

3条回答 默认 最新

  • 阿里嘎多学长 2024-03-20 20:56
    关注

    以下内容由CHATGPT及阿里嘎多学长共同生成、有用望采纳:


    在你的代码中,损失度(loss)始终为0的原因可能是因为你在计算损失时使用了torch.argmax(y, 1)torch.argmax(y, 1)会返回y(也就是标签)中每个样本的最大概率索引,将其转换为one-hot编码形式。而nn.CrossEntropyLoss函数已经包含了将预测概率和one-hot编码标签比较的过程,不需要你手动进行转换。

    因此,当你在计算损失时,你实际上是在比较模型的预测输出和它自己预测的概率最大值,这会导致损失始终为0,因为它们是相同的。

    正确的做法是直接将未处理的y(batch_label)传给nn.CrossEntropyLoss,如下所示:

    current_loss = loss(model(x), y)
    

    这样,损失函数将会正确地计算预测概率分布和实际标签之间的交叉熵损失。修改后,你的训练过程应该能够反映出损失的变化。

    评论

报告相同问题?

问题事件

  • 创建了问题 3月20日

悬赏问题

  • ¥15 无线连接树莓派,无法执行update,如何解决?(相关搜索:软件下载)
  • ¥15 Windows11, backspace, enter, space键失灵
  • ¥15 cfx离心泵非稳态计算
  • ¥15 动态列线图发布后出现An error has occurred. Check your logs or contact the app author for clarification.
  • ¥20 VM虚拟机崩溃,重新登录故障,移除加密访问。
  • ¥15 双VSG并网系统,matlab,状态变量稳态值求解
  • ¥15 关于#Stata#的问题:数据是面板数据,SPSS里面不能控制年份和时间,所以只能用Stata做
  • ¥20 基于基于NioEventLoop线程阻塞问题
  • ¥20 我需要"hill48屈服模型 等向强化 非线性硬化"的abaqus本构子程序(umat或者vumat)对应的理论推导过程。
  • ¥15 基于ucc28019的pfc电路中芯片一直不工作