好吃的是生煎 2024-04-15 17:42 采纳率: 100%
浏览 8
已结题

训练集中Val的Loss越来越高,acc也越来越高

左边是训练集和验证集代码,右边是训练结果
大家能帮忙看看代码有什么问题吗,Valid的数据越看越怪

img

  • 写回答

4条回答 默认 最新

  • 关注

    上午好☀️☀️☀️️
    本答案参考ChatGPT-3.5

    问题:训练集中Val的Loss越来越高,acc也越来越高,代码有什么问题?

    解决方案:

    1. 验证集的数据看起来比较怪异,可能需要确认验证集的数据是否正常。

    2. 可以添加一些正则化方式,比如权重衰减和Dropout等,来减少过拟合,从而提高验证集效果。

    3. 可以尝试修改一些超参数,比如学习率、批次大小等,以达到更好的训练效果。需要注意的是,每次修改一次超参数之后,都需要重新训练网络,并对验证集进行验证。

    修改后的代码(添加L2正则化):

    import torch.nn.functional as F
    import torch.optim as optim
    from torch.utils.data import DataLoader
    from torch.optim.lr_scheduler import ReduceLROnPlateau
    from torchvision.datasets import ImageFolder
    from torchvision.transforms import transforms
    
    transform = transforms.Compose([
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    train_dataset = ImageFolder(root='train', transform=transform)
    val_dataset = ImageFolder(root='val', transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
            self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
            self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
            self.fc1 = nn.Linear(256 * 28 * 28, 1024)
            self.fc2 = nn.Linear(1024, 176)
            self.dropout = nn.Dropout(p=0.5)
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = self.pool(F.relu(self.conv3(x)))
            x = x.view(-1, 256 * 28 * 28)
            x = F.relu(self.fc1(x))
            x = self.dropout(x)
            x = self.fc2(x)
            return x
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = Net().to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5, verbose=True)
    
    train_loss = []
    train_accs = []
    val_loss = []
    val_accs = []
    
    for epoch in range(50):
        model.train()
        for batch_idx, (imgs, labels) in enumerate(train_loader):
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(imgs)
            loss = criterion(output, labels)
            loss += 0.001 * torch.norm(model.fc1.weight)  # 添加L2正则化
            loss.backward()
            optimizer.step()
            acc = (output.argmax(dim=-1) == labels).float().mean()
            train_loss.append(loss.item())
            train_accs.append(acc)
    
        model.eval()
        with torch.no_grad():
            for batch_idx, (imgs, labels) in enumerate(val_loader):
                imgs, labels = imgs.to(device), labels.to(device)
                output = model(imgs)
                loss = criterion(output, labels)
                acc = (output.argmax(dim=-1) == labels).float().mean()
                val_loss.append(loss.item())
                val_accs.append(acc)
                
        train_loss_mean = sum(train_loss) / len(train_loss)
        train_acc_mean = sum(train_accs) / len(train_accs)
        val_loss_mean = sum(val_loss) / len(val_loss)
        val_acc_mean = sum(val_accs) / len(val_accs)
        
        scheduler.step(val_acc_mean)
        
        print(f"[Epoch {epoch+1:03d}/{50:03d}] train_loss: {train_loss_mean:.5f}, train_acc: {train_acc_mean:.5f}, val_loss: {val_loss_mean:.5f}, val_acc: {val_acc_mean:.5f}")
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(3条)

报告相同问题?

问题事件

  • 系统已结题 4月29日
  • 已采纳回答 4月21日
  • 创建了问题 4月15日

悬赏问题

  • ¥15 SDO如何更改STM32的波特率
  • ¥15 elasticsearch
  • ¥15 uniapp的uni-datetime-picker组件在ios端不适配
  • ¥15 前端 uniapp App端在离线状态如何使用modbus 连接手机蓝牙进行读写操控机器?
  • ¥15 SQL语句根据字段自动生成行
  • ¥500 “掌声响起来”软件(不确定性人工智能)
  • ¥500 我要找大模型一体机产品手册和应用案例
  • ¥20 关于游戏c++语言代码问题
  • ¥15 如何制作永久二维码,最好是微信也可以扫开的。(相关搜索:管理系统)
  • ¥15 delphi indy cookie 有效期