Yuntian_Dong 2020-02-27 21:40
浏览 427

机器翻译训练模型时总是报错关于矩阵类型不匹配,求帮助

求帮助……
小弟最近在做关于机器翻译的毕业设计,选题一时爽做题一直坑,遇到了这样的问题:TypeError: Input ‘b’ of ‘MatMul’ Op has type float32 that does not match type int32 of argument ‘a’

我是现在github上找了一个使用tensorflow 和 tf.keras的代码学习,先跑一跑感受一下,使用的是tensorflow2.0.0版本,在训练模型时就会遇到这种情况,有没有大佬可以帮忙解答一下如何修改代码。非常感谢!

代码如下:

def simple_model(input_shape, output_sequence_length, english_vocab_size, french_vocab_size):
    """
    Build and train a basic RNN on x and y
    :param input_shape: Tuple of input shape
    :param output_sequence_length: Length of output sequence
    :param english_vocab_size: Number of unique English words in the dataset
    :param french_vocab_size: Number of unique French words in the dataset
    :return: Keras model built, but not trained
    """
    # TODO: Build the model

    learning_rate = 1e-3

    input_seq = Input(input_shape[1:])

    rnn = GRU(64, return_sequences=True)(input_seq)

    logits = TimeDistributed(Dense(french_vocab_size))(rnn)

    model = Model(input_seq, Activation('softmax')(logits))

    model.summary()

    model.compile(loss=sparse_categorical_crossentropy,
                  optimizer=Adam(learning_rate),
                  metrics=['accuracy'])

    return model

# Reshaping the input to work with a basic RNN
tmp_x = pad(preproc_english_sentences, max_french_sequence_length)
tmp_x = tmp_x.reshape((-1, preproc_french_sentences.shape[-2], 1))

print(tmp_x.shape)

# Train the neural network
simple_rnn_model = simple_model(
    tmp_x.shape,
    max_french_sequence_length,
    english_vocab_size,
    french_vocab_size)

simple_rnn_model.fit(tmp_x, preproc_french_sentences, batch_size=1024, epochs=50, validation_split=0.2)

# Print prediction(s)
print("")
print(logits_to_text(simple_rnn_model.predict(tmp_x[:1])[0], french_tokenizer))

  • 写回答

0条回答 默认 最新

    报告相同问题?

    悬赏问题

    • ¥15 如何在scanpy上做差异基因和通路富集?
    • ¥20 关于#硬件工程#的问题,请各位专家解答!
    • ¥15 关于#matlab#的问题:期望的系统闭环传递函数为G(s)=wn^2/s^2+2¢wn+wn^2阻尼系数¢=0.707,使系统具有较小的超调量
    • ¥15 FLUENT如何实现在堆积颗粒的上表面加载高斯热源
    • ¥30 截图中的mathematics程序转换成matlab
    • ¥15 动力学代码报错,维度不匹配
    • ¥15 Power query添加列问题
    • ¥50 Kubernetes&Fission&Eleasticsearch
    • ¥15 報錯:Person is not mapped,如何解決?
    • ¥15 c++头文件不能识别CDialog