yinshuai5757 2025-05-25 23:44 采纳率: 0%
浏览 19

ChromaDB + SIGLIP:明明相关的文本图片相似度异常低(0.01)求解

代码描述
・这是一个用ChromaDB把图片用SIGLIP向量化的代码,然后用户输入文本也通过SIGLIP向量化,然后通过ChromaDB查询相似度高的图片的代码。
------------------------------------------------------------------------------------------------------------------------------------------------------------------------
问题
・为什么我输入的文本和图片的相关度很高,但是相似度才0.01?可能是什么原因导致明明相关的内容相似度这么低?是模型问题、预处理问题还是其他技术问题?
------------------------------------------------------------------------------------------------------------------------------------------------------------------------
补充信息
・我的ChromaDB返回的是余弦距离
・向量都做了L2归一化(范数=1)
・使用的转换公式:相似度 = 1 - 余弦距离
・SIGLIP模型用于图片和文本的向量化
・图片和文本向量都是768维
------------------------------------------------------------------------------------------------------------------------------------------------------------------------
期望结果
・相关度很高的文本和图片应该有较高的相似度(比如0.7+),但实际只有0.01。
------------------------------------------------------------------------------------------------------------------------------------------------------------------------
技术细节
・模型:SIGLIP-256
・向量维度:768
・归一化:L2归一化
・距离计算:余弦距离
・数据库:ChromaDB
・构筑环境:docker
------------------------------------------------------------------------------------------------------------------------------------------------------------------------


from flask import Flask, request, jsonify, send_file
from transformers import AutoModel, AutoProcessor, SiglipModel
from PIL import Image
import torch
import ftfy
import html
import re
import os
import logging
import chromadb
from chromadb.utils.embedding_functions import EmbeddingFunction
from urllib.parse import quote, unquote
from datetime import datetime
import glob
import hashlib
from tqdm import tqdm
import uuid
import time
import threading
import numpy as np


app = Flask(__name__)

# 设置日志
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# 数据库路径
db_path = "/app/vectordb"
# 确保目录存在
os.makedirs(db_path, exist_ok=True)

# 全局变量
client = None
collection = None
model = None
processor = None
collection_name = "unified_image_vectors"  # 默认集合名称

# 初始化模型和处理器
def init_model():
    global model, processor

    try:
        logger.info("加载模型和处理器...")
        model = AutoModel.from_pretrained(
            "/app/models/siglip256",
            trust_remote_code=True
        ).to('cuda')
        processor = AutoProcessor.from_pretrained("/app/models/siglip256")
        logger.info("模型和处理器加载成功")
        return True
    except Exception as e:
        logger.error(f"加载模型失败: {str(e)}")
        return False

# 初始化ChromaDB客户端
def init_db():
    global client, collection

    try:
        logger.info(f"初始化ChromaDB,连接路径: {db_path}")
        client = chromadb.PersistentClient(path=db_path)
        logger.info("ChromaDB客户端初始化成功")
        return True
    except Exception as e:
        logger.error(f"初始化ChromaDB客户端失败: {str(e)}")
        return False

# 获取或创建集合
def get_or_create_collection():
    global client, collection, model, processor, collection_name

    try:
        # 创建自定义嵌入函数
        embedding_function = SiglipTextEmbeddingFunction(model, processor)

        # 尝试获取现有集合
        try:
            collection = client.get_collection(
                name=collection_name,
                embedding_function=embedding_function
            )
            logger.info(f"已连接到现有集合: {collection_name},包含 {collection.count()} 个向量")
            logger.info(f"集合距离类型: {collection.metadata.get('hnsw:space', 'l2')}")
        except Exception as e:
            logger.warning(f"获取集合出错: {str(e)}")
            # 创建新集合
            collection = client.create_collection(
                name=collection_name,
                embedding_function=embedding_function,
                metadata={"hnsw:space": "cosine"}
            )
            logger.info(f"创建了新集合: {collection_name},使用余弦距离")

        return True
    except Exception as e:
        logger.error(f"获取或创建集合失败: {str(e)}")
        return False

