pumpkindd 2024-05-06 14:34 采纳率: 83.3%
浏览 15

运行deepsort的track遇到的问题

运行yolov5+deepsort的时候 得到deepsort的运动轨迹是这样的

img


但是使用yolov5的识别结果是

img


是哪里不匹配?问题可能出现在哪里呢
附上deepsort代码

import sys
sys.path.insert(0, './yolov5')

from yolov5.utils.datasets import LoadImages, LoadStreams
from yolov5.utils.general import check_img_size, non_max_suppression, scale_coords
from yolov5.utils.torch_utils import select_device, time_synchronized
from deep_sort_pytorch.utils.parser import get_config
from deep_sort_pytorch.deep_sort import DeepSort
import argparse
import os
import platform
import shutil
import time
from pathlib import Path
import cv2
import torch
import torch.backends.cudnn as cudnn
import numpy as np


palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)


def bbox_rel(*xyxy):
    """" Calculates the relative bounding box from absolute pixel values. """
    bbox_left = min([xyxy[0].item(), xyxy[2].item()])
    bbox_top = min([xyxy[1].item(), xyxy[3].item()])
    bbox_w = abs(xyxy[0].item() - xyxy[2].item())
    bbox_h = abs(xyxy[1].item() - xyxy[3].item())
    x_c = (bbox_left + bbox_w / 2)
    y_c = (bbox_top + bbox_h / 2)
    w = bbox_w
    h = bbox_h
    return x_c, y_c, w, h


def compute_color_for_labels(label):
    """
    Simple function that adds fixed color depending on the class
    """
    color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
    return tuple(color)


