m0_58447185 2023-03-15 20:40 采纳率: 41.7%
浏览 22
已结题

alexnet训练自己的数据集通道数报错

alexnet预训练结束,训练自己的数据集,通道报错

Traceback (most recent call last):
  File "E:\data\aasuceessful\train.py", line 132, in <module>
    main()
  File "E:\data\aasuceessful\train.py", line 93, in main
    outputs = alexnet(images.to(device))
……

RuntimeError: Given groups=1, weight of size [64, 1, 11, 11], expected input[3, 3, 224, 224] to have 1 channels, but got 3 channels instead

代码如下,预训练模型通道数已改为1

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from premodel import BuildAlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, r"C:\Users\ASUS\Desktop\L", "split")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {    "0": "benign","1": "malignant"}
    split_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in split_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 3
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=3, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))

    #alexnet = BuildAlexNet(model_type='new',n_output=2)


    alexnet.to(device)
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(alexnet.parameters(), lr=0.0002)

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        alexnet.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = alexnet(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        alexnet.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = alexnet(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(alexnet.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    model_type = 'new'
    n_output = 2
    alexnet = BuildAlexNet(model_type, n_output)  # 调用函数buildAlexnet,网络选项是预训练模型,输出是2,也就是分两类
    print(alexnet)
    main()
  • 写回答

2条回答 默认 最新

  • m0_58447185 2023-03-16 17:05
    关注

    将两个compose改为下面代码

    transforms.Grayscale(1),
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 3月24日
  • 已采纳回答 3月16日
  • 修改了问题 3月15日
  • 修改了问题 3月15日
  • 展开全部

悬赏问题

  • ¥120 计算机网络的新校区组网设计
  • ¥20 完全没有学习过GAN,看了CSDN的一篇文章,里面有代码但是完全不知道如何操作
  • ¥15 使用ue5插件narrative时如何切换关卡也保存叙事任务记录
  • ¥20 海浪数据 南海地区海况数据,波浪数据
  • ¥20 软件测试决策法疑问求解答
  • ¥15 win11 23H2删除推荐的项目,支持注册表等
  • ¥15 matlab 用yalmip搭建模型,cplex求解,线性化处理的方法
  • ¥15 qt6.6.3 基于百度云的语音识别 不会改
  • ¥15 关于#目标检测#的问题:大概就是类似后台自动检测某下架商品的库存,在他监测到该商品上架并且可以购买的瞬间点击立即购买下单
  • ¥15 神经网络怎么把隐含层变量融合到损失函数中?