# 🔧 添加手动重建API
@app.route('/recreate-collection', methods=['POST'])
def recreate_collection():
    global collection_name, collection, client

    try:
        # 删除现有集合
        try:
            client.delete_collection(collection_name)
            logger.info(f"已删除现有集合: {collection_name}")
        except:
            logger.info(f"集合 {collection_name} 不存在")

        # 重置全局变量
        collection = None

        # 重新创建集合
        if get_or_create_collection():
            return jsonify({
                'success': True,
                'message': f'成功重新创建集合 {collection_name},使用余弦距离',
                'metadata': collection.metadata,
                'count': collection.count()
            })
        else:
            return jsonify({'error': '重新创建集合失败'}), 500

    except Exception as e:
        return jsonify({'error': str(e)}), 500

# 文本处理函数
def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()

def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

# 生成图片ID的函数
def get_image_id(image_path: str) -> str:
    """为图片生成唯一标识符,基于文件名和修改时间"""
    file_name = os.path.basename(image_path)
    file_mtime = os.path.getmtime(image_path)
    # 结合文件名和修改时间创建唯一ID
    image_id = hashlib.md5(f"{file_name}_{file_mtime}".encode()).hexdigest()[:10]
    return image_id

# 自定义嵌入函数 - 用于文本查询
class SiglipTextEmbeddingFunction(EmbeddingFunction):
    def __init__(self, model, processor):
        self.model = model
        self.processor = processor

    def __call__(self, texts):
        text_embeddings = []
        for text in texts:
            # 处理文本
            cleaned_text = whitespace_clean(basic_clean(text))
            # 🔧 修复:正确的处理器调用方式
            inputs = self.processor(
                text=cleaned_text,  # 明确指定 text= 参数
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=64
            ).to('cuda')

            # 获取文本特征
            with torch.no_grad():
                text_embedding = model.get_text_features(**inputs)
                text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)

                embedding_np = text_embedding.cpu().numpy().flatten()

                # 🔧 添加验证
                logger.debug(f"文本嵌入维度: {len(embedding_np)}, 范数: {np.linalg.norm(embedding_np):.6f}")

                text_embeddings.append(embedding_np.tolist())

        return text_embeddings

# 自定义嵌入函数 - 用于图片处理
class SiglipEmbeddingFunction(EmbeddingFunction):
    def __init__(self, model, processor, device):
        self.model = model
        self.processor = processor
        self.device = device

    def __call__(self, texts):
        """
        ChromaDB需要这个接口,但我们实际上不会用它处理文本
        我们会直接传入预计算的向量
        """
        # 返回一个假的向量,维度为768以匹配SIGLIP的输出维度
        return [[0.0] * 768] * len(texts)

# 初始化应用
def initialize_app():
    global model, processor, client, collection

    # 如果模型和处理器未加载,进行加载
    if model is None or processor is None:
        if not init_model():
            return False

    # 如果数据库客户端未初始化,进行初始化
    if client is None:
        if not init_db():
            return False

    # 如果集合未获取,获取或创建集合
    if collection is None:
        if not get_or_create_collection():
            return False

    return True

# 确保应用初始化
initialize_app()

# API端点: 处理图片并添加到向量数据库
@app.route('/process-images', methods=['POST'])
def api_process_images():
    try:
        # 确保应用初始化
        if not initialize_app():
            return jsonify({'error': '初始化失败,无法处理图片'}), 500

        # 解析请求
        data = request.json
        input_dir = data.get('input_dir')
        batch_name = data.get('batch_name', f"batch_{int(time.time())}")
        user_mail = data.get('user_mail')
        original_file_path = data.get('original_file_path')
        selected_department = data.get('selected_department')
        timestamp = data.get('timestamp')

        # 调用处理函数
        result = process_images(
            input_dir=input_dir,
            batch_name=batch_name,
            timestamp=timestamp,
            user_mail=user_mail,
            original_file_path=original_file_path,
            selected_department=selected_department
        )

        return jsonify(result)

    except Exception as e:
        logger.error(f"处理图片API错误: {str(e)}")
        return jsonify({'error': str(e)}), 500

