Apupil 2019-11-07 12:37 采纳率: 0%
浏览 1214

Faster-RCNN-TensorFlow-Python3-master训练后,如何得到AP,mAP的结果

查了很多资料,tf-faster-rcnn和caffe-faster-rcnn里都是用test__net.py
来评估训练结果。但是我用的是Faster-RCNN-TensorFlow-Python3-master,里面没有test_net.py。那要怎么获得AP和mAP的结果呢?

  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2022-10-25 19:25
    关注
    不知道你这个问题是否已经解决, 如果还没有解决的话:
    • 给你找了一篇非常好的博客,你可以看看是否有帮助,链接:tensorflow版本faster rcnn的demo.py代码详解
    • 除此之外, 这篇博客: 对Faster-RCNN-TensorFlow-Python3.5-master训练模型的评价mAP中的 一、新建test_net.py文件 部分也许能够解决你的问题, 你可以仔细阅读以下内容或者直接跳转源博客中阅读:

      放Faster-RCNN-TensorFlow-Python3.5-master 根文件夹。

      #!/usr/bin/env python
      
      # --------------------------------------------------------
      # Tensorflow Faster R-CNN
      # Licensed under The MIT License [see LICENSE for details]
      # Written by Xinlei Chen, based on code from Ross Girshick
      # --------------------------------------------------------
      
      """
      Demo script showing detections in sample images.
      See README.md for installation instructions before running.
      """
      from __future__ import absolute_import
      from __future__ import division
      from __future__ import print_function
      
      import argparse
      import os
      
      import tensorflow as tf
      from lib.nets.vgg16 import vgg16
      from lib.datasets.factory import get_imdb
      from lib.utils.test import test_net
      
      # NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',), 'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
      NETS = {'vgg16': ('vgg16_faster_rcnn_iter_40000.ckpt',)}   #训练输出模型
      DATASETS = {'pascal_voc': ('voc_2007_trainval',), 'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
      
      
      
      def parse_args():
          """Parse input arguments."""
          parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN test')
          parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
                              choices=NETS.keys(), default='vgg16')
          parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
                              choices=DATASETS.keys(), default='pascal_voc')
          args = parser.parse_args()
      
          return args
      
      
      if __name__ == '__main__':
          args = parse_args()
      
          # model path
          demonet = args.demo_net
          dataset = args.dataset
          tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default', NETS[demonet][0])  #模型路径
          # 获得模型文件名称
          filename = (os.path.splitext(tfmodel)[0]).split('\\')[-1]
          filename = 'default' + '/' + filename
          imdb = get_imdb("voc_2007_test")  # 得到
          imdb.competition_mode('competition mode')
          if not os.path.isfile(tfmodel + '.meta'):
              print(tfmodel)
              raise IOError(('{:s} not found.\nDid you download the proper networks from '
                             'our server and place them properly?').format(tfmodel + '.meta'))
      
          # set config
          tfconfig = tf.ConfigProto(allow_soft_placement=True)
          tfconfig.gpu_options.allow_growth = True
      
          # init session
          sess = tf.Session(config=tfconfig)
          # load network
          if demonet == 'vgg16':
              net = vgg16(batch_size=1)
          # elif demonet == 'res101':
          # net = resnetv1(batch_size=1, num_layers=101)
          else:
              raise NotImplementedError
          net.create_architecture(sess, "TEST", 9,  #  记得修改第3个参数为:类别数量+1
                                  tag='default', anchor_scales=[8, 16, 32])
          saver = tf.train.Saver()
          saver.restore(sess, tfmodel)
      
          print('Loaded network {:s}'.format(tfmodel))
          test_net(sess, net, imdb, filename, max_per_image=100)
          sess.close()
      

      需要自行修改的地方:

      1.改为自己训练输出对ckpt文件名

      NETS = {'vgg16': ('vgg16_faster_rcnn_iter_40000.ckpt',)}   #训练输出模型

      2.修改第3个参数为:类别数量+1

       net.create_architecture(sess, "TEST", 9,  #  记得修改第3个参数为:类别数量+1
                                  tag='default', anchor_scales=[8, 16, 32])

    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
    评论

报告相同问题?

悬赏问题

  • ¥20 fluent无法启动
  • ¥15 孟德尔随机化r语言运行问题
  • ¥15 pyinstaller编译的时候出现No module named 'imp'
  • ¥15 nirs_kit中打码怎么看(打码文件是csv格式)
  • ¥15 怎么把多于硬盘空间放到根目录下
  • ¥15 Matlab问题解答有两个问题
  • ¥15 LCD12864中文显示
  • ¥15 在使用CH341SER.EXE时不小心把所有驱动文件删除了怎么解决
  • ¥15 gsoap生成onvif框架
  • ¥15 有关sql server business intellige安装,包括SSDT、SSMS。