pdf_processor.py
这个文件负责处理PDF文档,提取文本和图像,并进行OCR(光学字符识别)。
功能:
validate_file_path(file_path): 验证文件路径是否存在。
detect_file_encoding(file_path): 检测文件的编码格式。
extract_text_from_docx(docx_path): 从DOCX文件中提取文本。
extract_images_from_pdf(pdf_path): 从PDF文件中提取图像。
ocr_on_image(image): 对图像进行OCR处理,提取文本。
extract_text_from_pdf(pdf_path): 使用PyPDFLoader从PDF文件中提取文本。
save_text_to_file(text, file_path): 将文本保存到文件中。
process_pdf(pdf_path, output_dir): 处理PDF文件,提取文本和图像,合并结果并保存。
依赖:
fitz (PyMuPDF): 用于处理PDF文件。
PIL (Pillow): 用于处理图像。
pytesseract: 用于OCR处理。
langchain.document_loaders.PyPDFLoader: 用于从PDF中提取文本。
import fitz # PyMuPDF
from PIL import Image
import pytesseract
from langchain.document_loaders import PyPDFLoader
import os
import traceback
from docx import Document
import chardet
# 设置 Tesseract 的路径
from dotenv import load_dotenv
load_dotenv("config/local.env")
TESSERACT_PATH = os.getenv("TESSERACT_PATH")
pytesseract.pytesseract.tesseract_cmd = TESSERACT_PATH
def validate_file_path(file_path):
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
def detect_file_encoding(file_path):
with open(file_path, 'rb') as file:
raw_data = file.read()
result = chardet.detect(raw_data)
return result['encoding']
def extract_text_from_docx(docx_path):
doc = Document(docx_path)
return "\n".join([para.text for para in doc.paragraphs])
def extract_images_from_pdf(pdf_path):
pdf_document = fitz.open(pdf_path)
images_list = []
for page_num in range(len(pdf_document)):
page = pdf_document.load_page(page_num)
image_list = page.get_images(full=True)
for img_index, img in enumerate(image_list):
xref = img[0]
base_image = pdf_document.extract_image(xref)
image_bytes = base_image["image"]
image_ext = base_image["ext"]
if not image_bytes:
print(f"Warning: Empty image data found on page {page_num + 1}, image index {img_index + 1}")
continue
try:
# 确保 colorspace 是字符串类型
colorspace = str(base_image["colorspace"])
if colorspace not in ["RGB", "L", "CMYK"]: # 仅支持常见颜色模式
print(f"Skipping image on page {page_num + 1}, image index {img_index + 1}: unsupported colorspace {colorspace}")
continue
image = Image.frombytes(colorspace, [base_image["width"], base_image["height"]], image_bytes)
images_list.append((image, f"page{page_num + 1}_img{img_index + 1}.{image_ext}"))
except Exception as e:
print(f"Error processing image on page {page_num + 1}, image index {img_index + 1}: {e}")
return images_list
def ocr_on_image(image):
return pytesseract.image_to_string(image)
def extract_text_from_pdf(pdf_path):
loader = PyPDFLoader(pdf_path)
pages = loader.load_and_split()
return "\n".join([page.page_content for page in pages])
def save_text_to_file(text, file_path):
with open(file_path, 'w', encoding='utf-8') as file:
file.write(text)
def process_pdf(pdf_path, output_dir):
os.makedirs(output_dir, exist_ok=True)
try:
# 验证文件路径
validate_file_path(pdf_path)
# 提取文本
pdf_text = extract_text_from_pdf(pdf_path)
# 提取图像并进行 OCR
images = extract_images_from_pdf(pdf_path)
ocr_texts = [ocr_on_image(image) for image, _ in images]
# 合并文本和 OCR 结果
combined_text = "\n".join(ocr_texts) + "\n" + pdf_text
# 保存合并后的文本
output_file = os.path.join(output_dir, "combined_text.txt")
save_text_to_file(combined_text, output_file)
print(f"Saved combined text to: {output_file}")
except Exception as e:
print(f"Error processing PDF: {e}")
traceback.print_exc()
recall_module.py
这个文件负责从向量数据库和Redis中召回相关信息。
功能:
recall_from_vector_db(query, top_k=3, score_threshold=0.7): 从向量数据库中召回与查询相关的内容。
recall_from_redis(query_keywords): 从Redis中召回与查询关键词相关的内容。
hybrid_recall(query, top_k=3): 结合向量数据库和Redis的召回结果,返回最相关的内容。
依赖:
langchain.document_loaders.TextLoader: 用于加载文本文件。
langchain.text_splitter.CharacterTextSplitter: 用于将文本分割成块。
langchain.embeddings.HuggingFaceEmbeddings: 用于生成文本的嵌入向量。
langchain.vectorstores.FAISS: 用于创建和管理向量数据库。
redis: 用于与Redis数据库交互。
sklearn.metrics.pairwise.cosine_similarity: 用于计算文本相似度。
import os
import json
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
import redis
from dotenv import load_dotenv
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import CountVectorizer
from langchain.schema import Document
# 加载配置文件
load_dotenv("config/local.env")
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
REDIS_DB = int(os.getenv("REDIS_DB", 0))
class RecallModule:
def __init__(self):
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
def recall_from_vector_db(self, query, top_k=3, score_threshold=0.7):
"""
从向量数据库中召回相关内容
:param query: 用户查询
:param top_k: 返回的文本块数量
:param score_threshold: 相似度阈值
:return: 召回的相关问题及其对应的块内容
"""
try:
# 读取修正后的文本块和生成的问题
with open("data/txt/corrected_combined_text.json", "r", encoding="utf-8") as f:
chunk_questions = json.load(f)
# 将生成的问题和对应的块内容转换为 Document 对象
documents = []
for chunk, questions in chunk_questions:
for question in questions:
metadata = {"chunk": chunk, "questions": questions}
documents.append(Document(page_content=question, metadata=metadata))
# 使用 HuggingFace 的本地 Embedding 模型
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# 创建向量数据库
db = FAISS.from_documents(documents=documents, embedding=embeddings)
# 召回并排序
results = db.similarity_search_with_score(query, k=top_k)
filtered_results = [doc for doc, score in results if score >= score_threshold]
# 返回最相似的问题及其对应的块内容
return filtered_results
except Exception as e:
print(f"Error loading or processing documents: {e}")
return []
def recall_from_redis(self, query_keywords):
"""
从 Redis 中召回相关内容
:param query_keywords: 用户查询的关键词列表
:return: 召回的相关问答对列表
"""
recalled_qa = []
cursor = 0
try:
while True:
cursor, keys = self.redis_client.scan(cursor, match='qa:*', count=100)
for key in keys:
tags = self.redis_client.hget(key, 'tags')
if tags:
# 计算标签与查询关键词的相似度
tags_list = tags.split(",")
vectorizer = CountVectorizer().fit(tags_list + query_keywords)
tags_vec = vectorizer.transform(tags_list)
query_vec = vectorizer.transform(query_keywords)
similarity = cosine_similarity(tags_vec, query_vec).mean()
if similarity > 0.3: # 设置相似度阈值
answer = self.redis_client.hget(key, 'answer')
recalled_qa.append((key, answer, similarity))
if cursor == 0:
break
# 按相似度排序
recalled_qa.sort(key=lambda x: x[2], reverse=True)
return [(qa[0], qa[1]) for qa in recalled_qa]
except Exception as e:
print(f"Error recalling from Redis: {e}")
return []
def hybrid_recall(self, query, top_k=3):
"""
混合召回:从向量数据库和 Redis 中召回相关内容
:param query: 用户查询
:param top_k: 返回的文本块数量
:return: 召回的相关内容列表
"""
# 扩展查询关键词
expanded_query = query + " 腹痛 消化系统疾病"
# 从向量数据库中召回
vector_results = self.recall_from_vector_db(expanded_query, top_k)
# 从 Redis 中召回
redis_results = self.recall_from_redis(expanded_query.split())
# 融合结果并排序
combined_results = []
for doc in vector_results:
combined_results.append(("vector", doc.page_content, doc.metadata.get("chunk", ""), doc.metadata.get("questions", []), 1.0)) # 假设向量召回结果的分数为 1.0
for qa in redis_results:
combined_results.append(("redis", f"Query: {qa[0]}\nAnswer: {qa[1]}", "", [], 0.8)) # 假设 Redis 召回结果的分数为 0.8
# 按分数排序
combined_results.sort(key=lambda x: x[4], reverse=True)
# 返回前 top_k 个结果
return {
"vector_results": [result[1:4] for result in combined_results if result[0] == "vector"][:top_k],
"redis_results": [result[1:4] for result in combined_results if result[0] == "redis"][:top_k]
}
redis_manager.py
这个文件负责管理Redis数据库中的问答对。
功能:
store_qa_pair(query, answer, tags): 将问答对存储到Redis中。
get_qa_pair(query): 从Redis中获取问答对。
clean_redis_data(max_age_days=30): 清理过期的Redis数据。
acquire_lock(lock_key, timeout=10): 获取Redis锁。
release_lock(lock_key): 释放Redis锁。
依赖:
redis: 用于与Redis数据库交互。
import redis
from dotenv import load_dotenv
import os
import time # 导入 time 模块
# 加载 local.env
load_dotenv("config/local.env")
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
REDIS_DB = int(os.getenv("REDIS_DB", 0))
class RedisManager:
def __init__(self):
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
def store_qa_pair(self, query, answer, tags):
"""
存储问答对到 Redis
:param query: 问题
:param answer: 答案
:param tags: 标签列表
"""
self.redis_client.hset('qa:' + query, 'answer', answer)
self.redis_client.hset('qa:' + query, 'tags', ",".join(tags))
self.redis_client.hset('qa:' + query, 'last_accessed', time.time()) # 使用 time.time()
def get_qa_pair(self, query):
"""
从 Redis 中获取问答对
:param query: 问题
:return: 答案
"""
return self.redis_client.hget('qa:' + query, 'answer')
def clean_redis_data(self, max_age_days=30):
"""
清理过期数据
:param max_age_days: 数据最大保留天数
"""
cursor = 0
while True:
cursor, keys = self.redis_client.scan(cursor, match='qa:*', count=100)
for key in keys:
last_accessed = self.redis_client.hget(key, 'last_accessed')
if last_accessed and (time.time() - float(last_accessed)) > max_age_days * 86400:
self.redis_client.delete(key)
if cursor == 0:
break
def acquire_lock(self, lock_key, timeout=10):
"""
获取锁
:param lock_key: 锁的键
:param timeout: 锁的超时时间
:return: 是否成功获取锁
"""
return self.redis_client.setnx(lock_key, time.time() + timeout)
def release_lock(self, lock_key):
"""
释放锁
:param lock_key: 锁的键
"""
self.redis_client.delete(lock_key)
text_corrector.py
这个文件负责修正文本并生成相关问题。
功能:
correct_text_chunk(user_message): 调用DeepSeek API修正文本块。
generate_questions_for_chunk(chunk): 为文本块生成相关问题。
save_corrected_text(chunk_questions, output_dir): 保存修正后的文本和生成的问题。
split_text_into_chunks(text, max_chunk_size=400): 将文本分割成块。
process_file(file_path, output_dir): 处理文本文件,分块、修正、生成问题并保存。
依赖:
openai: 用于调用DeepSeek API。
multiprocessing.Pool: 用于多进程处理文本修正。
import time
import os
import re
import traceback
import json
from multiprocessing import Pool, cpu_count
import openai
from dotenv import load_dotenv
# 加载 deepseek.env
load_dotenv("config/deepseek.env")
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.api_base = os.getenv("OPENAI_API_BASE")
def correct_text_chunk(user_message):
"""
调用 DeepSeek API 进行文本修正
:param user_message: 需要修正的文本块
:return: 修正后的文本块和处理时间
"""
max_retries = 3
retry_delay = 2
for attempt in range(max_retries):
try:
# 读取提示词文件
with open("tests/text_correction3.txt", "r", encoding="utf-8") as f:
prompt = f.read().strip()
start_time = time.time()
completion = openai.ChatCompletion.create(
model="deepseek-chat",
messages=[
{'role': 'system', 'content': prompt}, # 使用文件中的提示词
{"role": "user", "content": user_message}
],
temperature=0.7,
max_tokens=2000
)
return completion.choices[0].message.content, time.time() - start_time
except Exception as e:
if attempt < max_retries - 1:
time.sleep(retry_delay)
continue
print(f"Error after {max_retries} attempts: {e}")
return user_message, 0
def generate_questions_for_chunk(chunk):
"""
为每个块生成三个问题
:param chunk: 文本块
:return: 生成的问题列表
"""
max_retries = 3
retry_delay = 2
for attempt in range(max_retries):
try:
start_time = time.time()
completion = openai.ChatCompletion.create(
model="deepseek-chat",
messages=[
{'role': 'system', 'content': "你是一个医学文献助手,请为以下文本生成三个相关问题。"},
{"role": "user", "content": chunk}
],
temperature=0.7,
max_tokens=2000
)
questions = completion.choices[0].message.content.split("\n")
questions = [q.strip() for q in questions if q.strip()] # 去除空行和多余空格
return questions[:3] # 返回前三个问题
except Exception as e:
if attempt < max_retries - 1:
time.sleep(retry_delay)
continue
print(f"Error generating questions for chunk: {e}")
return []
def save_corrected_text(chunk_questions, output_dir):
"""
保存修正后的文本块和生成的问题
:param chunk_questions: 修正后的文本块和生成的问题列表
output_dir: 输出目录
"""
# timestamp = time.strftime("%Y%m%d_%H%M%S") _{timestamp}
output_file = os.path.join(output_dir, f"corrected_combined_text.json") # 保存为 .json 文件
try:
# 将每个块和生成的问题保存为 JSON 格式
with open(output_file, 'w', encoding='utf-8') as file:
json.dump(chunk_questions, file, ensure_ascii=False, indent=4) # 使用 indent 格式化 JSON
print(f"Saved corrected text to: {output_file}")
except Exception as e:
print(f"Error saving corrected text: {e}")
def split_text_into_chunks(text, max_chunk_size=400):
"""
将文本按句子和段落分块,并控制分块大小
:param text: 原始文本
:param max_chunk_size: 每个块的最大字符数
:return: 分块后的文本列表
"""
chunks = []
current_chunk = ""
# 按段落分割文本
paragraphs = text.split("\n\n") # 假设段落之间用空行分隔
for paragraph in paragraphs:
# 按句子分割段落
sentences = re.split(r'(?<=[。!?])', paragraph)
for sentence in sentences:
if len(current_chunk) + len(sentence) <= max_chunk_size:
current_chunk += sentence
else:
if current_chunk.strip(): # 避免空块
chunks.append(current_chunk.strip())
current_chunk = sentence
if current_chunk.strip(): # 避免空块
chunks.append(current_chunk.strip())
current_chunk = ""
return chunks
def process_file(file_path, output_dir):
"""
处理文本文件,分块、修正、生成问题、保存
:param file_path: 文本文件路径
:param output_dir: 输出目录
:return: 修正后的文本块和生成的问题
"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
original_text = file.read().strip()
if not original_text:
print(f"Skipping empty file: {file_path}")
return []
# 分块
chunks = split_text_into_chunks(original_text)
# 多进程修正
start_time = time.time()
with Pool(cpu_count()) as pool:
results = pool.map(correct_text_chunk, chunks)
print(f'Processing {file_path} took {time.time() - start_time:.2f} seconds')
# 提取修正后的文本块
corrected_chunks = [result for result, _ in results]
# 为每个块生成问题
chunk_questions = []
for chunk in corrected_chunks:
questions = generate_questions_for_chunk(chunk)
chunk_questions.append((chunk, questions))
# 保存修正后的文本块和生成的问题
save_corrected_text(chunk_questions, output_dir)
return chunk_questions
except Exception as e:
print(f"Error processing file {file_path}: {e}")
traceback.print_exc()
return []
main.py
这个文件是项目的主入口,负责调用各个模块的功能。
功能:
setup_logging(): 设置日志记录。
main(): 主函数,依次处理PDF文件、修正文本、生成问题,并进行信息检索。
流程:
处理PDF文件,提取文本和图像。
修正文本并生成相关问题。
初始化召回模块。
用户输入查询。
执行混合召回,从向量数据库和Redis中获取相关结果。
输出召回结果。
依赖:
logging: 用于日志记录。
dotenv: 用于加载环境变量。
recall_module.RecallModule: 用于召回相关信息。
pdf_processor.process_pdf: 用于处理PDF文件。
text_corrector.process_file: 用于修正文本并生成问题。
import logging
from logging.handlers import RotatingFileHandler
from dotenv import load_dotenv
import os
from modules.recall_module import RecallModule
from modules.pdf_processor import process_pdf
from modules.text_corrector import process_file
def setup_logging():
"""
设置日志记录
"""
log_file = "logs/app.log"
handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(handler)
logger.setLevel(logging.INFO)
def main():
"""
主函数,调用整个项目的功能
"""
setup_logging()
logging.info("Starting the application")
# 加载配置文件
load_dotenv("config/local.env")
PDF_PATH = os.getenv("PDF_PATH", "data/pdf/中国心力衰竭诊断和治疗指南2024.pdf")
TXT_DIR = os.getenv("TXT_DIR", "data/txt")
CORRECTED_TXT_DIR = os.getenv("CORRECTED_TXT_DIR", "data/txt")
# 1. 处理 PDF 文件
logging.info("Processing PDF file...")
process_pdf(PDF_PATH, TXT_DIR)
logging.info("PDF processing completed.")
# 2. 修正文本并生成问题
combined_text_path = os.path.join(TXT_DIR, "combined_text.txt")
logging.info("Correcting text and generating questions...")
chunk_questions = process_file(combined_text_path, CORRECTED_TXT_DIR)
logging.info(f"Corrected {len(chunk_questions)} chunks and generated questions.")
logging.info("Text correction and question generation completed.")
# 3. 初始化召回模块
recall_module = RecallModule()
# 4. 用户查询
query = input("Enter your query: ")
top_k = int(input("Enter the number of results to retrieve (top_k): "))
# 5. 混合召回
logging.info("Performing hybrid recall...")
results = recall_module.hybrid_recall(query, top_k=top_k)
logging.info("Hybrid recall completed.")
# 6. 输出召回结果
print("Vector Results:")
for question, chunk, questions in results["vector_results"]:
print(f"Question: {question}")
print(f"Chunk: {chunk}")
print(f"Generated Questions: {questions}")
print()
print("\nRedis Results:")
for qa in results["redis_results"]:
print(f"Query: {qa[0]}\nAnswer: {qa[1]}\n")
if __name__ == "__main__":
main()