我是小聪明 2020-12-04 00:39 采纳率: 0%
浏览 40
已结题

widerface数据集voc格式转换tfrecord

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import hashlib
import io
import logging
import os

from lxml import etree
import PIL.Image
import tensorflow as tf

from object_detection.utils import dataset_util
from object_detection.utils import label_map_util

flags = tf.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '
                                    'merged set.')
flags.DEFINE_string('annotations_dir', 'Annotations',
                    '(Relative) path to annotations directory.')
flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt',
                    'Path to label map proto')
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
                                                          'difficult instances')
FLAGS = flags.FLAGS

SETS = ['train', 'val', 'trainval', 'test']
YEARS = ['fddb', 'widerface']


def dict_to_tf_example(data,
                       dataset_directory,
                       label_map_dict,
                       ignore_difficult_instances=False,
                       image_subdirectory='JPEGImages'):
    """Convert XML derived dict to tf.Example proto.
  Notice that this function normalizes the bounding box coordinates provided
  by the raw data.
  Args:
    data: dict holding PASCAL XML fields for a single image (obtained by
      running dataset_util.recursive_parse_xml_to_dict)
    dataset_directory: Path to root directory holding PASCAL dataset
    label_map_dict: A map from string label names to integers ids.
    ignore_difficult_instances: Whether to skip difficult instances in the
      dataset  (default: False).
    image_subdirectory: String specifying subdirectory within the
      PASCAL dataset directory holding the actual image data.
  Returns:
    example: The converted tf.Example.
  Raises:
    ValueError: if the image pointed to by data['filename'] is not a valid JPEG
  """
    img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])
    full_path = os.path.join(dataset_directory, img_path)
    with tf.gfile.GFile(full_path, 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = PIL.Image.open(encoded_jpg_io)
    if image.format != 'JPEG':
        raise ValueError('Image format not JPEG')
    key = hashlib.sha256(encoded_jpg).hexdigest()

    width = int(data['size']['width'])
    height = int(data['size']['height'])

    xmin = []
    ymin = []
    xmax = []
    ymax = []
    classes = []
    classes_text = []
    truncated = []
    poses = []
    difficult_obj = []
    if 'object' in data:
        for obj in data['object']:
            difficult = bool(int(obj['difficult']))
            if ignore_difficult_instances and difficult:
                continue

            difficult_obj.append(int(difficult))

            xmin.append(float(obj['bndbox']['xmin']) / width)
            ymin.append(float(obj['bndbox']['ymin']) / height)
            xmax.append(float(obj['bndbox']['xmax']) / width)
            ymax.append(float(obj['bndbox']['ymax']) / height)
            classes_text.append(obj['name'].encode('utf8'))
            classes.append(label_map_dict[obj['name']])
            truncated.append(int(obj['truncated']))
            poses.append(obj['pose'].encode('utf8'))

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(
            data['filename'].encode('utf8')),
        'image/source_id': dataset_util.bytes_feature(
            data['filename'].encode('utf8')),
        'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
        'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
        'image/object/truncated': dataset_util.int64_list_feature(truncated),
        'image/object/view': dataset_util.bytes_list_feature(poses),
    }))
    return example


def main(_):
    if FLAGS.set not in SETS:
        raise ValueError('set must be in : {}'.format(SETS))
    if FLAGS.year not in YEARS:
        raise ValueError('year must be in : {}'.format(YEARS))

    data_dir = FLAGS.data_dir
    years = ['fddb', 'widerface']
    if FLAGS.year != 'merged':
        years = [FLAGS.year]

    writer = tf.python_io.TFRecordWriter(FLAGS.output_path)

    label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)

    for year in years:
        logging.info('Reading from PASCAL %s dataset.', year)
        examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
                                     FLAGS.set + '.txt')
        annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir)
        examples_list = dataset_util.read_examples_list(examples_path)
        for idx, example in enumerate(examples_list):
            if idx % 100 == 0:
                logging.info('On image %d of %d', idx, len(examples_list))
            path = os.path.join(annotations_dir, example + '.xml')
            with tf.gfile.GFile(path, 'r') as fid:
                xml_str = fid.read()
            xml = etree.fromstring(xml_str)
            data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']

            tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
                                            FLAGS.ignore_difficult_instances)
            writer.write(tf_example.SerializeToString())

    writer.close()


