yj1015341081
十一个中文
2021-01-25 14:13

tensorflow训练完成生成了三个文件,如何转换为pb文件

100
  • python
  • tensorflow
  • 机器学习
  • 神经网络

我训练完数据,生成三个文件(data,index,meta),我使用脚本生成pb文件,但是在使用的时候,一直报:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 0 of node conv3_1/weights/Assign was passed float from conv3_1/weights:0 incompatible with expected float_ref.



我的生成pb脚本如下:

def get_pb_file(meta_path):
    with tf.Session() as sess:
        # Restore the graph
        saver = tf.train.import_meta_graph(meta_path)

        # Load weights
        saver.restore(sess, tf.train.latest_checkpoint('../checkpoint/'))

        # Output nodes
        output_node_names = [n.name for n in tf.get_default_graph().as_graph_def().node]
        print(str(output_node_names))
        for output_node in output_node_names:
            print(output_node)

        # Freeze the graph
        frozen_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            sess.graph_def,
            output_node_names)

        # Save the frozen graph
        with open('output_graph-1.pb', 'wb') as f:
            f.write(frozen_graph_def.SerializeToString())

请问这个是什么情况,有没有大神遇到

  • 点赞
  • 回答
  • 收藏
  • 复制链接分享

9条回答

为你推荐

换一换