Eticos_CZ 2023-03-25 12:47 采纳率: 75%
浏览 31
已结题

pytorch如何提取子层的特征图输出

在复现STDN网络的过程中,需要提取最后一个denseblock的中间某6层特征图输出用作预测和回归。
我所编写网络的部分打印结果如下:

…………
(transition3): Sequential(
    (transition_bn): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (transition_relu): ReLU(inplace=True)
    (transition_conv): Conv2d(1280, 640, kernel_size=(1, 1), stride=(1, 1))
    (transition_pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (denseblock4): Sequential(
    (dense_0): denselayer(
      (bn1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (dense_1): denselayer(
      (bn1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
…………

其中(transition3)为(denseblock4)前的过渡层,(denseblock4)内有dense_0 - 31共32个密集连接层。现在我需要提取在(denseblock4)其中的dense_4 9 14 19 24 31的输出特征图出来,请问应该怎么做?
网上找寻的方法只能提取整个(denseblock4)的输出,没法再深入到(denseblock4)其中的子层。
感谢!!

  • 写回答

2条回答 默认 最新

  • Eticos_CZ 2023-03-28 11:06
    关注

    在forward中,通过子层名提取出所需要的特征图结果输出:

        def forward(self, x):
            …… ……
            feature_list = []
            for (name, module) in self.denseblock4.named_children():
                name_list = ['dense_4', 'dense_9', 'dense_14', 'dense_19', 'dense_24', 'dense_31']
                x = module(x)
                if name in name_list:
                    feature_list.append(x)
            return x, feature_list
    
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 4月5日
  • 已采纳回答 3月28日
  • 创建了问题 3月25日