问题遇到的现象和发生背景
本人在用tensorflow2 TFrecord读取时看到一些写的脚本都是将其读取成batch里面包含了数据(图像矩阵)和标签,我想将数据和标签从batch分离出来使用,应该如何实现?
问题相关代码,请勿粘贴截图
tfrecord_file = ''
dataset = tf.data.TFRecordDataset(tfrecord_file) # 读取 TFRecord 文件
features = { # 定义Feature结构,告诉解码器每个Feature的类型是什么
'label': tf.io.FixedLenFeature([], tf.float32),
'img_raw': tf.io.FixedLenFeature([], tf.string)
}
def read_and_decode(example_string):
'''
从TFrecord格式文件中读取数据
'''
feature_dict = tf.io.parse_single_example(example_string, features)
image = tf.io.decode_raw(feature_dict['img_raw'], tf.uint8)
label = tf.cast(feature_dict['label'], tf.float32)
image = tf.reshape(image, [224, 224,3])
# image = tf.cast(image, dtype='float32') / 255. # 在流中抛出img张量
# label = tf.cast(feature_description['label'], tf.int64) # 在流中抛出label张量
image_batch = tf.cast(image,dtype=tf.float32)
return image_batch, label
dataset = dataset.repeat(2) # 重复数据集
dataset = dataset.map(read_and_decode) # 解析数据
dataset = dataset.shuffle(buffer_size=100) # 在缓冲区中随机打乱数据
test_batch = dataset.batch(batch_size=64) # 每10条数据为一个batch,生成一个新的Datasets
#print(test_batch)
运行结果及报错内容
<BatchDataset shapes: ((None, 224, 224, 3), (None,)), types: (tf.float32, tf.float32)>
我的解答思路和尝试过的方法
曾尝试过用循环读取出来但是害怕数据和标签对不起,还有要是将循环内容存入列表怕内存爆了
for img,label in test_batch:
labels.append(label)
print(np.array(labels))
我想要达到的结果
希望可以将BatchDataset分成数据和标签的形式