穆穆青风至 2022-10-25 10:18 采纳率: 97.4%
浏览 88
已结题

用pytorch写了一个经典的鸢尾花分类

损失降不下来甚至有升高的趋势是怎么回事

import torch
import numpy as np
from sklearn import datasets
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

x = datasets.load_iris().data
y = datasets.load_iris().target

x = torch.from_numpy(np.float32(x))
y = torch.from_numpy(np.float32(y))
x_data = x[:120]
x_test = x[120:]
y_data = x[:120]
y_test = x[120:]

train_data = TensorDataset(x_data, y_data)
train_data = DataLoader(train_data, batch_size=16, shuffle=True)

test_data = TensorDataset(x_test, y_test)
test_data = DataLoader(test_data, batch_size=16)


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = torch.nn.Linear(4, 16)
        self.out = torch.nn.Linear(16, 4)

    def forward(self, x):
        x = torch.nn.functional.relu(self.hidden(x))
        y = self.out(x)
        return y


model = MyModel()
lea = 0.001
cost = torch.nn.functional.cross_entropy
optimizer = torch.optim.SGD(model.parameters(), lr=lea)


def test_mse(datasets):
    loss = 0
    for data, label in datasets:
        batch_loss = cost(model(data), label)
        loss += batch_loss
    return loss


# 一般在训练模型时加上model.train(),这样会正常使用BatchNormalizationDropout
# 测试的时候一般选择model.eval(),这样就不会使用BatchNormalizationDropout
for epoch in range(100):
    model.train()
    for data, label in test_data:
        batch_loss = cost(model(data), label)
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    model.eval()
    if (epoch + 1) % 10 == 0:
        print('训练集loss {}\t测试集loss {}'.format(test_mse(train_data)/120,test_mse(test_data)/30))



输出结果如下

img

  • 写回答

2条回答 默认 最新

  • m0_61899108 2022-10-25 10:44
    关注

    尝试把batch size调小,或者学习率调小点试试。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 11月4日
  • 已采纳回答 10月27日
  • 创建了问题 10月25日

悬赏问题

  • ¥15 关于#c语言#的问题:构成555单稳态触发器,采用LED指示灯延时时间,对延时时间进行测量并显示(如楼道声控延时灯)需要Proteus仿真图和C语言代码
  • ¥50 神舟笔记本,没有linux的驱动,装的Ubuntu系统,想把风扇速度调到最大
  • ¥15 workstation加载centos进入emergency模式,查看日志报警如图,怎样解决呢?
  • ¥50 如何用单纯形法寻优不能精准找不到给定的参数,并联机构误差识别,给定误差有7个?matlab
  • ¥15 workstation加载centos进入emergency模式,查看日志报警如图,没有XFS,怎样解决呢?
  • ¥15 应用商店如何检测在架应用内容是否违规?
  • ¥15 Ubuntu系统配置PX4
  • ¥50 nw.js调用activex
  • ¥15 数据库获取信息反馈出错,直接查询了ref字段并且还使用了User文档的_id而不是自己的
  • ¥15 将安全信息用到以下对象时发生以下错误:c:dumpstack.log.tmp 另一个程序正在使用此文件,因此无法访问