别无所求_zjz 2024-04-02 22:39 采纳率: 20%
浏览 11
已结题

关于#python#的问题:不知道什么情况,请大家看一下!


import torch
from PIL import Image
import cv2
from torchvision import transforms
from ultralytics import YOLO
import matplotlib.pyplot as plt

# 建议将这些值作为参数传递给函数或设置为全局常量
ARROW_LABEL_ID = '0'
TARGET_LABEL_ID = '1'
CLASS_LABELS = {'0': 'Arrow', '1': 'Target'}

def draw_arrows_and_targets(image_path, arrows, targets):
    img = Image.open(image_path).convert("RGB")
    fig, ax = plt.subplots()
    ax.imshow(img)
    for arrow_start, arrow_end in arrows:
        ax.arrow(arrow_start[0], arrow_start[1], arrow_end[0] - arrow_start[0], arrow_end[1] - arrow_start[1],
                 head_width=0.2, head_length=0.4, fc='r', ec='r')
    for target in targets:
        circle = plt.Circle(target, radius=5, color='g', fill=False)
        ax.add_artist(circle)
    ax.set_xlim([0, img.width])
    ax.set_ylim([img.height, 0])  
    plt.show()

def load_model(model_path):
    try:
        model = YOLO(model_path)
        return model
    except FileNotFoundError:
        print("文件不存在,请检查模型路径。")
    except Exception as e:
        print(f"加载模型时发生错误:{e}")

def load_image(image_path):
    try:
        image = Image.open(image_path)
        return image
    except FileNotFoundError:
        print("文件不存在,请检查图像路径。")
    except Exception as e:
        print(f"加载图像时发生错误:{e}")

def process_results(results, class_labels, arrow_label_id, target_label_id):
    arrows = []
    targets = []
    
    if 'Arrow' not in class_labels or 'Target' not in class_labels:
        raise ValueError("class_labels 必须包含 'Arrow' 和 'Target'")
    
    for result in results:
        if result is not None and len(result) > 0 and len(result) >= 6:
            x1, y1, x2, y2, conf, cls = result
            if cls == class_labels['Arrow']:
                arrows.append([x1, y1, x2, y2, conf])
            elif cls == class_labels['Target']:
                targets.append([x1, y1, x2, y2, conf])
    return arrows, targets

def draw_bboxes(image, bbox_list, color):
    for bbox in bbox_list:
        x1, y1, x2, y2 = bbox
        cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)

def main(model_weights_path: str, input_image_path: str, class_labels: dict):
    # 加载模型
    model = load_model(model_weights_path)
    if model is None:
        return

    # 加载图像
    image = load_image(input_image_path)
    if image is None:
        return

    # 图像预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    image_tensor = transform(image)
    image_tensor = image_tensor.unsqueeze(0)  

    # 运行模型进行目标检测
    results = model(image_tensor)

    # 处理检测结果
    arrows, targets = process_results(results, class_labels, ARROW_LABEL_ID, TARGET_LABEL_ID)
    draw_arrows_and_targets(input_image_path, arrows, targets)

if __name__ == "__main__":
    MODEL_WEIGHTS_PATH = 'E:/arrow/train3/weights/best.pt'
    INPUT_IMAGE_PATH = 'E:/BaiduNetdiskDownload/val/403.jpeg'
    CLASS_LABELS = {0: 'Arrow', 1: 'Target'}
    main(MODEL_WEIGHTS_PATH, INPUT_IMAGE_PATH, CLASS_LABELS)

img

  • 写回答

