与朝阳同醒 2022-04-30 09:26 采纳率: 84.6%
浏览 79
已结题

对以下代码结合实际写个详细的注释,阐述其中的原理?

这是个人脸表情识别的训练代码,databace_face是个存放了jpg文件的数据集,注释最好是行注释,因为好多库都不熟悉。


# k-means_opt.py
# -*- coding: utf-8 -*-
import shutil
import numpy as np
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
from sklearn.cluster import KMeans


def predict(image, model):
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    img = transform(image)
    img = torch.unsqueeze(img, 0)
    output = model(img)[0].data.numpy()
    return output


if __name__ == '__main__':
    class_names = 7
    model_path = 'model/weights.pth'
    # build model
    model_ft = models.resnet34(pretrained=False)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_ftrs, 7),
        nn.Softmax(dim=1)
    )
    model_ft.load_state_dict(torch.load(model_path))
    img_path = 'database_face'
    img_list = os.listdir(img_path)
    model_ft.eval()
    list_a = []
    for value in img_list:
        try:
            image = Image.open(os.path.join(img_path, value))
            output = predict(image, model_ft)
            list_a.append(output)
        except:
            pass

    b = np.array(list_a)
    print(b)

    k = 7
    # 聚类
    kmodel = KMeans(n_clusters=k, init='k-means++', random_state=0)
    kmodel.fit(b)

    dir_path = os.listdir('database_face')
    for img_name in dir_path:
        try:
            image = Image.open('database_face/' + img_name)
            output = predict(image, model_ft)

            min_num = 9999
            min_index = 10
            for index, value in enumerate(kmodel.cluster_centers_):
                dis = np.linalg.norm(value - output)  # 欧氏距离,计算最近的点
                if dis < min_num:
                    min_num = dis
                    min_index = index
            shutil.copyfile('database_face/' + img_name, 'output/' + str(min_index) + '/' + img_name)
        except:
            pass

    # 训练完得到的7个中心点
    for index, value in enumerate(kmodel.cluster_centers_):
        print(value)

'

  • 写回答

2条回答 默认 最新

  • lazyn 2022-04-30 11:23
    关注
    # k-means_opt.py
    # -*- coding: utf-8 -*-
    # shutil模块是对os模块的补充,主要针对文件的拷贝、删除、移动、压缩和解压操作
    import shutil
    # 用来存储和处理大型矩阵
    import numpy as np
    # 深度学习库
    import torch
    import torch.nn as nn
    # torchvision由流行的数据集、模型架构和用于计算机视觉的常见图像转换组成
    from torchvision import models, transforms
    # 图像库
    from PIL import Image
    import os
    # KMeans聚类算法
    from sklearn.cluster import KMeans
    
    
    # 定义预测函数,传入待预测图象及使用的模型
    def predict(image, model):
        # 用Compose把多个步骤整合到一起
        # Resize用于调整图像尺寸,将原图像调整为224×224
        # ToTensor()将图像数据转换为tensor的
        # Normalize对图像进行标准化,mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],是从imagenet训练集中抽样算出来的
        transform = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        # 调用transform对图像进行处理
        img = transform(image)
        # torch.unsqueeze(input, dim, out=None)扩展维度,返回一个新的张量,对输入的既定位置插入维度1
        img = torch.unsqueeze(img, 0)
        # 调用模型,传入图像进行预测
        output = model(img)[0].data.numpy()
        return output
    
    
    if __name__ == '__main__':
        # 应该是定义类别数为7,下文并未用到此变量
        class_names = 7
        # 定义要导入的训练好的模型路径
        model_path = 'model/weights.pth'
        # build model
        # 调用resnet34模型,不使用预训练
        model_ft = models.resnet34(pretrained=False)
        # 得到模型分类层个数,即原模型的分类类别数
        num_ftrs = model_ft.fc.in_features
        # 重写分类层参数,nn.Sequential一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行
        # nn.dropout是为了防止或减轻过拟合数值为不保留节点数的比例
        # nn.Linear重新定义输出层,将项目分类为7个类别
        # nn.Softmax分类层激活函数,dim用来指定哪一维度相加为1,具体参考http://www.zzvips.com/article/207118.html
        model_ft.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_ftrs, 7),
            nn.Softmax(dim=1)
        )
        # 从本地载入已经训练好的模型参数
        model_ft.load_state_dict(torch.load(model_path))
        # 定义图片路径
        img_path = 'database_face'
        # 得到该路径下的所有图片列表
        img_list = os.listdir(img_path)
        # 设置模型为评估模型,用于预测
        model_ft.eval()
        # 定义用于存储预测结果的列表
        list_a = []
        # 循环调用图片列表中的每一张图片
        for value in img_list:
            try:
                # 读取本地图片到内存中
                image = Image.open(os.path.join(img_path, value))
                # 调用模型对图片进行预测
                output = predict(image, model_ft)
                # 将预测结果添加到列表中
                list_a.append(output)
            except:
                pass
        # 将列表转换为数组并打印出来
        b = np.array(list_a)
        print(b)
        # 定义簇的个数为7
        k = 7
        # 聚类
        # 调用KMeans聚类模型,聚类簇个数为7
        kmodel = KMeans(n_clusters=k, init='k-means++', random_state=0)
        # 使用聚类模型对预测结果进行聚类
        kmodel.fit(b)
    
        dir_path = os.listdir('database_face')
        for img_name in dir_path:
            try:
                image = Image.open('database_face/' + img_name)
                output = predict(image, model_ft)
    
                min_num = 9999
                min_index = 10
                # 循环读取聚类中心的下标及值
                for index, value in enumerate(kmodel.cluster_centers_):
                    dis = np.linalg.norm(value - output)  # 欧氏距离,计算最近的点
                    # 判断距离是否为最小距离
                    if dis < min_num:
                        # 定义最小距离为当前距离,下标为当前下标
                        min_num = dis
                        min_index = index
                # 将database_face中的图片复制到output/str(min_index)文件夹下
                shutil.copyfile('database_face/' + img_name, 'output/' + str(min_index) + '/' + img_name)
            except:
                pass
    
        # 训练完得到的7个中心点,打印值
        for index, value in enumerate(kmodel.cluster_centers_):
            print(value)
    

    如果有帮助的话望采纳,谢谢

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 4月30日
  • 已采纳回答 4月30日
  • 赞助了问题酬金10元 4月30日
  • 赞助了问题酬金30元 4月30日
  • 展开全部

悬赏问题

  • ¥15 #MATLAB仿真#车辆换道路径规划
  • ¥15 java 操作 elasticsearch 8.1 实现 索引的重建
  • ¥15 数据可视化Python
  • ¥15 要给毕业设计添加扫码登录的功能!!有偿
  • ¥15 kafka 分区副本增加会导致消息丢失或者不可用吗?
  • ¥15 微信公众号自制会员卡没有收款渠道啊
  • ¥100 Jenkins自动化部署—悬赏100元
  • ¥15 关于#python#的问题:求帮写python代码
  • ¥20 MATLAB画图图形出现上下震荡的线条
  • ¥15 关于#windows#的问题:怎么用WIN 11系统的电脑 克隆WIN NT3.51-4.0系统的硬盘