VGG16得到的混淆矩阵错误
这是main代码:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import json
import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from model import vgg
class ConfusionMatrix(object):
def __init__(self, num_classes: int, labels: list):
self.matrix = np.zeros((num_classes, num_classes))
self.num_classes = num_classes
self.labels = labels
def update(self, preds, labels):
for p, t in zip(preds, labels):
self.matrix[p, t] += 1
def summary(self):
# calculate accuracy
sum_TP = 0
for i in range(self.num_classes):
sum_TP += self.matrix[i, i]
acc = sum_TP / np.sum(self.matrix)
print("the model accuracy is ", acc)
# precision, recall, F1-score, specificity
table = PrettyTable()
table.field_names = ["label", "Precision", "Recall", "F1-score", "Specificity"] #"Specificity"
for i in range(self.num_classes):
TP = self.matrix[i, i]
FP = np.sum(self.matrix[i, :]) - TP
FN = np.sum(self.matrix[:, i]) - TP
TN = np.sum(self.matrix) - TP - FP - FN
Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
f1_score = round((2*Precision*Recall)/(Precision+Recall),3)
table.add_row([self.labels[i], Precision, Recall, f1_score,Specificity])
print(table)
def plot(self):
matrix = self.matrix
print(matrix)
plt.imshow(matrix, cmap=plt.cm.Blues)
# # 设置x轴坐标label
# plt.xticks(range(self.num_classes), self.labels, rotation=45)
# # 设置y轴坐标label
# plt.yticks(range(self.num_classes), self.labels)
# 设置x轴坐标label为1, 2, 3
plt.xticks(range(self.num_classes), list(range(1, self.num_classes + 1)), rotation=45)
# 设置y轴坐标label为1, 2, 3
plt.yticks(range(self.num_classes), list(range(1, self.num_classes + 1)))
# 显示colorbar
plt.colorbar()
plt.xlabel('True Labels')
plt.ylabel('Predicted Labels')
plt.title('Confusion matrix')
# 在图中标注数量/概率信息
thresh = matrix.max() / 2
for x in range(self.num_classes):
for y in range(self.num_classes):
# 注意这里的matrix[y, x]不是matrix[x, y]
info = int(matrix[y, x])
plt.text(x, y, info,
verticalalignment='center',
horizontalalignment='center',
color="white" if info > thresh else "black")
plt.tight_layout()
plt.show()
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# 数据集路径
data_root = r"C:\Users\yingnuo.DESKTOP-9E5CS2I\Desktop\T1-data"
assert os.path.exists(data_root), "data path {} does not exist.".format(data_root)
validate_dataset = datasets.ImageFolder(root=os.path.join(data_root, "test"),
transform=data_transform)
batch_size = 16
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=2)
net = vgg(model_name="vgg16", num_classes=4)
# load pretrain weights
model_weight_path = "./weights/best_model.pth"
assert os.path.exists(model_weight_path), "cannot find {} file".format(model_weight_path)
weight_dict = torch.load(model_weight_path, map_location=device)
net.load_state_dict(weight_dict, strict=False)
net.to(device)
# read class_indict
json_label_path = './class_indices.json'
assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)
json_file = open(json_label_path, 'r')
class_indict = json.load(json_file)
labels = [label for _, label in class_indict.items()]
confusion = ConfusionMatrix(num_classes=4, labels=labels)
net.eval()
with torch.no_grad():
for val_data in tqdm(validate_loader):
val_images, val_labels = val_data
outputs = net(val_images.to(device))
outputs = torch.softmax(outputs, dim=1)
outputs = torch.argmax(outputs, dim=1)
confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy())
confusion.plot()
confusion.summary()