Jean.L 2022-11-15 22:12 采纳率: 0%
浏览 0

tensorflow的session中有多个模型的话可以分别保存参数吗?

目前想把tensorflow的模型转成pytorch,但遇到了模型保存的问题。
如题,我的代码里面的模型是这样的:

class EWE_Resnet:
    def __init__(self, image, label, w_label, bs, num_class, lr, factors, temperatures, target, is_training, metric,
                 layers):
        self.num_class = num_class
        self.batch_size = bs
        self.lr = lr
        self.b1 = 0.9
        self.b2 = 0.99
        self.epsilon = 1e-5
        self.target = target
        self.w = w_label
        self.x = image
        self.Y = label
        self.y = tf.argmax(self.Y, 1)
        self.layers = layers
        self.temp = temperatures
        self.snnl_func = functools.partial(snnl, metric=metric)
        self.factor_1 = factors[0]
        self.factor_2 = factors[1]
        self.factor_3 = factors[2]
        self.is_training = is_training
        self.prediction = self.pred()
        self.error = self.error_rate()
        self.snnl_loss = self.snnl()
        self.ce_loss = self.cross_entropy()
        self.optimize = self.optimizer()
        self.snnl_trigger = self.snnl_gradient() # Some functions
        self.ce_trigger = self.ce_gradient()

    def pred(self, reuse=tf.compat.v1.AUTO_REUSE):
        res = []
        with tf.variable_scope("network", reuse=reuse):

            if self.layers > 34:
                residual_block = resnet.bottle_resblock
            else:
                residual_block = resnet.resblock

            residual_list = resnet.get_residual_layer(self.layers)

            ch = 64
            x = self.x
            # print(x)
            x = resnet.conv(x, channels=ch, kernel=3, stride=1, scope='conv')

            for i in range(residual_list[0]):
                x = tf.cond(tf.greater(self.is_training, 0),
                            lambda: residual_block(x, channels=ch, is_training=True, downsample=False,
                                   scope='resblock0_' + str(i)),
                            lambda: residual_block(x, channels=ch, is_training=False, downsample=False,
                                           scope='resblock0_' + str(i)))

            ########################################################################################################

            x = tf.cond(tf.greater(self.is_training, 0),
                        lambda: residual_block(x, channels=ch * 2, is_training=True, downsample=True,
                                       scope='resblock1_0'),
                        lambda: residual_block(x, channels=ch * 2, is_training=False, downsample=True,
                                       scope='resblock1_0'))


            for i in range(1, residual_list[1]):
                x = tf.cond(tf.greater(self.is_training, 0),
                            lambda: residual_block(x, channels=ch * 2, is_training=True, downsample=False,
                                           scope='resblock1_' + str(i)),
                            lambda: residual_block(x, channels=ch * 2, is_training=False, downsample=False,
                                           scope='resblock1_' + str(i)))

            ########################################################################################################

            x = tf.cond(tf.greater(self.is_training, 0),
                        lambda: residual_block(x, channels=ch * 4, is_training=True, downsample=True,
                                       scope='resblock2_0'),
                        lambda: residual_block(x, channels=ch * 4, is_training=False, downsample=True,
                                       scope='resblock2_0'))

            for i in range(1, residual_list[2]):
                x = tf.cond(tf.greater(self.is_training, 0),
                            lambda: residual_block(x, channels=ch * 4, is_training=True, downsample=False,
                                           scope='resblock2_' + str(i)),
                            lambda: residual_block(x, channels=ch * 4, is_training=False, downsample=False,
                                           scope='resblock2_' + str(i)))
            ########################################################################################################
            res.append(x)

            x = tf.cond(tf.greater(self.is_training, 0),
                        lambda: residual_block(x, channels=ch * 8, is_training=True, downsample=True,
                                       scope='resblock_3_0'),
                        lambda: residual_block(x, channels=ch * 8, is_training=False, downsample=True,
                                       scope='resblock_3_0'))


            for i in range(1, residual_list[3]):
                res.append(x)
                x = tf.cond(tf.greater(self.is_training, 0),
                            lambda: residual_block(x, channels=ch * 8, is_training=True, downsample=False,
                                           scope='resblock_3_' + str(i)),
                            lambda: residual_block(x, channels=ch * 8, is_training=False, downsample=False,
                                           scope='resblock_3_' + str(i)))

            ########################################################################################################

            x = tf.cond(tf.greater(self.is_training, 0),
                        lambda: resnet.batch_norm(x, True, scope='batch_norm'),
                        lambda: resnet.batch_norm(x, False, scope='batch_norm'), name='conv_4_2')
            # conv_4_2 = tf.Variable(x, name='conv_4_2')
            # conv_4_2 = x
            x = resnet.relu(x)
            # print(x)
            x = resnet.global_avg_pooling(x)
            res.append(x)
            # print(x)
            x = resnet.fully_conneted(x, units=self.num_class, scope='logit')
            res.append(x)
            # print(x)
            return res

    def error_rate(self):
        mistakes = tf.not_equal(tf.argmax(self.Y, 1), tf.argmax(self.prediction[-1], 1))
        return tf.reduce_mean(tf.cast(mistakes, tf.float32))
    
def cross_entropy(self):
    log_prob = tf.math.log(tf.nn.softmax(self.prediction[-1]) + 1e-12)
    cross_entropy = - tf.reduce_sum(self.Y * log_prob)
    return cross_entropy

def optimizer(self):
    optimizer = tf.train.AdamOptimizer(self.lr, self.b1, self.b2, self.epsilon)
    snnl = self.snnl()
    soft_nearest_neighbor = self.factor_1 * snnl[0] + self.factor_2 * snnl[1] + self.factor_3 * snnl[2]
    soft_nearest_neighbor = tf.cast(tf.greater(tf.math.reduce_mean(self.w), 0), tf.float32) * soft_nearest_neighbor
    return optimizer.minimize(self.ce_loss - soft_nearest_neighbor), tf.gradients(snnl, self.temp)

我想问一下怎么才能将这个模型的所有参数拿出来(后续再转成torch模型),还有一个就是load上来之后可以实现函数的调用吗?
或者有没有别的更便捷些的方法呢?
万分感谢大家的解答!

  • 写回答

1条回答

      报告相同问题?

      相关推荐 更多相似问题

      问题事件

      • 修改了问题 11月16日
      • 创建了问题 11月15日

      悬赏问题

      • ¥15 普罗米修斯Prometheus监控系统的几个问题调研
      • ¥15 pmp项目管理干系人分析
      • ¥15 请问DenseNet图像输入大小是否是固定的?
      • ¥15 template模板的参数问题
      • ¥15 查找处理学生信息问题,含多个文件,显示问题是无法调用其中一个文件
      • ¥15 simulink生成代码后提示告警
      • ¥16 jieba提取高频词,生成文件是空的
      • ¥15 怎么读取服务器中的文件去配置mongo的连接
      • ¥20 Python如何统计文本中两字及以上的词语个数
      • ¥15 MapReduce自定义对象怎么写