我写了一个类,想在fit里训练,在predict里测试训练结果,我的predict函数应该怎么写?我尝试把计算图保存到文件里,然后在predict函数里读取出来,但是总是报错。我用的是tensorflow1.13.1。
下面是相关的代码
def fit(self, x, y, learning_rate=0.1, iter_num=10):
(m, n_h, n_w, n_c) = x.shape
(m, n_y) = y.shape
self.learning_rate = learning_rate
self.x_ = tf.placeholder(tf.float32, [None, n_h, n_w, n_c])
self.y_ = tf.placeholder(tf.float32, [None, n_y])
self.init_paras()
self.f_p(self.x_, self.y_)
self.b_p()
init = tf.global_variables_initializer()
with tf.Session() as session:
session.run(init)
for cnt in range(iter_num):
_, cost = session.run([self.cache["optimizer"], self.cache["cost"]], feed_dict={self.x_: x, self.y_: y})
if cnt % 5 == 0:
print("第{}次代价为:{}".format(cnt, cost))
saver = tf.train.Saver()
saver.save(session, "./checkpoint_dir/model")
def predict(self, x, y):
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, "./checkpoint_dir/model.meta")
a = self.cache["output"]
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(a, axis=1), tf.argmax(y, axis=1)), tf.float32))
sess.run(accuracy, feed_dict={self.x_: x, self.y_: y})
print(accuracy)****
下面是报错的结果
tensorflow.python.framework.errors_impl.DataLossError: Unable to open table file ./checkpoint_dir/model.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
[[{{node save_1/RestoreV2}}]]
DataLossError (see above for traceback): Unable to open table file ./checkpoint_dir/model.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
[[node save_1/RestoreV2 (defined at ~/PycharmProjects/pythonProject1/test.py:67) ]]