def draw_boxes(img, bbox, identities=None, offset=(0, 0)):
    for i, box in enumerate(bbox):
        x1, y1, x2, y2 = [int(i) for i in box]
        x1 += offset[0]
        x2 += offset[0]
        y1 += offset[1]
        y2 += offset[1]
        # box text and bar
        id = int(identities[i]) if identities is not None else 0
        color = compute_color_for_labels(id)
        label = '{}{:d}'.format("", id)
        t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0]
        cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)
        cv2.rectangle(
            img, (x1, y1), (x1 + t_size[0] + 3, y1 + t_size[1] + 4), color, -1)
        cv2.putText(img, label, (x1, y1 +
                                 t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 2, [255, 255, 255], 2)
    return img


# xyxy2tlwh函数  这个函数一般都会自带
def xyxy2tlwh(x):
    '''
    (top left x, top left y,width, height)
    '''
    y = torch.zeros_like(x) if isinstance(x,
                                          torch.Tensor) else np.zeros_like(x)
    y[:, 0] = x[:, 0]
    y[:, 1] = x[:, 1]
    y[:, 2] = x[:, 2] - x[:, 0]
    y[:, 3] = x[:, 3] - x[:, 1]
    return y

def detect(opt, save_img=False):
    out, source, weights, view_img, save_txt, imgsz = \
        opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
    webcam = source == '0' or source.startswith(
        'rtsp') or source.startswith('http') or source.endswith('.txt')

    # initialize deepsort
    cfg = get_config()
    cfg.merge_from_file(opt.config_deepsort)
    deepsort = DeepSort('D:\yolo\yolov5\yolov5-master\Yolov5_DeepSort_Pytorch\deep_sort_pytorch\deep_sort\deep\checkpoint/ckpt5.t7',
                        max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
                        nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
                        max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET,
                        use_cuda=True)

    # Initialize
    device = select_device(opt.device)
    if os.path.exists(out):
        shutil.rmtree(out)  # delete output folder
    os.makedirs(out)  # make new output folder
    half = device.type != 'cpu'  # half precision only supported on CUDA

    # Load model
    model = torch.load(weights, map_location={'0':'CPU'})[
        'model'].float()  # load to FP32
    model.to(device).eval()
    if half:
        model.half()  # to FP16

    # Set Dataloader
    vid_path, vid_writer = None, None
    if webcam:
        view_img = True
        cudnn.benchmark = True  # set True to speed up constant image size inference
        dataset = LoadStreams(source, img_size=imgsz)
    else:
        view_img = True
        save_img = True
        dataset = LoadImages(source, img_size=imgsz)

    # Get names and colors
    names = model.module.names if hasattr(model, 'module') else model.names

    # Run inference
    t0 = time.time()
    img = torch.zeros((1, 3, imgsz, imgsz), device=device)  # init img
    # run once
    _ = model(img.half() if half else img) if device.type != 'cpu' else None

    save_path = str(Path(out))
    txt_path = str(Path(out)) + '/results.txt'
    dict_box=dict()

    for frame_idx, (path, img, im0s, vid_cap) in enumerate(dataset):
        img = torch.from_numpy(img).to(device)
        img = img.half() if half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        t1 = time_synchronized()
        pred = model(img, augment=opt.augment)[0]

        # Apply NMS
        pred = non_max_suppression(
            pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
        t2 = time_synchronized()

        # Process detections
        for i, det in enumerate(pred):  # detections per image
            if webcam:  # batch_size >= 1
                p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
            else:
                p, s, im0 = path, '', im0s

            s += '%gx%g ' % img.shape[2:]  # print string
            save_path = str(Path(out) / Path(p).name)

            if det is not None and len(det):
                # Rescale boxes from img_size to im0 size
                #det[:, :4] = scale_coords(
                    #img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += '%g %ss, ' % (n, names[int(c)])  # add to string

                bbox_xywh = []
                confs = []

                # Adapt detections to deep sort input format
                for *xyxy, conf, cls in det:
                    x_c, y_c, bbox_w, bbox_h = bbox_rel(*xyxy)
                    obj = [x_c, y_c, bbox_w, bbox_h]
                    bbox_xywh.append(obj)
                    confs.append([conf.item()])

                xywhs = torch.Tensor(bbox_xywh)
                confss = torch.Tensor(confs)

                # Pass detections to deepsort
                outputs = deepsort.update(xywhs, confss, im0)
                # outputs = [x1, y1, x2, y2, track_id]
                if len(outputs) > 0:
                    bbox_xyxy = outputs[:, :4]  # 提取前四列  坐标
                    identities = outputs[:, -1]  # 提取最后一列 ID
                    box_xywh = xyxy2tlwh(bbox_xyxy)
                    # xyxy2tlwh是坐标格式转换,从x1, y1, x2, y2转为top left x ,top left y, w, h 具体函数看文章最后
                    for j in range(len(box_xywh)):
                        x_center = box_xywh[j][0] + box_xywh[j][2] / 2  # 求框的中心x坐标
                        y_center = box_xywh[j][1] + box_xywh[j][3] / 2  # 求框的中心y坐标
                        id = outputs[j][-1]
                        center = [x_center, y_center]
                        dict_box.setdefault(id, []).append(center)  # 这个字典需要提前定义 dict_box = dict()
                    # 以下为画轨迹,原理就是将前后帧同ID的跟踪框中心坐标连接起来
                    if frame_idx > 2:
                        for key, value in dict_box.items():
                            for a in range(len(value) - 1):
                                # color = COLORS_10[key % len(COLORS_10)]
                                index_start = a
                                index_end = index_start + 1
                                cv2.line(im0, tuple(map(int, value[index_start])), tuple(map(int, value[index_end])),
                                         # map(int,"1234")转换为list[1,2,3,4]
                                        (255, 0, 0), thickness=2, lineType=8)

                # draw boxes for visualization
                if len(outputs) > 0:
                    bbox_xyxy = outputs[:, :4]
                    identities = outputs[:, -1]
                    draw_boxes(im0, bbox_xyxy, identities)

                # Write MOT compliant results to file
                if save_txt and len(outputs) != 0:
                    for j, output in enumerate(outputs):
                        bbox_left = output[0]
                        bbox_top = output[1]
                        bbox_w = output[2]
                        bbox_h = output[3]
                        identity = output[-1]
                        with open(txt_path, 'a') as f:
                            f.write(('%g ' * 10 + '\n') % (frame_idx, identity, bbox_left,
                                                           bbox_top, bbox_w, bbox_h, -1, -1, -1, -1))  # label format

            else:
                deepsort.increment_ages()

            # Print time (inference + NMS)
            print('%sDone. (%.3fs)' % (s, t2 - t1))

            # Stream results
            if view_img:
                cv2.imshow(p, im0)
                if cv2.waitKey(1) == ord('q'):  # q to quit
                    raise StopIteration

            # Save results (image with detections)
            if save_img:
                print('saving img!')
                if dataset.mode == 'images':
                    cv2.imwrite(save_path, im0)
                else:
                    print('saving video!')
                    if vid_path != save_path:  # new video
                        vid_path = save_path
                        if isinstance(vid_writer, cv2.VideoWriter):
                            vid_writer.release()  # release previous video writer

                        fps = vid_cap.get(cv2.CAP_PROP_FPS)
                        w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                        h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        vid_writer = cv2.VideoWriter(
                            save_path, cv2.VideoWriter_fourcc(*opt.fourcc), fps, (w, h))
                    vid_writer.write(im0)

    if save_txt or save_img:
        print('Results saved to %s' % os.getcwd() + os.sep + out)
        if platform == 'darwin':  # MacOS
            os.system('open ' + save_path)

    print('Done. (%.3fs)' % (time.time() - t0))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str,
                        default='D:\yolo\yolov5\yolov5-master/runs/train\exp15\weights/best.pt', help='model.pt path')
    # file/folder, 0 for webcam
    parser.add_argument('--source', type=str,
                        default='D:\yolo\yolov5\pic\mouse white.mov', help='source')
    parser.add_argument('--output', type=str, default='inference/output',
                        help='output folder')  # output folder
    parser.add_argument('--img-size', type=int, default=640,
                        help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float,
                        default=0.4, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float,
                        default=0.5, help='IOU threshold for NMS')
    parser.add_argument('--fourcc', type=str, default='mp4v',
                        help='output video codec (verify ffmpeg support)')
    parser.add_argument('--device', default='',
                        help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--view-img', action='store_true',
                        help='display results')
    parser.add_argument('--save-txt', default="true", action='store_true',
                        help='save results to *.txt')
    # class 0 is person
    parser.add_argument('--classes', nargs='+', type=int,
                        default=[0], help='filter by class')
    parser.add_argument('--agnostic-nms', action='store_true',
                        help='class-agnostic NMS')
    parser.add_argument('--augment', action='store_true',
                        help='augmented inference')
    parser.add_argument("--config_deepsort", type=str,
                        default="deep_sort_pytorch/configs/deep_sort.yaml")
    args = parser.parse_args()
    args.img_size = check_img_size(args.img_size)
    print(args)

    with torch.no_grad():
        detect(args)


  • 写回答

5条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-05-06 14:34
    关注

    引自免费微信小程序:皆我百晓生

    在您提供的代码中,detect 函数负责处理图像并应用DeepSort进行目标追踪。看起来您遇到了一些问题,导致跟踪结果不理想。首先,我们需要确定问题出在哪里。

    根据您提供的描述,运动轨迹可能存在问题。让我们逐段检查代码,找出可能导致问题的部分,并提供相应的修改建议。

    1. compute_color_for_labelsdraw_boxes 函数:这些函数与运动轨迹无关,因此它们不是问题所在。
    2. xyxy2tlwh 函数:这个函数用于将边界框从xyxy格式(左上角、右下角)转换为tlwh格式(左上角、宽度、高度),看起来没有问题。
    3. detect 函数:
      • 在计算跟踪结果时,您使用了dict_box字典来存储每个ID对应的中心点。这一步应该没问题。
      • 然后,您在循环内通过cv2.line绘制轨迹线。这里可能出现问题是,在每次迭代中都重新开始绘制线,而不是连续绘制。为了修复这个问题,您需要在每次循环外部初始化一个空列表,然后在每次迭代中添加新的中心点,最后在循环结束后一次性绘制所有线。

    下面是针对以上分析对代码做出的修改:

    def detect(opt, save_img=False):
        ...
        dict_box = {}  # 初始化字典
        points_to_connect = []  # 新增用于保存需要连接的点的列表
        ...
    
        for frame_idx, ...:
            ...
            if len(outputs) > 0:
                ...
                for j in range(len(box_xywh)):
                    ...
                    center = [x_center, y_center]
                    dict_box.setdefault(id, []).append(center)
                    points_to_connect.extend([(center, id)])  # 将每个中心点和其对应ID作为元组添加到列表中
    
            if frame_idx > 2:  # 避免在前三帧绘制线,因为此时还没有足够的点
                for point, id in points_to_connect:  # 连接所有的点
                    for a in range(len(dict_box[id]) - 1):
                        start = dict_box[id][a]
                        end = dict_box[id][a + 1]
                        cv2.line(im0, tuple(map(int, start)), tuple(map(int, end)),
                                 (255, 0, 0), thickness=2, lineType=8)
                points_to_connect.clear()  # 清空列表,准备下一轮迭代
    
            ...
    

    这样应该可以解决运动轨迹的问题。注意,上述代码是在原始代码的基础上进行了修改。将这些更改应用于您的源代码,并再次运行程序,看看是否解决了轨迹显示的问题。如果有其他问题,或者您需要进一步的帮助,请随时告诉我。

    评论 编辑记录

报告相同问题?

问题事件

  • 创建了问题 5月6日

悬赏问题

  • ¥15 treelib库有读取树数据方法吗?
  • ¥15 咨询一个PYTHON的问题
  • ¥15 机器学习建模调参,roc评价指标
  • ¥15 RCS plot 包内置数据集使用时报错,如何解决?
  • ¥15 keil+mspm0g3507+二维总线舵机
  • ¥15 如何用wireshark分析找出url接口和param参数
  • ¥15 有谁知道这是阿里云那个应用的域名吗,怎么调用?
  • ¥30 正则表达式的一些问题
  • ¥15 C#如何使用不需要安装 Microsoft Excel 的机器上的方法或者库实现:将指定Excel区域导出为图片(例如A1:AO50)
  • ¥15 虚拟机只能接收不能发送