代码
import os
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
cwd = 'E:\\Tensorflow\\Wenshan_Cai_Nanoletters\\classes\\'
classes = {'cats', 'dogs', 'horses', 'humans'}
def convert_to_tfrecord(classes_path, output):
writer = tf.python_io.TFRecordWriter(output)
for index, name in enumerate(classes):
class_path = classes_path + name + '\\'
for img_name in os.listdir(class_path):
img_path = class_path + img_name # 每个图片的地址
img = Image.open(img_path)
img = img.resize((64,64))
img_raw = img.tobytes() # 将图片转成二进制
example = tf.train.Example(features = tf.train.Features(feature = {
'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [index])),
'img_raw': tf.train.Feature(bytes_list = tf.train.BytesList(value = [img_raw]))
})) # example 对象对label 和 image 数据进行封装
writer.write(example.SerializeToString()) # 序列化为字符串
writer.close()
def dataset_input_fn(tfrecord_name):
dataset = tf.data.TFRecordDataset(tfrecord_name)
def parser(record):
keys_to_features = {
'image_data': tf.FixedLenFeature((), tf.string, default_value = ''),
'label': tf.FixedLenFeature((), tf.int64,
default_value= tf.zeros([], dtype = tf.int64)),
}
parsed = tf.parse_single_example(record, keys_to_features)
image = tf.image.decode_jpeg(parsed['image_data'])
image = tf.reshape(image, [64, 64, 3])
label = tf.cast(parsed['label'], tf.int32)
return image, label
dataset = dataset.map(parser())
dataset = dataset.shuffle(buffer_size = 10000) # buffer_size > data_set_size a perfect uniform random shuffle is guaranteed.
dataset = dataset.batch(batch_size = 2)
dataset = dataset.repeat(1) # 在每个epoch内将图片打乱组成大小为32的batch,并重复10次
iterator = dataset.make_one_shot_iterator() # 每次只访问一个元素
images, labels = iterator.get_next()
return images, labels # return a tuple
tfrecord_fn = 'E:\\Tensorflow\\Wenshan_Cai_Nanoletters\\mytrain.tfrecords'
convert_to_tfrecord(cwd, 'mytrain.tfrecords')
output_file = dataset_input_fn(tfrecord_fn)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
images, labels = sess.run(output_file)
for image, label in zip(images, labels):
img_data = tf.image.decode_jpeg(image)
plt.imshow(image)
plt.show()
print(label)
错误
Traceback (most recent call last):
File "E:/Tensorflow/Wenshan_Cai_Nanoletters/TFRecord.py", line 64, in
images, labels = sess.run(output_file)
File "E:\Tensorflow\venv\lib\site-packages\tensorflow\python\client\session.py", line 887, in run
run_metadata_ptr)
File "E:\Tensorflow\venv\lib\site-packages\tensorflow\python\client\session.py", line 1110, in _run
feed_dict_tensor, options, run_metadata)
File "E:\Tensorflow\venv\lib\site-packages\tensorflow\python\client\session.py", line 1286, in _do_run
run_metadata)
File "E:\Tensorflow\venv\lib\site-packages\tensorflow\python\client\session.py", line 1308, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected image (JPEG, PNG, or GIF), got empty file
[[{{node DecodeJpeg}} = DecodeJpegacceptable_fraction=1, channels=0, dct_method="", fancy_upscaling=true, ratio=1, try_recover_truncated=false]]
[[{{node IteratorGetNext}} = IteratorGetNextoutput_shapes=[[?,64,64,3], [?]], output_types=[DT_UINT8, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]
求大神帮忙解答一下