XJTU_Ironboy 2017-09-04 14:30 采纳率: 0%
浏览 4900

怎么在TensorFlow上导入ImageNet数据进行试验?

请问一下,将数据集转为tfrecord格式之后,自己load数据的时候经常跑到一半报错

tensorflow.python.framework.errors_impl.OutOfRangeError: RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 1, current size 0)
     [[Node: shuffle_batch = QueueDequeueUpToV2[component_types=[DT_UINT8, DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](shuffle_batch/random_shuffle_queue, shuffle_batch/n)]]

怎么回事,我这部分的代码大致是这样的:

import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt 
import tensorflow.contrib.slim as slim
from PIL import Image
tfrecord_paths = "./ImageNet_validate.tfrecord"
def read_and_decode(filename):
    #根据文件名生成一个队列
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'image' : tf.FixedLenFeature([], tf.string),
                                       })

    img = tf.decode_raw(features['image'], tf.uint8)
    img = tf.reshape(img,[1,433200])
    # img = tf.reshape(img, [380, 380, 3])
    # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(features['label'], tf.int32)

    return img, label

img, label = read_and_decode(tfrecord_paths)

img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                batch_size=1, capacity=1000,
                                                num_threads = 512,
                                                allow_smaller_final_batch=True,
                                                min_after_dequeue=1)

global_init = tf.global_variables_initializer()
local_init = tf.local_variables_initializer()

with tf.Session() as sess:
    sess.run(global_init)
    sess.run(local_init)
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)
    for i in range(1000):
        print(i)
        print("image:",img_batch.get_shape().as_list())
        print("label:",label_batch.get_shape().as_list())
        val, l= sess.run([img_batch,label_batch])
        print(val.shape, l)
  • 写回答

2条回答

  • xiaoyaoyao17 2018-08-06 06:48
    关注

    需要生成 TFRcords 才可以用 TensorFlow 读取

    评论

报告相同问题?

悬赏问题

  • ¥17 pro*C预编译“闪回查询”报错SCN不能识别
  • ¥15 微信会员卡接入微信支付商户号收款
  • ¥15 如何获取烟草零售终端数据
  • ¥15 数学建模招标中位数问题
  • ¥15 phython路径名过长报错 不知道什么问题
  • ¥15 深度学习中模型转换该怎么实现
  • ¥15 HLs设计手写数字识别程序编译通不过
  • ¥15 Stata外部命令安装问题求帮助!
  • ¥15 从键盘随机输入A-H中的一串字符串,用七段数码管方法进行绘制。提交代码及运行截图。
  • ¥15 TYPCE母转母,插入认方向