m0_75161151 2024-04-18 12:46 采纳率: 33.3%
浏览 12

关于目标跟踪的ID分配问题

先生们,我写了一个目标跟踪的算法,基于hsv颜色空间的,在进行多目标跟踪时,我给物体分配了ID,但总是出现ID错乱的情况,ID不能保证固定跟踪原有的物体,会跑到另外的被跟踪物体上,先生们能帮忙解决一下吗?下面是我的跟踪和分配id的代码:

class VideoThread(threading.Thread):
    def __init__(self, video, fps, red_speeds, time_stamps):
        super().__init__()
        self.video = video
        self.fps = fps
        self.red_speeds = red_speeds
        self.time_stamps = time_stamps
        self.is_running = True
        self.start_time = time.time()
        self.object_tracks = {}
        self.object_id_map = {}
        self.kalman_filters = {}  # 存储卡尔曼滤波器

    def track_object(self, rect, frame):
        x, y, w, h = rect
        center = np.array([[x + w // 2], [y + h // 2]], dtype=np.float32)
        object_position = str(center)

        if object_position not in self.object_id_map:
            object_id = len(self.object_id_map) + 1
            self.object_id_map[object_position] = object_id
            self.object_tracks[object_id] = {'speeds': [], 'time_stamps': []}
            self.kalman_filters[object_id] = KalmanFilter(state_dim=4, measurement_dim=2)  # 创建卡尔曼滤波器

        else:
            object_id = self.object_id_map[object_position]

        kalman_filter = self.kalman_filters[object_id]
        estimated_state = kalman_filter.update(center)

        speed = estimated_state[2, 0] / self.fps  # 速度为状态向量的第3个元素
        self.object_tracks[object_id]['speeds'].append(speed)
        current_time = time.time() - self.start_time
        self.object_tracks[object_id]['time_stamps'].append(current_time)

        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 255), 2)
        cv2.putText(frame, f"ID: {object_id}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

这里是完整代码:

import tkinter as tk
from tkinter import messagebox
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import time
import threading
import os
# 初始化视频捕获
video = cv2.VideoCapture(0)

# 定义红色范围(HSV颜色空间)
lower_red = np.array([0, 127, 130], dtype=np.uint8)
upper_red = np.array([5, 255, 255], dtype=np.uint8)

# 获取视频的帧率
fps = video.get(cv2.CAP_PROP_FPS) if video.isOpened() else 0

# 初始化速度列表和时间戳列表
red_speeds = []
time_stamps = []

# 全局变量,用于控制视频循环
is_running = False

def on_closing():
    global is_running
    if messagebox.askokcancel("Quit", "Do you want to quit?"):
        is_running = False
        if video.isOpened():
            video.release()
        cv2.destroyAllWindows()
        root.destroy()


class KalmanFilter:
    def __init__(self, state_dim, measurement_dim):
        self.kalman = cv2.KalmanFilter(state_dim, measurement_dim)
        self.kalman.transitionMatrix = np.eye(state_dim)
        self.kalman.processNoiseCov = 0.01 * np.eye(state_dim)
        self.kalman.measurementNoiseCov = 0.1 * np.eye(measurement_dim)
        self.kalman.statePost = np.zeros((state_dim, 1), dtype=np.float32)
        self.kalman.errorCovPost = np.eye(state_dim, dtype=np.float32)

    def update(self, measurement):
        prediction = self.kalman.predict()
        estimated = self.kalman.correct(measurement)
        return estimated

class VideoThread(threading.Thread):
    def __init__(self, video, fps, red_speeds, time_stamps):
        super().__init__()
        self.video = video
        self.fps = fps
        self.red_speeds = red_speeds
        self.time_stamps = time_stamps
        self.is_running = True
        self.start_time = time.time()
        self.object_tracks = {}
        self.object_id_map = {}
        self.kalman_filters = {}  # 存储卡尔曼滤波器

    def track_object(self, rect, frame):
        x, y, w, h = rect
        center = np.array([[x + w // 2], [y + h // 2]], dtype=np.float32)
        object_position = str(center)

        if object_position not in self.object_id_map:
            object_id = len(self.object_id_map) + 1
            self.object_id_map[object_position] = object_id
            self.object_tracks[object_id] = {'speeds': [], 'time_stamps': []}
            self.kalman_filters[object_id] = KalmanFilter(state_dim=4, measurement_dim=2)  # 创建卡尔曼滤波器

        else:
            object_id = self.object_id_map[object_position]

        kalman_filter = self.kalman_filters[object_id]
        estimated_state = kalman_filter.update(center)

        speed = estimated_state[2, 0] / self.fps  # 速度为状态向量的第3个元素
        self.object_tracks[object_id]['speeds'].append(speed)
        current_time = time.time() - self.start_time
        self.object_tracks[object_id]['time_stamps'].append(current_time)

        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 255), 2)
        cv2.putText(frame, f"ID: {object_id}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

    def stop_video(self):
        self.is_running = False
    def run(self):
        global is_running
        next_object_id = 1  # 将 next_object_id 移到这里来初始化
        try:
            while self.is_running:
                ret, frame = self.video.read()
                if not ret:
                    self.stop_video_with_error("Failed to grab frame")
                    break

                hsv_img = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
                mask = cv2.inRange(hsv_img, lower_red, upper_red)
                mask = cv2.GaussianBlur(mask, (5, 5), 0)
                contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

                self.process_contours(contours, frame, next_object_id)

                cv2.imshow("Tracking", frame)
                if cv2.waitKey(int(1000 / self.fps)) & 0xFF == ord('q'):
                    break
        except Exception as e:
            self.stop_video_with_error(f"An error occurred: {e}")
        finally:
            self.cleanup_video()

    def process_contours(self, contours, frame, next_object_id):
        for cnt in contours:
            rect = cv2.boundingRect(cnt)
            if rect[2] * rect[3] > 100:
                self.track_object(rect, frame, next_object_id)
                next_object_id += 1

    def track_object(self, rect, frame, object_id):
        x, y, w, h = rect
        center = (x + w // 2, y + h // 2)
        if object_id not in self.object_tracks:
            self.object_tracks[object_id] = {'speeds': [], 'time_stamps': []}
        speed = center[0] / self.fps
        self.object_tracks[object_id]['speeds'].append(speed)
        current_time = time.time() - self.start_time
        self.object_tracks[object_id]['time_stamps'].append(current_time)
        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 255), 2)
        cv2.putText(frame, f"ID: {object_id}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

    def stop_video_with_error(self, error_message):
        messagebox.showerror("Error", error_message)
        self.is_running = False
    def cleanup_video(self):
        self.is_running = False
        if self.video.isOpened():
            self.video.release()
        cv2.destroyAllWindows()
        is_running = False

# 创建主窗口
root = tk.Tk()
root.title("Red Object Tracker")
root.protocol("WM_DELETE_WINDOW", on_closing)

# 创建开始按钮,并定义开始视频捕捉的函数
def start_video():
    global video_thread, is_running
    if not is_running:
        is_running = True
        video_thread = VideoThread(video, fps, red_speeds, time_stamps)
        video_thread.start()
    else:
        messagebox.showinfo("Info", "Video is already running.")

start_button = tk.Button(root, text="Start Video", command=start_video)
start_button.pack()

# 创建停止按钮,并定义停止视频捕捉的函数
def stop_video():
    global video_thread, is_running
    if is_running:
        video_thread.stop_video()
        video_thread.join()
        is_running = False
    else:
        messagebox.showinfo("Info", "Video is not running.")

stop_button = tk.Button(root, text="Stop Video", command=stop_video)
stop_button.pack()


# 创建更新图表按钮
def update_chart():
    plt.close('all')  # 关闭所有之前的图表窗口
    for object_id, track in video_thread.object_tracks.items():
        if not track['speeds'] or not track['time_stamps']:
            continue

        # 创建速度图
        fig_speed = plt.figure(figsize=(5, 5))
        plt.plot(track['time_stamps'], track['speeds'], label=f'Red Object {object_id} Speed')
        plt.xlabel('Time (seconds)')
        plt.ylabel('Speed (pixels per second)')
        plt.title(f'Speed of Red Object {object_id} Over Time')
        plt.legend()
        plt.grid(True)

        # 进行傅里叶变换
        fft_speed = np.fft.fft(track['speeds'])
        freq = np.fft.fftfreq(len(track['speeds']))

        # 创建频域图
        fig_fft = plt.figure(figsize=(5, 5))
        plt.plot(freq, np.abs(fft_speed))
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Amplitude')
        plt.title(f'Frequency Domain of {object_id}')
        plt.grid(True)

        # 在界面左右两侧显示速度图和频域图
        speed_frame = tk.Frame(root)
        speed_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=1)

        fft_frame = tk.Frame(root)
        fft_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=1)

        speed_canvas = FigureCanvasTkAgg(fig_speed, master=speed_frame)
        speed_canvas.draw()
        speed_canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)

        fft_canvas = FigureCanvasTkAgg(fig_fft, master=fft_frame)
        fft_canvas.draw()
        fft_canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)

