学习小仙子 2021-03-17 15:43 采纳率: 0%
浏览 50

python---使用自己的数据生成.tfrecord文件,代码出错,运行没问题,但是数据读不进去

def main(_):
  writer = tf.python_io.TFRecordWriter(FLAGS.output_path)

  # load groundtruth file
  groundtruth_file = os.path.join(FLAGS.data_dir, 'labels.txt')
  with open(groundtruth_file, 'r') as f:
    groundtruth_lines = f.readlines()

  num_images = len(groundtruth_lines) - FLAGS.start_index
  if FLAGS.num_images > 0:
    num_images = min(num_images, FLAGS.num_images)

  indices = list(range(FLAGS.start_index, FLAGS.start_index + num_images))
  if FLAGS.shuffle:
    random.shuffle(indices)

  # a test decode pipeline for validating image
  image_jpeg_input = tf.placeholder(
    dtype=tf.string,
    shape=[]
  )
  image = tf.image.decode_jpeg(
    image_jpeg_input,
    channels=3,
    try_recover_truncated=False,
    acceptable_fraction=1
  )

  with tf.Session() as sess:
    for index in tqdm(indices):
      image_rel_path = groundtruth_lines[index].split(' ')[0]
      image_path = os.path.join(FLAGS.data_dir, image_rel_path)

      # validate image
      valid = True
      image_jpeg = None
      try:
        with open(image_path, 'rb') as f:
          image_jpeg = f.read()
          image_output = sess.run(image, feed_dict={
            image_jpeg_input: image_jpeg
          })
          if (image_output.ndim != 3 or
              image_output.shape[0] == 0 or
              image_output.shape[1] == 0 or
              image_output.shape[2] != 3):
            valid = False
      except:
        valid = False
      
      if not valid:
        logging.warn('Skip invalid image {}'.format(image_rel_path))
        continue

      # extract groundtruth
      groundtruth_text = image_rel_path.split('_')[1]

      # write example
      example = tf.train.Example(features=tf.train.Features(feature={
        fields.TfExampleFields.image_encoded: \
          dataset_util.bytes_feature(image_jpeg),
        fields.TfExampleFields.image_format: \
          dataset_util.bytes_feature('jpeg'.encode('utf-8')),
        fields.TfExampleFields.filename: \
          dataset_util.bytes_feature(image_rel_path.encode('utf-8')),
        fields.TfExampleFields.channels: \
          dataset_util.int64_feature(3),
        fields.TfExampleFields.colorspace: \
          dataset_util.bytes_feature('rgb'.encode('utf-8')),
        fields.TfExampleFields.transcript: \
          dataset_util.bytes_feature(groundtruth_text.encode('utf-8'))
      }))
      writer.write(example.SerializeToString())

  writer.close()

打印结果:

 

  • 写回答

1条回答 默认 最新

  • 学习小仙子 2021-03-17 15:46
    关注

    打印输出是把所有的图片全部跳过了,标签的格式是这样的:./1/1/1_stoled oversend.jpg

    求大佬指点!!!

     

    评论

报告相同问题?

悬赏问题

  • ¥15 is not in the mmseg::model registry。报错,模型注册表找不到自定义模块。
  • ¥15 安装quartus II18.1时弹出此error,怎么解决?
  • ¥15 keil官网下载psn序列号在哪
  • ¥15 想用adb命令做一个通话软件,播放录音
  • ¥30 Pytorch深度学习服务器跑不通问题解决?
  • ¥15 部分客户订单定位有误的问题
  • ¥15 如何在maya程序中利用python编写领子和褶裥的模型的方法
  • ¥15 Bug traq 数据包 大概什么价
  • ¥15 在anaconda上pytorch和paddle paddle下载报错
  • ¥25 自动填写QQ腾讯文档收集表