生成TFRecords核心代码(图片处理成224 * 224 * 3)
with tf.Session() as sess:
for i in range(len(img_path_0)):
# 获得图片的路径和类型
img_path = img_path_0[i]
label = label_0[i]
# 读取图片
image = tf.gfile.FastGFile(img_path, 'rb').read()
# 解码图片(如果是 png 格式就使用 decode_png)
image = tf.image.decode_jpeg(image)
image_size = 224
# 图像预处理
image = ima_preprocess.preprocess_for_train(image, image_size, image_size)
# 转换数据类型
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# resize 224 * 224 * 3
image = tf.image.resize_images(image, [width, height], method=0)
# 执行 op: image
image = sess.run(image)
#print(image)
# print(image.shape)
# plt.imshow(image)
# plt.show()
# 将其图片矩阵转换成 tostring,tobytes
image_raw = image.tostring()
# 将数据整理成 TFRecord 需要的数据结构
example = tf.train.Example(features=tf.train.Features(feature={
'image_raw': _bytes_feature(image_raw),
'label': _int64_feature(label),
'height': _int64_feature(height),
'width': _int64_feature(width),
'channels': _int64_feature(channels),
}))
# 写 TFRecord
writer.write(example.SerializeToString())
print(i, label)
writer.close()
在读取TFRecords的时候,由于之前使用tostring,导致使用代码
decode_image = tf.decode_raw(features['image'], tf.uint8)
产生的矩阵比原来大了4倍,即 224 * 224 * 3 * 4
没办法使用之后的reshape[224, 224, 3]
请问,tf下有什么函数可以把TFRecords内的数据decode到原始的图片的矩阵??