2301_78417472 2024-01-21 23:43 采纳率: 100%
浏览 16
已结题

IndexError: list index out of range报错

pycharm中基于yolo-3的模型,使用自己训练的模型识别图片过程中报错IndexError: list index out of range,但使用coco的训练模型识别不报错。

from __future__ import division

from models import *
from utils.utils import *
from utils.datasets import *

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import sys
import time
import datetime
import argparse

from PIL import Image

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.ticker import NullLocator

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--image_folder", type=str, default="data/samples", help="path to dataset")
    parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")
    parser.add_argument("--weights_path", type=str, default="weights/yolov3.weights", help="path to weights file")
    parser.add_argument("--class_path", type=str, default="data/coco.names", help="path to class label file")
    parser.add_argument("--conf_thres", type=float, default=0.8, help="object confidence threshold")
    parser.add_argument("--nms_thres", type=float, default=0.4, help="iou thresshold for non-maximum suppression")
    parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
    parser.add_argument("--n_cpu", type=int, default=0, help="number of cpu threads to use during batch generation")
    parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
    parser.add_argument("--checkpoint_model", type=str, help="path to checkpoint model")
    opt = parser.parse_args()
    print(opt)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs("output", exist_ok=True)

    # Set up model
    model = Darknet(opt.model_def, img_size=opt.img_size).to(device)

    if opt.weights_path.endswith(".weights"):
        # Load darknet weights
        model.load_darknet_weights(opt.weights_path)
    else:
        # Load checkpoint weights
        model.load_state_dict(torch.load(opt.weights_path))

    model.eval()  # Set in evaluation mode

    dataloader = DataLoader(
        ImageFolder(opt.image_folder, img_size=opt.img_size),
        batch_size=opt.batch_size,
        shuffle=False,
        num_workers=opt.n_cpu,
    )

    classes = load_classes(opt.class_path)  # Extracts class labels from file

    Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

    imgs = []  # Stores image paths
    img_detections = []  # Stores detections for each image index

    print("\nPerforming object detection:")
    prev_time = time.time()
    for batch_i, (img_paths, input_imgs) in enumerate(dataloader):
        # Configure input
        input_imgs = Variable(input_imgs.type(Tensor))

        # Get detections
        with torch.no_grad():
            detections = model(input_imgs)
            detections = non_max_suppression(detections, opt.conf_thres, opt.nms_thres)

        # Log progress
        current_time = time.time()
        inference_time = datetime.timedelta(seconds=current_time - prev_time)
        prev_time = current_time
        print("\t+ Batch %d, Inference Time: %s" % (batch_i, inference_time))

        # Save image and detections
        imgs.extend(img_paths)
        img_detections.extend(detections)

    # Bounding-box colors
    cmap = plt.get_cmap("tab20b")
    colors = [cmap(i) for i in np.linspace(0, 1, 20)]

    print("\nSaving images:")
    # Iterate through images and save plot of detections
    for img_i, (path, detections) in enumerate(zip(imgs, img_detections)):

        print("(%d) Image: '%s'" % (img_i, path))

        # Create plot
        img = np.array(Image.open(path))
        plt.figure()
        fig, ax = plt.subplots(1)
        ax.imshow(img)

        # Draw bounding boxes and labels of detections
        if detections is not None:
            # Rescale boxes to original image
            detections = rescale_boxes(detections, opt.img_size, img.shape[:2])
            unique_labels = detections[:, -1].cpu().unique()
            n_cls_preds = len(unique_labels)
            bbox_colors = random.sample(colors, n_cls_preds)
            for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:

                print("\t+ Label: %s, Conf: %.5f" % (classes[int(cls_pred)], cls_conf.item()))

                box_w = x2 - x1
                box_h = y2 - y1

                color = bbox_colors[int(np.where(unique_labels == int(cls_pred))[0])]
                # Create a Rectangle patch
                bbox = patches.Rectangle((x1, y1), box_w, box_h, linewidth=2, edgecolor=color, facecolor="none")
                # Add the bbox to the plot
                ax.add_patch(bbox)
                # Add label
                plt.text(
                    x1,
                    y1,
                    s=classes[int(cls_pred)],
                    color="white",
                    verticalalignment="top",
                    bbox={"color": color, "pad": 0},
                )

        # Save generated image with detections
        plt.axis("off")
        plt.gca().xaxis.set_major_locator(NullLocator())
        plt.gca().yaxis.set_major_locator(NullLocator())
        filename = path.split("/")[-1].split(".")[0]
        plt.savefig(f"output/{filename}.png", bbox_inches="tight", pad_inches=0.0)
        plt.close()


报错:


Saving images:
(0) Image: 'data/samples\1.jpg'
Traceback (most recent call last):
  File "D:\PyTorch-YOLOv3\PyTorch-YOLOv3\detect.py", line 116, in <module>
    print("\t+ Label: %s, Conf: %.5f" % (classes[int(cls_pred)], cls_conf.item()))
IndexError: list index out of range


  • 写回答

12条回答 默认 最新

  • 呈两面包夹芝士 2024-01-22 14:04
    关注

    该回答引用讯飞星火及结果代码方法已验证可行
    这个问题是由于在处理图像时,预测到的类别索引超出了类别列表的范围。为了解决这个问题,你可以在访问classes列表之前检查cls_pred的值是否在有效范围内。你可以通过以下方式修改代码:

    将这部分代码:

    print("\t+ Label: %s, Conf: %.5f" % (classes[int(cls_pred)], cls_conf.item()))
    

    替换为:

    if 0 <= int(cls_pred) < len(classes):
        print("\t+ Label: %s, Conf: %.5f" % (classes[int(cls_pred)], cls_conf.item()))
    else:
        print("\t+ Label: %d, Conf: %.5f" % (int(cls_pred), cls_conf.item()))
    

    这样,当cls_pred超出范围时,代码将不会引发IndexError,而是直接输出类别索引。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(11条)

报告相同问题?

问题事件

  • 系统已结题 2月1日
  • 已采纳回答 1月24日
  • 修改了问题 1月21日
  • 赞助了问题酬金15元 1月21日
  • 展开全部

悬赏问题

  • ¥60 ESP32怎么烧录自启动程序
  • ¥50 html2canvas超出滚动条不显示
  • ¥15 MATLAB四叉树处理长方形tif文件
  • ¥15 java业务性能问题求解(sql,业务设计相关)
  • ¥15 52810 尾椎c三个a 写蓝牙地址
  • ¥15 elmos524.33 eeprom的读写问题
  • ¥15 使用Java milo连接Kepserver服务端报错?
  • ¥15 用ADS设计一款的射频功率放大器
  • ¥15 怎么求交点连线的理论解?
  • ¥20 软件开发方法学习来了