m0_58447185 2023-03-14 21:53 采纳率: 41.7%
浏览 30
已结题

如何将三通道改为一通道

使用pytorch预训练alexnet网络,如何将三通道调整为一通道

from torchvision import models
from torchvision.models import AlexNet_Weights



model = models.alexnet(weights=AlexNet_Weights.DEFAULT)
print(model)

import torch.nn as nn  #nn设置网络结构详细参数
from torchvision import models
#torchvision包,它包括3个子包,分别是: torchvison.datasets ,torchvision.models ,torchvision.transforms ,
# # 分别是预定义好的数据集(比如MNIST、CIFAR10等)、预定义好的经典网络结构(比如AlexNet、VGG、ResNet等)
# # 和预定义好的数据增强方法(比如Resize、ToTensor等)。

#模型预训练

class BuildAlexNet(nn.Module):
    def __init__(self, model_type, n_output):
        super(BuildAlexNet, self).__init__()
        self.model_type = model_type
        if model_type == 'pre':        #定义两种model类型,一个直接从alexnet中继承这个参数和结构,定义名称为‘pre’。
                                       # 另一个是自己设定的网络结构,定义为'new'
            model = models.alexnet(weights=AlexNet_Weights.DEFAULT)
            #加载alexnet模型,pretrained为真,则加载网络结构和预训练参数。否则,只加载网络结构[2]
            self.features = model.features
            # 因为只要求更改最后的分类数,所以feature类直接从预训练网络中继承classifier类除了要更改的分类层,其他的也从原网络中定义好
            fc1 = nn.Linear(9216, 4096)
            # fc1和fc2继承原网络的classifier参数
            fc1.bias = model.classifier[1].bias
            fc1.weight = model.classifier[1].weight

            fc2 = nn.Linear(4096, 4096)
            fc2.bias = model.classifier[4].bias
            fc2.weight = model.classifier[4].weight

            self.classifier = nn.Sequential(
            # 定义新的classifier层,前两层保持不变,底端分类层分类数用n_output代替
                nn.Dropout(),
                fc1,
                nn.ReLU(inplace=True),
                nn.Dropout(),
                fc2,
                nn.ReLU(inplace=True),
                nn.Linear(4096, n_output))
            # 或者直接修改为
        #            model.classifier[6]==nn.Linear(4096,n_output)
        #            self.classifier = model.classifier
        if model_type == 'new':
        # 这是自己定义的网络模型(feature,classifier)
            self.features = nn.Sequential(
                #nn.Conv2d(3, 64, 11, 4, 2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(3, 2, 0),
                nn.Conv2d(64, 192, 5, 1, 2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(3, 2, 0),
                nn.Conv2d(192, 384, 3, 1, 1),
                nn.ReLU(inplace=True),
                nn.Conv2d(384, 256, 3, 1, 1),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(3, 2, 0))
            self.classifier = nn.Sequential(
                nn.Dropout(),
                nn.Linear(9216, 4096),
                nn.ReLU(inplace=True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(inplace=True),
                nn.Linear(4096, n_output))




    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        out = self.classifier(x)
        return out


#测试维度是否正确

import numpy as np           #从这里开始就是验证函数
from torch.autograd import Variable
import torch

if __name__ == '__main__':
    model_type = 'pre'
    n_output = 2
    alexnet = BuildAlexNet(model_type, n_output)            #调用函数buildAlexnet,网络选项是预训练模型,输出是2,也就是分两类
    model.conv1 = nn.Conv2d(1, 64, 4, 2)
    model.conv1.weight = alexnet.conv1.weight[:, 1, :, :]
    print(model.conv1)
   # print(alexnet)

    x = np.random.rand(1, 3, 224, 224)            #随机创建一个224*224,通道为3的数组,模拟三通道的图片
    x = x.astype(np.float32)
    x_ts = torch.from_numpy(x)                 #转换成torch能用的张量形式
    x_in = Variable(x_ts)
    y = alexnet(x_in)                     #输出结果



尝试过在model_type中添加


           #w = layers[0].weight
            #layers[0] = nn.Conv2d(1, 64, 11, 4, 2 )
            #layers[0].weight = torch.nn.Parameter(w[:, :1, :, :])
            #torch.save(alexnet, save_weights)

仍然报错

  • 写回答

3条回答 默认 最新

  • kingbqx 2023-03-14 22:04
    关注

    若要将AlexNet的三通道调整为一通道,可以在模型定义时修改第一层卷积层的输入通道数。具体地,在BuildAlexNet类中,找到self.features,即AlexNet的特征提取部分,将其中的第一层卷积层从原来的nn.Conv2d(3, 64, 11, 4, 2)修改为nn.Conv2d(1, 64, 11, 4, 2)即可。

    修改后的BuildAlexNet类代码如下:

    class BuildAlexNet(nn.Module):
        def __init__(self, model_type, n_output):
            super(BuildAlexNet, self).__init__()
            self.model_type = model_type
            if model_type == 'pre':
                model = models.alexnet(weights=AlexNet_Weights.DEFAULT)
                self.features = model.features
                fc1 = nn.Linear(9216, 4096)
                fc1.bias = model.classifier[1].bias
                fc1.weight = model.classifier[1].weight
    
                fc2 = nn.Linear(4096, 4096)
                fc2.bias = model.classifier[4].bias
                fc2.weight = model.classifier[4].weight
    
                self.classifier = nn.Sequential(
                    nn.Dropout(),
                    fc1,
                    nn.ReLU(inplace=True),
                    nn.Dropout(),
                    fc2,
                    nn.ReLU(inplace=True),
                    nn.Linear(4096, n_output))
            if model_type == 'new':
                self.features = nn.Sequential(
                    nn.Conv2d(1, 64, 11, 4, 2), # 修改这一行,将输入通道数从3改为1
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(3, 2, 0),
                    nn.Conv2d(64, 192, 5, 1, 2),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(3, 2, 0),
                    nn.Conv2d(192, 384, 3, 1, 1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(384, 256, 3, 1, 1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(3, 2, 0))
                self.classifier = nn.Sequential(
                    nn.Dropout(),
                    nn.Linear(9216, 4096),
                    nn.ReLU(inplace=True),
                    nn.Dropout(),
                    nn.Linear(4096, 4096),
                    nn.ReLU(inplace=True),
                    nn.Linear(4096, n_output))
    
        def forward(self, x):
            x = self.features(x)
            x = x.view(x.size(0), -1)
            out = self.classifier(x)
            return out
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(2条)

报告相同问题?

问题事件

  • 系统已结题 3月24日
  • 已采纳回答 3月16日
  • 创建了问题 3月14日

悬赏问题

  • ¥15 我需要全国每个城市的最新小区名字等数据。
  • ¥15 开发一个小区生态的小程序
  • ¥15 MddBootstrapInitialize2失败
  • ¥15 LCD Flicker
  • ¥15 Spring MVC项目,访问不到相应的控制器方法
  • ¥15 esp32在micropython环境下使用ssl/tls连接mqtt服务器出现以下报错Connected on 192.168.154.223发生意外错误: 5无法连接到 MQTT 代理,如何解决?
  • ¥15 关于#genesiscsheel#的问题,如何解决?
  • ¥15 Android aidl for hal
  • ¥15 STM32CubeIDE下载程序报错
  • ¥15 微信好友如何转变为会员系统?(相关搜索:小程序)