秋声studio 2025-01-14 09:23 采纳率: 0%
浏览 12
已结题

RAG召回效果不佳以及分块不佳

pdf_processor.py
这个文件负责处理PDF文档,提取文本和图像,并进行OCR(光学字符识别)。

功能:

validate_file_path(file_path): 验证文件路径是否存在。

detect_file_encoding(file_path): 检测文件的编码格式。

extract_text_from_docx(docx_path): 从DOCX文件中提取文本。

extract_images_from_pdf(pdf_path): 从PDF文件中提取图像。

ocr_on_image(image): 对图像进行OCR处理,提取文本。

extract_text_from_pdf(pdf_path): 使用PyPDFLoader从PDF文件中提取文本。

save_text_to_file(text, file_path): 将文本保存到文件中。

process_pdf(pdf_path, output_dir): 处理PDF文件,提取文本和图像,合并结果并保存。

依赖:

fitz (PyMuPDF): 用于处理PDF文件。

PIL (Pillow): 用于处理图像。

pytesseract: 用于OCR处理。

langchain.document_loaders.PyPDFLoader: 用于从PDF中提取文本。

    import fitz  # PyMuPDF
from PIL import Image
import pytesseract
from langchain.document_loaders import PyPDFLoader
import os
import traceback
from docx import Document
import chardet

# 设置 Tesseract 的路径
from dotenv import load_dotenv
load_dotenv("config/local.env")
TESSERACT_PATH = os.getenv("TESSERACT_PATH")
pytesseract.pytesseract.tesseract_cmd = TESSERACT_PATH

def validate_file_path(file_path):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")

def detect_file_encoding(file_path):
    with open(file_path, 'rb') as file:
        raw_data = file.read()
        result = chardet.detect(raw_data)
        return result['encoding']

def extract_text_from_docx(docx_path):
    doc = Document(docx_path)
    return "\n".join([para.text for para in doc.paragraphs])

def extract_images_from_pdf(pdf_path):
    pdf_document = fitz.open(pdf_path)
    images_list = []
    for page_num in range(len(pdf_document)):
        page = pdf_document.load_page(page_num)
        image_list = page.get_images(full=True)
        for img_index, img in enumerate(image_list):
            xref = img[0]
            base_image = pdf_document.extract_image(xref)
            image_bytes = base_image["image"]
            image_ext = base_image["ext"]

            if not image_bytes:
                print(f"Warning: Empty image data found on page {page_num + 1}, image index {img_index + 1}")
                continue

            try:
                # 确保 colorspace 是字符串类型
                colorspace = str(base_image["colorspace"])
                if colorspace not in ["RGB", "L", "CMYK"]:  # 仅支持常见颜色模式
                    print(f"Skipping image on page {page_num + 1}, image index {img_index + 1}: unsupported colorspace {colorspace}")
                    continue

                image = Image.frombytes(colorspace, [base_image["width"], base_image["height"]], image_bytes)
                images_list.append((image, f"page{page_num + 1}_img{img_index + 1}.{image_ext}"))
            except Exception as e:
                print(f"Error processing image on page {page_num + 1}, image index {img_index + 1}: {e}")
    return images_list

def ocr_on_image(image):
    return pytesseract.image_to_string(image)

def extract_text_from_pdf(pdf_path):
    loader = PyPDFLoader(pdf_path)
    pages = loader.load_and_split()
    return "\n".join([page.page_content for page in pages])

def save_text_to_file(text, file_path):
    with open(file_path, 'w', encoding='utf-8') as file:
        file.write(text)

