facenet_pytorch进行人脸识别训练,总是提示错误,有没有办法帮我解决一下呢?
当前环境如下:
# 导入所需库,用于图像处理、人脸检测和人脸识别
from facenet_pytorch import MTCNN, InceptionResnetV1
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import os
import cv2
import urllib.request
import matplotlib.pyplot as plt
import numpy as np
# 下载预训练模型权重文件到本地
pretrained_model_url = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt' # 替换为实际预训练模型的 URL
pretrained_model_path = './model/20180402-114759-vggface2.pt' # 本地存储预训练模型的路径
if not os.path.exists(pretrained_model_path):
print("Downloading the VGGFace2 pretrained model...")
urllib.request.urlretrieve(pretrained_model_url, pretrained_model_path)
print("Download complete.")
else:
print("Pretrained model file exists.")
# 输入人名和图片路径
#person_name = input("请输入姓名:")
#image_path = input("请输入你要入库的图片路径:")
person_name = "lixiaochun"
image_path = r"d:\1.jpg"
# 检查人名对应的文件夹是否存在,不存在则创建
person_dir = "database/orgin/" + person_name
if not os.path.exists(person_dir):
os.makedirs(person_dir)
# 保存图片到指定目录
image = cv2.imread(image_path)
cv2.imwrite(os.path.join(person_dir, "1.jpg"), image)
# 设定设备、多线程工作数和初始化模型
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))
mtcnn = MTCNN(
image_size=160, margin=0, min_face_size=20,
thresholds=[0.8, 0.8, 0.9], factor=0.709, post_process=True,
device=device
)
resnet = InceptionResnetV1(pretrained=None).eval().to(device)
# 定义自定义数据集类
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self.image_list = []
for root, dirs, files in os.walk(root_dir):
for file in files:
img_path = os.path.join(root, file)
# 先读取原始图像并进行预处理
img = cv2.imread(img_path)
resized_image = cv2.resize(img, (mtcnn.image_size, mtcnn.image_size)) # 将图片缩放到MTCNN所需的大小
tensor_img = ToTensor()(resized_image)
self.image_list.append((tensor_img, img_path))
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
tensor_img, img_path = self.image_list[idx]
person_name = os.path.basename(os.path.dirname(img_path))
return tensor_img, person_name
# 创建并使用自定义数据集
custom_dataset = CustomDataset('./database/orgin')
loader = DataLoader(custom_dataset, num_workers=0)
aligned_boxes = []
names = []
# 遍历数据集,进行人脸检测并提取特征向量
for index, (x, person_name) in enumerate(loader):
save_path = './database/aligned/{}/'.format(person_name) # 对齐后人脸的保存路径
if not os.path.exists(save_path):
os.makedirs(save_path)
print(x.shape)
# 不需要将图片从 (C, H, W) 转换为 (1, C, H, W),直接在原始图片上进行可视化
print(f"Processing image at index {index}...")
if x.shape[0] == 1: # 当 batch_size = 1 时
display_image = x.squeeze(0).permute(1, 2, 0).numpy()
else:
display_image = x.permute(1, 2, 0).numpy()
# 然后进行缩放操作
display_image = cv2.resize(display_image, (400, 400))
# 对于灰度图像或 RGB 图像分别处理显示
if x.shape[1] == 1: # 对于灰度图像
plt.imshow(display_image, cmap='gray')
elif x.shape[1] == 3: # 对于 RGB 图像
plt.imshow(display_image)
plt.show()
# 确保 x 在正确的设备上
x = x.to(device)
# 确保 x 是一个批次大小为1的四维张量
if len(x.shape) != 4 or x.shape[0] != 1:
print("Invalid input shape for MTCNN. Expected (1, C, H, W), got:", x.shape)
else:
results = mtcnn(x)
if not results:
print(f"No faces detected in the image at index {index}. Skipping.")
continue
boxes, probs, points = results
print(f"Number of faces detected: {len(boxes)}")
if len(boxes) > 0:
x_aligned = x[0][boxes[0]]
aligned_boxes.append(x_aligned)
save_path_with_index = save_path + '/{}.jpg'.format(index)
cv2.imwrite(save_path_with_index, x_aligned.permute(1, 2, 0).numpy())
names.append(person_name)
print(f"Face detected and saved at index {index}")
else:
print(f"No face detected in the image at index {index}")
# 确保仅对成功检测到人脸的样本进行后续处理
if aligned_boxes:
aligned_boxes = torch.stack(aligned_boxes).to(device)
embeddings = resnet(aligned_boxes).detach().cpu()
# 保存特征向量和对应的名字
torch.save(embeddings, './database/database.pt') # 将所有人脸的特征向量保存到一个文件中
torch.save(names, './database/names.pt') # 将所有人的名字列表保存到一个文件中
运行提示这个错误 怎么解决呢?
RESTART: C:\Users\soft-kaifa\Desktop\pytorch人脸识别\facenet_pytorch_ruku(离线版).py
20180402-114759-vggface2文件存在,程序即将开始运行!
请输入姓名:lixiaochun
请输入你要入库的图片路径:d:\1.jpg
Running on device: cpu
从本地路径加载预训练的人脸识别模型!
定义数据加载时的合并函数collate_fn
准备数据集
遍历数据集,进行人脸检测并提取特征向量
1-1 ./database/aligned/lixiaochun/
Traceback (most recent call last):
File "C:\Users\soft-kaifa\Desktop\pytorch人脸识别\facenet_pytorch_ruku(离线版).py", line 128, in <module>
results = mtcnn(x.unsqueeze(0), return_prob=True, save_path=save_path)
File "C:\Python36\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\soft-kaifa\Desktop\pytorch人脸识别\facenet_pytorch\models\mtcnn.py", line 258, in forward
batch_boxes, batch_probs, batch_points = self.detect(img, landmarks=True)
File "C:\Users\soft-kaifa\Desktop\pytorch人脸识别\facenet_pytorch\models\mtcnn.py", line 317, in detect
self.device
File "C:\Users\soft-kaifa\Desktop\pytorch人脸识别\facenet_pytorch\models\utils\detect_face.py", line 84, in detect_face
boxes = torch.cat(boxes, dim=0)
NotImplementedError: There were no tensor arguments to this function (e.g., you passed an empty list of Tensors), but no fallback function is registered for schema aten::_cat. This usually means that this function requires a non-empty list of Tensors, or that you (the operator writer) forgot to register a fallback function. Available functions are [CPU, QuantizedCPU, BackendSelect, Python, Named, Conjugate, Negative, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradLazy, AutogradXPU, AutogradMLC, AutogradHPU, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, UNKNOWN_TENSOR_TYPE_ID, Autocast, Batched, VmapMode].
CPU: registered at aten\src\ATen\RegisterCPU.cpp:18433 [kernel]
QuantizedCPU: registered at aten\src\ATen\RegisterQuantizedCPU.cpp:1068 [kernel]
BackendSelect: fallthrough registered at ..\aten\src\ATen\core\BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ..\aten\src\ATen\core\PythonFallbackKernel.cpp:47 [backend fallback]
Named: registered at ..\aten\src\ATen\core\NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ..\aten\src\ATen\ConjugateFallback.cpp:18 [backend fallback]
Negative: registered at ..\aten\src\ATen\native\NegateFallback.cpp:18 [backend fallback]
ADInplaceOrView: fallthrough registered at ..\aten\src\ATen\core\VariableFallbackKernel.cpp:64 [backend fallback]
AutogradOther: registered at ..\torch\csrc\autograd\generated\VariableType_3.cpp:10215 [autograd kernel]
AutogradCPU: registered at ..\torch\csrc\autograd\generated\VariableType_3.cpp:10215 [autograd kernel]
AutogradCUDA: registered at ..\torch\csrc\autograd\generated\VariableType_3.cpp:10215 [autograd kernel]
AutogradXLA: registered at ..\torch\csrc\autograd\generated\VariableType_3.cpp:10215 [autograd kernel]
AutogradLazy: registered at ..\torch\csrc\autograd\generated\VariableType_3.cpp:10215 [autograd kernel]
AutogradXPU: registered at ..\torch\csrc\autograd\generated\VariableType_3.cpp:10215 [autograd kernel]
AutogradMLC: registered at ..\torch\csrc\autograd\generated\VariableType_3.cpp:10215 [autograd kernel]
AutogradHPU: registered at ..\torch\csrc\autograd\generated\VariableType_3.cpp:10215 [autograd kernel]
AutogradNestedTensor: registered at ..\torch\csrc\autograd\generated\VariableType_3.cpp:10215 [autograd kernel]
AutogradPrivateUse1: registered at ..\torch\csrc\autograd\generated\VariableType_3.cpp:10215 [autograd kernel]
AutogradPrivateUse2: registered at ..\torch\csrc\autograd\generated\VariableType_3.cpp:10215 [autograd kernel]
AutogradPrivateUse3: registered at ..\torch\csrc\autograd\generated\VariableType_3.cpp:10215 [autograd kernel]
Tracer: registered at ..\torch\csrc\autograd\generated\TraceType_3.cpp:11593 [kernel]
UNKNOWN_TENSOR_TYPE_ID: fallthrough registered at ..\aten\src\ATen\autocast_mode.cpp:466 [backend fallback]
Autocast: fallthrough registered at ..\aten\src\ATen\autocast_mode.cpp:305 [backend fallback]
Batched: registered at ..\aten\src\ATen\BatchingRegistrations.cpp:1016 [backend fallback]
VmapMode: fallthrough registered at ..\aten\src\ATen\VmapModeRegistrations.cpp:33 [backend fallback]