Uktttish 2024-04-18 09:21 采纳率: 22.2%
浏览 3

多模态特征融合模型中验证集可以用什么指标

拿到一份图片id和文本caption对应的csv文件,还有一份以图片id为文件名的图片文件夹,用百分之80做训练集训练完多模态特征融合图文检索模型,其中以余弦相似度做损失函数,想在每一次迭代中用验证集防止过拟合,具体要计算验证集的什么

  • 写回答

4条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-04-18 09:21
    关注

    引自免费微信小程序:皆我百晓生

    在多模态特征融合的场景中,常用的评估指标有Recall@K、Mean Average Precision (MAP)、Precision@K等。这些指标可以用于衡量模型在检索任务中的性能,即给定一个查询(通常是文本),模型能否正确地从候选图片集中找到与之相关的图片。

    以下是使用Python实现这些指标的方法以及如何在训练过程中监控验证集上的表现:

    首先,确保安装了必要的库,例如numpy, scipy, collections,以及可能的深度学习框架(如TensorFlow或PyTorch)。

    pip install numpy scipy collections
    

    接下来,定义一个通用的评估函数,它接受预测结果和实际标签,并计算各种指标:

    import numpy as np
    from scipy.spatial.distance import cosine
    from sklearn.metrics.pairwise import cosine_similarity
    from collections import defaultdict
    
    def evaluate(retrieved_indices, ground_truth_indices, k=5):
        """
        Compute Recall@k and Mean Average Precision (MAP) for a batch of queries.
        
        Args:
            retrieved_indices: List of lists containing the indices of the top-k retrieved items for each query.
            ground_truth_indices: List of lists containing the true relevant item indices for each query.
            k: Number of top results to consider for evaluation.
    
        Returns:
            recall_k: List of Recall@k scores for each query.
            map_score: Mean Average Precision score over all queries.
        """
        recall_k = []
        ap_scores = []
    
        # Iterate through queries
        for i in range(len(retrieved_indices)):
            relevant_items = set(ground_truth_indices[i])
            retrieved_items = set(retrieved_indices[i][:k])
    
            # Calculate Recall@k
            recall_k.append(len(relevant_items & retrieved_items) / len(relevant_items))
    
            # Calculate AP score
            precisions = []
            for j in range(k):
                if retrieved_indices[i][j] in relevant_items:
                    precisions.append(len(relevant_items & set(retrieved_indices[i][:j+1])) / (j + 1))
                else:
                    precisions.append(precisions[-1])
            ap_scores.append(np.sum(precisions) / len(relevant_items))
    
        # Compute mean AP
        map_score = np.mean(ap_scores)
    
        return recall_k, map_score
    

    在训练过程中,在每个epoch结束时,将验证数据集送入这个评估函数:

    # 假设train_loader和val_loader是你的训练和验证数据加载器
    # model是你已经训练好的多模态特征融合模型
    # device是你的设备(CPU或GPU)
    
    model.eval()
    total_recall_k = []
    total_map_score = []
    
    with torch.no_grad():
        for inputs_train, targets_train in train_loader:
            # ... 进行前向传播,获取预测结果 ...
            
        for inputs_val, targets_val in val_loader:
            # 将数据转移到设备
            inputs_val, targets_val = inputs_val.to(device), targets_val.to(device)
            
            # 进行前向传播,获取预测结果
            # 注意:这里不需要反向传播,因此通常会禁用梯度计算
            outputs_val = model(inputs_val)
            
            # 计算余弦相似度
            cos_similarities = cosine_similarity(outputs_val.cpu().numpy(), targets_val.cpu().numpy())
            
            # 获取top-k索引
            _, topk_indices = torch.topk(cos_similarities, k=5, dim=1)
            
            # 调用评估函数
            recall_k, map_score = evaluate(topk_indices.detach().cpu().numpy(), targets_val.numpy())
            total_recall_k.extend(recall_k)
            total_map_score.append(map_score)
    
        average_recall_k = np.mean(total_recall_k)
        average_map_score = np.mean(total_map_score)
    
    print(f"Average Recall@5 on validation set: {average_recall_k}")
    print(f"Average MAP on validation set: {average_map_score}")
    

    这样,每次迭代后你都可以得到验证集上的平均Recall@5和MAP值,从而监控模型的泛化能力并防止过拟合。你可以根据这两个指标调整超参数或提前停止训练。

    评论 编辑记录

报告相同问题?

问题事件

  • 创建了问题 4月18日

悬赏问题

  • ¥15 想咨询点问题,与算法转换,负荷预测,数字孪生有关
  • ¥15 C#中的编译平台的区别影响
  • ¥15 软件供应链安全是跟可靠性有关还是跟安全性有关?
  • ¥15 电脑蓝屏logfilessrtsrttrail问题
  • ¥20 关于wordpress建站遇到的问题!(语言-php)(相关搜索:云服务器)
  • ¥15 【求职】怎么找到一个周围人素质都很高不会欺负他人,并且未来月薪能够达到一万以上(技术岗)的工作?希望可以收到写有具体,可靠,已经实践过了的路径的回答?
  • ¥15 Java+vue部署版本反编译
  • ¥100 对反编译和ai熟悉的开发者。
  • ¥15 带序列特征的多输出预测模型
  • ¥15 Python 如何安装 distutils模块