m0_58447185 2023-03-15 20:40 采纳率: 41.7%
浏览 22
已结题

alexnet训练自己的数据集通道数报错

alexnet预训练结束,训练自己的数据集,通道报错

Traceback (most recent call last):
  File "E:\data\aasuceessful\train.py", line 132, in <module>
    main()
  File "E:\data\aasuceessful\train.py", line 93, in main
    outputs = alexnet(images.to(device))
……

RuntimeError: Given groups=1, weight of size [64, 1, 11, 11], expected input[3, 3, 224, 224] to have 1 channels, but got 3 channels instead

代码如下,预训练模型通道数已改为1

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from premodel import BuildAlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, r"C:\Users\ASUS\Desktop\L", "split")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {    "0": "benign","1": "malignant"}
    split_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in split_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 3
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=3, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))

    #alexnet = BuildAlexNet(model_type='new',n_output=2)


    alexnet.to(device)
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(alexnet.parameters(), lr=0.0002)

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        alexnet.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = alexnet(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        alexnet.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = alexnet(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(alexnet.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    model_type = 'new'
    n_output = 2
    alexnet = BuildAlexNet(model_type, n_output)  # 调用函数buildAlexnet,网络选项是预训练模型,输出是2,也就是分两类
    print(alexnet)
    main()
  • 写回答

2条回答 默认 最新

  • m0_58447185 2023-03-16 17:05
    关注

    将两个compose改为下面代码

    transforms.Grayscale(1),
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 3月24日
  • 已采纳回答 3月16日
  • 修改了问题 3月15日
  • 修改了问题 3月15日
  • 展开全部

悬赏问题

  • ¥30 sort cuteSV.vcf by bcftools用IGV可视化出现报错
  • ¥100 SOS!对STK中导出的天体图像进行质心提取有没有人做过啊
  • ¥15 python 欧式距离
  • ¥15 运行qteasy报错
  • ¥15 遗传算法解决有工序顺序约束的大规模FJSP问题
  • ¥15 企业消防水炮塔设计方案
  • ¥20 WORKBENCH网格划分
  • ¥60 急招师兄远程解决下载NPCAP的BUG!
  • ¥15 关于#51单片机#的问题:51单片机LM1602的数据只能显示一个字符,在使用矩形键盘送数据的时候不能显示出来而是在显示初始位置上,达不到密码锁的效果
  • ¥15 求旧版本ns-2软件