keras 提示出错 初学者 不明白为什么

from keras.layers import Input, Dense, merge
from keras.models import Model
from keras import backend as K

a = Input(shape=(2,), name='a')
b = Input(shape=(2,), name='b')

a_rotated = Dense(2, activation='linear')(a)

def cosine(x):
axis = len(x[0]._keras_shape)-1
dot = lambda a, b: K.batch_dot(a, b, axes=axis)
return dot(x[0], x[1]) / K.sqrt(dot(x[0], x[0]) * dot(x[1], x[1]))

cosine_sim = merge([a_rotated, b], mode=cosine, output_shape=lambda x: x[:-1])

model = Model(input=[a, b], output=[cosine_sim])
model.compile(optimizer='sgd', loss='mse')

import numpy as np

a_data = np.asarray([[0, 1], [1, 0], [0, -1], [-1, 0]])
b_data = np.asarray([[1, 0], [0, -1], [-1, 0], [0, 1]])
targets = np.asarray([1, 1, 1, 1])

model.fit([a_data, b_data], [targets], nb_epoch=1000)
print(model.layers[2].W.get_value())
这段代码有问题

3个回答

oydxxynu
oydxxynu 你好 在你的电脑上运行没有错吗
3 年多之前 回复

你好 在你的电脑上运行没有错吗

将你自己写的merge层和cosine函数改为

    from keras.layer import dot
    axis=len(a_rotated._keras_shape)-1
cosine_sim=dot([a_rotated,b],axes=axis)
Csdn user default icon
上传中...
上传图片
插入图片
抄袭、复制答案,以达到刷声望分或其他目的的行为,在CSDN问答是严格禁止的,一经发现立刻封号。是时候展现真正的技术了!
立即提问