泗水之弦 2019-04-01 09:48 采纳率: 75%
浏览 417

如何把下面的代码改成能测试一个文件夹下所有的图片?

1.代码是猫狗大战改写的,测试图片的代码是get_one_image,每次只能测试目录下随机一张图片,那么怎么才能让它测试“test”文件夹下所有的图片呢?Python新人对这一块的方法很模糊,求大佬们解答
2附上test代码

# -*- coding: utf-8 -*-
"""
Created on Mon Mar 18 08:05:21 2019

@author: pc
"""
# 评估模型
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import input_data
import model
import training

def get_one_image(train):
    n = len(train)
    ind = np.random.randint(0, n)
    img_dir = train[ind]

    image = Image.open(img_dir)
    plt.imshow(image)
    plt.show()
    image = image.resize([224, 224])
    image = np.array(image)
    return image


def evaluate_one_image():
    train_dir = "D:\\python\\Anaconda\\envs\\tensorflow\\shuidao\\data\\train\\"
    train, train_label = input_data.get_files(train_dir)
    image_array = get_one_image(train)

    with tf.Graph().as_default():
        BATCH_SIZE = 1
        N_CLASSES = 4

        image = tf.cast(image_array, tf.float32)
        image = tf.reshape(image, [1, 224, 224, 3])
        logit = model.inference(image, BATCH_SIZE, N_CLASSES)
        logit = tf.nn.softmax(logit)

        x = tf.placeholder(tf.float32, shape=[224, 224, 3])

        logs_train_dir = "D:\\python\\Anaconda\\envs\\tensorflow\\shuidao\\logs_1\\"
        saver = tf.train.Saver()

        with tf.Session() as sess:
            print("Reading checkpoints...")
            ckpt = tf.train.get_checkpoint_state(logs_train_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split("/")[-1].split("-")[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print("Loading success, global_step is %s" % global_step)
            else:
                print("No checkpoint file found")

            prediction = sess.run(logit, feed_dict={x: image_array})
            max_index = np.argmax(prediction)
            if max_index == 0:
                print("This is daowen with possibility %.6f" % prediction[:, 0])
            elif max_index == 1:
                print("This is baiye with possibility %.6f" % prediction[:, 1])
            elif max_index == 2:
                print("This is wenku with possibility %.6f" % prediction[:, 2])
            elif max_index == 3:
                print("This is emiao with possibility %.6f" % prediction[:, 3])

training.run_training()
evaluate_one_image()
  • 写回答

1条回答 默认 最新

  • 吃鸡王者 2019-04-01 10:39
    关注

    -*- coding: utf-8 -*-

    """
    Created on Mon Mar 18 08:05:21 2019

    @author: pc
    """

    评估模型

    from PIL import Image
    import matplotlib.pyplot as plt
    import numpy as np
    import tensorflow as tf
    import input_data
    import model
    import training
    import time

    def get_one_image(imag_dir):

    image = Image.open(img_dir)
    plt.imshow(image)
    plt.show()
    image = image.resize([224, 224])
    image = np.array(image)
    return image
    

    def evaluate_one_image(img_dir):
    image_array = get_one_image(img_dir)

    with tf.Graph().as_default():
        BATCH_SIZE = 1
        N_CLASSES = 4
    
        image = tf.cast(image_array, tf.float32)
        image = tf.reshape(image, [1, 224, 224, 3])
        logit = model.inference(image, BATCH_SIZE, N_CLASSES)
        logit = tf.nn.softmax(logit)
    
        x = tf.placeholder(tf.float32, shape=[224, 224, 3])
    
        logs_train_dir = "D:\\python\\Anaconda\\envs\\tensorflow\\shuidao\\logs_1\\"
        saver = tf.train.Saver()
    
        with tf.Session() as sess:
            print("Reading checkpoints...")
            ckpt = tf.train.get_checkpoint_state(logs_train_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split("/")[-1].split("-")[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print("Loading success, global_step is %s" % global_step)
            else:
                print("No checkpoint file found")
    
            prediction = sess.run(logit, feed_dict={x: image_array})
            max_index = np.argmax(prediction)
            if max_index == 0:
                print("This is daowen with possibility %.6f" % prediction[:, 0])
            elif max_index == 1:
                print("This is baiye with possibility %.6f" % prediction[:, 1])
            elif max_index == 2:
                print("This is wenku with possibility %.6f" % prediction[:, 2])
            elif max_index == 3:
                print("This is emiao with possibility %.6f" % prediction[:, 3])
    

    training.run_training()
    #evaluate_one_image()
    train_dir = "D:\python\Anaconda\envs\tensorflow\shuidao\data\train\"
    train, train_label = input_data.get_files(train_dir)
    for img_dir in train:
    evaluate_one_image(img_dir)
    time.sleep(60) #暂定一儿,便于查看结果
    plt.colse()

    这样试试,依次处理每个图片,但没有测试,但思路应该是没问题的。

    要想同时处理多个图片的话,可以参考train时的数据处理方式,过程是一样的。

    评论

报告相同问题?

悬赏问题

  • ¥20 求数据集和代码#有偿答复
  • ¥15 关于下拉菜单选项关联的问题
  • ¥20 java-OJ-健康体检
  • ¥15 rs485的上拉下拉,不会对a-b<-200mv有影响吗,就是接受时,对判断逻辑0有影响吗
  • ¥15 使用phpstudy在云服务器上搭建个人网站
  • ¥15 应该如何判断含间隙的曲柄摇杆机构,轴与轴承是否发生了碰撞?
  • ¥15 vue3+express部署到nginx
  • ¥20 搭建pt1000三线制高精度测温电路
  • ¥15 使用Jdk8自带的算法,和Jdk11自带的加密结果会一样吗,不一样的话有什么解决方案,Jdk不能升级的情况
  • ¥15 画两个图 python或R