yj1015341081
十一个中文
采纳率50%
2021-01-25 14:13 阅读 234

tensorflow训练完成生成了三个文件,如何转换为pb文件

100

我训练完数据,生成三个文件(data,index,meta),我使用脚本生成pb文件,但是在使用的时候,一直报:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 0 of node conv3_1/weights/Assign was passed float from conv3_1/weights:0 incompatible with expected float_ref.



我的生成pb脚本如下:

def get_pb_file(meta_path):
    with tf.Session() as sess:
        # Restore the graph
        saver = tf.train.import_meta_graph(meta_path)

        # Load weights
        saver.restore(sess, tf.train.latest_checkpoint('../checkpoint/'))

        # Output nodes
        output_node_names = [n.name for n in tf.get_default_graph().as_graph_def().node]
        print(str(output_node_names))
        for output_node in output_node_names:
            print(output_node)

        # Freeze the graph
        frozen_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            sess.graph_def,
            output_node_names)

        # Save the frozen graph
        with open('output_graph-1.pb', 'wb') as f:
            f.write(frozen_graph_def.SerializeToString())

请问这个是什么情况,有没有大神遇到

  • 点赞
  • 写回答
  • 关注问题
  • 收藏
  • 复制链接分享

9条回答 默认 最新

  • 已采纳
    csdn_garbage1 csdn_garbage1 2021-02-01 15:22
    点赞 评论 复制链接分享
  • yj1015341081 十一个中文 2021-01-25 20:12
    TensorBoard的graphs图 TensorBoard 的graphs图

     

    点赞 评论 复制链接分享
  • qq_22475211 深白色的风 2021-01-25 23:09
  • qq_22475211 深白色的风 2021-01-25 23:13

    确保你的pb_file格式正确(类似这样),并尝试在import_graph_def()的'name'参数中设置一些值,以尝试覆盖“import”默认值,如下所示:

    点赞 评论 复制链接分享
  • qq_22475211 深白色的风 2021-01-25 23:14
    import tensorflow as tf
     
    from tensorflow.python.platform import gfile
    model_path="/tmp/frozen/dcgan.pb"
     
    # read graph definition
    f = gfile.FastGFile(model_path, "rb")
    gd = graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
     
    # fix nodes
    for node in graph_def.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in xrange(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr: del node.attr['use_locking']
     
    # import graph into session
    tf.import_graph_def(graph_def, name='')
    tf.train.write_graph(graph_def, './', 'good_frozen.pb', as_text=False)
    tf.train.write_graph(graph_def, './', 'good_frozen.pbtxt', as_text=True)
    点赞 评论 复制链接分享
  • bill20100829 歇歇 2021-01-26 07:27

    如下

    import tensorflow as tf
    import os
    from tensorflow.python.tools import freeze_graph
    # 本来这个model本无需解释太多,但是这么多人不能耐下心来看,那么我简单的说一下吧
    # network是你们自己定义的模型结构而已
    # ps:
    # def network(input):
    #    return tf.layers.max_pooling2d(input, 2, 2)
    from model import network
    
    
    os.environ['CUDA_VISIBLE_DEVICES']='2'  #设置GPU
    
    
    model_path  = "path to /model.ckpt-0000" #设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前
    
    
    def main():
    
        tf.reset_default_graph()
    
        input_node = tf.placeholder(tf.float32, shape=(228, 304, 3)) #这个是你送入网络的图片大小,如果你是其他的大小自行修改
        input_node = tf.expand_dims(input_node, 0)
        flow = network(input_node)
        flow = tf.cast(flow, tf.uint8, 'out') #设置输出类型以及输出的接口名字,为了之后的调用pb的时候使用
    
        saver = tf.train.Saver()
        with tf.Session() as sess:
    
            saver.restore(sess, model_path)
    
            #保存图
            tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model.pb')
            #把图和参数结构一起
            freeze_graph.freeze_graph('output_model/pb_model/model.pb', '', False, model_path, 'out','save/restore_all', 'save/Const:0', 'output_model/pb_model/frozen_model.pb', False, "")
    
        print("done")
    
    if __name__ == '__main__':
        main()

    这节是关于tensorflow的Freezing,字面意思是冷冻,可理解为整合合并;整合什么呢,就是将模型文件和权重文件整合合并为一个文件,主要用途是便于发布。
    官方解释可参考:https://www.tensorflow.org/extend/tool_developers/#freezing
    这里我按我的理解翻译下,不对的地方请指正:
    有一点令我们为比较困惑的是,tensorflow在训练过程中,通常不会将权重数据保存的格式文件里(这里我理解是模型文件),反而是分开保存在一个叫checkpoint的检查点文件里,当初始化时,再通过模型文件里的变量Op节点来从checkoupoint文件读取数据并初始化变量。这种模型和权重数据分开保存的情况,使得发布产品时不是那么方便,所以便有了freeze_graph.py脚本文件用来将这两文件整合合并成一个文件。
    freeze_graph.py是怎么做的呢?首行它先加载模型文件,再从checkpoint文件读取权重数据初始化到模型里的权重变量,再将权重变量转换成权重 常量 (因为 常量 能随模型一起保存在同一个文件里),然后再通过指定的输出节点将没用于输出推理的Op节点从图中剥离掉,再重新保存到指定的文件里(用write_graphdef或Saver)

    文件目录:tensorflow/python/tools/free_graph.py
    测试文件:tensorflow/python/tools/free_graph_test.py 这个测试文件很有学习价值
    参数:
    总共有11个参数,一个个介绍下(必选: 表示必须有值;可选: 表示可以为空):
    1、input_graph:(必选)模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明)
    2、input_saver:(可选)Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。
    3、input_binary:(可选)配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False
    4、input_checkpoint:(必选)检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。
    5、output_node_names:(必选)输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。
    6、restore_op_name:(可选)从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all
    7、filename_tensor_name:(可选)已弃用。默认:save/Const:0
    8、output_graph:(必选)用来保存整合后的模型输出文件。
    9、clear_devices:(可选),默认True。指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)
    10、initializer_nodes:(可选)默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。
    11、variable_names_blacklist:(可先)默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。
    用法:
    例:python tensorflow/python/tools/free_graph.py \
    –input_graph=some_graph_def.pb \ 注意:这里的pb文件是用tf.train.write_graph方法保存的
    –input_checkpoint=model.ckpt.1001 \ 注意:这里若是r12以上的版本,只需给.data-00000….前面的文件名,如:model.ckpt.1001.data-00000-of-00001,只需写model.ckpt.1001
    –output_graph=/tmp/frozen_graph.pb
    –output_node_names=softmax

    另外,如果模型文件是.meta格式的,也就是说用saver.Save方法和checkpoint一起生成的元模型文件,free_graph.py不适用,但可以改造下:
    1、copy free_graph.py为free_graph_meta.py
    2、修改free_graph.py,导入meta_graph:from tensorflow.python.framework import meta_graph
    3、将91行到97行换成:input_graph_def = meta_graph.read_meta_graph_file(input_graph).graph_def

    这样改即可加载meta文件

    点赞 评论 复制链接分享
  • yj1015341081 十一个中文 2021-01-26 12:24
    import tensorflow.compat.v1 as tf
    #import horovod.tensorflow as hvd
    from tensorflow.python.tools import optimize_for_inference_lib
    
    print (tf.__version__)
    meta_path = './checkpoint/chinese-rec-model-51.meta'
    output_node_names = ['label_batch']
    input_node_name = ['conv3_1/image_batch']
    with tf.Session() as sess:
        # Restore the graph
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
        saver = tf.train.import_meta_graph(meta_path, clear_devices=True)
    
    
        # Load weights
        saver.restore(sess, tf.train.latest_checkpoint('./checkpoint/'))
    
        # Freeze the graph
        frozen_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            sess.graph_def,
            output_node_names)
    
        # Save the frozen graph
        with open('model_infer_mixed.pb', 'wb') as f:
          f.write(frozen_graph_def.SerializeToString())
    
    
    
    
        outputGraph = optimize_for_inference_lib.optimize_for_inference(
                    frozen_graph_def,
                    input_node_name, # an array of the input node(s)
                    output_node_names, # an array of output nodes
                    tf.float32.as_datatype_enum
                    )
        with open('model_infer_mixed_opt.pb', 'wb') as f:
          f.write(outputGraph.SerializeToString())
    
    

    写了一个脚本,但是报错:
     

    Traceback (most recent call last):
      File "/Users/zl/codes/python/OCR/ocr_demo/ckp_to_pb.py", line 39, in <module>
        tf.float32.as_datatype_enum
      File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow_core/python/tools/optimize_for_inference_lib.py", line 111, in optimize_for_inference
        placeholder_type_enum)
      File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow_core/python/tools/strip_unused_lib.py", line 85, in strip_unused
        raise KeyError("The following input nodes were not found: %s" % not_found)
    KeyError: "The following input nodes were not found: {'conv3_1/image_batch'}"

    而且通过tensorBoard看 是有这个节点的
     

     

    点赞 评论 复制链接分享
  • yj1015341081 十一个中文 2021-01-27 22:29

    到底该怎么选择输出节点呢

    点赞 评论 复制链接分享
  • yj1015341081 十一个中文 2021-02-01 15:21
    点赞 评论 复制链接分享

相关推荐