qq_52333369
枳井户
2021-06-22 21:25
采纳率: 100%
浏览 30

TensorFlow-Lite Object-Detection转成tflite文件,代转

https://github.com/EdjeElectronics/TensorFlow-Lite-Object-Detection-on-Android-and-Raspberry-Pi/

我按照这篇文章的方法先是训练了一个模型,但是在转成tflite时各种遇到错误,在这一步

出现了

  • 点赞
  • 写回答
  • 关注问题
  • 收藏
  • 邀请回答

1条回答 默认 最新

  • Feng_wwf
    NDSC专家-王文峯 2021-06-24 12:01
    已采纳

    望采纳,谢谢

    TensorFlowLite 需要tflite文件模型,这个模型可以由TensorFlow训练的模型转换而成。所以首先需要知道如何保存训练好的TensorFlow模型。
    
    一般有这几种保存形式:
    
    Checkpoints
    HDF5
    SavedModel等
    保存与读取CheckPoint
    当模型训练结束,可以用以下代码把权重保存成checkpoint格式
    
    model.save_weights('./MyModel',True)
    
    checkpoints文件仅是保存训练好的权重,不带网络结构,所以做predict时需要结合model使用
    如:
    
    model = keras_segmentation.models.segnet.mobilenet_segnet(n_classes=2, input_height=224, input_width=224)
    model.load_weights('./MyModel')
    
    保存成H5
    把训练好的网络保存成h5文件很简单
    
    model.save('MyModel.h5')
    
    H5转换成TFLite
    这里是文章主要内容
    
    我习惯使用H5文件转换成tflite文件
    
    官网代码是这样的
    
    converter = tf.lite.TFLiteConverter.from_keras_model_file('newModel.h5')
    tflite_model = converter.convert()
    open("converted_model.tflite", "wb").write(tflite_model)
    
    但我用的keras 2.2.4版本会报下面错误,好像说是新版的keras把relu6改掉了,找不到方法
    ValueError: Unknown activation function:relu6
    
    于是需要自己定义一个relu6
    
    import tensorflow as tf
    from tensorflow.python.keras import backend as K
    from tensorflow.python.keras.utils import CustomObjectScope
    
    def relu6(x):
      return K.relu(x, max_value=6)
    
    with CustomObjectScope({'relu6': relu6}):
        converter = tf.lite.TFLiteConverter.from_keras_model_file('newModel.h5')
        tflite_model = converter.convert()
        open("newModel.tflite", "wb").write(tflite_model)
    
    
    看到生成的tflite文件表示保存成功了
    
    也可以这么查看tflite网络的输入输出
    
    import numpy as np
    import tensorflow as tf
    
    # Load TFLite model and allocate tensors.
    interpreter = tf.lite.Interpreter(model_path="newModel.tflite")
    interpreter.allocate_tensors()
    
    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    print(input_details)
    print(output_details)
    
    输出了以下信息
    
    [{'name': 'input_1', 'index': 115, 'shape': array([  1, 224, 224,   3]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
    
    [{'name': 'activation_1/truediv', 'index': 6, 'shape': array([    1, 12544,     2]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
    
    两个shape分别表示输入输出的numpy数组结构,dtype是数据类型
    
    点赞 评论

相关推荐