import torch
from PIL import Image
import cv2
from torchvision import transforms
from ultralytics import YOLO
import matplotlib.pyplot as plt
# 建议将这些值作为参数传递给函数或设置为全局常量
ARROW_LABEL_ID = '0'
TARGET_LABEL_ID = '1'
CLASS_LABELS = {'0': 'Arrow', '1': 'Target'}
def draw_arrows_and_targets(image_path, arrows, targets):
img = Image.open(image_path).convert("RGB")
fig, ax = plt.subplots()
ax.imshow(img)
for arrow_start, arrow_end in arrows:
ax.arrow(arrow_start[0], arrow_start[1], arrow_end[0] - arrow_start[0], arrow_end[1] - arrow_start[1],
head_width=0.2, head_length=0.4, fc='r', ec='r')
for target in targets:
circle = plt.Circle(target, radius=5, color='g', fill=False)
ax.add_artist(circle)
ax.set_xlim([0, img.width])
ax.set_ylim([img.height, 0])
plt.show()
def load_model(model_path):
try:
model = YOLO(model_path)
return model
except FileNotFoundError:
print("文件不存在,请检查模型路径。")
except Exception as e:
print(f"加载模型时发生错误:{e}")
def load_image(image_path):
try:
image = Image.open(image_path)
return image
except FileNotFoundError:
print("文件不存在,请检查图像路径。")
except Exception as e:
print(f"加载图像时发生错误:{e}")
def process_results(results, class_labels, arrow_label_id, target_label_id):
arrows = []
targets = []
if 'Arrow' not in class_labels or 'Target' not in class_labels:
raise ValueError("class_labels 必须包含 'Arrow' 和 'Target'")
for result in results:
if result is not None and len(result) > 0 and len(result) >= 6:
x1, y1, x2, y2, conf, cls = result
if cls == class_labels['Arrow']:
arrows.append([x1, y1, x2, y2, conf])
elif cls == class_labels['Target']:
targets.append([x1, y1, x2, y2, conf])
return arrows, targets
def draw_bboxes(image, bbox_list, color):
for bbox in bbox_list:
x1, y1, x2, y2 = bbox
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
def main(model_weights_path: str, input_image_path: str, class_labels: dict):
# 加载模型
model = load_model(model_weights_path)
if model is None:
return
# 加载图像
image = load_image(input_image_path)
if image is None:
return
# 图像预处理
transform = transforms.Compose([
transforms.ToTensor(),
])
image_tensor = transform(image)
image_tensor = image_tensor.unsqueeze(0)
# 运行模型进行目标检测
results = model(image_tensor)
# 处理检测结果
arrows, targets = process_results(results, class_labels, ARROW_LABEL_ID, TARGET_LABEL_ID)
draw_arrows_and_targets(input_image_path, arrows, targets)
if __name__ == "__main__":
MODEL_WEIGHTS_PATH = 'E:/arrow/train3/weights/best.pt'
INPUT_IMAGE_PATH = 'E:/BaiduNetdiskDownload/val/403.jpeg'
CLASS_LABELS = {0: 'Arrow', 1: 'Target'}
main(MODEL_WEIGHTS_PATH, INPUT_IMAGE_PATH, CLASS_LABELS)
![](https://profile-avatar.csdnimg.cn/a9d733ab1a74412583433f517be9a118_zjz_cn.jpg!4)
关于#python#的问题:不知道什么情况,请大家看一下!
- 写回答
- 好问题 0 提建议
- 追加酬金
- 关注问题
- 邀请回答
-
15条回答 默认 最新
- Kwan的解忧杂货铺 2024-04-02 22:39关注
晚上好🌙🌙🌙
本答案参考ChatGPT-3.5这段代码主要实现了使用YOLO模型进行目标检测,并通过matplotlib库将检测结果可视化出来。下面是对代码的分析和解决方案:
-
代码中的
ARROW_LABEL_ID
和TARGET_LABEL_ID
以及CLASS_LABELS
变量被定义为常量,建议将这些值作为函数参数传递或设置为全局常量,以便代码的可维护性和灵活性。- 解决方案:将这些值作为参数传递给
draw_arrows_and_targets
函数和process_results
函数。
- 解决方案:将这些值作为参数传递给
-
load_model
函数用于加载YOLO模型,如果模型文件不存在会抛出FileNotFoundError
异常,如果有其他异常会直接打印错误信息。建议在函数中添加对其他异常的处理,以便能够更好地调试模型加载错误的问题。- 解决方案:在
except Exception as e
中打印错误信息并返回None。
- 解决方案:在
-
load_image
函数用于加载图像,如果图像文件不存在会抛出FileNotFoundError
异常,如果有其他异常会直接打印错误信息。建议在函数中添加对其他异常的处理,以便能够更好地调试图像加载错误的问题。- 解决方案:在
except Exception as e
中打印错误信息并返回None。
- 解决方案:在
-
process_results
函数用于处理检测结果,提取出箭头和目标的位置信息。如果class_labels
字典中不包含'Arrow'或'Target'的键,会抛出ValueError
异常。建议在抛出异常之前先检查字典中是否包含这些键。- 解决方案:在抛出异常之前添加判断语句,检查字典中是否包含'Arrow'和'Target'。
-
draw_bboxes
函数用于在图像上绘制边界框,但在代码中并没有调用该函数。如果需要绘制边界框,可以在draw_arrows_and_targets
函数中调用该函数。- 解决方案:在
draw_arrows_and_targets
函数中添加调用draw_bboxes
函数的代码。
- 解决方案:在
-
main
函数用于加载模型、图像预处理、运行模型进行检测,并最终绘制检测结果。其中,模型路径、图像路径、类别标签通过函数参数传递,但在函数内部又重新定义了MODEL_WEIGHTS_PATH
、INPUT_IMAGE_PATH
和CLASS_LABELS
变量。建议直接使用函数参数传递的值,而不是重新定义变量。- 解决方案:将
MODEL_WEIGHTS_PATH
、INPUT_IMAGE_PATH
和CLASS_LABELS
的定义移除,直接使用函数参数传递的值。
- 解决方案:将
-
最后,在
if __name__ == "__main__":
中可以直接调用main
函数进行测试,而不需要再定义一次常量和类别标签。可以直接传入所需的参数。- 解决方案:将
main
函数的调用直接写在if __name__ == "__main__":
块中,并传入适当的参数。
- 解决方案:将
修改后的代码如下:
import torch from PIL import Image import cv2 from torchvision import transforms from ultralytics import YOLO import matplotlib.pyplot as plt def draw_arrows_and_targets(image_path, arrows, targets): img = Image.open(image_path).convert("RGB") fig, ax = plt.subplots() ax.imshow(img) for arrow_start, arrow_end in arrows: ax.arrow(arrow_start[0], arrow_start[1], arrow_end[0] - arrow_start[0], arrow_end[1] - arrow_start[1], head_width=0.2, head_length=0.4, fc='r', ec='r') for target in targets: circle = plt.Circle(target, radius=5, color='g', fill=False) ax.add_artist(circle) ax.set_xlim([0, img.width]) ax.set_ylim([img.height, 0]) plt.show() def load_model(model_path): try: model = YOLO(model_path) return model except FileNotFoundError: print("文件不存在,请检查模型路径。") return None except Exception as e: print(f"加载模型时发生错误:{e}") return None def load_image(image_path): try: image = Image.open(image_path) return image except FileNotFoundError: print("文件不存在,请检查图像路径。") return None except Exception as e: print(f"加载图像时发生错误:{e}") return None def process_results(results, class_labels, arrow_label_id, target_label_id): arrows = [] targets = [] if 'Arrow' not in class_labels or 'Target' not in class_labels: raise ValueError("class_labels 必须包含 'Arrow' 和 'Target'") for result in results: if result is not None and len(result) >= 6: x1, y1, x2, y2, conf, cls = result if cls == class_labels['Arrow']: arrows.append([x1, y1, x2, y2, conf]) elif cls == class_labels['Target']: targets.append([x1, y1, x2, y2, conf]) return arrows, targets def draw_bboxes(image, bbox_list, color): for bbox in bbox_list: x1, y1, x2, y2 = bbox cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2) def main(model_weights_path: str, input_image_path: str, class_labels: dict): # 加载模型 model = load_model(model_weights_path) if model is None: return # 加载图像 image = load_image(input_image_path) if image is None: return # 图像预处理 transform = transforms.Compose([ transforms.ToTensor(), ]) image_tensor = transform(image) image_tensor = image_tensor.unsqueeze(0) # 运行模型进行目标检测 results = model(image_tensor) # 处理检测结果 arrows, targets = process_results(results, class_labels, ARROW_LABEL_ID, TARGET_LABEL_ID) draw_arrows_and_targets(input_image_path, arrows, targets) if __name__ == "__main__": model_weights_path = 'E:/arrow/train3/weights/best.pt' input_image_path = 'E:/BaiduNetdiskDownload/val/403.jpeg' class_labels = {0: 'Arrow', 1: 'Target'} main(model_weights_path, input_image_path, class_labels)
希望以上解决方案对你有帮助!
解决 无用评论 打赏 举报 编辑记录 -
悬赏问题
- ¥15 matlab数据降噪处理,提高数据的可信度,确保峰值信号的不损失?
- ¥15 怎么看我在bios每次修改的日志
- ¥15 python+mysql图书管理系统
- ¥15 Questasim Error: (vcom-13)
- ¥15 船舶旋回实验matlab
- ¥30 SQL 数组,游标,递归覆盖原值
- ¥15 为什么我的数据接收的那么慢呀有没有完整的 hal 库并 代码呀有的话能不能发我一份并且我用 printf 函数显示处理之后的数据,用 debug 就不能运行了呢
- ¥20 gitlab 中文路径,无法下载
- ¥15 用动态规划算法均分纸牌
- ¥30 udp socket,bind 0.0.0.0 ,如何自动选取用户访问的服务器IP来回复数据