晚上好🌙🌙🌙
本答案参考ChatGPT-3.5
这段代码主要实现了使用YOLO模型进行目标检测,并通过matplotlib库将检测结果可视化出来。下面是对代码的分析和解决方案:
-
代码中的ARROW_LABEL_ID和TARGET_LABEL_ID以及CLASS_LABELS变量被定义为常量,建议将这些值作为函数参数传递或设置为全局常量,以便代码的可维护性和灵活性。
- 解决方案:将这些值作为参数传递给
draw_arrows_and_targets函数和process_results函数。
-
load_model函数用于加载YOLO模型,如果模型文件不存在会抛出FileNotFoundError异常,如果有其他异常会直接打印错误信息。建议在函数中添加对其他异常的处理,以便能够更好地调试模型加载错误的问题。
- 解决方案:在
except Exception as e中打印错误信息并返回None。
-
load_image函数用于加载图像,如果图像文件不存在会抛出FileNotFoundError异常,如果有其他异常会直接打印错误信息。建议在函数中添加对其他异常的处理,以便能够更好地调试图像加载错误的问题。
- 解决方案:在
except Exception as e中打印错误信息并返回None。
-
process_results函数用于处理检测结果,提取出箭头和目标的位置信息。如果class_labels字典中不包含'Arrow'或'Target'的键,会抛出ValueError异常。建议在抛出异常之前先检查字典中是否包含这些键。
- 解决方案:在抛出异常之前添加判断语句,检查字典中是否包含'Arrow'和'Target'。
-
draw_bboxes函数用于在图像上绘制边界框,但在代码中并没有调用该函数。如果需要绘制边界框,可以在draw_arrows_and_targets函数中调用该函数。
- 解决方案:在
draw_arrows_and_targets函数中添加调用draw_bboxes函数的代码。
-
main函数用于加载模型、图像预处理、运行模型进行检测,并最终绘制检测结果。其中,模型路径、图像路径、类别标签通过函数参数传递,但在函数内部又重新定义了MODEL_WEIGHTS_PATH、INPUT_IMAGE_PATH和CLASS_LABELS变量。建议直接使用函数参数传递的值,而不是重新定义变量。
- 解决方案:将
MODEL_WEIGHTS_PATH、INPUT_IMAGE_PATH和CLASS_LABELS的定义移除,直接使用函数参数传递的值。
-
最后,在if __name__ == "__main__":中可以直接调用main函数进行测试,而不需要再定义一次常量和类别标签。可以直接传入所需的参数。
- 解决方案:将
main函数的调用直接写在if __name__ == "__main__":块中,并传入适当的参数。
修改后的代码如下:
import torch
from PIL import Image
import cv2
from torchvision import transforms
from ultralytics import YOLO
import matplotlib.pyplot as plt
def draw_arrows_and_targets(image_path, arrows, targets):
img = Image.open(image_path).convert("RGB")
fig, ax = plt.subplots()
ax.imshow(img)
for arrow_start, arrow_end in arrows:
ax.arrow(arrow_start[0], arrow_start[1], arrow_end[0] - arrow_start[0], arrow_end[1] - arrow_start[1],
head_width=0.2, head_length=0.4, fc='r', ec='r')
for target in targets:
circle = plt.Circle(target, radius=5, color='g', fill=False)
ax.add_artist(circle)
ax.set_xlim([0, img.width])
ax.set_ylim([img.height, 0])
plt.show()
def load_model(model_path):
try:
model = YOLO(model_path)
return model
except FileNotFoundError:
print("文件不存在,请检查模型路径。")
return None
except Exception as e:
print(f"加载模型时发生错误:{e}")
return None
def load_image(image_path):
try:
image = Image.open(image_path)
return image
except FileNotFoundError:
print("文件不存在,请检查图像路径。")
return None
except Exception as e:
print(f"加载图像时发生错误:{e}")
return None
def process_results(results, class_labels, arrow_label_id, target_label_id):
arrows = []
targets = []
if 'Arrow' not in class_labels or 'Target' not in class_labels:
raise ValueError("class_labels 必须包含 'Arrow' 和 'Target'")
for result in results:
if result is not None and len(result) >= 6:
x1, y1, x2, y2, conf, cls = result
if cls == class_labels['Arrow']:
arrows.append([x1, y1, x2, y2, conf])
elif cls == class_labels['Target']:
targets.append([x1, y1, x2, y2, conf])
return arrows, targets
def draw_bboxes(image, bbox_list, color):
for bbox in bbox_list:
x1, y1, x2, y2 = bbox
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
def main(model_weights_path: str, input_image_path: str, class_labels: dict):
# 加载模型
model = load_model(model_weights_path)
if model is None:
return
# 加载图像
image = load_image(input_image_path)
if image is None:
return
# 图像预处理
transform = transforms.Compose([
transforms.ToTensor(),
])
image_tensor = transform(image)
image_tensor = image_tensor.unsqueeze(0)
# 运行模型进行目标检测
results = model(image_tensor)
# 处理检测结果
arrows, targets = process_results(results, class_labels, ARROW_LABEL_ID, TARGET_LABEL_ID)
draw_arrows_and_targets(input_image_path, arrows, targets)
if __name__ == "__main__":
model_weights_path = 'E:/arrow/train3/weights/best.pt'
input_image_path = 'E:/BaiduNetdiskDownload/val/403.jpeg'
class_labels = {0: 'Arrow', 1: 'Target'}
main(model_weights_path, input_image_path, class_labels)
希望以上解决方案对你有帮助!