Issssac 2022-02-16 18:53
浏览 18
已结题

我想自己搭建出来一个利用自己数据的CNN卷积神经网络,出现一些问题

import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
import cv2,os
import torch.utils.data
import torch
from PIL import Image
from torchvision import transforms

epoches = 2
batch_size = 50
learning_rate = 0.001
class Coaldataset(torch.utils.data.Dataset):
    def __init__(self, train=True, train_transform=None, target_transform=None,mode = None,stage = None  ):
        self.mode = mode
        self.stage = stage
        self.target_transform = target_transform
        self.train_transform = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),
                                                   transforms.ToTensor()
        ])
        self.train = train
        if self.train:
            data_file = "c:/Users/admin/Desktop/Coal_Stone_Classification/AI_Dict_GPU/DataBase_Train_Test/Coal_data/image_Train.txt"
        else:
            data_file = "c:/Users/admin/Desktop/Coal_Stone_Classification/AI_Dict_GPU/DataBase_Train_Test/Coal_data/image_Train.txt"
        self.data = []
        self.targets = []
        cnt = 0
        for line in open(data_file):
            str_list = line.split()
            img = cv2.imread("c:/Users/admin/Desktop/Coal_Stone_Classification/AI_Dict_GPU" + str_list[0])
            # 统一数据集大小


            img = cv2.resize(img, (64,64), interpolation=cv2.INTER_AREA)

            # img = cv2.resize(img, (224,224), interpolation=cv2.INTER_AREA)
            # img = cv2.resize(img, (800, 800), interpolation=cv2.INTER_AREA)
            #cv2.imshow("img", img)#查图片用的
            #cv2.waitKey(20)
            self.data.append(img)
            self.targets.append(str_list[1])

            cnt += 1
            if (cnt % 100 == 0):
                print(cnt)

    def __getitem__(self, index):
        img, target = self.data[index], (self.targets[index])
        if target == "coal":
            target = int(1)
        else: target = int(0)

        img = Image.fromarray(img)
        if self.train_transform is not None:
            img = self.train_transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

    def __len__(self):
        return len(self.data)

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=16,
                kernel_size = 5,
                stride=1,
                padding = 2,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=6,
                stride=1,
                padding=2,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.output=nn.Linear(in_features=32*16*16,out_features=10)

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(-1,32*16*16)
        output = self.output(x)
        return  output


train_dataset = Coaldataset(train=True,train_transform=True,target_transform=None)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size = batch_size,shuffle=True )
def main():

    cnn = CNN()
    print(cnn)

    optimizer = torch.optim.Adam(cnn.parameters(),lr=learning_rate)
    loss_function = nn.CrossEntropyLoss()

    for epoch in range(epoches):
        print("进行第{}个epoch".format(epoch))
        for step ,(batch_x,batch_y) in enumerate(train_loader):
            output = cnn(batch_x)
            loss = loss_function(output,batch_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(loss)

if __name__ == "__main__":
    main()

这是我自己做的最简易型CNN卷积神经网络,debug后我的图片size如下

img

出现的是这个问题,但是我不知道该怎么修改了

img

  • 写回答

0条回答 默认 最新

    报告相同问题?

    问题事件

    • 系统已结题 2月24日
    • 创建了问题 2月16日

    悬赏问题

    • ¥15 系统 24h2 专业工作站版,浏览文件夹的图库,视频,图片之类的怎样删除?
    • ¥15 怎么把512还原为520格式
    • ¥15 MATLAB的动态模态分解出现错误,以CFX非定常模拟结果为快照
    • ¥15 求高通平台Softsim调试经验
    • ¥15 canal如何实现将mysql多张表(月表)采集入库到目标表中(一张表)?
    • ¥15 wpf ScrollViewer实现冻结左侧宽度w范围内的视图
    • ¥15 栅极驱动低侧烧毁MOSFET
    • ¥30 写segy数据时出错3
    • ¥100 linux下qt运行QCefView demo报错
    • ¥50 F1C100S下的红外解码IR_RX驱动问题