weixin_51793354 2022-08-01 00:41 采纳率: 54.5%
浏览 108
已结题

tensorflow 训练完后如何测试?尝试读取文件,报错了。

我写了一个类,想在fit里训练,在predict里测试训练结果,我的predict函数应该怎么写?我尝试把计算图保存到文件里,然后在predict函数里读取出来,但是总是报错。我用的是tensorflow1.13.1。
下面是相关的代码

    def fit(self, x, y, learning_rate=0.1, iter_num=10):
        (m, n_h, n_w, n_c) = x.shape
        (m, n_y) = y.shape
        self.learning_rate = learning_rate
        self.x_ = tf.placeholder(tf.float32, [None, n_h, n_w, n_c])
        self.y_ = tf.placeholder(tf.float32, [None, n_y])
        self.init_paras()
        self.f_p(self.x_, self.y_)
        self.b_p()
        init = tf.global_variables_initializer()
        with tf.Session() as session:
            session.run(init)
            for cnt in range(iter_num):
                _, cost = session.run([self.cache["optimizer"], self.cache["cost"]], feed_dict={self.x_: x, self.y_: y})
                if cnt % 5 == 0:
                    print("第{}次代价为:{}".format(cnt, cost))
            saver = tf.train.Saver()
            saver.save(session, "./checkpoint_dir/model")

    def predict(self, x, y):
        with tf.Session() as sess:
            saver = tf.train.Saver()
            saver.restore(sess, "./checkpoint_dir/model.meta")
            a = self.cache["output"]
            accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(a, axis=1), tf.argmax(y, axis=1)), tf.float32))
            sess.run(accuracy, feed_dict={self.x_: x, self.y_: y})
            print(accuracy)****

下面是报错的结果

tensorflow.python.framework.errors_impl.DataLossError: Unable to open table file ./checkpoint_dir/model.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
     [[{{node save_1/RestoreV2}}]]

DataLossError (see above for traceback): Unable to open table file ./checkpoint_dir/model.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
     [[node save_1/RestoreV2 (defined at ~/PycharmProjects/pythonProject1/test.py:67) ]]
  • 写回答

1条回答 默认 最新

  • 海洋 之心 2022年度博客之星人工智能领域TOP 1 2022-08-01 12:20
    关注

    Set up your data format vector and pass it into the Model for inference

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 8月1日
  • 已采纳回答 8月1日
  • 创建了问题 8月1日

悬赏问题

  • ¥20 机器学习能否像多层线性模型一样处理嵌套数据
  • ¥20 西门子S7-Graph,S7-300,梯形图
  • ¥50 用易语言http 访问不了网页
  • ¥50 safari浏览器fetch提交数据后数据丢失问题
  • ¥15 matlab不知道怎么改,求解答!!
  • ¥15 永磁直线电机的电流环pi调不出来
  • ¥15 用stata实现聚类的代码
  • ¥15 请问paddlehub能支持移动端开发吗?在Android studio上该如何部署?
  • ¥20 docker里部署springboot项目,访问不到扬声器
  • ¥15 netty整合springboot之后自动重连失效