Eticos_CZ 2023-03-25 12:47 采纳率: 75%
浏览 31
已结题

pytorch如何提取子层的特征图输出

在复现STDN网络的过程中,需要提取最后一个denseblock的中间某6层特征图输出用作预测和回归。
我所编写网络的部分打印结果如下:

…………
(transition3): Sequential(
    (transition_bn): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (transition_relu): ReLU(inplace=True)
    (transition_conv): Conv2d(1280, 640, kernel_size=(1, 1), stride=(1, 1))
    (transition_pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (denseblock4): Sequential(
    (dense_0): denselayer(
      (bn1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (dense_1): denselayer(
      (bn1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
…………

其中(transition3)为(denseblock4)前的过渡层,(denseblock4)内有dense_0 - 31共32个密集连接层。现在我需要提取在(denseblock4)其中的dense_4 9 14 19 24 31的输出特征图出来,请问应该怎么做?
网上找寻的方法只能提取整个(denseblock4)的输出,没法再深入到(denseblock4)其中的子层。
感谢!!

  • 写回答

2条回答 默认 最新

  • Eticos_CZ 2023-03-28 11:06
    关注

    在forward中,通过子层名提取出所需要的特征图结果输出:

        def forward(self, x):
            …… ……
            feature_list = []
            for (name, module) in self.denseblock4.named_children():
                name_list = ['dense_4', 'dense_9', 'dense_14', 'dense_19', 'dense_24', 'dense_31']
                x = module(x)
                if name in name_list:
                    feature_list.append(x)
            return x, feature_list
    
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 4月5日
  • 已采纳回答 3月28日
  • 创建了问题 3月25日

悬赏问题

  • ¥15 CVRP 图论 物流运输优化
  • ¥15 Tableau online 嵌入ppt失败
  • ¥100 支付宝网页转账系统不识别账号
  • ¥15 基于单片机的靶位控制系统
  • ¥15 真我手机蓝牙传输进度消息被关闭了,怎么打开?(关键词-消息通知)
  • ¥15 下图接收小电路,谁知道原理
  • ¥15 装 pytorch 的时候出了好多问题,遇到这种情况怎么处理?
  • ¥20 IOS游览器某宝手机网页版自动立即购买JavaScript脚本
  • ¥15 手机接入宽带网线,如何释放宽带全部速度
  • ¥30 关于#r语言#的问题:如何对R语言中mfgarch包中构建的garch-midas模型进行样本内长期波动率预测和样本外长期波动率预测