方案是获取边界框四点坐标,并根据四点坐标去推衍出箭尖落到靶子上的坐标,并获取靶子上的环数,并进行计数计分,实现实时报靶的功能。
现在是这种情况,没办法正常去获取到绘制的图像及坐标,目前是处于这个报错状态。
请给出修改代码,及修改方案,并提出建议,谢谢!
import torch
from PIL import Image
import cv2
from torchvision import transforms
from ultralytics import YOLO
import matplotlib.pyplot as plt
# 建议将这些值作为参数传递给函数或设置为全局常量
ARROW_LABEL_ID = '0'
TARGET_LABEL_ID = '1'
CLASS_LABELS = {'0': 'Arrow', '1': 'Target'}
def draw_arrows_and_targets(image_path, arrows, targets):
img = Image.open(image_path).convert("RGB")
fig, ax = plt.subplots()
ax.imshow(img)
for arrow_start, arrow_end in arrows:
ax.arrow(arrow_start[0], arrow_start[1], arrow_end[0] - arrow_start[0], arrow_end[1] - arrow_start[1],
head_width=0.2, head_length=0.4, fc='r', ec='r')
for target in targets:
circle = plt.Circle(target, radius=5, color='g', fill=False)
ax.add_artist(circle)
ax.set_xlim([0, img.width])
ax.set_ylim([img.height, 0])
plt.show()
def load_model(model_path):
try:
model = YOLO(model_path)
return model
except FileNotFoundError:
print("文件不存在,请检查模型路径。")
except Exception as e:
print(f"加载模型时发生错误:{e}")
def load_image(image_path):
try:
image = Image.open(image_path)
return image
except FileNotFoundError:
print("文件不存在,请检查图像路径。")
except Exception as e:
print(f"加载图像时发生错误:{e}")
def process_results(results, class_labels, arrow_label_id, target_label_id):
arrows = []
targets = []
if 'Arrow' not in class_labels or 'Target' not in class_labels:
raise ValueError("class_labels 必须包含 'Arrow' 和 'Target'")
for result in results:
if result is not None and len(result) > 0 and len(result) >= 6:
x1, y1, x2, y2, conf, cls = result
if cls == class_labels['Arrow']:
arrows.append([x1, y1, x2, y2, conf])
elif cls == class_labels['Target']:
targets.append([x1, y1, x2, y2, conf])
return arrows, targets
def draw_bboxes(image, bbox_list, color):
for bbox in bbox_list:
x1, y1, x2, y2 = bbox
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
def main(model_weights_path: str, input_image_path: str, class_labels: dict):
# 加载模型
model = load_model(model_weights_path)
if model is None:
return
# 加载图像
image = load_image(input_image_path)
if image is None:
return
# 图像预处理
transform = transforms.Compose([
transforms.ToTensor(),
])
image_tensor = transform(image)
image_tensor = image_tensor.unsqueeze(0)
# 运行模型进行目标检测
results = model(image_tensor)
# 处理检测结果
arrows, targets = process_results(results, class_labels, ARROW_LABEL_ID, TARGET_LABEL_ID)
draw_arrows_and_targets(input_image_path, arrows, targets)
if __name__ == "__main__":
MODEL_WEIGHTS_PATH = 'E:/arrow/train3/weights/best.pt'
INPUT_IMAGE_PATH = 'E:/BaiduNetdiskDownload/val/403.jpeg'
CLASS_LABELS = {0: 'Arrow', 1: 'Target'}
main(MODEL_WEIGHTS_PATH, INPUT_IMAGE_PATH, CLASS_LABELS)