wenqinlong_1 2018-10-25 03:45 采纳率: 0%
浏览 4428
已结题

图片数据集的写入和读取 tfrecord, tfdata

代码

import os
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

cwd = 'E:\\Tensorflow\\Wenshan_Cai_Nanoletters\\classes\\'
classes = {'cats', 'dogs', 'horses', 'humans'}

def convert_to_tfrecord(classes_path, output):

    writer = tf.python_io.TFRecordWriter(output)
    for index, name in enumerate(classes):
        class_path = classes_path + name + '\\'
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name   # 每个图片的地址

            img = Image.open(img_path)
            img = img.resize((64,64))
            img_raw = img.tobytes()       # 将图片转成二进制
            example = tf.train.Example(features = tf.train.Features(feature = {
                'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [index])),
                'img_raw': tf.train.Feature(bytes_list = tf.train.BytesList(value = [img_raw]))
            }))     # example 对象对label 和 image 数据进行封装
            writer.write(example.SerializeToString())   # 序列化为字符串
    writer.close()

def dataset_input_fn(tfrecord_name):

    dataset = tf.data.TFRecordDataset(tfrecord_name)

    def parser(record):
        keys_to_features = {
            'image_data': tf.FixedLenFeature((), tf.string, default_value  = ''),
            'label': tf.FixedLenFeature((), tf.int64,
                                        default_value= tf.zeros([], dtype = tf.int64)),
        }
        parsed = tf.parse_single_example(record, keys_to_features)
        image = tf.image.decode_jpeg(parsed['image_data'])
        image = tf.reshape(image, [64, 64, 3])
        label = tf.cast(parsed['label'], tf.int32)

        return image, label

    dataset = dataset.map(parser())
    dataset = dataset.shuffle(buffer_size = 10000)     # buffer_size > data_set_size a perfect uniform random shuffle is guaranteed.
    dataset = dataset.batch(batch_size = 2)
    dataset = dataset.repeat(1)              # 在每个epoch内将图片打乱组成大小为32的batch,并重复10次
    iterator = dataset.make_one_shot_iterator()   # 每次只访问一个元素


    images, labels = iterator.get_next()
    return images, labels             # return a tuple

tfrecord_fn = 'E:\\Tensorflow\\Wenshan_Cai_Nanoletters\\mytrain.tfrecords'
convert_to_tfrecord(cwd, 'mytrain.tfrecords')
output_file = dataset_input_fn(tfrecord_fn)


init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    images, labels = sess.run(output_file)
    for image, label in zip(images, labels):

        img_data = tf.image.decode_jpeg(image)
        plt.imshow(image)
        plt.show()
        print(label)

错误
Traceback (most recent call last):
File "E:/Tensorflow/Wenshan_Cai_Nanoletters/TFRecord.py", line 64, in
images, labels = sess.run(output_file)
File "E:\Tensorflow\venv\lib\site-packages\tensorflow\python\client\session.py", line 887, in run
run_metadata_ptr)
File "E:\Tensorflow\venv\lib\site-packages\tensorflow\python\client\session.py", line 1110, in _run
feed_dict_tensor, options, run_metadata)
File "E:\Tensorflow\venv\lib\site-packages\tensorflow\python\client\session.py", line 1286, in _do_run
run_metadata)
File "E:\Tensorflow\venv\lib\site-packages\tensorflow\python\client\session.py", line 1308, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected image (JPEG, PNG, or GIF), got empty file
[[{{node DecodeJpeg}} = DecodeJpegacceptable_fraction=1, channels=0, dct_method="", fancy_upscaling=true, ratio=1, try_recover_truncated=false]]
[[{{node IteratorGetNext}} = IteratorGetNextoutput_shapes=[[?,64,64,3], [?]], output_types=[DT_UINT8, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

求大神帮忙解答一下

  • 写回答

1条回答

  • dabocaiqq 2018-10-26 05:52
    关注
    评论

报告相同问题?

悬赏问题

  • ¥15 #MATLAB仿真#车辆换道路径规划
  • ¥15 java 操作 elasticsearch 8.1 实现 索引的重建
  • ¥15 数据可视化Python
  • ¥15 要给毕业设计添加扫码登录的功能!!有偿
  • ¥15 kafka 分区副本增加会导致消息丢失或者不可用吗?
  • ¥15 微信公众号自制会员卡没有收款渠道啊
  • ¥100 Jenkins自动化部署—悬赏100元
  • ¥15 关于#python#的问题:求帮写python代码
  • ¥20 MATLAB画图图形出现上下震荡的线条
  • ¥15 关于#windows#的问题:怎么用WIN 11系统的电脑 克隆WIN NT3.51-4.0系统的硬盘