update_chart_button = tk.Button(root, text="Update Chart", command=update_chart)
update_chart_button.pack()

# 创建导出数据按钮
def export_data():
    if not video_thread.object_tracks:  # 检查是否有跟踪对象数据
        messagebox.showerror("Error", "No tracking data available.")
        return

    try:
        all_data = pd.DataFrame()  # 创建一个空的 DataFrame 用于存储所有速度数据
        for object_id, track in video_thread.object_tracks.items():
            df = pd.DataFrame({'Time (s)': track['time_stamps'], f'Speed_Object_{object_id} (px/s)': track['speeds']})
            all_data = pd.concat([all_data, df], axis=1)  # 将每个物体的速度数据合并到一个 DataFrame 中

        with pd.ExcelWriter('tracking_data.xlsx') as writer:
            all_data.to_excel(writer, sheet_name='All_Speed_Data', index=False)  # 将所有速度数据导出到一个工作表中

        messagebox.showinfo("Success", "Tracking data exported to tracking_data.xlsx.")

        # 唤醒 Excel 并打开导出的 Excel 文件
        os.startfile('tracking_data.xlsx')
    except Exception as e:
        messagebox.showerror("Error", f"Failed to export tracking data: {e}")

export_button = tk.Button(root, text="Export Data to Excel", command=export_data)
export_button.pack()

