与朝阳同醒
2022-04-30 09:26
采纳率: 66.7%
浏览 76

对以下代码结合实际写个详细的注释,阐述其中的原理?

这是个人脸表情识别的训练代码,databace_face是个存放了jpg文件的数据集,注释最好是行注释,因为好多库都不熟悉。


# k-means_opt.py
# -*- coding: utf-8 -*-
import shutil
import numpy as np
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
from sklearn.cluster import KMeans


def predict(image, model):
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    img = transform(image)
    img = torch.unsqueeze(img, 0)
    output = model(img)[0].data.numpy()
    return output


if __name__ == '__main__':
    class_names = 7
    model_path = 'model/weights.pth'
    # build model
    model_ft = models.resnet34(pretrained=False)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_ftrs, 7),
        nn.Softmax(dim=1)
    )
    model_ft.load_state_dict(torch.load(model_path))
    img_path = 'database_face'
    img_list = os.listdir(img_path)
    model_ft.eval()
    list_a = []
    for value in img_list:
        try:
            image = Image.open(os.path.join(img_path, value))
            output = predict(image, model_ft)
            list_a.append(output)
        except:
            pass

    b = np.array(list_a)
    print(b)

    k = 7
    # 聚类
    kmodel = KMeans(n_clusters=k, init='k-means++', random_state=0)
    kmodel.fit(b)

    dir_path = os.listdir('database_face')
    for img_name in dir_path:
        try:
            image = Image.open('database_face/' + img_name)
            output = predict(image, model_ft)

            min_num = 9999
            min_index = 10
            for index, value in enumerate(kmodel.cluster_centers_):
                dis = np.linalg.norm(value - output)  # 欧氏距离,计算最近的点
                if dis < min_num:
                    min_num = dis
                    min_index = index
            shutil.copyfile('database_face/' + img_name, 'output/' + str(min_index) + '/' + img_name)
        except:
            pass

    # 训练完得到的7个中心点
    for index, value in enumerate(kmodel.cluster_centers_):
        print(value)

'

2条回答 默认 最新

相关推荐 更多相似问题