# 图片处理函数
def process_images(input_dir=None, batch_name=None, timestamp=None, custom_logger=None,
                  user_mail=None, original_file_path=None, selected_department=None):
    """
    处理图片并生成特征向量,保存到ChromaDB集合中
    """
    global collection, model, processor

    # 确保应用已初始化
    if not initialize_app():
        return {'error': '初始化失败,无法处理图片'}

    # 如果未提供时间戳,生成一个
    if timestamp is None:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

    # 批次名称
    batch_name = f"{batch_name}_{timestamp}" if batch_name else f"batch_{timestamp}"

    # 使用自定义logger或默认logger
    log = custom_logger or logger

    log.info(f"=== 开始处理图片,批次: {batch_name} ===")

    # 确定图片目录路径
    if input_dir:
        image_dir = input_dir
    else:
        image_dir = f"./extracted_content/{batch_name}/images_{timestamp}" if batch_name else "./extracted_content/images"

    log.info(f"图片来源目录: {image_dir}")

    # 获取所有图片文件并记录总数
    image_files = glob.glob(f"{image_dir}/*")
    total_images = len([f for f in image_files if f.endswith(('.jpg', '.png', '.jpeg'))])
    log.info(f"发现 {total_images} 个图片文件需要处理")

    # 记录处理结果
    processed_images = []

    # 批量处理,提高效率
    batch_size = 10
    for i in range(0, len(image_files), batch_size):
        batch_files = image_files[i:i+batch_size]
        batch_ids = []
        batch_embeddings = []
        batch_metadatas = []
        batch_documents = []

        for image_path in batch_files:
            # 跳过非图片文件
            if not image_path.endswith(('.jpg', '.png', '.jpeg')):
                log.debug(f"跳过非图片文件: {image_path}")
                continue

            # 获取文件基本信息
            filename = os.path.basename(image_path)
            image_id = get_image_id(image_path)  # 生成稳定的ID
            image_type = os.path.splitext(filename)[1].lower()
            image_size = os.path.getsize(image_path)

            log.info(f"处理图片: {filename} (ID: {image_id})")

            try:
                # 加载图片
                image = Image.open(image_path)
                img_width, img_height = image.size
                log.debug(f"图片加载成功. 大小: {image.size}, 模式: {image.mode}")

                # 处理图片
                processed_image = processor(images=image, return_tensors="pt").to('cuda')

                with torch.no_grad():
                    # 提取特征向量
                    embedding = model.get_image_features(**processed_image)
                    embedding = embedding / embedding.norm(dim=-1, keepdim=True)
                    embedding_np = embedding.cpu().numpy().flatten().tolist()

                # 检查向量维度是否正确
                embedding_dim = len(embedding_np)
                if embedding_dim != 768:
                    log.warning(f"向量维度异常: {embedding_dim} (应为768)")

                # 创建唯一ID格式: batch_name/image_id
                doc_id = f"{batch_name}/{image_id}"

                # 从文件名中提取页码(格式如 page1_image_1_1.png)
                page_number = None
                if filename.startswith("page"):
                    try:
                        # 使用正则表达式提取页码
                        import re
                        page_match = re.match(r'page(\d+)_', filename)
                        if page_match:
                            page_number = int(page_match.group(1))
                    except:
                        # 如果解析失败,页码保持为None
                        pass

                # 完整的元数据
                metadata = {
                    "image_id": image_id,
                    "batch_name": batch_name,
                    "filename": filename,
                    "path": image_path,
                    "original_file_path": original_file_path,
                    "image_type": image_type,
                    "image_size": image_size,
                    "width": img_width,
                    "height": img_height,
                    "user_mail": user_mail,
                    "selected_department": selected_department,
                    "page_number": page_number,  # 添加页码信息
                    "processed_at": datetime.now().isoformat()
                }

                # 添加到批处理列表
                batch_ids.append(doc_id)
                batch_embeddings.append(embedding_np)
                batch_metadatas.append(metadata)
                batch_documents.append(f"Image: {filename}")

                # 记录处理结果
                processed_images.append({
                    "id": doc_id,
                    "filename": filename,
                    "image_id": image_id
                })

                log.info(f"图片处理成功: {filename}, ID: {doc_id}")

            except Exception as e:
                log.error(f"处理图片 {filename} 时出错: {str(e)}")
                continue

        # 将批次添加到ChromaDB
        if batch_ids:
            try:
                # 检查是否存在相同ID的文档,如果存在则先删除
                existing_ids = []
                for doc_id in batch_ids:
                    try:
                        result = collection.get(ids=[doc_id])
                        if result['ids']:
                            existing_ids.append(doc_id)
                    except:
                        pass

                if existing_ids:
                    log.info(f"发现 {len(existing_ids)} 个重复ID,先删除")
                    collection.delete(ids=existing_ids)

                # 添加新的批次
                collection.add(
                    ids=batch_ids,
                    embeddings=batch_embeddings,
                    metadatas=batch_metadatas,
                    documents=batch_documents
                )
                log.info(f"成功添加 {len(batch_ids)} 个图片向量到ChromaDB")
            except Exception as e:
                log.error(f"向ChromaDB添加批次时出错: {str(e)}")

    # 获取统计信息
    count = collection.count()
    batch_count = len(processed_images)
    log.info(f"处理完成。添加了 {batch_count} 个图片向量,集合中现在共有 {count} 个向量")
    log.info("=== 处理完成 ===")

    # 返回处理结果
    result = {
        'collection_name': collection_name,
        'batch_name': batch_name,
        'processed_count': batch_count,
        'total_count': count,
        'processed_images': processed_images,
        'timestamp': timestamp
    }

    return result

