HAooZ 2023-04-06 17:27 采纳率: 50%
浏览 12

fasterrcnn目标预测时怎么过滤某个类别

faster rcnn目标识别,数据集是voc2012,有21个类别,怎么在预测代码里移除某一类目标。比如预测时忽略(不预测)aeroplan这一类别,或者不显示这类的边框。

import os
import time
import json

import torch
import torchvision
from PIL import Image
import matplotlib.pyplot as plt

from torchvision import transforms
from network_files import FasterRCNN, FastRCNNPredictor, AnchorsGenerator
from backbone import resnet50_fpn_backbone, MobileNetV2
from draw_box_utils import draw_objs


def create_model(num_classes):

    # resNet50+fpn+faster_RCNN
    # 注意,这里的norm_layer要和训练脚本中保持一致
    backbone = resnet50_fpn_backbone(norm_layer=torch.nn.BatchNorm2d)
    model = FasterRCNN(backbone=backbone, num_classes=num_classes, rpn_score_thresh=0.5)

    return model


def time_synchronized():
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    return time.time()


def main():
    # get devices
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    # create model
    model = create_model(num_classes=21)


    # load train weights
    weights_path = "./fasterrcnn_voc2012.pth"
    assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
    weights_dict = torch.load(weights_path, map_location='cpu')
    weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
    model.load_state_dict(weights_dict)
    model.to(device)

    # read class_indict
    label_json_path = './pascal_voc_classes.json'
    assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
    with open(label_json_path, 'r') as f:
        class_dict = json.load(f)

    category_index = {str(v): str(k) for k, v in class_dict.items()}

    # load image
    original_img = Image.open("./correction_image./2007_004459_fisheye_12.jpg")

    # from pil image to tensor, do not normalize image
    data_transform = transforms.Compose([transforms.ToTensor()])
    img = data_transform(original_img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    model.eval()  # 进入验证模式
    with torch.no_grad():
        # init
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        t_start = time_synchronized()
        predictions = model(img.to(device))[0]
        t_end = time_synchronized()
        print("inference+NMS time: {}".format(t_end - t_start))

        predict_boxes = predictions["boxes"].to("cpu").numpy()
        predict_classes = predictions["labels"].to("cpu").numpy()
        predict_scores = predictions["scores"].to("cpu").numpy()

        if len(predict_boxes) == 0:
            print("没有检测到任何目标!")


        plot_img = draw_objs(original_img,
                             predict_boxes,
                             predict_classes,
                             predict_scores,
                             category_index=category_index,
                             box_thresh=0.5,
                             line_thickness=3,
                             font='arial.ttf',
                             font_size=20)
        plt.imshow(plot_img)
        plt.show()
        # 保存预测的图片结果
        plot_img.save("test_result.jpg")

if __name__ == '__main__':
    main()


img

  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2023-04-07 03:16
    关注
    不知道你这个问题是否已经解决, 如果还没有解决的话:
    • 这篇博客: 使用Faster RCNN训练自己的数据集中的 2.预训练模型编译 部分也许能够解决你的问题, 你可以仔细阅读以下内容或者直接跳转源博客中阅读:
      • 新建文件夹

      (注:本文将原文件夹重命名为faster-rcnn)在文件夹中新建data文件夹

      cd faster-rcnn && mkdir data

      data文件夹中新建pretrained_model文件夹

      mkdir pretrained_model
      • 下载预训练模型VGG16与ResNet-101

      预训练模型VGG16:

      VGG16

      预训练模型ResNet-101:ResNet-101

      将下载好的预训练模型放到pretrained_model文件夹中

      • 执行编译
      cd lib
      python setup.py build develop
      cd ..

      编译完成,如图所示

      如果执行编译后,训练自己的数据集仍然报错:

      ImportError: cannot import name '_mask'

      则是缺少COCO API,需要执行以下指令

      cd data
      
      git clone https://github.com/pdollar/coco.git 
      
      cd coco/PythonAPI
      
      make
      
      cd ../../..

      如图所示

      可以看到'_mask.o'已经编译成功

      • Scipy降版本

      使用pip查看已经安装的Python库

      pip list

      可以看到其中Scipy与Pillow版本分别问scipy==1.5.4与Pillow==8.2.0,由于Scipy版本自身的变动原因,需要对Scipy进行降版本,否则在训练中会报错

      ImportError: cannot import name 'imread' 

      首先卸载以上两个版本

      pip uninstall scipy
      
      pip uninstall pillow

      然后安装指定版本即可

      pip install scipy == 1.2.1
      pip install pillow == 6.1.0

    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
    评论

报告相同问题?

问题事件

  • 创建了问题 4月6日

悬赏问题

  • ¥20 用keil,写代码解决两个问题,用库函数
  • ¥15 ChatGPT网络被篡改怎么办?
  • ¥50 ID中开关量采样信号通道、以及程序流程的设计
  • ¥15 U-Mamba/nnunetv2固定随机数种子
  • ¥15 vba使用jmail发送邮件正文里面怎么加图片
  • ¥15 vb6.0如何向数据库中添加自动生成的字段数据。
  • ¥20 在easyX库下编写C语言扑克游戏跑的快,能实现简单的人机对战
  • ¥15 svpwm波形异常求解答
  • ¥15 STM32——硬件IIC从机通信代码实现
  • ¥15 微生物组数据分析--微生物代谢物