def process_pdf(pdf_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    try:
        # 验证文件路径
        validate_file_path(pdf_path)

        # 提取文本
        pdf_text = extract_text_from_pdf(pdf_path)

        # 提取图像并进行 OCR
        images = extract_images_from_pdf(pdf_path)
        ocr_texts = [ocr_on_image(image) for image, _ in images]

        # 合并文本和 OCR 结果
        combined_text = "\n".join(ocr_texts) + "\n" + pdf_text

        # 保存合并后的文本
        output_file = os.path.join(output_dir, "combined_text.txt")
        save_text_to_file(combined_text, output_file)
        print(f"Saved combined text to: {output_file}")

    except Exception as e:
        print(f"Error processing PDF: {e}")
        traceback.print_exc()

recall_module.py
这个文件负责从向量数据库和Redis中召回相关信息。

功能:

recall_from_vector_db(query, top_k=3, score_threshold=0.7): 从向量数据库中召回与查询相关的内容。

recall_from_redis(query_keywords): 从Redis中召回与查询关键词相关的内容。

hybrid_recall(query, top_k=3): 结合向量数据库和Redis的召回结果,返回最相关的内容。

依赖:

langchain.document_loaders.TextLoader: 用于加载文本文件。

langchain.text_splitter.CharacterTextSplitter: 用于将文本分割成块。

langchain.embeddings.HuggingFaceEmbeddings: 用于生成文本的嵌入向量。

langchain.vectorstores.FAISS: 用于创建和管理向量数据库。

redis: 用于与Redis数据库交互。

sklearn.metrics.pairwise.cosine_similarity: 用于计算文本相似度。

    import os
import json
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
import redis
from dotenv import load_dotenv
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import CountVectorizer
from langchain.schema import Document

# 加载配置文件
load_dotenv("config/local.env")
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
REDIS_DB = int(os.getenv("REDIS_DB", 0))

class RecallModule:
    def __init__(self):
        self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)

    def recall_from_vector_db(self, query, top_k=3, score_threshold=0.7):
        """
        从向量数据库中召回相关内容
        :param query: 用户查询
        :param top_k: 返回的文本块数量
        :param score_threshold: 相似度阈值
        :return: 召回的相关问题及其对应的块内容
        """
        try:
            # 读取修正后的文本块和生成的问题
            with open("data/txt/corrected_combined_text.json", "r", encoding="utf-8") as f:
                chunk_questions = json.load(f)

            # 将生成的问题和对应的块内容转换为 Document 对象
            documents = []
            for chunk, questions in chunk_questions:
                for question in questions:
                    metadata = {"chunk": chunk, "questions": questions}
                    documents.append(Document(page_content=question, metadata=metadata))

            # 使用 HuggingFace 的本地 Embedding 模型
            embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

            # 创建向量数据库
            db = FAISS.from_documents(documents=documents, embedding=embeddings)

            # 召回并排序
            results = db.similarity_search_with_score(query, k=top_k)
            filtered_results = [doc for doc, score in results if score >= score_threshold]

            # 返回最相似的问题及其对应的块内容
            return filtered_results
        except Exception as e:
            print(f"Error loading or processing documents: {e}")
            return []

    def recall_from_redis(self, query_keywords):
        """
        从 Redis 中召回相关内容
        :param query_keywords: 用户查询的关键词列表
        :return: 召回的相关问答对列表
        """
        recalled_qa = []
        cursor = 0
        try:
            while True:
                cursor, keys = self.redis_client.scan(cursor, match='qa:*', count=100)
                for key in keys:
                    tags = self.redis_client.hget(key, 'tags')
                    if tags:
                        # 计算标签与查询关键词的相似度
                        tags_list = tags.split(",")
                        vectorizer = CountVectorizer().fit(tags_list + query_keywords)
                        tags_vec = vectorizer.transform(tags_list)
                        query_vec = vectorizer.transform(query_keywords)
                        similarity = cosine_similarity(tags_vec, query_vec).mean()

                        if similarity > 0.3:  # 设置相似度阈值
                            answer = self.redis_client.hget(key, 'answer')
                            recalled_qa.append((key, answer, similarity))
                if cursor == 0:
                    break

            # 按相似度排序
            recalled_qa.sort(key=lambda x: x[2], reverse=True)
            return [(qa[0], qa[1]) for qa in recalled_qa]
        except Exception as e:
            print(f"Error recalling from Redis: {e}")
            return []

    def hybrid_recall(self, query, top_k=3):
        """
        混合召回:从向量数据库和 Redis 中召回相关内容
        :param query: 用户查询
        :param top_k: 返回的文本块数量
        :return: 召回的相关内容列表
        """
        # 扩展查询关键词
        expanded_query = query + " 腹痛 消化系统疾病"

        # 从向量数据库中召回
        vector_results = self.recall_from_vector_db(expanded_query, top_k)

        # 从 Redis 中召回
        redis_results = self.recall_from_redis(expanded_query.split())

        # 融合结果并排序
        combined_results = []
        for doc in vector_results:
            combined_results.append(("vector", doc.page_content, doc.metadata.get("chunk", ""), doc.metadata.get("questions", []), 1.0))  # 假设向量召回结果的分数为 1.0
        for qa in redis_results:
            combined_results.append(("redis", f"Query: {qa[0]}\nAnswer: {qa[1]}", "", [], 0.8))  # 假设 Redis 召回结果的分数为 0.8

        # 按分数排序
        combined_results.sort(key=lambda x: x[4], reverse=True)

        # 返回前 top_k 个结果
        return {
            "vector_results": [result[1:4] for result in combined_results if result[0] == "vector"][:top_k],
            "redis_results": [result[1:4] for result in combined_results if result[0] == "redis"][:top_k]
        }