# API端点: 搜索图片
def cosine_distance_to_similarity(cosine_distance):
    """将余弦距离转换为余弦相似度"""
    # 对于余弦距离:相似度 = 1 - 距离
    cosine_similarity = 1 - cosine_distance
    return max(0, min(1, cosine_similarity))  # 限制在[0, 1]范围内

@app.route('/search', methods=['POST'])
def search_images():
   try:
       # 确保应用初始化
       if not initialize_app():
           return jsonify({'error': '初始化失败,无法搜索图片'}), 500

       data = request.get_json()
       query = data.get('query', '')
       user_mail = data.get('user_mail', '')
       selected_department = data.get('selected_department', [])

       logger.info(f"处理查询: {query}, 用户: {user_mail}, 部门: {selected_department}")

       # 生成查询向量
       embedding_function = SiglipTextEmbeddingFunction(model, processor)
       query_embedding = embedding_function([query])[0]

       # 🔍 添加调试信息
       logger.info(f"查询向量维度: {len(query_embedding)}")
       logger.info(f"查询向量范数: {np.linalg.norm(query_embedding):.6f}")

       # 构建where条件
       where_condition = build_where_condition(selected_department, user_mail)
       logger.info(f"Where条件: {where_condition}")

       # 在ChromaDB中搜索相似图片
       try:
           results = collection.query(
               query_embeddings=[query_embedding],
               n_results=100,
               where=where_condition,
               include=["metadatas", "distances"]
           )
           logger.info(f"查询成功,获取到 {len(results['ids'][0]) if results['ids'] else 0} 个结果")
       except Exception as query_error:
           logger.error(f"查询错误: {str(query_error)}")
           raise

       # 检查空结果
       if not results["ids"] or not results["ids"][0]:
           logger.info("未找到任何匹配结果")
           return jsonify({'results': []})

       # 🔧 转换L2距离为余弦相似度并过滤
       filtered_results = []

       for i, result_id in enumerate(results["ids"][0]):
           metadata = results["metadatas"][0][i]
           cosine_distance = results["distances"][0][i]
           logger.info(f"Chroma返回的距离: {cosine_distance}")

           # 转换为余弦相似度
           cosine_similarity = cosine_distance_to_similarity(cosine_distance)

           # 使用相似度阈值0.7
           if cosine_similarity < 0.7:
               logger.info(f"在第{i+1}个结果处停止,相似度{cosine_similarity:.3f}低于阈值0.7")
               break

           filtered_results.append({
               'id': result_id,
               'metadata': metadata,
               'similarity': cosine_similarity,
               'cosine_distance': cosine_distance,  # 保留原始距离用于调试
               'meta_user_mail': metadata.get("user_mail", ""),
               'meta_department': metadata.get("selected_department", "")
           })

       logger.info(f"找到 {len(filtered_results)} 个相似度≥0.7的结果")

       # 构建最终结果并检查文件存在性
       processed_results = []
       file_not_found_count = 0

       for result in filtered_results:
           metadata = result['metadata']
           filename = metadata.get("filename")

           if not filename:
               logger.warning(f"结果 {result['id']} 缺少filename")
               continue

           # 获取图片路径
           image_path = metadata.get("path", os.path.join("/app/extracted_content/images", filename))

           # 验证文件存在性
           if not os.path.exists(image_path):
               logger.warning(f"图片文件不存在: {image_path}")
               file_not_found_count += 1
               continue

           # 构建图片URL
           image_url = f"/image/{filename}?path={quote(image_path)}"

           processed_results.append({
               'id': result['id'],
               'filename': filename,
               'similarity': round(result['similarity'], 3),  # 保留3位小数
               'cosine_distance': round(result['cosine_distance'], 3),  # 保留原始距离用于调试
               'path': image_path,
               'image_url': image_url,
               'user_mail': result['meta_user_mail'],
               'selected_department': result['meta_department'],
               'metadata': metadata
           })

       logger.info(f"文件存在性检查 - 文件不存在: {file_not_found_count}, 最终返回: {len(processed_results)} 个结果")
       logger.info(f"集合元数据: {collection.metadata}")
       return jsonify({'results': processed_results})

   except Exception as e:
       logger.error(f"search_images错误: {str(e)}")
       return jsonify({'error': str(e)}), 500