if __name__ == '__main__':
    tf.app.run()

运行失败


  File "object_detection/dataset_tools/create_pascal_tf_record.py", line 179, in <module>
    tf.app.run()
  File "C:\Anaconda3\envs\tensorflow1\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run
    _sys.exit(main(argv))
  File "object_detection/dataset_tools/create_pascal_tf_record.py", line 172, in main
    FLAGS.ignore_difficult_instances)
  File "object_detection/dataset_tools/create_pascal_tf_record.py", line 114, in dict_to_tf_example
    classes.append(label_map_dict[obj['name']])
KeyError: 'face'
  • 写回答

1条回答 默认 最新

  • 淋风沐雨 2021-02-04 13:13
    关注

    import os,cv2,sys,shutil
     
    from xml.dom.minidom import Document
     
    def writexml(filename,saveimg,bboxes,xmlpath):
        doc = Document()
     
        annotation = doc.createElement('annotation')
     
        doc.appendChild(annotation)
     
        folder = doc.createElement('folder')
     
        folder_name = doc.createTextNode('widerface')
        folder.appendChild(folder_name)
        annotation.appendChild(folder)
        filenamenode = doc.createElement('filename')
        filename_name = doc.createTextNode(filename)
        filenamenode.appendChild(filename_name)
        annotation.appendChild(filenamenode)
        source = doc.createElement('source')
        annotation.appendChild(source)
        database = doc.createElement('database')
        database.appendChild(doc.createTextNode('wider face Database'))
        source.appendChild(database)
        annotation_s = doc.createElement('annotation')
        annotation_s.appendChild(doc.createTextNode('PASCAL VOC2007'))
        source.appendChild(annotation_s)
        image = doc.createElement('image')
        image.appendChild(doc.createTextNode('flickr'))
        source.appendChild(image)
        flickrid = doc.createElement('flickrid')
        flickrid.appendChild(doc.createTextNode('-1'))
        source.appendChild(flickrid)
        owner = doc.createElement('owner')
        annotation.appendChild(owner)
        flickrid_o = doc.createElement('flickrid')
        flickrid_o.appendChild(doc.createTextNode('yanyu'))
        owner.appendChild(flickrid_o)
        name_o = doc.createElement('name')
        name_o.appendChild(doc.createTextNode('yanyu'))
        owner.appendChild(name_o)
     
        size = doc.createElement('size')
        annotation.appendChild(size)
     
        width = doc.createElement('width')
        width.appendChild(doc.createTextNode(str(saveimg.shape[1])))
        height = doc.createElement('height')
        height.appendChild(doc.createTextNode(str(saveimg.shape[0])))
        depth = doc.createElement('depth')
        depth.appendChild(doc.createTextNode(str(saveimg.shape[2])))
     
        size.appendChild(width)
     
        size.appendChild(height)
        size.appendChild(depth)
        segmented = doc.createElement('segmented')
        segmented.appendChild(doc.createTextNode('0'))
        annotation.appendChild(segmented)
        for i in range(len(bboxes)):
            bbox = bboxes[i]
            objects = doc.createElement('object')
            annotation.appendChild(objects)
            object_name = doc.createElement('name')
            object_name.appendChild(doc.createTextNode('face'))
            objects.appendChild(object_name)
            pose = doc.createElement('pose')
            pose.appendChild(doc.createTextNode('Unspecified'))
            objects.appendChild(pose)
            truncated = doc.createElement('truncated')
            truncated.appendChild(doc.createTextNode('1'))
            objects.appendChild(truncated)
            difficult = doc.createElement('difficult')
            difficult.appendChild(doc.createTextNode('0'))
            objects.appendChild(difficult)
            bndbox = doc.createElement('bndbox')
            objects.appendChild(bndbox)
            xmin = doc.createElement('xmin')
            xmin.appendChild(doc.createTextNode(str(bbox[0])))
            bndbox.appendChild(xmin)
            ymin = doc.createElement('ymin')
            ymin.appendChild(doc.createTextNode(str(bbox[1])))
            bndbox.appendChild(ymin)
            xmax = doc.createElement('xmax')
            xmax.appendChild(doc.createTextNode(str(bbox[0] + bbox[2])))
            bndbox.appendChild(xmax)
            ymax = doc.createElement('ymax')
            ymax.appendChild(doc.createTextNode(str(bbox[1] + bbox[3])))
            bndbox.appendChild(ymax)
        f = open(xmlpath, "w")
        f.write(doc.toprettyxml(indent=''))
        f.close()
     
     
    rootdir = "***/wider_face"
     
     
    def convertimgset(img_set):
        imgdir = rootdir + "/WIDER_" + img_set + "/images"
        gtfilepath = rootdir + "/wider_face_split/wider_face_" + img_set + "_bbx_gt.txt"
     
        fwrite = open(rootdir + "/ImageSets/Main/" + img_set + ".txt", 'w')
     
        index = 0
     
        with open(gtfilepath, 'r') as gtfiles:
            while(True): #true
                filename = gtfiles.readline()[:-1]
                if filename == None or filename == "":
                    break
                imgpath = imgdir + "/" + filename
     
                img = cv2.imread(imgpath)
     
                if not img.data:
                    break;
     
     
                numbbox = int(gtfiles.readline())
     
                bboxes = []
     
                print(numbbox)
     
                for i in range(numbbox):
                    line = gtfiles.readline()
                    lines = line.split(" ")
                    lines = lines[0:4]
     
                    bbox = (int(lines[0]), int(lines[1]), int(lines[2]), int(lines[3]))
     
                    if int(lines[2]) < 40 or int(lines[3]) < 40:
                        continue
     
                    bboxes.append(bbox)
     
                    #cv2.rectangle(img, (bbox[0],bbox[1]),(bbox[0]+bbox[2],bbox[1]+bbox[3]),color=(255,255,0),thickness=1)
     
                filename = filename.replace("/", "_")
     
                if len(bboxes) == 0:
                    print("no face")
                    continue
                #cv2.imshow("img", img)
                #cv2.waitKey(0)
     
                cv2.imwrite("{}/JPEGImages/{}".format(rootdir,filename), img)
     
                fwrite.write(filename.split(".")[0] + "\n")
     
                xmlpath = "{}/Annotations/{}.xml".format(rootdir,filename.split(".")[0])
     
                writexml(filename, img, bboxes, xmlpath)
     
                print("success number is ", index)
                index += 1
     
        fwrite.close()
     
    if __name__=="__main__":
        img_sets = ["train","val"]
        for img_set in img_sets:
            convertimgset(img_set)
     
        shutil.move(rootdir + "/ImageSets/Main/" + "train.txt", rootdir + "/ImageSets/Main/" + "trainval.txt")
        shutil.move(rootdir + "/ImageSets/Main/" + "val.txt", rootdir + "/ImageSets/Main/" + "test.txt")

    评论

报告相同问题?

悬赏问题

  • ¥15 github符合条件20分钟秒到账,github空投 提供github账号可兑换💰感兴趣的可以找我交流一下
  • ¥50 永磁型步进电机PID算法
  • ¥15 sqlite 附加(attach database)加密数据库时,返回26是什么原因呢?
  • ¥88 找成都本地经验丰富懂小程序开发的技术大咖
  • ¥15 如何处理复杂数据表格的除法运算
  • ¥15 如何用stc8h1k08的片子做485数据透传的功能?(关键词-串口)
  • ¥15 有兄弟姐妹会用word插图功能制作类似citespace的图片吗?
  • ¥200 uniapp长期运行卡死问题解决
  • ¥15 latex怎么处理论文引理引用参考文献
  • ¥15 请教:如何用postman调用本地虚拟机区块链接上的合约?