qq_41978522 2021-10-21 11:21 采纳率: 71.4%
浏览 81
已结题

tensorflow怎么获取注意力矩阵生成注意力图呀?

我在xception最后一层特征后加了个空间注意力层,训练好模型后,我想要输入一张图片,然后可以得到图片对应的注意力矩阵,请问怎么写代码?
代码如下,我想要在测试时输入一张图片得到这个图片对应输出的”cbam_feature“(空间注意力层的注意力矩阵)

class LMV(tf.keras.Model):
    def __init__(self, shape=224):
        super(LMV, self).__init__()
        self.dense1 = Dense(1024)
        self.dense2 = Dense(1, activation='sigmoid')
        self.base = tf.keras.applications.Xception(include_top=False, input_shape=(shape, shape, 3))
        self.att = tf.constant(np.random.randint(0,255,[16, 7, 7, 1]), name='att')

    def spatial_attention(self, input_feature):
        kernel_size = 7

        if K.image_data_format() == "channels_first":
            channel = input_feature.shape[1]
            cbam_feature = Permute((2, 3, 1))(input_feature)
        else:
            channel = input_feature.shape[-1]
            cbam_feature = input_feature

        avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)
        assert avg_pool.shape[-1] == 1
        max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)
        assert max_pool.shape[-1] == 1
        concat = Concatenate(axis=3)([avg_pool, max_pool])
        assert concat.shape[-1] == 2
        with tf.name_scope('sptical_attention'):
            cbam_feature = Conv2D(filters=1,
                                  kernel_size=kernel_size,
                                  activation='sigmoid',
                                  strides=1,
                                  padding='same',
                                  kernel_initializer='he_normal',
                                  use_bias=False)(concat)
        assert cbam_feature.shape[-1] == 1

        if K.image_data_format() == "channels_first":
            cbam_feature = Permute((3, 1, 2))(cbam_feature)

        return multiply([input_feature, cbam_feature])

    def call(self, inputs, training=None, mask=None):
        x0 = preprocess_input(inputs)
        x0 = self.base(x0)
        #global_mask = self.global_attention(x0)
        x = self.spatial_attention(x0)

        x = Flatten()(x)
        # print(x.shape)
        x = self.dense1(x)
        x = Dropout(0.5)(x)
        x = self.dense2(x)
        return x

  • 写回答

1条回答 默认 最新

  • 野鹤无粮 2021-10-21 21:43
    关注

    没理解错的话,题主是要获取 cbam_feature 变量。那就 在spatial_attention这个方法中顺便输出 cbam_feature ,也就是

            return multiply([input_feature, cbam_feature]),cbam_feature
    

    后面 call() 中 x = self.spatial_attention(x0)变成 x,cbam_feature = self.spatial_attention(x0)
    call() 的输出 return x 变为 return x ,cbam_feature
    这样就可以输出得到注意力矩阵 cbam_feature 了。

    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 10月28日
  • 修改了问题 10月21日
  • 创建了问题 10月21日

悬赏问题

  • ¥15 素材场景中光线烘焙后灯光失效
  • ¥15 请教一下各位,为什么我这个没有实现模拟点击
  • ¥15 执行 virtuoso 命令后,界面没有,cadence 启动不起来
  • ¥50 comfyui下连接animatediff节点生成视频质量非常差的原因
  • ¥20 有关区间dp的问题求解
  • ¥15 多电路系统共用电源的串扰问题
  • ¥15 slam rangenet++配置
  • ¥15 有没有研究水声通信方面的帮我改俩matlab代码
  • ¥15 ubuntu子系统密码忘记
  • ¥15 保护模式-系统加载-段寄存器