bllllll1010 2022-06-25 21:01 采纳率: 100%
浏览 64
已结题

deformable detr怎么直接使用GitHub上训练好的保存的那个权重的文件来进行物体检测,非常想得到你的回复

我想请问一下,deformable detr怎么直接使用GitHub上训练好的保存的那个权重的文件来进行物体检测,谢谢,非常想得到你的回复

  • 写回答

1条回答 默认 最新

  • Alaso_soso 2022-06-26 14:14
    关注

    编写一个detect.py文件,使用预训练模型。
    https://www.jianshu.com/p/b364534fd0a7
    上面时原作者的内容,可以进行参考,感觉很不错,代码可能需要改一点点,不多,很简单,希望可以帮到你

    
    ```python
    import cv2
    from PIL import Image
    import numpy as np
    import os
    import time
    
    import torch
    from torch import nn
    # from torchvision.models import resnet50
    import torchvision.transforms as T
    from main import get_args_parser as get_main_args_parser
    from models import build_model
    
    torch.set_grad_enabled(False)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("[INFO] 当前使用{}做推断".format(device))
    
    # 图像数据处理
    transform = T.Compose([
        T.Resize(800),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    
    # 将xywh转xyxy
    def box_cxcywh_to_xyxy(x):
        x_c, y_c, w, h = x.unbind(1)
        b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
             (x_c + 0.5 * w), (y_c + 0.5 * h)]
        return torch.stack(b, dim=1)
    
    
    # 将0-1映射到图像
    def rescale_bboxes(out_bbox, size):
        img_w, img_h = size
        b = box_cxcywh_to_xyxy(out_bbox)
        b = b.cpu().numpy()
        b = b * np.array([img_w, img_h, img_w, img_h], dtype=np.float32)
        return b
    
    
    # plot box by opencv
    def plot_result(pil_img, prob, boxes, save_name=None, imshow=False, imwrite=False):
        LABEL = ['all','hat', 'person', 'groundrod', 'vest', 'workclothes_clothes', 'workclothes_trousers', 'winter_clothes',
                 'winter_trousers', 'noworkclothes_clothes', 'noworkclothes_trousers', 'height', 'safteybelt', 'smoking',
                 'noheight', 'fire', 'extinguisher', 'roll_workclothes', 'roll_noworkclothes', 'insulating_gloves', 'car',
                 'fence', 'bottle', 'shorts', 'holes', 'single_ladder', 'down', 'double_ladder', 'oxygen_horizontally',
                 'oxygen_vertically', 'acetylene_vertically', 'acetylene_horizontally']
    
        len(prob)
        opencvImage = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
    
    
        if len(prob) == 0:
            print("[INFO] NO box detect !!! ")
            if imwrite:
                if not os.path.exists("./result/pred_no"):
                    os.makedirs("./result/pred_no")
                cv2.imwrite(os.path.join("./result/pred_no", save_name), opencvImage)
            return
    
        for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes):
            cl = p.argmax()
            label_text = '{}: {}%'.format(LABEL[cl], round(p[cl] * 100, 2))
    
            cv2.rectangle(opencvImage, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 255, 0), 2)
            cv2.putText(opencvImage, label_text, (int(xmin) + 10, int(ymin) + 30), cv2.FONT_HERSHEY_SIMPLEX, 1,
                        (255, 255, 0), 2)
    
        if imshow:
            cv2.imshow('detect', opencvImage)
            cv2.waitKey(0)
    
        if imwrite:
            if not os.path.exists("./result/pred"):
                os.makedirs('./result/pred')
            cv2.imwrite('./result/pred/{}'.format(save_name), opencvImage)
    
    def load_model(model_path , args):
    
        model, _, _ = build_model(args)
        model.cuda()
        model.eval()
        state_dict = torch.load(model_path)  # <-----------修改加载模型的路径
        model.load_state_dict(state_dict["model"])
        model.to(device)
        print("load model sucess")
        return model
    
    # 单张图像的推断
    def detect(im, model, transform, prob_threshold=0.7):
        # mean-std normalize the input image (batch-size: 1)
        img = transform(im).unsqueeze(0)
    
        # demo model only support by default images with aspect ratio between 0.5 and 2
        # if you want to use images with an aspect ratio outside this range
        # rescale your image so that the maximum size is at most 1333 for best results
        
        #assert img.shape[-2] <= 1600 and img.shape[
        #                                     -1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'
    
        # propagate through the model
        img = img.to(device)
        start = time.time()
        outputs = model(img)
        #end = time.time()
        # keep only predictions with 0.7+ confidence
        # print(outputs['pred_logits'].softmax(-1)[0, :, :-1])
        probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
        keep = probas.max(-1).values > prob_threshold
        #end = time.time()
    
        probas = probas.cpu().detach().numpy()
        keep = keep.cpu().detach().numpy()
    
        # convert boxes from [0; 1] to image scales
        bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
        end = time.time()
        return probas[keep], bboxes_scaled, end - start
    
    
    if __name__ == "__main__":
        
        main_args = get_main_args_parser().parse_args()
    #加载模型
        dfdetr = load_model('exps/r50_deformable_detr/checkpoint0049.pth',main_args)
    
        files = os.listdir("coco/testdata/test2017")
    
        cn = 0
        waste=0
        for file in files:
            img_path = os.path.join("coco/testdata/test2017", file)
            im = Image.open(img_path)
    
            scores, boxes, waste_time = detect(im, dfdetr, transform)
            plot_result(im, scores, boxes, save_name=file, imshow=False, imwrite=True)
            print("{} [INFO] {} time: {} done!!!".format(cn,file, waste_time))
    
            cn+=1
            waste+=waste_time
        waste_avg = waste/cn
        print(waste_avg)
    
    
    
    

    ```

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 6月26日
  • 已采纳回答 6月26日
  • 创建了问题 6月25日

悬赏问题

  • ¥15 用土力学知识进行土坡稳定性分析与挡土墙设计
  • ¥70 PlayWright在Java上连接CDP关联本地Chrome启动失败,貌似是Windows端口转发问题
  • ¥15 帮我写一个c++工程
  • ¥30 Eclipse官网打不开,官网首页进不去,显示无法访问此页面,求解决方法
  • ¥15 关于smbclient 库的使用
  • ¥15 微信小程序协议怎么写
  • ¥15 c语言怎么用printf(“\b \b”)与getch()实现黑框里写入与删除?
  • ¥20 怎么用dlib库的算法识别小麦病虫害
  • ¥15 华为ensp模拟器中S5700交换机在配置过程中老是反复重启
  • ¥15 uniapp uview http 如何实现统一的请求异常信息提示?