15条回答 默认 最新

  • Kwan的解忧杂货铺 2024-04-02 22:39
    关注

    晚上好🌙🌙🌙
    本答案参考ChatGPT-3.5

    这段代码主要实现了使用YOLO模型进行目标检测,并通过matplotlib库将检测结果可视化出来。下面是对代码的分析和解决方案:

    1. 代码中的ARROW_LABEL_IDTARGET_LABEL_ID以及CLASS_LABELS变量被定义为常量,建议将这些值作为函数参数传递或设置为全局常量,以便代码的可维护性和灵活性。

      • 解决方案:将这些值作为参数传递给draw_arrows_and_targets函数和process_results函数。
    2. load_model函数用于加载YOLO模型,如果模型文件不存在会抛出FileNotFoundError异常,如果有其他异常会直接打印错误信息。建议在函数中添加对其他异常的处理,以便能够更好地调试模型加载错误的问题。

      • 解决方案:在except Exception as e中打印错误信息并返回None。
    3. load_image函数用于加载图像,如果图像文件不存在会抛出FileNotFoundError异常,如果有其他异常会直接打印错误信息。建议在函数中添加对其他异常的处理,以便能够更好地调试图像加载错误的问题。

      • 解决方案:在except Exception as e中打印错误信息并返回None。
    4. process_results函数用于处理检测结果,提取出箭头和目标的位置信息。如果class_labels字典中不包含'Arrow'或'Target'的键,会抛出ValueError异常。建议在抛出异常之前先检查字典中是否包含这些键。

      • 解决方案:在抛出异常之前添加判断语句,检查字典中是否包含'Arrow'和'Target'。
    5. draw_bboxes函数用于在图像上绘制边界框,但在代码中并没有调用该函数。如果需要绘制边界框,可以在draw_arrows_and_targets函数中调用该函数。

      • 解决方案:在draw_arrows_and_targets函数中添加调用draw_bboxes函数的代码。
    6. main函数用于加载模型、图像预处理、运行模型进行检测,并最终绘制检测结果。其中,模型路径、图像路径、类别标签通过函数参数传递,但在函数内部又重新定义了MODEL_WEIGHTS_PATHINPUT_IMAGE_PATHCLASS_LABELS变量。建议直接使用函数参数传递的值,而不是重新定义变量。

      • 解决方案:将MODEL_WEIGHTS_PATHINPUT_IMAGE_PATHCLASS_LABELS的定义移除,直接使用函数参数传递的值。
    7. 最后,在if __name__ == "__main__":中可以直接调用main函数进行测试,而不需要再定义一次常量和类别标签。可以直接传入所需的参数。

      • 解决方案:将main函数的调用直接写在if __name__ == "__main__":块中,并传入适当的参数。

    修改后的代码如下:

    import torch
    from PIL import Image
    import cv2
    from torchvision import transforms
    from ultralytics import YOLO
    import matplotlib.pyplot as plt
    
    def draw_arrows_and_targets(image_path, arrows, targets):
        img = Image.open(image_path).convert("RGB")
        fig, ax = plt.subplots()
        ax.imshow(img)
        for arrow_start, arrow_end in arrows:
            ax.arrow(arrow_start[0], arrow_start[1], arrow_end[0] - arrow_start[0], arrow_end[1] - arrow_start[1],
                     head_width=0.2, head_length=0.4, fc='r', ec='r')
        for target in targets:
            circle = plt.Circle(target, radius=5, color='g', fill=False)
            ax.add_artist(circle)
        ax.set_xlim([0, img.width])
        ax.set_ylim([img.height, 0])  
        plt.show()
    
    def load_model(model_path):
        try:
            model = YOLO(model_path)
            return model
        except FileNotFoundError:
            print("文件不存在,请检查模型路径。")
            return None
        except Exception as e:
            print(f"加载模型时发生错误:{e}")
            return None
    
    def load_image(image_path):
        try:
            image = Image.open(image_path)
            return image
        except FileNotFoundError:
            print("文件不存在,请检查图像路径。")
            return None
        except Exception as e:
            print(f"加载图像时发生错误:{e}")
            return None
    
    def process_results(results, class_labels, arrow_label_id, target_label_id):
        arrows = []
        targets = []
        
        if 'Arrow' not in class_labels or 'Target' not in class_labels:
            raise ValueError("class_labels 必须包含 'Arrow' 和 'Target'")
        
        for result in results:
            if result is not None and len(result) >= 6:
                x1, y1, x2, y2, conf, cls = result
                if cls == class_labels['Arrow']:
                    arrows.append([x1, y1, x2, y2, conf])
                elif cls == class_labels['Target']:
                    targets.append([x1, y1, x2, y2, conf])
        return arrows, targets
    
    def draw_bboxes(image, bbox_list, color):
        for bbox in bbox_list:
            x1, y1, x2, y2 = bbox
            cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
    
    def main(model_weights_path: str, input_image_path: str, class_labels: dict):
        # 加载模型
        model = load_model(model_weights_path)
        if model is None:
            return
    
        # 加载图像
        image = load_image(input_image_path)
        if image is None:
            return
    
        # 图像预处理
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        image_tensor = transform(image)
        image_tensor = image_tensor.unsqueeze(0)  
    
        # 运行模型进行目标检测
        results = model(image_tensor)
    
        # 处理检测结果
        arrows, targets = process_results(results, class_labels, ARROW_LABEL_ID, TARGET_LABEL_ID)
        draw_arrows_and_targets(input_image_path, arrows, targets)
    
    if __name__ == "__main__":
        model_weights_path = 'E:/arrow/train3/weights/best.pt'
        input_image_path = 'E:/BaiduNetdiskDownload/val/403.jpeg'
        class_labels = {0: 'Arrow', 1: 'Target'}
        main(model_weights_path, input_image_path, class_labels)
    

    希望以上解决方案对你有帮助!

    评论 编辑记录

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 4月3日
  • 创建了问题 4月2日

悬赏问题

  • ¥15 matlab数据降噪处理,提高数据的可信度,确保峰值信号的不损失?
  • ¥15 怎么看我在bios每次修改的日志
  • ¥15 python+mysql图书管理系统
  • ¥15 Questasim Error: (vcom-13)
  • ¥15 船舶旋回实验matlab
  • ¥30 SQL 数组,游标,递归覆盖原值
  • ¥15 为什么我的数据接收的那么慢呀有没有完整的 hal 库并 代码呀有的话能不能发我一份并且我用 printf 函数显示处理之后的数据,用 debug 就不能运行了呢
  • ¥20 gitlab 中文路径,无法下载
  • ¥15 用动态规划算法均分纸牌
  • ¥30 udp socket,bind 0.0.0.0 ,如何自动选取用户访问的服务器IP来回复数据