m0_58447185 2023-03-15 20:40 采纳率: 41.7%
浏览 24
已结题

alexnet训练自己的数据集通道数报错

alexnet预训练结束,训练自己的数据集,通道报错

Traceback (most recent call last):
  File "E:\data\aasuceessful\train.py", line 132, in <module>
    main()
  File "E:\data\aasuceessful\train.py", line 93, in main
    outputs = alexnet(images.to(device))
……

RuntimeError: Given groups=1, weight of size [64, 1, 11, 11], expected input[3, 3, 224, 224] to have 1 channels, but got 3 channels instead

代码如下,预训练模型通道数已改为1

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from premodel import BuildAlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, r"C:\Users\ASUS\Desktop\L", "split")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {    "0": "benign","1": "malignant"}
    split_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in split_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 3
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=3, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))

    #alexnet = BuildAlexNet(model_type='new',n_output=2)


    alexnet.to(device)
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(alexnet.parameters(), lr=0.0002)

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        alexnet.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = alexnet(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        alexnet.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = alexnet(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(alexnet.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    model_type = 'new'
    n_output = 2
    alexnet = BuildAlexNet(model_type, n_output)  # 调用函数buildAlexnet,网络选项是预训练模型,输出是2,也就是分两类
    print(alexnet)
    main()
  • 写回答

2条回答 默认 最新

  • m0_58447185 2023-03-16 17:05
    关注

    将两个compose改为下面代码

    transforms.Grayscale(1),
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 3月24日
  • 已采纳回答 3月16日
  • 修改了问题 3月15日
  • 修改了问题 3月15日
  • 展开全部

悬赏问题

  • ¥15 Stata链式中介效应代码修改
  • ¥15 latex投稿显示click download
  • ¥15 请问读取环境变量文件失败是什么原因?
  • ¥15 在若依框架下实现人脸识别
  • ¥15 添加组件无法加载页面,某块加载卡住
  • ¥15 网络科学导论,网络控制
  • ¥100 安卓tv程序连接SQLSERVER2008问题
  • ¥15 利用Sentinel-2和Landsat8做一个水库的长时序NDVI的对比,为什么Snetinel-2计算的结果最小值特别小,而Lansat8就很平均
  • ¥15 metadata提取的PDF元数据,如何转换为一个Excel
  • ¥15 关于arduino编程toCharArray()函数的使用