Blucoris 2022-07-22 12:31 采纳率: 75%
浏览 91
已结题

pytorch的resnet猫狗大战跑不完不收敛

问题遇到的现象和发生背景

我用这个代码跑了一下猫狗大战数据集,为什么跑到一半就跑不下去了,只能跑到第二代,而且感觉跑得非常慢。

问题相关代码,请勿粘贴截图
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 22 10:26:33 2022

# 11:34第一代
# 11.48第二代
# 12:14 跑到一半不跑了
@author: 19544
"""

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

EPOCH=5
BATCH_SIZE=40
LR=0.01

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])   

train_dataset = datasets.ImageFolder(
        'D:\\项目实验文件夹\\猫狗大战数据集\\dogcat_2',
        transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
                ]))

train_loader = Data.DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True)

test_loader = Data.DataLoader(
        datasets.ImageFolder(
                'D:\\项目实验文件夹\\猫狗大战数据集\\dogcat_2', 
                transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        normalize,
                        ])),
        batch_size=BATCH_SIZE, shuffle=False,)

model = models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(in_features=512, out_features=5, bias=True)

fc_params = list(map(id, model.fc.parameters())) # map函数是将fc.parameters()的id返回并组成一个列表
base_params = filter(lambda p: id(p) not in fc_params, model.parameters()) # filter函数是将model.parameters()中地址不在fc.parameters的id中的滤出来
optimizer = torch.optim.SGD([ {'params': base_params}, {'params': model.fc.parameters(), 'lr': LR * 100}], lr=LR,)
loss_func=nn.CrossEntropyLoss()

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)
            
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res           
            
train_losses = AverageMeter('TrainLoss', ':.4e')
train_top1 = AverageMeter('TrainAccuracy', ':6.2f')
test_losses = AverageMeter('TestLoss', ':.4e')
test_top1 = AverageMeter('TestAccuracy', ':6.2f')

for epoch in range(EPOCH):
    
    model.train()
    for i,(images,target) in enumerate(train_loader):
        output=model(images)
        loss= loss_func(output,target)
        
        acc1, = accuracy(output, target, topk=(1,))
        train_losses.update(loss.item(), images.size(0))
        train_top1.update(acc1[0], images.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print('Epoch[{}/{}],TrainLoss:{}, TrainAccuracy:{}'.format(epoch,EPOCH,train_losses.val, train_top1.val))
           
    model.eval()
    with torch.no_grad():
        for i,(images,target) in enumerate(test_loader):
            output=model(images)
            loss= loss_func(output,target)
            
            acc1, = accuracy(output, target, topk=(1,))
            test_losses.update(loss.item(), images.size(0))
            test_top1.update(acc1[0], images.size(0))
            
    print('TestLoss:{}, TestAccuracy:{}'.format(test_losses.avg, test_top1.avg))


运行结果及报错内容

Epoch[2/5],TrainLoss:0.7035315036773682, TrainAccuracy:47.5
Epoch[2/5],TrainLoss:0.7905141711235046, TrainAccuracy:47.5
Epoch[2/5],TrainLoss:0.7110738158226013, TrainAccuracy:47.5
Epoch[2/5],TrainLoss:0.709513783454895, TrainAccuracy:47.5
Epoch[2/5],TrainLoss:0.6796354055404663, TrainAccuracy:60.0
Epoch[2/5],TrainLoss:0.6862636804580688, TrainAccuracy:55.0

我的解答思路和尝试过的方法

我尝试过改了下代数,可是还是算得太长,而且正确率不收敛。

我想要达到的结果

希望帮忙让它跑完并收敛。

  • 写回答

3条回答 默认 最新

  • 迪菲赫尔曼 人工智能领域优质创作者 2022-07-22 13:14
    关注

    跑得慢是因为没有使用cuda加速

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

问题事件

  • 系统已结题 7月31日
  • 已采纳回答 7月23日
  • 创建了问题 7月22日

悬赏问题

  • ¥20 西门子S7-Graph,S7-300,梯形图
  • ¥50 用易语言http 访问不了网页
  • ¥50 safari浏览器fetch提交数据后数据丢失问题
  • ¥15 matlab不知道怎么改,求解答!!
  • ¥15 永磁直线电机的电流环pi调不出来
  • ¥15 用stata实现聚类的代码
  • ¥15 请问paddlehub能支持移动端开发吗?在Android studio上该如何部署?
  • ¥20 docker里部署springboot项目,访问不到扬声器
  • ¥15 netty整合springboot之后自动重连失效
  • ¥15 悬赏!微信开发者工具报错,求帮改