Ron_617 2024-04-02 16:29 采纳率: 0%
浏览 75
已结题

TensorFlow的.pb格式转换为ONNX格式

从TensorFlow的.pb格式转换为ONNX格式。务必确保你在转换时指定正确的输入和输出节点名称。如果你的模型包含BatchNormalization层且在训练时启用了moving averages,可能还需要在转换前关闭moving averages或者对其进行处理,以避免转换过程中的不兼容问题。
请问这个moving averages怎么关闭呢,
我的代码部分

`def _batch_norm(x, is_training, name=None):
    return tf.keras.layers.BatchNormalization(
        momentum=0.95,
        center=True,
        scale=True,
        trainable=is_training,
        name=(name + '_batch_norm')
    )(x, training=is_training)

  • 写回答

19条回答 默认 最新

  • 关注

    下午好🌅🌅🌅
    本答案参考ChatGPT-3.5

    关于关闭moving averages的问题,可以通过以下方式进行处理:

    • 对于训练阶段:在训练阶段,我们通常希望BatchNormalization层的moving averages是开启的,这对于模型的收敛和准确性很重要。因此,在训练阶段,不需要关闭moving averages

    • 对于转换阶段:在转换阶段,我们需要将TensorFlow的模型转换为ONNX格式,而ONNX对BatchNormalization的实现方式与TensorFlow有所不同。因此,在转换之前,我们需要关闭moving averages或进行其他处理,以确保转换过程中的兼容性问题得到解决。以下是几种可能的解决方案:

      1. 手动禁用BatchNormalization层的moving averages:在转换之前,可以通过设置trainable=False来禁用BatchNormalization层的moving averages。例如,在你的代码中,可以修改_batch_norm函数如下:
      def _batch_norm(x, is_training, name=None):
          return tf.keras.layers.BatchNormalization(
              momentum=0.95,
              center=True,
              scale=True,
              trainable=(not is_training),  # 禁用moving averages
              name=(name + '_batch_norm')
          )(x, training=is_training)
      

      这样,通过将is_training参数取反,可以在转换之前禁用moving averages

      1. 使用TensorFlow自带的tf.nn.batch_normalization函数:在转换之前,可以考虑使用TensorFlow自带的tf.nn.batch_normalization函数来手动实现BatchNormalization层,然后再转换为ONNX格式。这样可以避免moving averages的问题。例如,你可以修改_batch_norm函数如下:
      def _batch_norm(x, is_training, name=None):
          return tf.nn.batch_normalization(
              x,
              mean=tf.reduce_mean(x, axis=[0, 1, 2]),
              variance=tf.reduce_mean(tf.square(x - tf.reduce_mean(x, axis=[0, 1, 2]))),
              offset=0,
              scale=1,
              variance_epsilon=1e-5,
              name=(name + '_batch_norm')
          )
      
    • 完成上述修改后,你可以使用对应的节点名称进行模型转换,确保输入和输出节点名称指定正确的节点。具体的代码部分请参考你的实际情况进行修改。

    评论 编辑记录

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 4月7日
  • 创建了问题 4月2日

悬赏问题

  • ¥15 非科班怎么跑代码?如何导数据和调参
  • ¥15 福州市的全人群死因监测点死亡原因报表
  • ¥15 打开powerpont询问是否安装officeplus不小心点了不安装以后再也不提示是否安装了
  • ¥15 Altair EDEM中生成一个颗粒,并且各个方向没有初始速度
  • ¥15 系统2008r2 装机配置推荐一下
  • ¥500 服务器搭建cisco AnyConnect vpn
  • ¥15 用大厂网站防红自己网站
  • ¥15 悬赏Python-playwright部署在centos7上
  • ¥15 psoc creator软件有没有人能远程安装啊
  • ¥15 快速扫描算法求解Eikonal方程咨询