tiaya01 2024-04-18 16:50 采纳率: 85.7%
浏览 6

VGG16得到的混淆矩阵错误

VGG16得到的混淆矩阵错误

img

这是main代码:

import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import json

import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from prettytable import PrettyTable

from model import vgg


class ConfusionMatrix(object):


    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))
        self.num_classes = num_classes
        self.labels = labels

    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[p, t] += 1

    def summary(self):
        # calculate accuracy
        sum_TP = 0
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]
        acc = sum_TP / np.sum(self.matrix)
        print("the model accuracy is ", acc)

        # precision, recall, F1-score,          specificity
        table = PrettyTable()
        table.field_names = ["label", "Precision", "Recall", "F1-score", "Specificity"] #"Specificity"
        for i in range(self.num_classes):
            TP = self.matrix[i, i]
            FP = np.sum(self.matrix[i, :]) - TP
            FN = np.sum(self.matrix[:, i]) - TP
            TN = np.sum(self.matrix) - TP - FP - FN
            Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
            Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
            Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
            f1_score = round((2*Precision*Recall)/(Precision+Recall),3)
            table.add_row([self.labels[i], Precision, Recall, f1_score,Specificity])
        print(table)

    def plot(self):
        matrix = self.matrix
        print(matrix)
        plt.imshow(matrix, cmap=plt.cm.Blues)

        # # 设置x轴坐标label
        # plt.xticks(range(self.num_classes), self.labels, rotation=45)
        # # 设置y轴坐标label
        # plt.yticks(range(self.num_classes), self.labels)
        # 设置x轴坐标label为1, 2, 3
        plt.xticks(range(self.num_classes), list(range(1, self.num_classes + 1)), rotation=45)
        # 设置y轴坐标label为1, 2, 3
        plt.yticks(range(self.num_classes), list(range(1, self.num_classes + 1)))
        # 显示colorbar
        plt.colorbar()
        plt.xlabel('True Labels')
        plt.ylabel('Predicted Labels')
        plt.title('Confusion matrix')

        # 在图中标注数量/概率信息
        thresh = matrix.max() / 2
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 注意这里的matrix[y, x]不是matrix[x, y]
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        plt.tight_layout()
        plt.show()


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    data_transform = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    # 数据集路径
    data_root = r"C:\Users\yingnuo.DESKTOP-9E5CS2I\Desktop\T1-data"
    assert os.path.exists(data_root), "data path {} does not exist.".format(data_root)

    validate_dataset = datasets.ImageFolder(root=os.path.join(data_root, "test"),
                                            transform=data_transform)

    batch_size = 16
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=2)
    net = vgg(model_name="vgg16", num_classes=4)
    # load pretrain weights
    model_weight_path = "./weights/best_model.pth"
    assert os.path.exists(model_weight_path), "cannot find {} file".format(model_weight_path)
    weight_dict = torch.load(model_weight_path, map_location=device)
    net.load_state_dict(weight_dict, strict=False)
    net.to(device)

    # read class_indict
    json_label_path = './class_indices.json'
    assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)
    json_file = open(json_label_path, 'r')
    class_indict = json.load(json_file)

    labels = [label for _, label in class_indict.items()]
    confusion = ConfusionMatrix(num_classes=4, labels=labels)
    net.eval()
    with torch.no_grad():
        for val_data in tqdm(validate_loader):
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            outputs = torch.softmax(outputs, dim=1)
            outputs = torch.argmax(outputs, dim=1)
            confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy())
    confusion.plot()
    confusion.summary()


  • 写回答

