穆穆青风至 2022-10-25 10:18 采纳率: 97.4%
浏览 93
已结题

用pytorch写了一个经典的鸢尾花分类

损失降不下来甚至有升高的趋势是怎么回事

import torch
import numpy as np
from sklearn import datasets
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

x = datasets.load_iris().data
y = datasets.load_iris().target

x = torch.from_numpy(np.float32(x))
y = torch.from_numpy(np.float32(y))
x_data = x[:120]
x_test = x[120:]
y_data = x[:120]
y_test = x[120:]

train_data = TensorDataset(x_data, y_data)
train_data = DataLoader(train_data, batch_size=16, shuffle=True)

test_data = TensorDataset(x_test, y_test)
test_data = DataLoader(test_data, batch_size=16)


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = torch.nn.Linear(4, 16)
        self.out = torch.nn.Linear(16, 4)

    def forward(self, x):
        x = torch.nn.functional.relu(self.hidden(x))
        y = self.out(x)
        return y


model = MyModel()
lea = 0.001
cost = torch.nn.functional.cross_entropy
optimizer = torch.optim.SGD(model.parameters(), lr=lea)


def test_mse(datasets):
    loss = 0
    for data, label in datasets:
        batch_loss = cost(model(data), label)
        loss += batch_loss
    return loss


# 一般在训练模型时加上model.train(),这样会正常使用BatchNormalizationDropout
# 测试的时候一般选择model.eval(),这样就不会使用BatchNormalizationDropout
for epoch in range(100):
    model.train()
    for data, label in test_data:
        batch_loss = cost(model(data), label)
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    model.eval()
    if (epoch + 1) % 10 == 0:
        print('训练集loss {}\t测试集loss {}'.format(test_mse(train_data)/120,test_mse(test_data)/30))



输出结果如下

img

  • 写回答

2条回答 默认 最新

  • m0_61899108 2022-10-25 10:44
    关注

    尝试把batch size调小,或者学习率调小点试试。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 11月4日
  • 已采纳回答 10月27日
  • 创建了问题 10月25日

悬赏问题

  • ¥30 使用matlab将观测点聚合成多条目标轨迹
  • ¥15 Workbench中材料库无法更新,如何解决?
  • ¥20 如何推断此服务器配置
  • ¥15 关于github的项目怎么在pycharm上面运行
  • ¥15 内存地址视频流转RTMP
  • ¥100 有偿,谁有移远的EC200S固件和最新的Qflsh工具。
  • ¥15 有没有整苹果智能分拣线上图像数据
  • ¥20 有没有人会这个东西的
  • ¥15 cfx考虑调整“enforce system memory limit”参数的设置
  • ¥30 航迹分离,航迹增强,误差分析