目前想把tensorflow的模型转成pytorch,但遇到了模型保存的问题。
如题,我的代码里面的模型是这样的:
class EWE_Resnet:
def __init__(self, image, label, w_label, bs, num_class, lr, factors, temperatures, target, is_training, metric,
layers):
self.num_class = num_class
self.batch_size = bs
self.lr = lr
self.b1 = 0.9
self.b2 = 0.99
self.epsilon = 1e-5
self.target = target
self.w = w_label
self.x = image
self.Y = label
self.y = tf.argmax(self.Y, 1)
self.layers = layers
self.temp = temperatures
self.snnl_func = functools.partial(snnl, metric=metric)
self.factor_1 = factors[0]
self.factor_2 = factors[1]
self.factor_3 = factors[2]
self.is_training = is_training
self.prediction = self.pred()
self.error = self.error_rate()
self.snnl_loss = self.snnl()
self.ce_loss = self.cross_entropy()
self.optimize = self.optimizer()
self.snnl_trigger = self.snnl_gradient() # Some functions
self.ce_trigger = self.ce_gradient()
def pred(self, reuse=tf.compat.v1.AUTO_REUSE):
res = []
with tf.variable_scope("network", reuse=reuse):
if self.layers > 34:
residual_block = resnet.bottle_resblock
else:
residual_block = resnet.resblock
residual_list = resnet.get_residual_layer(self.layers)
ch = 64
x = self.x
# print(x)
x = resnet.conv(x, channels=ch, kernel=3, stride=1, scope='conv')
for i in range(residual_list[0]):
x = tf.cond(tf.greater(self.is_training, 0),
lambda: residual_block(x, channels=ch, is_training=True, downsample=False,
scope='resblock0_' + str(i)),
lambda: residual_block(x, channels=ch, is_training=False, downsample=False,
scope='resblock0_' + str(i)))
########################################################################################################
x = tf.cond(tf.greater(self.is_training, 0),
lambda: residual_block(x, channels=ch * 2, is_training=True, downsample=True,
scope='resblock1_0'),
lambda: residual_block(x, channels=ch * 2, is_training=False, downsample=True,
scope='resblock1_0'))
for i in range(1, residual_list[1]):
x = tf.cond(tf.greater(self.is_training, 0),
lambda: residual_block(x, channels=ch * 2, is_training=True, downsample=False,
scope='resblock1_' + str(i)),
lambda: residual_block(x, channels=ch * 2, is_training=False, downsample=False,
scope='resblock1_' + str(i)))
########################################################################################################
x = tf.cond(tf.greater(self.is_training, 0),
lambda: residual_block(x, channels=ch * 4, is_training=True, downsample=True,
scope='resblock2_0'),
lambda: residual_block(x, channels=ch * 4, is_training=False, downsample=True,
scope='resblock2_0'))
for i in range(1, residual_list[2]):
x = tf.cond(tf.greater(self.is_training, 0),
lambda: residual_block(x, channels=ch * 4, is_training=True, downsample=False,
scope='resblock2_' + str(i)),
lambda: residual_block(x, channels=ch * 4, is_training=False, downsample=False,
scope='resblock2_' + str(i)))
########################################################################################################
res.append(x)
x = tf.cond(tf.greater(self.is_training, 0),
lambda: residual_block(x, channels=ch * 8, is_training=True, downsample=True,
scope='resblock_3_0'),
lambda: residual_block(x, channels=ch * 8, is_training=False, downsample=True,
scope='resblock_3_0'))
for i in range(1, residual_list[3]):
res.append(x)
x = tf.cond(tf.greater(self.is_training, 0),
lambda: residual_block(x, channels=ch * 8, is_training=True, downsample=False,
scope='resblock_3_' + str(i)),
lambda: residual_block(x, channels=ch * 8, is_training=False, downsample=False,
scope='resblock_3_' + str(i)))
########################################################################################################
x = tf.cond(tf.greater(self.is_training, 0),
lambda: resnet.batch_norm(x, True, scope='batch_norm'),
lambda: resnet.batch_norm(x, False, scope='batch_norm'), name='conv_4_2')
# conv_4_2 = tf.Variable(x, name='conv_4_2')
# conv_4_2 = x
x = resnet.relu(x)
# print(x)
x = resnet.global_avg_pooling(x)
res.append(x)
# print(x)
x = resnet.fully_conneted(x, units=self.num_class, scope='logit')
res.append(x)
# print(x)
return res
def error_rate(self):
mistakes = tf.not_equal(tf.argmax(self.Y, 1), tf.argmax(self.prediction[-1], 1))
return tf.reduce_mean(tf.cast(mistakes, tf.float32))
def cross_entropy(self):
log_prob = tf.math.log(tf.nn.softmax(self.prediction[-1]) + 1e-12)
cross_entropy = - tf.reduce_sum(self.Y * log_prob)
return cross_entropy
def optimizer(self):
optimizer = tf.train.AdamOptimizer(self.lr, self.b1, self.b2, self.epsilon)
snnl = self.snnl()
soft_nearest_neighbor = self.factor_1 * snnl[0] + self.factor_2 * snnl[1] + self.factor_3 * snnl[2]
soft_nearest_neighbor = tf.cast(tf.greater(tf.math.reduce_mean(self.w), 0), tf.float32) * soft_nearest_neighbor
return optimizer.minimize(self.ce_loss - soft_nearest_neighbor), tf.gradients(snnl, self.temp)
我想问一下怎么才能将这个模型的所有参数拿出来(后续再转成torch模型),还有一个就是load上来之后可以实现函数的调用吗?
或者有没有别的更便捷些的方法呢?
万分感谢大家的解答!