用吉他弹奏摇滚乐 2022-11-03 10:21 采纳率: 0%
浏览 12

tensorflow学习笔记字母预测问题

遇到的问题

学习慕课的tensorflow学习笔记,在RNN实现字母预测中,课程中先将字母映射到数值id的词典,再对数值ID进行独热编码;我想尝试一下,用数值ID不进行独热编码,直接进行模型训练,但是结果无法正常训练,找不到报错的具体地方

尝试的字母处理方式
input_word = "abcde"
w_to_id = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4}  # 单词映射到数值id的词典
# id_to_onehot = {0: [1., 0., 0., 0., 0.], 1: [0., 1., 0., 0., 0.], 2: [0., 0., 1., 0., 0.], 3: [0., 0., 0., 1., 0.],
#                 4: [0., 0., 0., 0., 1.]}  # id编码为one-hot

# x_train = [
#     [id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']], id_to_onehot[w_to_id['d']]],
#     [id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']], id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']]],
#     [id_to_onehot[w_to_id['c']], id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']], id_to_onehot[w_to_id['a']]],
#     [id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']], id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']]],
#     [id_to_onehot[w_to_id['e']], id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']]],
# ]

x_train = [
    [w_to_id['a'],w_to_id['b'],w_to_id['c'],w_to_id['d']],
    [w_to_id['b'],w_to_id['c'],w_to_id['d'],w_to_id['e']],
    [w_to_id['c'],w_to_id['d'],w_to_id['e'],w_to_id['a']],
    [w_to_id['d'],w_to_id['e'],w_to_id['a'],w_to_id['b']],
    [w_to_id['e'],w_to_id['a'],w_to_id['b'],w_to_id['c']],
]

y_train = [w_to_id['e'], w_to_id['a'], w_to_id['b'], w_to_id['c'], w_to_id['d']]

报错内容
Traceback (most recent call last):
  File "D:/code_edit/AI/class6/p21_rnn_onehot_4pre1.py", line 63, in <module>
    history = model.fit(x_train, y_train, batch_size=32, epochs=1, callbacks=[cp_callback])
  File "D:\Anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\keras\engine\training.py", line 66, in _method_wrapper
    return method(self, *args, **kwargs)
  File "D:\Anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\keras\engine\training.py", line 848, in fit
    tmp_logs = train_function(iterator)
  File "D:\Anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\eager\def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "D:\Anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\func_graph.py", line 968, in wrapper
    raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:
    D:\Anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\keras\engine\training.py:571 train_function  *
        outputs = self.distribute_strategy.run(
    D:\Anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:951 run  **
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    D:\Anaconda\envs\tensorflow1\lib\site-packages\tensorflow\python\framework\op_def_library.py:503 _apply_op_helper
        raise TypeError(

    TypeError: Input 'b' of 'MatMul' Op has type float32 that does not match type int32 of argument 'a'.

我想要达到的结果

请问大家有遇到这种问题吗?求指教!

  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2022-11-03 12:28
    关注
    评论

报告相同问题?

问题事件

  • 创建了问题 11月3日

悬赏问题

  • ¥20 机器学习能否像多层线性模型一样处理嵌套数据
  • ¥20 西门子S7-Graph,S7-300,梯形图
  • ¥50 用易语言http 访问不了网页
  • ¥50 safari浏览器fetch提交数据后数据丢失问题
  • ¥15 matlab不知道怎么改,求解答!!
  • ¥15 永磁直线电机的电流环pi调不出来
  • ¥15 用stata实现聚类的代码
  • ¥15 请问paddlehub能支持移动端开发吗?在Android studio上该如何部署?
  • ¥20 docker里部署springboot项目,访问不到扬声器
  • ¥15 netty整合springboot之后自动重连失效