redis_manager.py
这个文件负责管理Redis数据库中的问答对。

功能:

store_qa_pair(query, answer, tags): 将问答对存储到Redis中。

get_qa_pair(query): 从Redis中获取问答对。

clean_redis_data(max_age_days=30): 清理过期的Redis数据。

acquire_lock(lock_key, timeout=10): 获取Redis锁。

release_lock(lock_key): 释放Redis锁。

依赖:

redis: 用于与Redis数据库交互。


import redis
from dotenv import load_dotenv
import os
import time  # 导入 time 模块

# 加载 local.env
load_dotenv("config/local.env")
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
REDIS_DB = int(os.getenv("REDIS_DB", 0))

class RedisManager:
    def __init__(self):
        self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)

    def store_qa_pair(self, query, answer, tags):
        """
        存储问答对到 Redis
        :param query: 问题
        :param answer: 答案
        :param tags: 标签列表
        """
        self.redis_client.hset('qa:' + query, 'answer', answer)
        self.redis_client.hset('qa:' + query, 'tags', ",".join(tags))
        self.redis_client.hset('qa:' + query, 'last_accessed', time.time())  # 使用 time.time()

    def get_qa_pair(self, query):
        """
        从 Redis 中获取问答对
        :param query: 问题
        :return: 答案
        """
        return self.redis_client.hget('qa:' + query, 'answer')

    def clean_redis_data(self, max_age_days=30):
        """
        清理过期数据
        :param max_age_days: 数据最大保留天数
        """
        cursor = 0
        while True:
            cursor, keys = self.redis_client.scan(cursor, match='qa:*', count=100)
            for key in keys:
                last_accessed = self.redis_client.hget(key, 'last_accessed')
                if last_accessed and (time.time() - float(last_accessed)) > max_age_days * 86400:
                    self.redis_client.delete(key)
            if cursor == 0:
                break

    def acquire_lock(self, lock_key, timeout=10):
        """
        获取锁
        :param lock_key: 锁的键
        :param timeout: 锁的超时时间
        :return: 是否成功获取锁
        """
        return self.redis_client.setnx(lock_key, time.time() + timeout)

    def release_lock(self, lock_key):
        """
        释放锁
        :param lock_key: 锁的键
        """
        self.redis_client.delete(lock_key)

text_corrector.py
这个文件负责修正文本并生成相关问题。

功能:

correct_text_chunk(user_message): 调用DeepSeek API修正文本块。

generate_questions_for_chunk(chunk): 为文本块生成相关问题。

save_corrected_text(chunk_questions, output_dir): 保存修正后的文本和生成的问题。

split_text_into_chunks(text, max_chunk_size=400): 将文本分割成块。

process_file(file_path, output_dir): 处理文本文件,分块、修正、生成问题并保存。

依赖:

openai: 用于调用DeepSeek API。

multiprocessing.Pool: 用于多进程处理文本修正。


import time
import os
import re
import traceback
import json
from multiprocessing import Pool, cpu_count
import openai
from dotenv import load_dotenv

# 加载 deepseek.env
load_dotenv("config/deepseek.env")
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.api_base = os.getenv("OPENAI_API_BASE")

def correct_text_chunk(user_message):
    """
    调用 DeepSeek API 进行文本修正
    :param user_message: 需要修正的文本块
    :return: 修正后的文本块和处理时间
    """
    max_retries = 3
    retry_delay = 2
    for attempt in range(max_retries):
        try:
            # 读取提示词文件
            with open("tests/text_correction3.txt", "r", encoding="utf-8") as f:
                prompt = f.read().strip()

            start_time = time.time()
            completion = openai.ChatCompletion.create(
                model="deepseek-chat",
                messages=[
                    {'role': 'system', 'content': prompt},  # 使用文件中的提示词
                    {"role": "user", "content": user_message}
                ],
                temperature=0.7,
                max_tokens=2000
            )
            return completion.choices[0].message.content, time.time() - start_time
        except Exception as e:
            if attempt < max_retries - 1:
                time.sleep(retry_delay)
                continue
            print(f"Error after {max_retries} attempts: {e}")
            return user_message, 0