4条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-04-18 16:50
    关注

    引自免费微信小程序:皆我百晓生

    根据您提供的代码和混淆矩阵,看起来存在一些逻辑上的错误。首先,让我们修复这些问题并重新运行您的代码。

    问题1:模型输出处理

    for val_data in tqdm(validate_loader):循环里,您应该将模型输出转换为预测标签,而不是取softmax后取最大值。因为混淆矩阵是基于原始类别预测的,而不是概率分布。

    修复方法:

    outputs = torch.argmax(outputs, dim=1)
    

    改为

    outputs = outputs.to("cpu").numpy().argmax(axis=1)
    

    问题2:类别的表示

    混淆矩阵中的标签显示为1、2、3、4,而实际的标签列表是通过读取json文件得到的。为了保持一致,我们需要确保混淆矩阵中的标签与实际标签列表匹配。

    修复方法:confusion.plot()函数内部,将x轴和y轴的标签从整数更改为对应的字符串标签。

    修复后的代码:

    def plot(self):
        matrix = self.matrix
        print(matrix)
        plt.imshow(matrix, cmap=plt.cm.Blues)
    
        # 设置x轴坐标label
        plt.xticks(range(self.num_classes), self.labels, rotation=45)
        # 设置y轴坐标label
        plt.yticks(range(self.num_classes), self.labels)
        # 设置x轴坐标label为1, 2, 3
        # plt.xticks(range(self.num_classes), list(range(1, self.num_classes + 1)), rotation=45)
        # 设置y轴坐标label为1, 2, 3
        # plt.yticks(range(self.num_classes), list(range(1, self.num_classes + 1)))
        # 显示colorbar
        plt.colorbar()
        plt.xlabel('True Labels')
        plt.ylabel('Predicted Labels')
        plt.title('Confusion matrix')
    
        # 在图中标注数量/概率信息
        thresh = matrix.max() / 2
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 注意这里的matrix[y, x]不是matrix[x, y]
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        plt.tight_layout()
        plt.show()
    

    问题3:计算精度、召回率和F1分数时

    当您计算这些指标时,您没有考虑到特殊情况,即分母可能为零的情况。在这种情况下,您可以使用scipy.stats库中的binary_confusion_matrix函数来自动处理这种情况,并返回精确度、召回率和F1分数。

    修复方法: 添加from scipy.stats import binary_confusion_matrix到顶部,并替换summary函数为以下内容:

    from scipy.stats import binary_confusion_matrix
    
    def summary(self):
        # calculate accuracy
        sum_TP = 0
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]
        acc = sum_TP / np.sum(self.matrix)
        print("the model accuracy is ", acc)
    
        # 使用scipy计算精确度、召回率和F1分数
        cm = binary_confusion_matrix(list(range(1, self.num_classes + 1)),
                                      list(range(1, self.num_classes + 1)), self.matrix.flatten())
    
        table = PrettyTable()
        table.field_names = ["label", "Precision", "Recall", "F1-score"]
        for i in range(self.num_classes):
            Precision = round(cm[1][i] / (cm[1][i] + cm[0][i]), 3)
            Recall = round(cm[1][i] / (cm[1][i] + cm[2][i]), 3)
            f1_score = round((2 * Precision * Recall) / (Precision + Recall), 3)
            table.add_row([self.labels[i], Precision, Recall, f1_score])
        print(table)
    

    现在,您只需按照上述更改更新代码,并运行程序,您应该会得到一个正确的混淆矩阵和相应的统计摘要。

    评论 编辑记录

报告相同问题?

问题事件

  • 创建了问题 4月18日

悬赏问题

  • ¥300 寻抓云闪付tn组成网页付款链接
  • ¥15 请问Ubuntu要怎么安装chrome呀?
  • ¥15 视频编码 十六进制问题
  • ¥15 Xsheii7我安装这个文件的时候跳出来另一个文件已锁定文件的无一部分进程无法访问。这个该怎么解决
  • ¥15 unity terrain打包后地形错位,跟建筑不在同一个位置,怎么办
  • ¥15 FileNotFoundError 解决方案
  • ¥15 uniapp实现如下图的图表功能
  • ¥15 u-subsection如何修改相邻两个节点样式
  • ¥30 vs2010开发 WFP(windows filtering platform)
  • ¥15 服务端控制goose报文控制块的发布问题