图片数据集的写入和读取 tfrecord, tfdata 5C

代码

import os
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

cwd = 'E:\\Tensorflow\\Wenshan_Cai_Nanoletters\\classes\\'
classes = {'cats', 'dogs', 'horses', 'humans'}

def convert_to_tfrecord(classes_path, output):

    writer = tf.python_io.TFRecordWriter(output)
    for index, name in enumerate(classes):
        class_path = classes_path + name + '\\'
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name   # 每个图片的地址

            img = Image.open(img_path)
            img = img.resize((64,64))
            img_raw = img.tobytes()       # 将图片转成二进制
            example = tf.train.Example(features = tf.train.Features(feature = {
                'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [index])),
                'img_raw': tf.train.Feature(bytes_list = tf.train.BytesList(value = [img_raw]))
            }))     # example 对象对label 和 image 数据进行封装
            writer.write(example.SerializeToString())   # 序列化为字符串
    writer.close()

def dataset_input_fn(tfrecord_name):

    dataset = tf.data.TFRecordDataset(tfrecord_name)

    def parser(record):
        keys_to_features = {
            'image_data': tf.FixedLenFeature((), tf.string, default_value  = ''),
            'label': tf.FixedLenFeature((), tf.int64,
                                        default_value= tf.zeros([], dtype = tf.int64)),
        }
        parsed = tf.parse_single_example(record, keys_to_features)
        image = tf.image.decode_jpeg(parsed['image_data'])
        image = tf.reshape(image, [64, 64, 3])
        label = tf.cast(parsed['label'], tf.int32)

        return image, label

    dataset = dataset.map(parser())
    dataset = dataset.shuffle(buffer_size = 10000)     # buffer_size > data_set_size a perfect uniform random shuffle is guaranteed.
    dataset = dataset.batch(batch_size = 2)
    dataset = dataset.repeat(1)              # 在每个epoch内将图片打乱组成大小为32的batch,并重复10次
    iterator = dataset.make_one_shot_iterator()   # 每次只访问一个元素


    images, labels = iterator.get_next()
    return images, labels             # return a tuple

tfrecord_fn = 'E:\\Tensorflow\\Wenshan_Cai_Nanoletters\\mytrain.tfrecords'
convert_to_tfrecord(cwd, 'mytrain.tfrecords')
output_file = dataset_input_fn(tfrecord_fn)


init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    images, labels = sess.run(output_file)
    for image, label in zip(images, labels):

        img_data = tf.image.decode_jpeg(image)
        plt.imshow(image)
        plt.show()
        print(label)

错误
Traceback (most recent call last):
File "E:/Tensorflow/Wenshan_Cai_Nanoletters/TFRecord.py", line 64, in
images, labels = sess.run(output_file)
File "E:\Tensorflow\venv\lib\site-packages\tensorflow\python\client\session.py", line 887, in run
run_metadata_ptr)
File "E:\Tensorflow\venv\lib\site-packages\tensorflow\python\client\session.py", line 1110, in _run
feed_dict_tensor, options, run_metadata)
File "E:\Tensorflow\venv\lib\site-packages\tensorflow\python\client\session.py", line 1286, in _do_run
run_metadata)
File "E:\Tensorflow\venv\lib\site-packages\tensorflow\python\client\session.py", line 1308, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected image (JPEG, PNG, or GIF), got empty file
[[{{node DecodeJpeg}} = DecodeJpegacceptable_fraction=1, channels=0, dct_method="", fancy_upscaling=true, ratio=1, try_recover_truncated=false]]
[[{{node IteratorGetNext}} = IteratorGetNextoutput_shapes=[[?,64,64,3], [?]], output_types=[DT_UINT8, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

求大神帮忙解答一下

1个回答

Csdn user default icon
上传中...
上传图片
插入图片
抄袭、复制答案,以达到刷声望分或其他目的的行为,在CSDN问答是严格禁止的,一经发现立刻封号。是时候展现真正的技术了!
立即提问
相关内容推荐