我想请问一下,deformable detr怎么直接使用GitHub上训练好的保存的那个权重的文件来进行物体检测,谢谢,非常想得到你的回复
1条回答 默认 最新
- Alaso_soso 2022-06-26 14:14关注
编写一个detect.py文件,使用预训练模型。
https://www.jianshu.com/p/b364534fd0a7
上面时原作者的内容,可以进行参考,感觉很不错,代码可能需要改一点点,不多,很简单,希望可以帮到你```python import cv2 from PIL import Image import numpy as np import os import time import torch from torch import nn # from torchvision.models import resnet50 import torchvision.transforms as T from main import get_args_parser as get_main_args_parser from models import build_model torch.set_grad_enabled(False) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("[INFO] 当前使用{}做推断".format(device)) # 图像数据处理 transform = T.Compose([ T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 将xywh转xyxy def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=1) # 将0-1映射到图像 def rescale_bboxes(out_bbox, size): img_w, img_h = size b = box_cxcywh_to_xyxy(out_bbox) b = b.cpu().numpy() b = b * np.array([img_w, img_h, img_w, img_h], dtype=np.float32) return b # plot box by opencv def plot_result(pil_img, prob, boxes, save_name=None, imshow=False, imwrite=False): LABEL = ['all','hat', 'person', 'groundrod', 'vest', 'workclothes_clothes', 'workclothes_trousers', 'winter_clothes', 'winter_trousers', 'noworkclothes_clothes', 'noworkclothes_trousers', 'height', 'safteybelt', 'smoking', 'noheight', 'fire', 'extinguisher', 'roll_workclothes', 'roll_noworkclothes', 'insulating_gloves', 'car', 'fence', 'bottle', 'shorts', 'holes', 'single_ladder', 'down', 'double_ladder', 'oxygen_horizontally', 'oxygen_vertically', 'acetylene_vertically', 'acetylene_horizontally'] len(prob) opencvImage = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) if len(prob) == 0: print("[INFO] NO box detect !!! ") if imwrite: if not os.path.exists("./result/pred_no"): os.makedirs("./result/pred_no") cv2.imwrite(os.path.join("./result/pred_no", save_name), opencvImage) return for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes): cl = p.argmax() label_text = '{}: {}%'.format(LABEL[cl], round(p[cl] * 100, 2)) cv2.rectangle(opencvImage, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 255, 0), 2) cv2.putText(opencvImage, label_text, (int(xmin) + 10, int(ymin) + 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2) if imshow: cv2.imshow('detect', opencvImage) cv2.waitKey(0) if imwrite: if not os.path.exists("./result/pred"): os.makedirs('./result/pred') cv2.imwrite('./result/pred/{}'.format(save_name), opencvImage) def load_model(model_path , args): model, _, _ = build_model(args) model.cuda() model.eval() state_dict = torch.load(model_path) # <-----------修改加载模型的路径 model.load_state_dict(state_dict["model"]) model.to(device) print("load model sucess") return model # 单张图像的推断 def detect(im, model, transform, prob_threshold=0.7): # mean-std normalize the input image (batch-size: 1) img = transform(im).unsqueeze(0) # demo model only support by default images with aspect ratio between 0.5 and 2 # if you want to use images with an aspect ratio outside this range # rescale your image so that the maximum size is at most 1333 for best results #assert img.shape[-2] <= 1600 and img.shape[ # -1] <= 1600, 'demo model only supports images up to 1600 pixels on each side' # propagate through the model img = img.to(device) start = time.time() outputs = model(img) #end = time.time() # keep only predictions with 0.7+ confidence # print(outputs['pred_logits'].softmax(-1)[0, :, :-1]) probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] keep = probas.max(-1).values > prob_threshold #end = time.time() probas = probas.cpu().detach().numpy() keep = keep.cpu().detach().numpy() # convert boxes from [0; 1] to image scales bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size) end = time.time() return probas[keep], bboxes_scaled, end - start if __name__ == "__main__": main_args = get_main_args_parser().parse_args() #加载模型 dfdetr = load_model('exps/r50_deformable_detr/checkpoint0049.pth',main_args) files = os.listdir("coco/testdata/test2017") cn = 0 waste=0 for file in files: img_path = os.path.join("coco/testdata/test2017", file) im = Image.open(img_path) scores, boxes, waste_time = detect(im, dfdetr, transform) plot_result(im, scores, boxes, save_name=file, imshow=False, imwrite=True) print("{} [INFO] {} time: {} done!!!".format(cn,file, waste_time)) cn+=1 waste+=waste_time waste_avg = waste/cn print(waste_avg)
```
本回答被题主选为最佳回答 , 对您是否有帮助呢?解决 无用评论 打赏 举报
悬赏问题
- ¥30 关于#java#的问题,请各位专家解答!
- ¥30 vue+element根据数据循环生成多个table,如何实现最后一列 平均分 合并
- ¥20 pcf8563时钟芯片不启振
- ¥20 pip2.40更新pip2.43时报错
- ¥15 换yum源但仍然用不了httpd
- ¥50 C# 使用DEVMOD设置打印机首选项
- ¥15 麒麟V10 arm安装gdal
- ¥20 OPENVPN连接问题
- ¥15 flask实现搜索框访问数据库
- ¥15 mrk3399刷完安卓11后投屏调试只能显示一个设备