先生们,我写了一个目标跟踪的算法,基于hsv颜色空间的,在进行多目标跟踪时,我给物体分配了ID,但总是出现ID错乱的情况,ID不能保证固定跟踪原有的物体,会跑到另外的被跟踪物体上,先生们能帮忙解决一下吗?下面是我的跟踪和分配id的代码:
class VideoThread(threading.Thread):
def __init__(self, video, fps, red_speeds, time_stamps):
super().__init__()
self.video = video
self.fps = fps
self.red_speeds = red_speeds
self.time_stamps = time_stamps
self.is_running = True
self.start_time = time.time()
self.object_tracks = {}
self.object_id_map = {}
self.kalman_filters = {} # 存储卡尔曼滤波器
def track_object(self, rect, frame):
x, y, w, h = rect
center = np.array([[x + w // 2], [y + h // 2]], dtype=np.float32)
object_position = str(center)
if object_position not in self.object_id_map:
object_id = len(self.object_id_map) + 1
self.object_id_map[object_position] = object_id
self.object_tracks[object_id] = {'speeds': [], 'time_stamps': []}
self.kalman_filters[object_id] = KalmanFilter(state_dim=4, measurement_dim=2) # 创建卡尔曼滤波器
else:
object_id = self.object_id_map[object_position]
kalman_filter = self.kalman_filters[object_id]
estimated_state = kalman_filter.update(center)
speed = estimated_state[2, 0] / self.fps # 速度为状态向量的第3个元素
self.object_tracks[object_id]['speeds'].append(speed)
current_time = time.time() - self.start_time
self.object_tracks[object_id]['time_stamps'].append(current_time)
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 255), 2)
cv2.putText(frame, f"ID: {object_id}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
这里是完整代码:
import tkinter as tk
from tkinter import messagebox
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import time
import threading
import os
# 初始化视频捕获
video = cv2.VideoCapture(0)
# 定义红色范围(HSV颜色空间)
lower_red = np.array([0, 127, 130], dtype=np.uint8)
upper_red = np.array([5, 255, 255], dtype=np.uint8)
# 获取视频的帧率
fps = video.get(cv2.CAP_PROP_FPS) if video.isOpened() else 0
# 初始化速度列表和时间戳列表
red_speeds = []
time_stamps = []
# 全局变量,用于控制视频循环
is_running = False
def on_closing():
global is_running
if messagebox.askokcancel("Quit", "Do you want to quit?"):
is_running = False
if video.isOpened():
video.release()
cv2.destroyAllWindows()
root.destroy()
class KalmanFilter:
def __init__(self, state_dim, measurement_dim):
self.kalman = cv2.KalmanFilter(state_dim, measurement_dim)
self.kalman.transitionMatrix = np.eye(state_dim)
self.kalman.processNoiseCov = 0.01 * np.eye(state_dim)
self.kalman.measurementNoiseCov = 0.1 * np.eye(measurement_dim)
self.kalman.statePost = np.zeros((state_dim, 1), dtype=np.float32)
self.kalman.errorCovPost = np.eye(state_dim, dtype=np.float32)
def update(self, measurement):
prediction = self.kalman.predict()
estimated = self.kalman.correct(measurement)
return estimated
class VideoThread(threading.Thread):
def __init__(self, video, fps, red_speeds, time_stamps):
super().__init__()
self.video = video
self.fps = fps
self.red_speeds = red_speeds
self.time_stamps = time_stamps
self.is_running = True
self.start_time = time.time()
self.object_tracks = {}
self.object_id_map = {}
self.kalman_filters = {} # 存储卡尔曼滤波器
def track_object(self, rect, frame):
x, y, w, h = rect
center = np.array([[x + w // 2], [y + h // 2]], dtype=np.float32)
object_position = str(center)
if object_position not in self.object_id_map:
object_id = len(self.object_id_map) + 1
self.object_id_map[object_position] = object_id
self.object_tracks[object_id] = {'speeds': [], 'time_stamps': []}
self.kalman_filters[object_id] = KalmanFilter(state_dim=4, measurement_dim=2) # 创建卡尔曼滤波器
else:
object_id = self.object_id_map[object_position]
kalman_filter = self.kalman_filters[object_id]
estimated_state = kalman_filter.update(center)
speed = estimated_state[2, 0] / self.fps # 速度为状态向量的第3个元素
self.object_tracks[object_id]['speeds'].append(speed)
current_time = time.time() - self.start_time
self.object_tracks[object_id]['time_stamps'].append(current_time)
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 255), 2)
cv2.putText(frame, f"ID: {object_id}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
def stop_video(self):
self.is_running = False
def run(self):
global is_running
next_object_id = 1 # 将 next_object_id 移到这里来初始化
try:
while self.is_running:
ret, frame = self.video.read()
if not ret:
self.stop_video_with_error("Failed to grab frame")
break
hsv_img = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
mask = cv2.inRange(hsv_img, lower_red, upper_red)
mask = cv2.GaussianBlur(mask, (5, 5), 0)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
self.process_contours(contours, frame, next_object_id)
cv2.imshow("Tracking", frame)
if cv2.waitKey(int(1000 / self.fps)) & 0xFF == ord('q'):
break
except Exception as e:
self.stop_video_with_error(f"An error occurred: {e}")
finally:
self.cleanup_video()
def process_contours(self, contours, frame, next_object_id):
for cnt in contours:
rect = cv2.boundingRect(cnt)
if rect[2] * rect[3] > 100:
self.track_object(rect, frame, next_object_id)
next_object_id += 1
def track_object(self, rect, frame, object_id):
x, y, w, h = rect
center = (x + w // 2, y + h // 2)
if object_id not in self.object_tracks:
self.object_tracks[object_id] = {'speeds': [], 'time_stamps': []}
speed = center[0] / self.fps
self.object_tracks[object_id]['speeds'].append(speed)
current_time = time.time() - self.start_time
self.object_tracks[object_id]['time_stamps'].append(current_time)
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 255), 2)
cv2.putText(frame, f"ID: {object_id}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
def stop_video_with_error(self, error_message):
messagebox.showerror("Error", error_message)
self.is_running = False
def cleanup_video(self):
self.is_running = False
if self.video.isOpened():
self.video.release()
cv2.destroyAllWindows()
is_running = False
# 创建主窗口
root = tk.Tk()
root.title("Red Object Tracker")
root.protocol("WM_DELETE_WINDOW", on_closing)
# 创建开始按钮,并定义开始视频捕捉的函数
def start_video():
global video_thread, is_running
if not is_running:
is_running = True
video_thread = VideoThread(video, fps, red_speeds, time_stamps)
video_thread.start()
else:
messagebox.showinfo("Info", "Video is already running.")
start_button = tk.Button(root, text="Start Video", command=start_video)
start_button.pack()
# 创建停止按钮,并定义停止视频捕捉的函数
def stop_video():
global video_thread, is_running
if is_running:
video_thread.stop_video()
video_thread.join()
is_running = False
else:
messagebox.showinfo("Info", "Video is not running.")
stop_button = tk.Button(root, text="Stop Video", command=stop_video)
stop_button.pack()
# 创建更新图表按钮
def update_chart():
plt.close('all') # 关闭所有之前的图表窗口
for object_id, track in video_thread.object_tracks.items():
if not track['speeds'] or not track['time_stamps']:
continue
# 创建速度图
fig_speed = plt.figure(figsize=(5, 5))
plt.plot(track['time_stamps'], track['speeds'], label=f'Red Object {object_id} Speed')
plt.xlabel('Time (seconds)')
plt.ylabel('Speed (pixels per second)')
plt.title(f'Speed of Red Object {object_id} Over Time')
plt.legend()
plt.grid(True)
# 进行傅里叶变换
fft_speed = np.fft.fft(track['speeds'])
freq = np.fft.fftfreq(len(track['speeds']))
# 创建频域图
fig_fft = plt.figure(figsize=(5, 5))
plt.plot(freq, np.abs(fft_speed))
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude')
plt.title(f'Frequency Domain of {object_id}')
plt.grid(True)
# 在界面左右两侧显示速度图和频域图
speed_frame = tk.Frame(root)
speed_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=1)
fft_frame = tk.Frame(root)
fft_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=1)
speed_canvas = FigureCanvasTkAgg(fig_speed, master=speed_frame)
speed_canvas.draw()
speed_canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)
fft_canvas = FigureCanvasTkAgg(fig_fft, master=fft_frame)
fft_canvas.draw()
fft_canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)
update_chart_button = tk.Button(root, text="Update Chart", command=update_chart)
update_chart_button.pack()
# 创建导出数据按钮
def export_data():
if not video_thread.object_tracks: # 检查是否有跟踪对象数据
messagebox.showerror("Error", "No tracking data available.")
return
try:
all_data = pd.DataFrame() # 创建一个空的 DataFrame 用于存储所有速度数据
for object_id, track in video_thread.object_tracks.items():
df = pd.DataFrame({'Time (s)': track['time_stamps'], f'Speed_Object_{object_id} (px/s)': track['speeds']})
all_data = pd.concat([all_data, df], axis=1) # 将每个物体的速度数据合并到一个 DataFrame 中
with pd.ExcelWriter('tracking_data.xlsx') as writer:
all_data.to_excel(writer, sheet_name='All_Speed_Data', index=False) # 将所有速度数据导出到一个工作表中
messagebox.showinfo("Success", "Tracking data exported to tracking_data.xlsx.")
# 唤醒 Excel 并打开导出的 Excel 文件
os.startfile('tracking_data.xlsx')
except Exception as e:
messagebox.showerror("Error", f"Failed to export tracking data: {e}")
export_button = tk.Button(root, text="Export Data to Excel", command=export_data)
export_button.pack()
# 创建显示统计信息的按钮
# 创建显示统计信息的按钮
def show_stats():
if not video_thread.object_tracks: # 检查是否有跟踪对象数据
messagebox.showerror("Error", "No tracking data available.")
return
stats_message = ""
for object_id, track in video_thread.object_tracks.items():
if not track['speeds'] or not track['time_stamps']:
continue
avg_speed = np.mean(track['speeds'])
speed_variance = np.var(track['speeds'])
correlation = np.corrcoef(track['time_stamps'], track['speeds'])[0, 1]
stats_message += f"鸡冠 {object_id} - 平均速度: {avg_speed:.2f} px/s, 速度方差: {speed_variance:.2f}, 相关性: {correlation:.2f}\n"
if stats_message:
messagebox.showinfo("Statistics", stats_message)
else:
messagebox.showerror("Error", "No valid tracking data for statistics.")
print("统计信息:")
print(stats_message)
stats_button = tk.Button(root, text="Show Statistics", command=show_stats)
stats_button.pack()
# 主循环
root.mainloop()