Skrllex 2019-02-27 19:33 采纳率: 0%
浏览 2236

tensorflow重载模型继续训练得到的loss比原模型继续训练得到的loss大,是什么原因??

我使用tensorflow训练了一个模型,在第10个epoch时保存模型,然后在一个新的文件里重载模型继续训练,结果我发现重载的模型在第一个epoch的loss比原模型在epoch=11的loss要大,我感觉既然是重载了原模型,那么重载模型训练的第一个epoch应该是和原模型训练的第11个epoch相等的,一直找不到问题或者自己逻辑的问题,希望大佬能指点迷津。源代码和重载模型的代码如下:

原代码:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import os
import numpy as np

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

mnist = input_data.read_data_sets("./",one_hot=True)
tf.reset_default_graph()

###定义数据和标签
n_inputs = 784
n_classes = 10
X = tf.placeholder(tf.float32,[None,n_inputs],name='X')
Y = tf.placeholder(tf.float32,[None,n_classes],name='Y')

###定义网络结构
n_hidden_1 = 256
n_hidden_2 = 256
layer_1 = tf.layers.dense(inputs=X,units=n_hidden_1,activation=tf.nn.relu,kernel_regularizer=tf.contrib.layers.l2_regularizer(0.01))
layer_2 = tf.layers.dense(inputs=layer_1,units=n_hidden_2,activation=tf.nn.relu,kernel_regularizer=tf.contrib.layers.l2_regularizer(0.01))
outputs = tf.layers.dense(inputs=layer_2,units=n_classes,name='outputs')
pred = tf.argmax(tf.nn.softmax(outputs,axis=1),axis=1)
print(pred.name)
err = tf.count_nonzero((pred - tf.argmax(Y,axis=1)))
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=outputs,labels=Y),name='cost')
print(cost.name)

###定义优化器
learning_rate = 0.001

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost,name='OP')
saver = tf.train.Saver()
checkpoint = 'softmax_model/dense_model.cpkt'
###训练
batch_size = 100
training_epochs = 11
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(training_epochs):
        batch_num = int(mnist.train.num_examples / batch_size)
        epoch_cost = 0
        sumerr = 0
        for i in range(batch_num):
            batch_x,batch_y = mnist.train.next_batch(batch_size)
            c,e = sess.run([cost,err],feed_dict={X:batch_x,Y:batch_y})
            _ = sess.run(optimizer,feed_dict={X:batch_x,Y:batch_y})
            epoch_cost += c / batch_num
            sumerr += e / mnist.train.num_examples
            if epoch == (training_epochs - 1):
                print('batch_cost = ',c)
        if epoch == (training_epochs - 2):
            saver.save(sess, checkpoint)
            print('test_error = ',sess.run(cost, feed_dict={X: mnist.test.images, Y: mnist.test.labels}))

重载模型的代码:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import os


os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

mnist = input_data.read_data_sets("./",one_hot=True)  #one_hot=True指对样本标签进行独热编码


file_path = 'softmax_model/dense_model.cpkt'

saver = tf.train.import_meta_graph(file_path + '.meta')
graph = tf.get_default_graph()


X = graph.get_tensor_by_name('X:0')
Y = graph.get_tensor_by_name('Y:0')
cost = graph.get_operation_by_name('cost').outputs[0]
train_op = graph.get_operation_by_name('OP')


training_epoch = 10
learning_rate = 0.001
batch_size = 100
with tf.Session() as sess:
    saver.restore(sess,file_path)
    print('test_cost = ',sess.run(cost, feed_dict={X: mnist.test.images, Y: mnist.test.labels}))
    for epoch in range(training_epoch):
        batch_num = int(mnist.train.num_examples / batch_size)
        epoch_cost = 0
        for i in range(batch_num):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            c = sess.run(cost, feed_dict={X: batch_x, Y: batch_y})
            _ = sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
            epoch_cost += c / batch_num


        print(epoch_cost)

值得注意的是,我在原模型和重载模型里都计算了测试集的cost,两者的结果是一致的。说明参数载入应该是对的

  • 写回答

1条回答

  • threenewbee 2019-02-28 00:03
    关注

    排除你模型本身的原因,loss变大可能是过拟合了。

    评论

报告相同问题?

悬赏问题

  • ¥20 为什么我写出来的绘图程序是这样的,有没有lao哥改一下
  • ¥15 js,页面2返回页面1时定位进入的设备
  • ¥50 导入文件到网吧的电脑并且在重启之后不会被恢复
  • ¥15 (希望可以解决问题)ma和mb文件无法正常打开,打开后是空白,但是有正常内存占用,但可以在打开Maya应用程序后打开场景ma和mb格式。
  • ¥15 绘制多分类任务的roc曲线时只画出了一类的roc,其它的auc显示为nan
  • ¥20 ML307A在使用AT命令连接EMQX平台的MQTT时被拒绝
  • ¥20 腾讯企业邮箱邮件可以恢复么
  • ¥15 有人知道怎么将自己的迁移策略布到edgecloudsim上使用吗?
  • ¥15 错误 LNK2001 无法解析的外部符号
  • ¥50 安装pyaudiokits失败