太空的旅行者 2023-07-12 20:10 采纳率: 28.6%
浏览 6
已结题

tensorflow中如何将numpy数组存入tfrecords

我使用tensorflow将数据集转换为tfrecords格式。数据集主要是包括两个部分,一个就是jpg图像,这个图像直接使用tf.io.read file进行读取,读成bytes就可以顺利的转化为tfrecords,转换后的数据体积不会明显膨胀。另一部分是一个二进制文件,我不得不对他进行处理转换为numpy数组,我通过将np数组转化为byte存入tfrecords后,体积巨大,请问有没有什么好一点的方法能解决呢?

  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2023-07-12 22:04
    关注
    • 这篇文章:tensorflow读取分类数据集,并随机将其分割为训练集和测试集,以tfrecords形式保存 也许有你想要的答案,你可以看看
    • 除此之外, 这篇博客: 黑马程序员3天带你玩转Python深度学习TensorFlow框架学习笔记中的 3.4.1、tfrecords文件存储 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
      1. 构造存储实例,tf.python_io.TFRecordWriter(path)
        • 写入tfrecords文件
        • path: TFRecords文件的路径
        • return:写文件
          • method方法
            • write(record):向文件中写入一个example
            • close():关闭文件写入器
      2. 循环将数据填入到Example协议内存块(protocol buffer)
      class Cifar(object):
      
         def __init__(self):
            # 初始化操作
            self.height=32
            self.width=32
            self.channels=3
      
            # 字节数
            self.image_bytes=self.height*self.width*self.channels # 图片像素数
            self.label_bytes=1 # 标签数
            self.all_bytes=self.label_bytes+self.image_bytes # 总字节数
      
         def read_and_decode(self,file_list):
            # 1、构造文件名队列
            file_queue=tf.train.string_input_producer(file_list)
      
            # 2、读取与解码
            # 读取阶段
            reader=tf.FixedLengthRecordReader(self.all_bytes)
            # key 文件名,value一个样本
            key,value=reader.read(file_queue)
      
            # 解码阶段
            decode=tf.decode_raw(value,tf.uint8)
            # 将目标值和特征值切片分开,即标签和通道分开。tf.slice(data,起始位置,个数)
            label=tf.slice(decode,[0],[self.label_bytes])
            image=tf.slice(decode,[self.label_bytes],[self.image_bytes])
            # 调整图片形状
            image_reshaped=tf.reshape(image,shape = [self.channels,self.height,self.width])
            # 转置,转成tf图片的表示格式 height,width,channels
            image_transposed=tf.transpose(image_reshaped,[1,2,0])
            # 跳转图像类型,uint8转为float32
            image_cast=tf.cast(image_transposed,tf.float32)
            
              # 3、批处理
            label_batch,image_batch=tf.train.batch([label,image_cast],batch_size = 100,num_threads = 1,capacity = 100)
      
            # 开启会话
            with tf.Session() as sess:
               print('------------------开启会话------------------')
               # 开启线程
               coord=tf.train.Coordinator() # 协调器
               threads=tf.train.start_queue_runners(sess=sess,coord = coord)
               label_batch_new,image_batch_new=sess.run([label_batch,image_batch])
               # 回收线程
               coord.request_stop()
               coord.join(threads)
            return label_batch_new,image_batch_new
      
         def write_to_tfrecords(self,label_batch,image_batch):
            # 将样本的特征值和目标值写入tfrecords文件
            with tf.python_io.TFRecordWriter('./temp/cifar10/cifar10.tfrecords') as tfWriter:
               # 循环构造example对象,并序列化写入文件
               for i in range(label_batch.size):
                  image=image_batch[i].tostring() # 序列化
                  label=label_batch[i][0] # [i][0]取出一维数组的值
                  example = tf.train.Example(features = tf.train.Features(feature = {
                     "image": tf.train.Feature(bytes_list = tf.train.BytesList(value=[image])),
                     "label": tf.train.Feature(int64_list = tf.train.Int64List(value=[label]))
                  }))
                  # 将序列化后的example写入到cifar10.tfrecords文件中
                  tfWriter.write(example.SerializeToString())
                  
      if __name__ == '__main__':
      	file_name=os.listdir('./data/cifar-10-batches-bin')
      	# 构造路径 + 文件名的列表
      	file_list=[os.path.join('./data/cifar-10-batches-bin',file) for file in file_name if file[-3:]=='bin']
      	print('file_llist: ',file_list)
      	#实例化Cifar类
      	cifar=Cifar()
      	label_batch,image_batch=cifar.read_and_decode(file_list)
      	cifar.write_to_tfrecords(label_batch,image_batch)            
      
    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 7月20日
  • 创建了问题 7月12日

悬赏问题

  • ¥15 用stata实现聚类的代码
  • ¥15 请问paddlehub能支持移动端开发吗?在Android studio上该如何部署?
  • ¥170 如图所示配置eNSP
  • ¥20 docker里部署springboot项目,访问不到扬声器
  • ¥15 netty整合springboot之后自动重连失效
  • ¥15 悬赏!微信开发者工具报错,求帮改
  • ¥20 wireshark抓不到vlan
  • ¥20 关于#stm32#的问题:需要指导自动酸碱滴定仪的原理图程序代码及仿真
  • ¥20 设计一款异域新娘的视频相亲软件需要哪些技术支持
  • ¥15 stata安慰剂检验作图但是真实值不出现在图上