枳井户 2021-06-22 21:25 采纳率: 100%
浏览 46
已采纳

TensorFlow-Lite Object-Detection转成tflite文件,代转

https://github.com/EdjeElectronics/TensorFlow-Lite-Object-Detection-on-Android-and-Raspberry-Pi/

我按照这篇文章的方法先是训练了一个模型,但是在转成tflite时各种遇到错误,在这一步

出现了

  • 写回答

1条回答 默认 最新

  • 王大师王文峰 企业官方账号 2021-06-24 12:01
    关注

    望采纳,谢谢

    TensorFlowLite 需要tflite文件模型,这个模型可以由TensorFlow训练的模型转换而成。所以首先需要知道如何保存训练好的TensorFlow模型。
    
    一般有这几种保存形式:
    
    Checkpoints
    HDF5
    SavedModel等
    保存与读取CheckPoint
    当模型训练结束,可以用以下代码把权重保存成checkpoint格式
    
    model.save_weights('./MyModel',True)
    
    checkpoints文件仅是保存训练好的权重,不带网络结构,所以做predict时需要结合model使用
    如:
    
    model = keras_segmentation.models.segnet.mobilenet_segnet(n_classes=2, input_height=224, input_width=224)
    model.load_weights('./MyModel')
    
    保存成H5
    把训练好的网络保存成h5文件很简单
    
    model.save('MyModel.h5')
    
    H5转换成TFLite
    这里是文章主要内容
    
    我习惯使用H5文件转换成tflite文件
    
    官网代码是这样的
    
    converter = tf.lite.TFLiteConverter.from_keras_model_file('newModel.h5')
    tflite_model = converter.convert()
    open("converted_model.tflite", "wb").write(tflite_model)
    
    但我用的keras 2.2.4版本会报下面错误,好像说是新版的keras把relu6改掉了,找不到方法
    ValueError: Unknown activation function:relu6
    
    于是需要自己定义一个relu6
    
    import tensorflow as tf
    from tensorflow.python.keras import backend as K
    from tensorflow.python.keras.utils import CustomObjectScope
    
    def relu6(x):
      return K.relu(x, max_value=6)
    
    with CustomObjectScope({'relu6': relu6}):
        converter = tf.lite.TFLiteConverter.from_keras_model_file('newModel.h5')
        tflite_model = converter.convert()
        open("newModel.tflite", "wb").write(tflite_model)
    
    
    看到生成的tflite文件表示保存成功了
    
    也可以这么查看tflite网络的输入输出
    
    import numpy as np
    import tensorflow as tf
    
    # Load TFLite model and allocate tensors.
    interpreter = tf.lite.Interpreter(model_path="newModel.tflite")
    interpreter.allocate_tensors()
    
    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    print(input_details)
    print(output_details)
    
    输出了以下信息
    
    [{'name': 'input_1', 'index': 115, 'shape': array([  1, 224, 224,   3]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
    
    [{'name': 'activation_1/truediv', 'index': 6, 'shape': array([    1, 12544,     2]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
    
    两个shape分别表示输入输出的numpy数组结构,dtype是数据类型
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

悬赏问题

  • ¥15 用土力学知识进行土坡稳定性分析与挡土墙设计
  • ¥15 帮我写一个c++工程
  • ¥30 Eclipse官网打不开,官网首页进不去,显示无法访问此页面,求解决方法
  • ¥15 关于smbclient 库的使用
  • ¥15 微信小程序协议怎么写
  • ¥15 c语言怎么用printf(“\b \b”)与getch()实现黑框里写入与删除?
  • ¥20 怎么用dlib库的算法识别小麦病虫害
  • ¥15 华为ensp模拟器中S5700交换机在配置过程中老是反复重启
  • ¥15 java写代码遇到问题,求帮助
  • ¥15 uniapp uview http 如何实现统一的请求异常信息提示?