def generate_questions_for_chunk(chunk):
    """
    为每个块生成三个问题
    :param chunk: 文本块
    :return: 生成的问题列表
    """
    max_retries = 3
    retry_delay = 2
    for attempt in range(max_retries):
        try:
            start_time = time.time()
            completion = openai.ChatCompletion.create(
                model="deepseek-chat",
                messages=[
                    {'role': 'system', 'content': "你是一个医学文献助手,请为以下文本生成三个相关问题。"},
                    {"role": "user", "content": chunk}
                ],
                temperature=0.7,
                max_tokens=2000
            )
            questions = completion.choices[0].message.content.split("\n")
            questions = [q.strip() for q in questions if q.strip()]  # 去除空行和多余空格
            return questions[:3]  # 返回前三个问题
        except Exception as e:
            if attempt < max_retries - 1:
                time.sleep(retry_delay)
                continue
            print(f"Error generating questions for chunk: {e}")
            return []

def save_corrected_text(chunk_questions, output_dir):
    """
    保存修正后的文本块和生成的问题
    :param chunk_questions: 修正后的文本块和生成的问题列表
      output_dir: 输出目录
    """
    # timestamp = time.strftime("%Y%m%d_%H%M%S") _{timestamp}
    output_file = os.path.join(output_dir, f"corrected_combined_text.json")  # 保存为 .json 文件
    try:
        # 将每个块和生成的问题保存为 JSON 格式
        with open(output_file, 'w', encoding='utf-8') as file:
            json.dump(chunk_questions, file, ensure_ascii=False, indent=4)  # 使用 indent 格式化 JSON
        print(f"Saved corrected text to: {output_file}")
    except Exception as e:
        print(f"Error saving corrected text: {e}")

def split_text_into_chunks(text, max_chunk_size=400):
    """
    将文本按句子和段落分块,并控制分块大小
    :param text: 原始文本
    :param max_chunk_size: 每个块的最大字符数
    :return: 分块后的文本列表
    """
    chunks = []
    current_chunk = ""

    # 按段落分割文本
    paragraphs = text.split("\n\n")  # 假设段落之间用空行分隔
    for paragraph in paragraphs:
        # 按句子分割段落
        sentences = re.split(r'(?<=[。!?])', paragraph)
        for sentence in sentences:
            if len(current_chunk) + len(sentence) <= max_chunk_size:
                current_chunk += sentence
            else:
                if current_chunk.strip():  # 避免空块
                    chunks.append(current_chunk.strip())
                current_chunk = sentence
        if current_chunk.strip():  # 避免空块
            chunks.append(current_chunk.strip())
            current_chunk = ""

    return chunks