def build_where_condition(selected_department, user_mail):
   """根据选择的部门构建where条件"""
   if not selected_department:
       return None  # 查询所有知识库

   # 确保是列表格式
   if isinstance(selected_department, str):
       departments = [selected_department]  # 单个部门也转为列表
   else:
       departments = selected_department

   # 去除空值和重复值
   departments = list(set([dept.strip() for dept in departments if dept and dept.strip()]))

   if not departments:
       return None

   if len(departments) == 1:
       # 单个部门处理
       dept = departments[0]
       if dept == "個人ナレッジ":
           return {
               "$and": [
                   {"selected_department": "個人ナレッジ"},
                   {"user_mail": user_mail}
               ]
           }
       else:
           return {"selected_department": dept}

   else:
       # 多个部门处理
       if "個人ナレッジ" in departments:
           # 包含个人知识库的情况
           other_departments = [d for d in departments if d != "個人ナレッジ"]
           conditions = []

           # 其他部门条件
           if other_departments:
               conditions.append({"selected_department": {"$in": other_departments}})

           # 个人知识库条件(需要同时匹配邮箱)
           conditions.append({
               "$and": [
                   {"selected_department": "個人ナレッジ"},
                   {"user_mail": user_mail}
               ]
           })

           return {"$or": conditions}
       else:
           # 不包含个人知识库,直接用$in查询
           return {"selected_department": {"$in": departments}}

# API端点 - 删除图像向量
@app.route('/delete-image-vectors', methods=['POST'])
# @app.route('/delete-image-vectors', methods=['POST'])
def delete_image_vectors():
    """根据original_file_path删除图像向量数据"""
    global collection_name  # 使用全局变量

    try:
        data = request.get_json()
        original_file_path = data.get('original_file_path')

        # 暂时保存原始collection_name
        original_collection_name = collection_name

        # 如果请求中指定了集合名称,则临时修改全局变量
        if data.get('collection_name'):
            # 临时修改全局变量
            collection_name = data.get('collection_name')
        else:
            # 默认使用图像向量集合
            collection_name = 'unified_image_vectors'

        if not original_file_path:
            # 恢复原始collection_name
            collection_name = original_collection_name
            return jsonify({'error': '缺少original_file_path参数'}), 400

        logger.info(f"接收到删除图像向量请求: original_file_path={original_file_path}, collection={collection_name}")

        # 因为get_or_create_collection使用全局变量collection_name,所以不需要传参
        if not get_or_create_collection():
            # 恢复原始collection_name
            collection_name = original_collection_name
            return jsonify({'error': f'无法获取集合: {collection_name}'}), 404

        # 此时collection是全局变量,已被get_or_create_collection更新

        # 查询匹配的文档数量
        try:
            matching_docs = collection.get(
                where={"original_file_path": original_file_path}
            )

            if not matching_docs['ids']:
                # 恢复原始collection_name
                collection_name = original_collection_name
                return jsonify({
                    'success': True,
                    'message': f'未找到匹配的图像向量数据: {original_file_path}',
                    'deleted_count': 0
                })

            # 记录匹配的文档数量
            matching_count = len(matching_docs['ids'])
            logger.info(f"找到 {matching_count} 个匹配的图像向量")

            # 执行删除操作
            collection.delete(
                where={"original_file_path": original_file_path}
            )

            # 获取当前集合中的向量总数
            current_count = collection.count()

            # 恢复原始collection_name
            collection_name = original_collection_name

            return jsonify({
                'success': True,
                'message': f'成功删除图像向量数据: {original_file_path}',
                'deleted_count': matching_count,
                'current_total': current_count
            })

        except Exception as e:
            logger.error(f"删除图像向量数据时出错: {str(e)}")
            # 恢复原始collection_name
            collection_name = original_collection_name
            return jsonify({'error': f'删除图像向量数据时出错: {str(e)}'}), 500

    except Exception as e:
        logger.error(f"删除图像向量API错误: {str(e)}")
        # 确保恢复原始collection_name
        collection_name = original_collection_name if 'original_collection_name' in locals() else collection_name
        return jsonify({'error': str(e)}), 500

