def attention_3d_block(inputs):
# inputs.shape = (batch_size, time_steps, input_dim)
TIME_STEPS = int(inputs.shape[1]) # 输入的维数
input_dim = int(inputs.shape[2])
a_ = Permute((2, 1))(inputs)
#a = Reshape((input_dim, time_step))(a) # this line is not useful. It's just to know which dimension is what.
a_= Dense(TIME_STEPS, activation='softmax')(a_)
a_= Lambda(lambda x: K.mean(x, axis=1), name='dim_reduction')(a_)
a_= RepeatVector(input_dim)(a_)
a_probs = Permute((2, 1), name='attention_vec')(a_)
output_attention_mul = Multiply()([inputs, a_probs])
return output_attention_mul
我将该代码变量增加到我原有的代码中在参数恢复 saver.restore(sess, module_file)中会报错。Key lstm_1/dense_1/bias not found in checkpoint。