我在xception最后一层特征后加了个空间注意力层,训练好模型后,我想要输入一张图片,然后可以得到图片对应的注意力矩阵,请问怎么写代码?
代码如下,我想要在测试时输入一张图片得到这个图片对应输出的”cbam_feature“(空间注意力层的注意力矩阵)
class LMV(tf.keras.Model):
def __init__(self, shape=224):
super(LMV, self).__init__()
self.dense1 = Dense(1024)
self.dense2 = Dense(1, activation='sigmoid')
self.base = tf.keras.applications.Xception(include_top=False, input_shape=(shape, shape, 3))
self.att = tf.constant(np.random.randint(0,255,[16, 7, 7, 1]), name='att')
def spatial_attention(self, input_feature):
kernel_size = 7
if K.image_data_format() == "channels_first":
channel = input_feature.shape[1]
cbam_feature = Permute((2, 3, 1))(input_feature)
else:
channel = input_feature.shape[-1]
cbam_feature = input_feature
avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)
assert avg_pool.shape[-1] == 1
max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)
assert max_pool.shape[-1] == 1
concat = Concatenate(axis=3)([avg_pool, max_pool])
assert concat.shape[-1] == 2
with tf.name_scope('sptical_attention'):
cbam_feature = Conv2D(filters=1,
kernel_size=kernel_size,
activation='sigmoid',
strides=1,
padding='same',
kernel_initializer='he_normal',
use_bias=False)(concat)
assert cbam_feature.shape[-1] == 1
if K.image_data_format() == "channels_first":
cbam_feature = Permute((3, 1, 2))(cbam_feature)
return multiply([input_feature, cbam_feature])
def call(self, inputs, training=None, mask=None):
x0 = preprocess_input(inputs)
x0 = self.base(x0)
#global_mask = self.global_attention(x0)
x = self.spatial_attention(x0)
x = Flatten()(x)
# print(x.shape)
x = self.dense1(x)
x = Dropout(0.5)(x)
x = self.dense2(x)
return x