# API端点: 刷新集合
@app.route('/refresh', methods=['GET'])
def refresh_collection():
    try:
        # 确保应用初始化
        if not initialize_app():
            return jsonify({'error': '初始化失败,无法刷新集合'}), 500

        # 获取集合状态
        count = collection.count()

        return jsonify({
            'success': True,
            'count': count,
            'message': f'集合已刷新,当前包含 {count} 个向量',
            'collection_name': collection_name
        })
    except Exception as e:
        logger.error(f"刷新集合失败: {str(e)}")
        return jsonify({'error': str(e)}), 500

# API端点: 获取状态
@app.route('/status', methods=['GET'])
def get_status():
    try:
        # 确保应用初始化
        if not initialize_app():
            return jsonify({'error': '初始化失败,无法获取状态'}), 500

        # 获取集合状态
        count = collection.count()

        # 获取所有集合名称
        collection_names = client.list_collections()

        # 获取每个集合的计数
        collection_info = []
        for name in collection_names:
            try:
                coll = client.get_collection(name)
                collection_info.append({
                    'name': name,
                    'count': coll.count()
                })
            except Exception as e:
                collection_info.append({
                    'name': name,
                    'error': str(e)
                })

        # 检查存储目录
        db_files = []
        if os.path.exists(db_path):
            db_files = os.listdir(db_path)

        return jsonify({
            'status': 'running',
            'collection_name': collection_name,
            'vector_count': count,
            'model_loaded': model is not None,
            'processor_loaded': processor is not None,
            'all_collections': collection_info,
            'db_path': db_path,
            'db_files': db_files,
            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        })
    except Exception as e:
        logger.error(f"获取状态失败: {str(e)}")
        return jsonify({'error': str(e)}), 500

# API端点: 获取图片
@app.route('/image/<path:filename>')
def serve_image(filename):
    try:
        # 从URL参数中获取图片路径
        image_path = request.args.get('path')
        logger.info(f"从URL参数获取的路径: {image_path}")

        if image_path:
            # URL解码
            image_path = unquote(image_path)

            if os.path.exists(image_path):
                return send_file(image_path, mimetype='image/jpeg')

        # 如果找不到,尝试从ChromaDB查询
        if collection is not None:
            results = collection.query(
                query_texts=[filename],  # 使用文件名作为查询条件
                n_results=1,
                where={"filename": filename}  # 精确匹配文件名
            )

            if results["ids"] and len(results["ids"][0]) > 0:
                metadata = results["metadatas"][0][0]
                image_path = metadata.get("path")
                if image_path and os.path.exists(image_path):
                    return send_file(image_path, mimetype='image/jpeg')

        return jsonify({'error': 'Image not found'}), 404
    except Exception as e:
        logger.error(f"serve_image错误: {str(e)}")
        return jsonify({'error': str(e)}), 404

