穆穆青风至 2022-10-25 10:18 采纳率: 97.4%
浏览 85
已结题

用pytorch写了一个经典的鸢尾花分类

损失降不下来甚至有升高的趋势是怎么回事

import torch
import numpy as np
from sklearn import datasets
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

x = datasets.load_iris().data
y = datasets.load_iris().target

x = torch.from_numpy(np.float32(x))
y = torch.from_numpy(np.float32(y))
x_data = x[:120]
x_test = x[120:]
y_data = x[:120]
y_test = x[120:]

train_data = TensorDataset(x_data, y_data)
train_data = DataLoader(train_data, batch_size=16, shuffle=True)

test_data = TensorDataset(x_test, y_test)
test_data = DataLoader(test_data, batch_size=16)


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = torch.nn.Linear(4, 16)
        self.out = torch.nn.Linear(16, 4)

    def forward(self, x):
        x = torch.nn.functional.relu(self.hidden(x))
        y = self.out(x)
        return y


model = MyModel()
lea = 0.001
cost = torch.nn.functional.cross_entropy
optimizer = torch.optim.SGD(model.parameters(), lr=lea)


def test_mse(datasets):
    loss = 0
    for data, label in datasets:
        batch_loss = cost(model(data), label)
        loss += batch_loss
    return loss


# 一般在训练模型时加上model.train(),这样会正常使用BatchNormalizationDropout
# 测试的时候一般选择model.eval(),这样就不会使用BatchNormalizationDropout
for epoch in range(100):
    model.train()
    for data, label in test_data:
        batch_loss = cost(model(data), label)
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    model.eval()
    if (epoch + 1) % 10 == 0:
        print('训练集loss {}\t测试集loss {}'.format(test_mse(train_data)/120,test_mse(test_data)/30))



输出结果如下

img

  • 写回答

2条回答 默认 最新

  • m0_61899108 2022-10-25 10:44
    关注

    尝试把batch size调小,或者学习率调小点试试。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 11月4日
  • 已采纳回答 10月27日
  • 创建了问题 10月25日

悬赏问题

  • ¥15 微信会员卡等级和折扣规则
  • ¥15 微信公众平台自制会员卡可以通过收款码收款码收款进行自动积分吗
  • ¥15 随身WiFi网络灯亮但是没有网络,如何解决?
  • ¥15 gdf格式的脑电数据如何处理matlab
  • ¥20 重新写的代码替换了之后运行hbuliderx就这样了
  • ¥100 监控抖音用户作品更新可以微信公众号提醒
  • ¥15 UE5 如何可以不渲染HDRIBackdrop背景
  • ¥70 2048小游戏毕设项目
  • ¥20 mysql架构,按照姓名分表
  • ¥15 MATLAB实现区间[a,b]上的Gauss-Legendre积分