卡尔顿 2022-09-01 00:20 采纳率: 33.3%
浏览 41
已结题

CNN网络2分类失败

问题遇到的现象和发生背景

建立了一个二分类的CNN网络,但是没有识别能力。一个类是西方龙图片(很多相似),一个类是其他图片(差别很大).为什么会没有识别能力。很迷惑

问题相关代码,请勿粘贴截图

import torch
import torchvision
import os
from PIL import Image
from torch.utils.data import Dataset
import torch.utils.data as Data
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class MyData(Dataset):
    def __init__(self,root,labal):
        self.dic={'dragon':[1,0],'else':[0,1]}
        self.root=root
        self.labal=labal
        self.path=os.path.join(self.root,self.labal)
        self.imgpass=os.listdir(self.path)
    def __getitem__(self, item):
        imgname=self.imgpass[item]
        imgitempass=os.path.join(self.root,self.labal,imgname)
        img =Image.open(imgitempass).convert('RGB')
        labal=self.dic[self.labal]
        trains = torchvision.transforms.Resize([256, 256])
        trains1 = torchvision.transforms.ToTensor()
        img1=trains(img)
        return trains1(img1),torch.tensor(labal,dtype=torch.float32)
    def __len__(self):
        return len(self.imgpass)

mydata=MyData(r'C:\Users\xiaob\Desktop\西方龙',r'dragon')
batch = 3

loader = Data.DataLoader(
    dataset=mydata,
    batch_size=batch,
    drop_last=True,
    num_workers=0
 )


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1=nn.Conv2d(3,32,5,padding=2,stride=1)
        self.maxpoo1=nn.MaxPool2d(2)
        self.conv2=nn.Conv2d(32,32,5,padding=2,stride=1)
        self.maxpoo2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(32, 64, 5, padding=2, stride=1)
        self.maxpoo3 = nn.MaxPool2d(2)
        self.conv4 = nn.Conv2d(64, 64, 5, padding=2, stride=1)
        self.maxpoo4 = nn.MaxPool2d(2)
        self.maxpoo5 = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.Linear1 = nn.Linear(64*64, 64)
        self.Linear2 = nn.Linear(64, 16)
        self.Linear3 = nn.Linear(16, 2)

    def forward(self,x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.maxpoo1(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = self.maxpoo2(x)
        x = self.conv3(x)
        x = nn.functional.relu(x)
        x = self.maxpoo3(x)
        x = self.conv4(x)
        x = nn.functional.relu(x)
        x = self.maxpoo4(x)
        x = self.maxpoo5(x)
        x = x.view(-1,4096)
        x = self.Linear1(x)
        x = self.Linear2(x)
        x = self.Linear3(x)
        return x


model=Model()

model = torch.load(r'C:\Users\xiaob\Desktop\西方龙\model\mo.pt')
model.eval()

print('mode is complit ,1 for trainning ,0 for test')
x=input(':')
if (int(x)==1) :
    loss = nn.CrossEntropyLoss()
    optim=torch.optim.SGD(model.parameters(),lr=0.01)
    for epoch in range(100):
        runningloss=0.0
        for data in loader:
            img,targets=data
            ouputs=model(img)
            result_loss=loss(ouputs,targets)
            optim.zero_grad()
            result_loss.backward()
            optim.step()
            runningloss=runningloss+result_loss
        print(runningloss)

    torch.save(model,r'C:\Users\xiaob\Desktop\西方龙\model\mo.pt')
else:
    x2 = input('root:')#测试图片地址
    img1=Image.open(x2).convert('RGB')
    print('openseccses')
    trains = torchvision.transforms.Resize([256, 256])
    trains1 = torchvision.transforms.ToTensor()
    img1=trains(img1)
    img1 = trains1(img1)
    print('trainsseccses')
    ouputs = model(img1)
    print('outsseccses')
    kk=torch.nn.functional.softmax(ouputs,dim=1)[0]
    kk=kk.detach().numpy().tolist()[0]
    print(kk)
运行结果及报错内容

不知道为什么,啥图片放上去都是0.99997的概率.白云图片,狗图片,放上去全是0.999997

我的解答思路和尝试过的方法

我觉得是因为只训练了龙图片,于是我又拿了一个狗类的图片去训练。然后当狗的识别率上升的时候,龙图片的识别率急速下降。反之亦然。我在想是不是激活函数的问题

我想要达到的结果

想知道如何构建一个识别西方龙图片的网络,以及我做错了什么。本人qq248167069

  • 写回答

2条回答 默认 最新

  • 万里鹏程转瞬至 人工智能领域优质创作者 2022-09-01 09:11
    关注

    因为你返回的数据都是只有一个类别的r'dragon',训练时的标签一直是[1,0],所以模型针对任何数据的输出都是1,0。你应该修改一下你的数据加载器。你可以先参考一下我的博客Pytorch 2 迁移学习 图像数据集分类_万里鹏程转瞬至的博客-CSDN博客

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 9月9日
  • 已采纳回答 9月1日
  • 创建了问题 9月1日

悬赏问题

  • ¥15 CARSIM前车变道设置
  • ¥50 三种调度算法报错 有实例
  • ¥15 关于#python#的问题,请各位专家解答!
  • ¥200 询问:python实现大地主题正反算的程序设计,有偿
  • ¥15 smptlib使用465端口发送邮件失败
  • ¥200 总是报错,能帮助用python实现程序实现高斯正反算吗?有偿
  • ¥15 对于squad数据集的基于bert模型的微调
  • ¥15 为什么我运行这个网络会出现以下报错?CRNN神经网络
  • ¥20 steam下载游戏占用内存
  • ¥15 CST保存项目时失败