def process_file(file_path, output_dir):
    """
    处理文本文件,分块、修正、生成问题、保存
    :param file_path: 文本文件路径
    :param output_dir: 输出目录
    :return: 修正后的文本块和生成的问题
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            original_text = file.read().strip()

        if not original_text:
            print(f"Skipping empty file: {file_path}")
            return []

        # 分块
        chunks = split_text_into_chunks(original_text)

        # 多进程修正
        start_time = time.time()
        with Pool(cpu_count()) as pool:
            results = pool.map(correct_text_chunk, chunks)

        print(f'Processing {file_path} took {time.time() - start_time:.2f} seconds')

        # 提取修正后的文本块
        corrected_chunks = [result for result, _ in results]

        # 为每个块生成问题
        chunk_questions = []
        for chunk in corrected_chunks:
            questions = generate_questions_for_chunk(chunk)
            chunk_questions.append((chunk, questions))

        # 保存修正后的文本块和生成的问题
        save_corrected_text(chunk_questions, output_dir)

        return chunk_questions

    except Exception as e:
        print(f"Error processing file {file_path}: {e}")
        traceback.print_exc()
        return []

main.py
这个文件是项目的主入口,负责调用各个模块的功能。

功能:

setup_logging(): 设置日志记录。

main(): 主函数,依次处理PDF文件、修正文本、生成问题,并进行信息检索。

流程:

处理PDF文件,提取文本和图像。

修正文本并生成相关问题。

初始化召回模块。

用户输入查询。

执行混合召回,从向量数据库和Redis中获取相关结果。

输出召回结果。

依赖:

logging: 用于日志记录。

dotenv: 用于加载环境变量。

recall_module.RecallModule: 用于召回相关信息。

pdf_processor.process_pdf: 用于处理PDF文件。

text_corrector.process_file: 用于修正文本并生成问题。


import logging
from logging.handlers import RotatingFileHandler
from dotenv import load_dotenv
import os
from modules.recall_module import RecallModule
from modules.pdf_processor import process_pdf
from modules.text_corrector import process_file

def setup_logging():
    """
    设置日志记录
    """
    log_file = "logs/app.log"
    handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger = logging.getLogger()
    logger.addHandler(handler)
    logger.setLevel(logging.INFO)

def main():
    """
    主函数,调用整个项目的功能
    """
    setup_logging()
    logging.info("Starting the application")

    # 加载配置文件
    load_dotenv("config/local.env")
    PDF_PATH = os.getenv("PDF_PATH", "data/pdf/中国心力衰竭诊断和治疗指南2024.pdf")
    TXT_DIR = os.getenv("TXT_DIR", "data/txt")
    CORRECTED_TXT_DIR = os.getenv("CORRECTED_TXT_DIR", "data/txt")

    # 1. 处理 PDF 文件
    logging.info("Processing PDF file...")
    process_pdf(PDF_PATH, TXT_DIR)

    logging.info("PDF processing completed.")

    # 2. 修正文本并生成问题
    combined_text_path = os.path.join(TXT_DIR, "combined_text.txt")
    logging.info("Correcting text and generating questions...")
    chunk_questions = process_file(combined_text_path, CORRECTED_TXT_DIR)
    logging.info(f"Corrected {len(chunk_questions)} chunks and generated questions.")
    logging.info("Text correction and question generation completed.")

    # 3. 初始化召回模块
    recall_module = RecallModule()

    # 4. 用户查询
    query = input("Enter your query: ")
    top_k = int(input("Enter the number of results to retrieve (top_k): "))

    # 5. 混合召回
    logging.info("Performing hybrid recall...")
    results = recall_module.hybrid_recall(query, top_k=top_k)
    logging.info("Hybrid recall completed.")

    # 6. 输出召回结果
    print("Vector Results:")
    for question, chunk, questions in results["vector_results"]:
        print(f"Question: {question}")
        print(f"Chunk: {chunk}")
        print(f"Generated Questions: {questions}")
        print()

    print("\nRedis Results:")
    for qa in results["redis_results"]:
        print(f"Query: {qa[0]}\nAnswer: {qa[1]}\n")

if __name__ == "__main__":
    main()
  • 写回答

2条回答 默认 最新

  • 关注

    以下回复参考:皆我百晓生券券喵儿等免费微信小程序作答:

    从您提供的代码来看,您似乎正在开发一个涉及PDF处理、文本修正、问题生成以及信息检索的系统。您遇到了召回效果不佳和分块不佳的问题。

    1. 召回效果不佳

      • 检查查询和数据库内容的相关性:确保您的查询与数据库中存储的问题和答案具有足够的相关性。如果查询与数据库内容不匹配,召回的结果可能不会很好。
      • 调整相似度阈值:在recall_from_redis函数中,您使用了0.3的相似度阈值。尝试调整此阈值,看看是否可以改善召回结果。
      • 优化文本向量化:在recall_from_vector_db函数中,您使用了HuggingFaceEmbeddings来生成文本的嵌入向量。尝试使用不同的嵌入模型或调整嵌入模型的参数,以获取更好的文本表示。
      • 增加训练数据:如果您的向量数据库的数据量不够,可能无法提供足够的召回结果。考虑增加训练数据,以改善召回性能。
    2. 分块不佳

      • 调整分块大小:在split_text_into_chunks函数中,您使用了400个字符作为分块的最大大小。尝试调整此值,以找到最适合您的应用的分块大小。
      • 考虑句子长度和段落结构:在分块时,可以考虑句子的长度和段落结构,以确保每个块都具有逻辑上的完整性。
      • 使用更复杂的分块策略:您目前基于字符数进行分块。考虑使用更复杂的分块策略,如基于句子、段落或章节进行分块。

    除了上述建议,您还可以考虑以下优化:

    • 优化Redis存储:在redis_manager.py中,考虑使用哈希表(Hash)或集合(Set)来存储问答对,以便更高效地检索和清理数据。
    • 错误处理和日志记录:确保您的代码包含适当的错误处理和日志记录机制,以便在出现问题时能够诊断和解决问题。
    • 扩展功能:考虑将您的系统扩展到支持更多类型的文档格式(如Word、HTML等),以及更高级的文本修正和生成问题功能。

    希望这些建议能帮助您改善系统的召回效果和分块效果。如果还有其他问题,请随时告诉我!

    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 1月15日
  • 创建了问题 1月14日