怎么在TensorFlow上导入ImageNet数据进行试验?

请问一下,将数据集转为tfrecord格式之后,自己load数据的时候经常跑到一半报错

tensorflow.python.framework.errors_impl.OutOfRangeError: RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 1, current size 0)
     [[Node: shuffle_batch = QueueDequeueUpToV2[component_types=[DT_UINT8, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](shuffle_batch/random_shuffle_queue, shuffle_batch/n)]]

怎么回事,我这部分的代码大致是这样的:

import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt 
import tensorflow.contrib.slim as slim
from PIL import Image
tfrecord_paths = "./ImageNet_validate.tfrecord"
def read_and_decode(filename):
    #根据文件名生成一个队列
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'image' : tf.FixedLenFeature([], tf.string),
                                       })

    img = tf.decode_raw(features['image'], tf.uint8)
    img = tf.reshape(img,[1,433200])
    # img = tf.reshape(img, [380, 380, 3])
    # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(features['label'], tf.int32)

    return img, label

img, label = read_and_decode(tfrecord_paths)

img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                batch_size=1, capacity=1000,
                                                num_threads = 512,
                                                allow_smaller_final_batch=True,
                                                min_after_dequeue=1)

global_init = tf.global_variables_initializer()
local_init = tf.local_variables_initializer()

with tf.Session() as sess:
    sess.run(global_init)
    sess.run(local_init)
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)
    for i in range(1000):
        print(i)
        print("image:",img_batch.get_shape().as_list())
        print("label:",label_batch.get_shape().as_list())
        val, l= sess.run([img_batch,label_batch])
        print(val.shape, l)
qq_27278153
qq_27278153 博主,请问,你有没转换好数据集呀。我现在也遇到这个问题。能否加我QQ 571205937 指导下,谢谢!
接近 2 年之前 回复

2个回答

需要生成 TFRcords 才可以用 TensorFlow 读取

qq_27278153
qq_27278153 你好。我现在也遇到这个问题。能否加我QQ 571205937 指导下,谢谢!
接近 2 年之前 回复

不转也行,只是训练速度慢而已

Csdn user default icon
上传中...
上传图片
插入图片
抄袭、复制答案,以达到刷声望分或其他目的的行为,在CSDN问答是严格禁止的,一经发现立刻封号。是时候展现真正的技术了!
立即提问
相关内容推荐