关于Tensorflow的TFRecords读取问题 10C

生成TFRecords核心代码(图片处理成224 * 224 * 3)

 with tf.Session() as sess:
        for i in range(len(img_path_0)):
            # 获得图片的路径和类型
            img_path = img_path_0[i]
            label = label_0[i]


            # 读取图片
            image = tf.gfile.FastGFile(img_path, 'rb').read()
            # 解码图片(如果是 png 格式就使用 decode_png)
            image = tf.image.decode_jpeg(image)
            image_size = 224
                        # 图像预处理
            image = ima_preprocess.preprocess_for_train(image, image_size, image_size)
            # 转换数据类型
            image = tf.image.convert_image_dtype(image, dtype=tf.float32)
            # resize 224 * 224 * 3
            image = tf.image.resize_images(image, [width, height], method=0)
            # 执行 op: image
            image = sess.run(image)
            #print(image)
            # print(image.shape)
            # plt.imshow(image)
            # plt.show()

            # 将其图片矩阵转换成 tostring,tobytes
            image_raw = image.tostring()

            # 将数据整理成 TFRecord 需要的数据结构
            example = tf.train.Example(features=tf.train.Features(feature={
                'image_raw': _bytes_feature(image_raw),
                'label': _int64_feature(label),
                'height': _int64_feature(height),
                'width': _int64_feature(width),
                'channels': _int64_feature(channels),
            }))

            # 写 TFRecord
            writer.write(example.SerializeToString())
            print(i, label)

    writer.close()

在读取TFRecords的时候,由于之前使用tostring,导致使用代码

 decode_image = tf.decode_raw(features['image'], tf.uint8)

产生的矩阵比原来大了4倍,即 224 * 224 * 3 * 4
没办法使用之后的reshape[224, 224, 3]

请问,tf下有什么函数可以把TFRecords内的数据decode到原始的图片的矩阵??

1个回答

xajlzz
xajlzz 还是在解码之后,直接reshape的,在实际运行代码的时候,会报错。因为tostring之后,1个数字会被拆分为4个string。然后,解码为uint8之后,会变成4个整数。这比之前的图片矩阵整整大了4倍,没办法直接reshape!!!
接近 2 年之前 回复
xajlzz
xajlzz #对于图像数据需要使用decode_raw解码,同样也需要set_shape # dense data image = tf.decode_raw(features['image_raw'], tf.uint8) image_shape = tf.stack([height, width, 3]) image = tf.reshape(image, image_shape) image.set_shape([cfg.TRAIN.IMG_SIZE[0], cfg.TRAIN.IMG_SIZE[1], 3])
接近 2 年之前 回复
xajlzz
xajlzz 你给的链接里面
接近 2 年之前 回复
Csdn user default icon
上传中...
上传图片
插入图片
抄袭、复制答案,以达到刷声望分或其他目的的行为,在CSDN问答是严格禁止的,一经发现立刻封号。是时候展现真正的技术了!
立即提问
相关内容推荐