# API端点 - 清空所有图像向量
@app.route('/clear-all-vectors', methods=['POST'])
def clear_all_vectors():
    """清空集合中的所有图像向量数据"""
    global collection_name  # 使用全局变量

    try:
        data = request.get_json() or {}

        # 暂时保存原始collection_name
        original_collection_name = collection_name

        # 如果请求中指定了集合名称,则临时修改全局变量
        if data.get('collection_name'):
            collection_name = data.get('collection_name')
        else:
            # 默认使用图像向量集合
            collection_name = 'unified_image_vectors'

        logger.info(f"接收到清空所有向量请求: collection={collection_name}")

        # 因为get_or_create_collection使用全局变量collection_name,所以不需要传参
        if not get_or_create_collection():
            # 恢复原始collection_name
            collection_name = original_collection_name
            return jsonify({'error': f'无法获取集合: {collection_name}'}), 404

        # 获取清空前的向量数量
        count_before = collection.count()
        logger.info(f"清空前集合中有 {count_before} 个向量")

        # 删除集合中的所有文档
        try:
            # 获取所有文档ID
            all_docs = collection.get()

            if all_docs['ids']:
                # 删除所有文档
                collection.delete(ids=all_docs['ids'])
                logger.info(f"成功删除 {len(all_docs['ids'])} 个向量")
            else:
                logger.info("集合中没有向量需要删除")

            # 获取清空后的向量数量
            count_after = collection.count()

            # 恢复原始collection_name
            collection_name = original_collection_name

            return jsonify({
                'success': True,
                'message': f'成功清空集合中的所有向量数据',
                'deleted_count': count_before,
                'current_total': count_after
            })

        except Exception as e:
            logger.error(f"清空向量数据时出错: {str(e)}")
            # 恢复原始collection_name
            collection_name = original_collection_name
            return jsonify({'error': f'清空向量数据时出错: {str(e)}'}), 500

    except Exception as e:
        logger.error(f"清空向量API错误: {str(e)}")
        # 确保恢复原始collection_name
        collection_name = original_collection_name if 'original_collection_name' in locals() else collection_name
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    # 启动应用前确保初始化
    initialize_app()
    logger.info(f"应用已初始化,集合 {collection_name} 中有 {collection.count()} 个向量")

    # 启动应用
    app.run(host='0.0.0.0', port=7860)
```python


```

  • 写回答

5条回答 默认 最新

  • 阿里嘎多学长 2025-05-25 23:44
    关注

    阿里嘎多学长整理AIGC生成,因移动端显示问题导致当前答案未能完全显示,请使用PC端查看更加详细的解答过程

    问题解答

    你使用 ChromaDB 和 SIGLIP 实现图片向量化和文本图片相似度计算,但是发现相似度异常低(0.01),这可能是由于以下几个原因:

    1. 图片向量化的质量:SIGLIP 可能没有正确地将图片向量化,导致相似度计算结果不准确。你可以尝试使用其他图片向量化算法或调整 SIGLIP 的参数来提高向量化的质量。
    2. 文本向量化的质量:SIGLIP 也可能没有正确地将文本向量化。你可以尝试使用其他文本向量化算法或调整 SIGLIP 的参数来提高向量化的质量。
    3. ChromaDB 的配置:ChromaDB 的配置可能不正确,导致相似度计算结果不准确。你可以尝试调整 ChromaDB 的参数,例如调整相似度计算的阈值或调整索引的方式。
    4. 数据的质量:图片和文本数据的质量可能不高,导致相似度计算结果不准确。你可以尝试使用高质量的图片和文本数据来测试相似度计算。

    解决方案

    1. 重新检查图片向量化和文本向量化的过程,确保它们正确地将图片和文本转换为向量。
    2. 尝试使用其他图片向量化算法或调整 SIGLIP 的参数来提高向量化的质量。
    3. 尝试使用其他文本向量化算法或调整 SIGLIP 的参数来提高向量化的质量。
    4. 调整 ChromaDB 的参数,例如调整相似度计算的阈值或调整索引的方式。
    5. 使用高质量的图片和文本数据来测试相似度计算。

    如果你需要更多帮助,请提供更多的代码和数据信息,我将尽力帮助你解决问题。

    评论

报告相同问题?

问题事件

  • 创建了问题 5月25日