# 创建显示统计信息的按钮
# 创建显示统计信息的按钮
def show_stats():
    if not video_thread.object_tracks:  # 检查是否有跟踪对象数据
        messagebox.showerror("Error", "No tracking data available.")
        return

    stats_message = ""
    for object_id, track in video_thread.object_tracks.items():
        if not track['speeds'] or not track['time_stamps']:
            continue

        avg_speed = np.mean(track['speeds'])
        speed_variance = np.var(track['speeds'])
        correlation = np.corrcoef(track['time_stamps'], track['speeds'])[0, 1]
        stats_message += f"鸡冠 {object_id} - 平均速度: {avg_speed:.2f} px/s, 速度方差: {speed_variance:.2f}, 相关性: {correlation:.2f}\n"

    if stats_message:
        messagebox.showinfo("Statistics", stats_message)
    else:
        messagebox.showerror("Error", "No valid tracking data for statistics.")
    print("统计信息:")
    print(stats_message)
stats_button = tk.Button(root, text="Show Statistics", command=show_stats)
stats_button.pack()

# 主循环
root.mainloop()

  • 写回答

4条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-04-18 12:46
    关注

    引自免费微信小程序:皆我百晓生

    在你的代码中,ID 错乱的问题可能源于当目标重叠或者消失又重新出现时,没有正确地管理ID映射。为了解决这个问题,我们可以采取以下步骤:

    1. 使用更稳定的特征来识别目标,例如使用轮廓特征或者目标检测方法得到的边界框。
    2. 当目标消失时,暂时保留其ID,等待它再次出现。
    3. 当多个目标重叠时,根据它们的颜色、形状或者其他特征进行区分。

    这里提供一种改进方案,我们将使用边界框的中心点和宽度与高度的比值作为唯一标识符。这样可以更好地处理目标重叠的情况。此外,我们添加一个last_seen字典来记录每个ID最后一次出现的位置和时间,以便在目标短暂消失后恢复其ID。

    首先,修改track_object函数,增加对边界框宽度与高度比值的处理,同时更新last_seen字典:

    def track_object(self, rect, frame):
        x, y, w, h = rect
        center = np.array([[x + w // 2], [y + h // 2]], dtype=np.float32)
        aspect_ratio = w / h
        object_key = f"{center[0]}, {center[1]}, {aspect_ratio}"
    
        if object_key not in self.object_id_map:
            object_id = len(self.object_id_map) + 1
            self.object_id_map[object_key] = object_id
            self.object_tracks[object_id] = {'speeds': [], 'time_stamps': []}
            self.kalman_filters[object_id] = KalmanFilter(state_dim=4, measurement_dim=2)  # 创建卡尔曼滤波器
            self.last_seen[object_id] = (center, time.time())
    
        else:
            object_id = self.object_id_map[object_key]
            last_center, last_time = self.last_seen[object_id]
            distance = np.linalg.norm(np.array(last_center) - center)
            elapsed_time = time.time() - last_time
            if distance < 5 and elapsed_time < 1:  # 可调整距离和时间阈值
                # 如果目标距离上次位置很近且时间间隔短,则认为是同一目标
                pass
            else:
                # 否则,创建新的ID并更新last_seen
                del self.object_id_map[object_key]
                del self.object_tracks[object_id]
                del self.kalman_filters[object_id]
                object_id = len(self.object_id_map) + 1
                self.object_id_map[object_key] = object_id
                self.object_tracks[object_id] = {'speeds': [], 'time_stamps': []}
                self.kalman_filters[object_id] = KalmanFilter(state_dim=4, measurement_dim=2)
                self.last_seen[object_id] = (center, time.time())
    
        kalman_filter = self.kalman_filters[object_id]
        estimated_state = kalman_filter.update(center)
    
        speed = estimated_state[2, 0] / self.fps  # 速度为状态向量的第3个元素
        self.object_tracks[object_id]['speeds'].append(speed)
        current_time = time.time() - self.start_time
        self.object_tracks[object_id]['time_stamps'].append(current_time)
    
        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 255), 2)
        cv2.putText(frame, f"ID: {object_id}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
    

    然后,在process_contours函数中,我们需要处理目标消失的情况,即当目标不在当前帧中出现时,将ID标记为已消失:

    def process_contours(self, contours, frame, next_object_id):
        for cnt in contours:
            rect = cv2.boundingRect(cnt)
            if rect[2] * rect[3] > 100:
                self.track_object(rect, frame, next_object_id)
                next_object_id += 1
            else:
                # 处理未检测到的目标
                for object_id, (_, last_time) in self.last_seen.items():
                    if time.time() - last_time > 5:  # 可调整这个时间阈值
                        del self.object_id_map[str(last_time)]
                        del self.object_tracks[object_id]
                        del self.kalman_filters[object_id]
                        del self.last_seen[object_id]
    

    这样,你应该能解决ID错乱的问题。不过请注意,这只是一个基本的解决方案,实际应用中可能还需要考虑更多的因素,比如目标的运动模型、遮挡情况等。对于复杂的场景,可能需要使用更复杂的目标跟踪算法,如DeepSORT、SORT等。

    评论 编辑记录

报告相同问题?

问题事件

  • 创建了问题 4月18日

悬赏问题

  • ¥15 如何在vue.config.js中读取到public文件夹下window.APP_CONFIG.API_BASE_URL的值
  • ¥50 浦育平台scratch图形化编程
  • ¥20 求这个的原理图 只要原理图
  • ¥15 vue2项目中,如何配置环境,可以在打完包之后修改请求的服务器地址
  • ¥20 微信的店铺小程序如何修改背景图
  • ¥15 UE5.1局部变量对蓝图不可见
  • ¥15 一共有五道问题关于整数幂的运算还有房间号码 还有网络密码的解答?(语言-python)
  • ¥20 sentry如何捕获上传Android ndk 崩溃
  • ¥15 在做logistic回归模型限制性立方条图时候,不能出完整图的困难
  • ¥15 G0系列单片机HAL库中景园gc9307液晶驱动芯片无法使用硬件SPI+DMA驱动,如何解决?