跳转到内容

9.3 核心模块实现

从设计到代码:核心模块的落地

前两节我们完成了需求分析和架构设计,有了清晰的蓝图。现在要进入最实质性的阶段——把设计图变成可运行的代码。这一节会涵盖企业知识库系统最核心的几个模块:统一数据加载器、RAG 引擎主类、权限过滤检索器、引用构建器和查询分类器。每个模块都会给出完整的实现代码,并解释关键的设计决策和容易踩的坑。

统一数据加载器:DataLoaderHub

在第二章中我们学习了 LlamaIndex 的各种 Connector,但在企业场景中不能每次都手动选择用哪个 Reader——我们需要一个统一的入口来管理所有数据源。DataLoaderHub 就是这个角色:

python
from abc import ABC, abstractmethod
from typing import Any, Generator
from pathlib import Path
from dataclasses import dataclass, field
from llama_index.core import Document
from llama_index.core.readers import SimpleDirectoryReader
from llama_index.core.schema import TextNode
import logging

logger = logging.getLogger(__name__)


@dataclass
class LoadResult:
    """数据加载结果"""
    source_id: int
    documents: list[Document] = field(default_factory=list)
    total_chars: int = 0
    file_count: int = 0
    errors: list[dict] = field(default_factory=list)
    warnings: list[str] = field(default_factory=list)


class BaseSourceLoader(ABC):
    """数据源加载器的基类"""

    def __init__(self, config: dict):
        self.config = config
        self.source_type = self.__class__.__name__

    @abstractmethod
    def load(self) -> LoadResult:
        """加载数据并返回 Document 列表"""
        ...

    @abstractmethod
    def detect_changes(self) -> list[str]:
        """检测自上次加载以来的变更文件/记录"""
        ...

    @abstractmethod
    def get_source_info(self) -> dict:
        """返回数据源的元信息(文档数量、总大小等)"""
        ...


class FileDirectoryLoader(BaseSourceLoader):
    """文件目录加载器——支持 PDF/Word/Markdown/HTML/Excel/PPT"""

    SUPPORTED_EXTENSIONS = {
        ".pdf": {"engine": "pymupdf", "description": "PDF 文档"},
        ".docx": {"engine": "python-docx", "description": "Word 文档"},
        ".doc": {"engine": "antiword", "description": "旧版 Word(需安装 antiword)"},
        ".md": {"engine": "default", "description": "Markdown 文件"},
        ".markdown": {"engine": "default", "description": "Markdown 文件"},
        ".html": {"engine": "default", "description": "HTML 网页"},
        ".htm": {"engine": "default", "description": "HTML 网页"},
        ".xlsx": {"engine": "openpyxl", "description": "Excel 表格"},
        ".xls": {"engine": "xlrd", "description": "旧版 Excel"},
        ".csv": {"engine": "default", "description": "CSV 数据"},
        ".pptx": {"engine": "python-pptx", "description": "PowerPoint"},
        ".txt": {"engine": "default", "description": "纯文本"},
        ".rst": {"engine": "default", "description": "reStructuredText"},
    }

    def __init__(self, config: dict):
        super().__init__(config)
        self.directory = Path(config["path"])
        self.file_patterns = config.get("file_patterns", ["**/*"])
        self.exclude_patterns = config.get(
            "exclude_patterns",
            ["**/.git/**", "**/node_modules/**", "**/__pycache__/**", "**/*.tmp"],
        )
        self.recursive = config.get("recursive", True)
        self.max_file_size_mb = config.get("max_file_size_mb", 50)

    def load(self) -> LoadResult:
        result = LoadResult(source_id=self.config.get("source_id", 0))

        if not self.directory.exists():
            result.errors.append({
                "type": "directory_not_found",
                "message": f"目录不存在: {self.directory}",
            })
            logger.error(f"目录不存在: {self.directory}")
            return result

        try:
            reader = SimpleDirectoryReader(
                input_dir=str(self.directory),
                required_exts=list(self.SUPPORTED_EXTENSIONS.keys()),
                recursive=self.recursive,
                exclude=self.exclude_patterns,
                num_files_limit=self.config.get("num_files_limit"),
                file_metadata=lambda filepath: {
                    "file_name": Path(filepath).name,
                    "file_path": str(filepath),
                    "file_size_bytes": Path(filepath).stat().st_size,
                },
            )

            documents = reader.load_data()
            result.documents = documents
            result.file_count = len(documents)
            result.total_chars = sum(len(doc.text) for doc in documents)

            # 检查是否有超大文件被跳过
            for doc in documents:
                size_mb = doc.metadata.get("file_size_bytes", 0) / (1024 * 1024)
                if size_mb > self.max_file_size_mb:
                    result.warnings.append(
                        f"大文件 {doc.metadata.get('file_name')} ({size_mb:.1f}MB)"
                        f" 超过建议大小上限 {self.max_file_size_mb}MB"
                    )

            logger.info(
                f"文件目录加载完成: {result.file_count} 个文档, "
                f"{result.total_chars} 字符"
            )

        except Exception as e:
            result.errors.append({"type": "load_error", "message": str(e)})
            logger.error(f"文件目录加载失败: {e}", exc_info=True)

        return result

    def detect_changes(self) -> list[str]:
        """通过文件修改时间检测变更"""
        import os
        from datetime import datetime

        changed_files = []
        cutoff = self.config.get("last_sync_time")

        if not cutoff:
            return [str(p) for p in self.directory.rglob("*") if p.is_file()]

        cutoff_dt = datetime.fromisoformat(cutoff) if isinstance(cutoff, str) else cutoff

        for pattern in self.file_patterns:
            for file_path in self.directory.glob(pattern):
                if file_path.is_file():
                    mtime = datetime.fromtimestamp(file_path.stat().st_mtime)
                    if mtime > cutoff_dt:
                        changed_files.append(str(file_path))

        return changed_files

    def get_source_info(self) -> dict:
        total_files = sum(1 for _ in self.directory.rglob("*") if _.is_file())
        total_size = sum(
            f.stat().st_size for f in self.directory.rglob("*") if f.is_file()
        )
        return {
            "type": "file_directory",
            "path": str(self.directory),
            "total_files": total_files,
            "total_size_mb": round(total_size / (1024 * 1024), 2),
            "supported_formats": list(self.SUPPORTED_EXTENSIONS.keys()),
        }


class DatabaseLoader(BaseSourceLoader):
    """数据库加载器——将数据库记录转为 Document 对象"""

    def __init__(self, config: dict):
        super().__init__(config)
        self.db_type = config["db_type"]  # postgresql / mysql / sqlite
        self.connection_string = config["connection_string"]
        self.query_template = config.get("query")
        self.tables = config.get("tables", [])
        self.id_column = config.get("id_column", "id")
        self.content_column = config.get("content_column", "content")
        self.title_column = config.get("title_column", "title")
        self.metadata_columns = config.get("metadata_columns", [])
        self.batch_size = config.get("batch_size", 500)

    def _get_connection(self):
        """根据数据库类型创建连接"""
        if self.db_type == "postgresql":
            import psycopg2
            return psycopg2.connect(self.connection_string)
        elif self.db_type == "mysql":
            import mysql.connector
            return mysql.connector.connect(**self._parse_mysql_config())
        elif self.db_type == "sqlite":
            import sqlite3
            return sqlite3.connect(self.connection_string)
        else:
            raise ValueError(f"不支持的数据库类型: {self.db_type}")

    def _parse_mysql_config(self) -> dict:
        """从连接字符串解析 MySQL 配置"""
        import re
        match = re.match(
            r"mysql\+(?:pymysql|mysqlconnector)://([^:]+):([^@]+)@([^:]+):(\d+)/(.+)",
            self.connection_string,
        )
        if match:
            return {
                "user": match.group(1),
                "password": match.group(2),
                "host": match.group(3),
                "port": int(match.group(4)),
                "database": match.group(5),
            }
        raise ValueError(f"无法解析 MySQL 连接字符串: {self.connection_string}")

    def load(self) -> LoadResult:
        result = LoadResult(source_id=self.config.get("source_id", 0))

        try:
            conn = self._get_connection()
            cursor = conn.cursor()

            if self.query_template:
                query = self.query_template
            else:
                columns = [self.id_column, self.content_column]
                if self.title_column:
                    columns.append(self.title_column)
                columns.extend(self.metadata_columns)

                table_name = self.tables[0] if self.tables else "documents"
                query = f"SELECT {', '.join(columns)} FROM {table_name}"

            cursor.execute(query)

            documents = []
            while True:
                rows = cursor.fetchmany(self.batch_size)
                if not rows:
                    break

                for row in rows:
                    metadata = {}
                    col_names = [desc[0] for desc in cursor.description]

                    for i, col_name in enumerate(col_names):
                        if col_name == self.content_column:
                            content = str(row[i]) if row[i] is not None else ""
                        elif col_name == self.title_column:
                            metadata["title"] = str(row[i]) if row[i] is not None else ""
                        elif col_name == self.id_column:
                            metadata["db_id"] = str(row[i])
                        elif col_name in self.metadata_columns or col_name not in (
                            self.id_column,
                            self.content_column,
                            self.title_column,
                        ):
                            metadata[f"db_{col_name}"] = row[i]

                    if content.strip():
                        doc = Document(text=content, metadata=metadata)
                        documents.append(doc)

            result.documents = documents
            result.file_count = len(documents)
            result.total_chars = sum(len(doc.text) for doc in documents)

            cursor.close()
            conn.close()

            logger.info(
                f"数据库加载完成 ({self.db_type}): "
                f"{result.file_count} 条记录, {result.total_chars} 字符"
            )

        except Exception as e:
            result.errors.append({"type": "db_error", "message": str(e)})
            logger.error(f"数据库加载失败: {e}", exc_info=True)

        return result

    def detect_changes(self) -> list[str]:
        """基于 updated_at 或 id 范围检测变更"""
        conn = self._get_connection()
        cursor = conn.cursor()

        last_sync_id = self.config.get("last_sync_id", "0")
        table = self.tables[0] if self.tables else "documents"

        try:
            cursor.execute(
                f"SELECT {self.id_column} FROM {table} "
                f"WHERE {self.id_column} > %s ORDER BY {self.id_column}",
                (last_sync_id,),
            )
            changed_ids = [str(row[0]) for row in cursor.fetchall()]
            cursor.close()
            conn.close()
            return changed_ids
        except Exception as e:
            logger.warning(f"变更检测失败: {e}")
            cursor.close()
            conn.close()
            return []

    def get_source_info(self) -> dict:
        conn = self._get_connection()
        cursor = conn.cursor()
        table = self.tables[0] if self.tables else "documents"

        try:
            cursor.execute(f"SELECT COUNT(*) FROM {table}")
            count = cursor.fetchone()[0]

            cursor.close()
            conn.close()

            return {
                "type": f"database_{self.db_type}",
                "table": table,
                "total_records": count,
                "columns": self.metadata_columns + [
                    self.id_column,
                    self.content_column,
                    self.title_column,
                ],
            }
        except Exception:
            cursor.close()
            conn.close()
            return {"type": f"database_{self.db_type}", "error": "无法获取表信息"}


class DataLoaderHub:
    """统一数据加载中心——注册和管理所有数据源加载器"""

    _loaders: dict[str, type[BaseSourceLoader]] = {}

    @classmethod
    def register(cls, name: str, loader_class: type[BaseSourceLoader]):
        cls._loaders[name] = loader_class
        logger.info(f"已注册数据源加载器: {name}")

    @classmethod
    def get_loader(cls, source_type: str, config: dict) -> BaseSourceLoader:
        if source_type not in cls._loaders:
            raise ValueError(f"未知的数据源类型: {source_type}, 可用: {list(cls._loaders.keys())}")
        return cls._loaders[source_type](config)

    @classmethod
    def load_all(cls, sources: list[dict]) -> dict[int, LoadResult]:
        """批量加载多个数据源"""
        results = {}
        for source_config in sources:
            source_type = source_config["source_type"]
            loader = cls.get_loader(source_type, source_config)
            result = loader.load()
            results[source_config.get("source_id", 0)] = result
        return results


# 注册内置加载器
DataLoaderHub.register("file_directory", FileDirectoryLoader)
DataLoaderHub.register("database_postgres", DatabaseLoader)
DataLoaderHub.register("database_mysql", DatabaseLoader)

这个 DataLoaderHub 的设计采用了**策略模式(Strategy Pattern)+ 工厂模式(Factory Pattern)**的组合。每种数据源类型对应一个 Loader 类(策略),DataLoaderHub 作为工厂根据 source_type 创建对应的实例。新增一种数据源只需要:(1)继承 BaseSourceLoader 实现新类;(2)调用 DataLoaderHub.register() 注册。不需要修改任何已有代码——这符合开闭原则(Open-Closed Principle)。

有一个容易被忽略的细节是 FileDirectoryLoader 中的 max_file_size_mb 检查。在实际项目中经常有人上传几百 MB 的 PDF 扫描件(比如整本产品手册的扫描版),这种文件不仅解析极慢,而且 OCR 后的质量通常很差——满篇错别字和乱码。与其让它在解析阶段卡住整个流程,不如提前检测并发出警告。50MB 的阈值是根据经验设定的——正常的产品 PDF 手册一般在 5-20MB,超过 50MB 的很可能是扫描件或包含大量高清图片的非文本 PDF。

RAG 引擎主类:EnterpriseRAGEngine

这是整个系统的大脑——它把前面学过的所有 LlamaIndex 组件组装成一个完整的问答流水线:

python
import time
import hashlib
from typing import Optional
from dataclasses import dataclass, field
from enum import Enum

from llama_index.core import (
    VectorStoreIndex,
    StorageContext,
    Settings,
    Document,
    QueryBundle,
)
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.response_synthesizers import ResponseMode, get_response_synthesizer
from llama_index.core.postprocessor import SimilarityPostProcessor, LongContextReorder
from llama_index.core.schema import NodeWithScore, TextNode
from llama_index.vector_stores.qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI

from app.data.loaders import DataLoaderHub
from app.core.retriever import HybridRetriever
from app.core.citation_builder import CitationBuilder
from app.core.query_classifier import QueryClassifier


class RetrievalMode(Enum):
    VECTOR_ONLY = "vector_only"
    HYBRID = "hybrid"
    KEYWORD_ONLY = "keyword_only"


@dataclass
class RAGConfig:
    """RAG 引擎配置"""
    embedding_model: str = "text-embedding-3-large"
    llm_model: str = "gpt-4o"
    chunk_size: int = 512
    chunk_overlap: int = 50
    similarity_top_k: int = 15
    rerank_top_k: int = 5
    similarity_cutoff: float = 0.5
    retrieval_mode: RetrievalMode = RetrievalMode.HYBRID
    use_hyde: bool = False
    use_reranker: bool = True
    reranker_model: str = "cohere-rerank-v3.5"
    response_mode: ResponseMode = ResponseMode.REFINE
    enable_citation: bool = True
    max_context_tokens: int = 8000


@dataclass
class QueryResult:
    """查询结果"""
    answer: str
    sources: list[dict]
    citations: list[dict]
    confidence: float
    latency_ms: float
    nodes_used: int
    tokens_used: int
    retrieval_details: dict = field(default_factory=dict)


class EnterpriseRAGEngine:
    """
    企业级 RAG 引擎

    将 LlamaIndex 的各组件组装为完整的问答流水线,
    支持混合检索、权限过滤、重排序、引用溯源等功能。
    """

    def __init__(self, config: Optional[RAGConfig] = None):
        self.config = config or RAGConfig()
        self._initialize_settings()
        self._index: Optional[VectorStoreIndex] = None
        self._retriever: Optional[HybridRetriever] = None
        self._citation_builder = CitationBuilder()
        self._query_classifier = QueryClassifier()

    def _initialize_settings(self):
        """初始化 LlamaIndex 全局设置"""
        Settings.llm = OpenAI(model=self.config.llm_model, temperature=0.1)
        Settings.embed_model = OpenAIEmbedding(
            model=self.config.embedding_model,
            dimensions=1024 if "large" in self.config.embedding_model else None,
        )
        Settings.node_parser = SentenceSplitter(
            chunk_size=self.config.chunk_size,
            chunk_overlap=self.config.chunk_overlap,
        )

    def build_index_from_sources(self, sources: list[dict]) -> dict:
        """从多个数据源构建索引"""
        start = time.perf_counter()

        all_documents = []
        load_results = DataLoaderHub.load_all(sources)

        total_docs = 0
        total_errors = 0
        for source_id, result in load_results.items():
            all_documents.extend(result.documents)
            total_docs += result.file_count
            total_errors += len(result.errors)
            if result.errors:
                for err in result.errors:
                    logger.warning(f"数据源 {source_id} 加载错误: {err}")

        if not all_documents:
            return {"status": "error", "message": "没有成功加载任何文档"}

        logger.info(f"共加载 {len(all_documents)} 个文档, 开始构建向量索引...")

        qdrant_client = QdrantClient(url="http://qdrant:6333")
        vector_store = QdrantVectorStore(
            client=qdrant_client,
            collection_name="kb_main",
        )
        storage_context = StorageContext.from_defaults(vector_store=vector_store)

        self._index = VectorStoreIndex.from_documents(
            documents=all_documents,
            storage_context=storage_context,
            show_progress=True,
        )

        self._setup_retriever()

        elapsed = time.perf_counter() - start
        logger.info(f"索引构建完成, 耗时 {elapsed:.1f}s, 文档数: {len(all_documents)}")

        return {
            "status": "success",
            "document_count": len(all_documents),
            "node_count": len(list(self._index.docstore.docs.values())),
            "build_time_seconds": round(elapsed, 2),
            "errors": total_errors,
        }

    def _setup_retriever(self):
        """配置检索器"""
        if self.config.retrieval_mode == RetrievalMode.HYBRID:
            self._retriever = HybridRetriever(
                index=self._index,
                similarity_top_k=self.config.similarity_top_k,
                sparse_top_k=20,
            )
        else:
            self._retriever = self._index.as_retriever(
                similarity_top_k=self.config.similarity_top_k
            )

    async def query(
        self,
        question: str,
        user_context: Optional[dict] = None,
        chat_history: Optional[list[dict]] = None,
        allowed_doc_ids: Optional[list[str]] = None,
        stream: bool = False,
    ) -> QueryResult:
        """
        执行完整 RAG 查询

        Args:
            question: 用户问题
            user_context: 用户上下文(部门、角色等)
            chat_history: 多轮对话历史
            allowed_doc_ids: 用户有权访问的文档ID列表(权限过滤)
            stream: 是否流式输出
        """
        start = time.perf_counter()

        if not self._index or not self._retriever:
            raise RuntimeError("索引尚未构建,请先调用 build_index_from_sources()")

        # Step 1: 问题预处理与分类
        query_type = self._query_classifier.classify(question)
        enriched_query = self._enrich_query(question, chat_history, user_context)

        # Step 2: 检索
        raw_nodes = await self._retriever.aretrieve(enriched_query)

        # Step 3: 权限过滤
        if allowed_doc_ids:
            filtered_nodes = self._filter_by_permission(raw_nodes, allowed_doc_ids)
        else:
            filtered_nodes = raw_nodes

        # Step 4: 重排序
        if self.config.use_reranker and filtered_nodes:
            reranked_nodes = await self._rerank(filtered_nodes, enriched_query)
        else:
            reranked_nodes = filtered_nodes[: self.config.rerank_top_k]

        # Step 5: 后处理
        postprocessed = self._postprocess(reranked_nodes)

        # Step 6: 回答合成
        synthesizer = get_response_synthesizer(
            response_mode=self.config.response_mode,
            use_async=True,
        )

        if stream:
            response = await synthesizer.asynthesize(
                query=enriched_query,
                nodes=postprocessed,
            )
            # 流式处理逻辑...
            answer = str(response)
        else:
            response = await synthesizer.asynthesize(
                query=enriched_query,
                nodes=postprocessed,
            )
            answer = str(response)

        # Step 7: 构建引用
        citations = []
        if self.config.enable_citation and postprocessed:
            citations = self._citation_builder.build(answer, postprocessed)

        # Step 8: 计算置信度
        confidence = self._compute_confidence(postprocessed, answer)

        latency = (time.perf_counter() - start) * 1000

        return QueryResult(
            answer=answer,
            sources=[
                {
                    "doc_id": n.node.metadata.get("doc_id", ""),
                    "title": n.node.metadata.get("file_name", n.node.metadata.get("title", "")),
                    "score": round(n.score, 4),
                    "snippet": n.node.text[:200],
                }
                for n in postprocessed[:3]
            ],
            citations=citations,
            confidence=confidence,
            latency_ms=round(latency, 1),
            nodes_used=len(postprocessed),
            tokens_used=response.metadata.get("total_tokens", 0) if hasattr(response, 'metadata') else 0,
            retrieval_details={
                "query_type": query_type.value,
                "raw_retrieved": len(raw_nodes),
                "after_permission_filter": len(filtered_nodes),
                "after_rerank": len(reranked_nodes),
                "after_postprocess": len(postprocessed),
            },
        )

    def _enrich_query(
        self, question: str, chat_history: list[dict], user_context: dict
    ) -> str:
        """丰富查询上下文"""
        parts = []

        if chat_history and len(chat_history) > 0:
            context_str = "\n".join([
                f"Q: {h['question']}\nA: {h['answer']}"
                for h in chat_history[-4:]  # 只取最近4轮
            ])
            parts.append(f"[对话历史]\n{context_str}")

        if user_context:
            dept = user_context.get("department", "")
            role = user_context.get("role", "")
            if dept or role:
                parts.append(f"[用户背景] 部门={dept}, 角色={role}")

        parts.append(f"[当前问题] {question}")
        return "\n\n".join(parts)

    def _filter_by_permission(
        self, nodes: list[NodeWithScore], allowed_doc_ids: list[str]
    ) -> list[NodeWithScore]:
        """基于文档 ID 过滤无权访问的节点"""
        allowed_set = set(allowed_doc_ids)
        filtered = [
            n for n in nodes
            if n.node.metadata.get("doc_id", "") in allowed_set
            or n.node.metadata.get("source_doc_id", "") in allowed_set
        ]

        if len(filtered) < len(nodes):
            dropped = len(nodes) - len(filtered)
            logger.debug(f"权限过滤: 移除了 {dropped} 个无权访问的节点")

        return filtered

    async def _rerank(
        self, nodes: list[NodeWithScore], query: str
    ) -> list[NodeWithScore]:
        """使用 Cross-Encoder 进行重排序"""
        try:
            from llama_index.postprocessor.cohere_rerank import CohereRerank

            reranker = CohereRerank(
                api_key=self.config.get("cohere_api_key"),
                top_n=self.config.rerank_top_k,
                model=self.config.reranker_model,
            )
            reranked = reranker.postprocess_nodes(nodes, query_str=query)
            return reranked
        except ImportError:
            logger.warning("Cohere Rerank 不可用,使用分数截断作为替代")
            return sorted(nodes, key=lambda x: x.score, reverse=True)[
                : self.config.rerank_top_k
            ]
        except Exception as e:
            logger.error(f"重排序失败: {e},回退到原始排序")
            return nodes[: self.config.rerank_top_k]

    def _postprocess(self, nodes: list[NodeWithScore]) -> list[NodeWithScore]:
        """后处理:相似度过滤 + Lost in the Middle 重排序"""
        if not nodes:
            return nodes

        # 相似度阈值过滤
        above_threshold = [
            n for n in nodes if n.score >= self.config.similarity_cutoff
        ]
        if len(above_threshold) < len(nodes):
            logger.debug(
                f"相似度过滤: {len(nodes)}{len(above_threshold)}"
                f" (阈值={self.config.similarity_cutoff})"
            )

        # Lost in the Middle: 把最好的节点放到首尾
        if len(above_threshold) >= 3:
            sorted_by_score = sorted(
                above_threshold, key=lambda x: x.score, reverse=True
            )
            reordered = [sorted_by_score[0]]
            mid_start = 1
            mid_end = len(sorted_by_score)
            step = 2
            left, right = mid_start, mid_end - 1
            while left < right:
                reordered.append(sorted_by_score[right])
                if left < right:
                    reordered.append(sorted_by_score[left])
                left += step
                right -= step
            if left == right:
                reordered.append(sorted_by_score[left])
            return reordered

        return above_threshold

    def _compute_confidence(
        self, nodes: list[NodeWithScore], answer: str
    ) -> float:
        """计算回答置信度"""
        if not nodes:
            return 0.0

        avg_score = sum(n.score for n in nodes) / len(nodes)
        max_score = max(n.score for n in nodes)

        avg_weight = 0.6
        max_weight = 0.4
        confidence = avg_score * avg_weight + max_score * max_weight

        if len(answer) < 20:
            confidence *= 0.7  # 太短的回答可能不够充分

        return round(min(confidence, 1.0), 3)

    def incremental_index(self, source_id: int, changed_docs: list[str]) -> dict:
        """增量更新索引"""
        if not self._index:
            return {"status": "error", "message": "索引未初始化"}

        inserted = 0
        deleted = 0

        for doc_ref in changed_docs:
            try:
                new_docs = DataLoaderHub.load_single(doc_ref, source_id)
                for doc in new_docs:
                    self._index.insert(doc)
                    inserted += 1
            except Exception as e:
                logger.warning(f"增量插入失败 {doc_ref}: {e}")

        return {
            "status": "success",
            "inserted": inserted,
            "deleted": deleted,
        }

    def get_stats(self) -> dict:
        """获取引擎统计信息"""
        if not self._index:
            return {"status": "not_initialized"}

        try:
            docstore = self._index.docstore
            index_store = self._index.index_store
            return {
                "status": "ready",
                "total_documents": len(docstore.docs),
                "total_nodes": len(docstore.docs),
                "index_count": len(index_store.index_structs),
                "config": {
                    "embedding_model": self.config.embedding_model,
                    "llm_model": self.config.llm_model,
                    "chunk_size": self.config.chunk_size,
                    "retrieval_mode": self.config.retrieval_mode.value,
                    "use_reranker": self.config.use_reranker,
                },
            }
        except Exception as e:
            return {"status": "error", "message": str(e)}

EnterpriseRAGEnginequery() 方法就是上一节架构图中那个 11 步流水线的具体实现。有几个值得深入讨论的实现细节:

权限过滤的位置选择

你可能会问:为什么权限过滤放在检索之后而不是之前?能不能在 Qdrant 查询时就通过 filter 参数直接排除掉不允许的文档?答案是:两种方式各有优劣,我们在实现中选择了"先粗检后细滤"的两阶段策略。

如果在 Qdrant 查询时就做严格的权限过滤(把 allowed_doc_ids 作为 filter 条件传入),好处是减少网络传输的数据量——Qdrant 只返回允许访问的节点。但问题是:当用户的权限范围很大(比如 admin 可以看所有文档,allowed_doc_ids 可能有几千个),把这个列表传给 Qdrant 的 filter 会显著降低查询性能。而且 Qdrant 的 filter 不支持超大规模的 IN 列表(通常建议控制在 100 以内)。

所以我们采用的策略是:检索时不做过滤(或少做粗粒度过滤),拿到候选集后在应用层做精确的权限过滤。对于普通用户(允许访问的文档较少),可以在检索时加上 filter;对于管理员(允许访问全部文档),直接跳过过滤步骤。这个判断逻辑可以进一步优化:

python
def _smart_permission_filter(
    self, nodes: list[NodeWithScore], allowed_doc_ids: list[str]
) -> list[NodeWithScore]:
    """智能权限过滤:根据允许列表大小选择过滤策略"""
    FILTER_THRESHOLD = 100  # 允许列表超过此值则延迟过滤

    if len(allowed_doc_ids) <= FILTER_THRESHOLD:
        # 小范围权限:在检索时就过滤(更高效)
        return self._filter_by_permission(nodes, allowed_doc_ids)
    elif len(allowed_doc_ids) > len(nodes) * 10:
        # 超大范围权限(如 admin):可能无需过滤
        return nodes
    else:
        # 中等范围:应用层过滤
        return self._filter_by_permission(nodes, allowed_doc_ids)

置信度计算的启发式方法

_compute_confidence() 方法中的置信度计算是一个简单的启发式公式:0.6 × 平均分 + 0.4 × 最高分。为什么这样设计?因为平均分反映了整体检索质量(如果大部分节点的相关度都不高,说明系统对这个问题的理解可能有问题),而最高分代表了最好的一条证据有多强(即使平均分一般,只要有一条非常匹配的证据,回答仍然可能是好的)。0.6 和 0.4 的权重分配是基于经验调优的结果——你可以根据自己的数据和反馈调整这些参数。

另外注意那个"短回答惩罚":如果最终生成的回答少于 20 个字符,置信度会被乘以 0.7。这是因为太短的回答(如"是的"、"不清楚"、"请参考手册")往往意味着检索没有找到足够的信息,或者模型选择了保守的回避策略。

混合检索器:HybridRetriever

第五章中我们详细学习了 Hybrid Search 的原理。下面是在项目中实际使用的混合检索器实现:

python
from typing import List, Optional
from llama_index.core import VectorStoreIndex
from llama_core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
from llama_index.core.node_parser import SentenceSplitter
import math
import re
import jieba  # 中文分词


class BM25Retriever(BaseRetriever):
    """BM25 关键词检索器(内存索引实现)"""

    def __init__(
        self,
        nodes: List[TextNode],
        similarity_top_k: int = 20,
        k1: float = 1.5,
        b: float = 0.75,
    ):
        self._nodes = nodes
        self._similarity_top_k = similarity_top_k
        self.k1 = k1
        self.b = b
        self._index = self._build_bm25_index()

    def _build_bm25_index(self) -> dict:
        """构建 BM25 倒排索引"""
        from collections import defaultdict
        import math

        df = defaultdict(int)  # document frequency
        tf_list = []  # term frequency per document
        doc_lengths = []

        for node in self._nodes:
            tokens = self._tokenize(node.text)
            tf = defaultdict(int)
            for token in tokens:
                tf[token] += 1
            tf_list.append(dict(tf))
            doc_lengths.append(len(tokens))

            for token in set(tokens):
                df[token] += 1

        avg_dl = sum(doc_lengths) / len(doc_lengths) if doc_lengths else 1
        N = len(self._nodes)

        # 预计算 IDF
        idf = {}
        for term, freq in df.items():
            idf[term] = math.log((N - freq + 0.5) / (freq + 0.5) + 1)

        return {
            "tf_list": tf_list,
            "df": dict(df),
            "idf": idf,
            "avg_dl": avg_dl,
            "doc_lengths": doc_lengths,
            "N": N,
        }

    def _tokenize(self, text: str) -> list[str]:
        """中文分词 + 英文小写化"""
        text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9]', ' ', text)
        chinese_tokens = list(jieba.cut_for_search(text))
        english_tokens = re.findall(r'[a-zA-Z]+', text)
        tokens = [t.lower() for t in chinese_tokens + english_tokens if len(t.strip()) > 0]
        return tokens

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        query_text = query.query_str if hasattr(query_bundle, 'query_str') else str(query_bundle)
        query_tokens = self._tokenize(query_text)
        idx = self._index

        scores = []
        for i, node in enumerate(self._nodes):
            score = self._bm25_score(query_tokens, idx["tf_list"][i], idx["doc_lengths"][i], idx)
            scores.append((score, node))

        scores.sort(key=lambda x: x[0], reverse=True)
        return [
            NodeWithScore(node=node, score=score)
            for score, node in scores[: self._similarity_top_k]
        ]

    def _bm25_score(self, query_tokens: list, tf: dict, dl: int, idx: dict) -> float:
        """计算单个文档的 BM25 分数"""
        score = 0.0
        for token in query_tokens:
            if token not in idx["idf"]:
                continue
            freq = tf.get(token, 0)
            if freq == 0:
                continue
            idf = idx["idf"][token]
            tf_component = (freq * (self.k1 + 1)) / (freq + self.k1 * (1 - self.b + self.b * dl / idx["avg_dl"]))
            score += idf * tf_component
        return score

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        query_str = query_bundle.query_str
        query_tokens = self._tokenize(query_str)
        idx = self._index

        scores = []
        for i, node in enumerate(self._nodes):
            score = self._bm25_score(query_tokens, idx["tf_list"][i], idx["doc_lengths"][i], idx)
            scores.append((score, node))

        scores.sort(key=lambda x: x[0], reverse=True)
        return [
            NodeWithScore(node=node, score=score)
            for score, node in scores[: self._similarity_top_k]
        ]


class HybridRetriever(BaseRetriever):
    """混合检索器:融合向量搜索和 BM25 关键词搜索"""

    def __init__(
        self,
        index: VectorStoreIndex,
        similarity_top_k: int = 15,
        sparse_top_k: int = 20,
        rrf_k: int = 60,  # RRF 参数
    ):
        self._vector_retriever = index.as_retriever(similarity_top_k=similarity_top_k)
        self._sparse_top_k = sparse_top_k
        self._rrf_k = rrf_k
        self._bm25_retriever: Optional[BM25Retriever] = None
        self._all_nodes: Optional[List[TextNode]] = None

    def _ensure_bm25_ready(self):
        """懒初始化 BM25 索引(只在第一次查询时构建)"""
        if self._bm25_retriever is None:
            all_nodes = list(self._vector_retriever._index.docstore.docs.values())
            self._all_nodes = all_nodes
            self._bm25_retriever = BM25Retriever(
                nodes=all_nodes,
                similarity_top_k=self._sparse_top_k,
            )
            logger.info(f"BM25 索引构建完成, 共 {len(all_nodes)} 个节点")

    def _rrf_fusion(
        self,
        vector_results: List[NodeWithScore],
        bm25_results: List[NodeWithScore],
    ) -> List[NodeWithScore]:
        """
        Reciprocal Rank Fusion 融合算法

        公式: RRF(d) = Σ 1/(k + rank_i(d))
        其中 k 是平滑常数(默认60),rank_i 是文档在第 i 个结果列表中的排名
        """
        rrf_scores: dict[str, float] = {}
        node_map: dict[str, NodeWithScore] = {}

        for rank, result in enumerate(vector_results):
            node_id = result.node.node_id
            if node_id not in rrf_scores:
                rrf_scores[node_id] = 0
                node_map[node_id] = result
            rrf_scores[node_id] += 1.0 / (self._rrf_k + rank + 1)

        for rank, result in enumerate(bm25_results):
            node_id = result.node.node_id
            if node_id not in rrf_scores:
                rrf_scores[node_id] = 0
                node_map[node_id] = result
            rrf_scores[node_id] += 1.0 / (self._rrf_k + rank + 1)

        fused = sorted(
            [(node_id, score) for node_id, score in rrf_scores.items()],
            key=lambda x: x[1],
            reverse=True,
        )

        return [
            NodeWithScore(node=node_map[node_id].node, score=score)
            for node_id, score in fused[: self._vector_retriever._similarity_top_k]
        ]

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        self._ensure_bm25_ready()

        vector_results = self._vector_retriever.retrieve(query_bundle)
        bm25_results = self._bm25_retriever.retrieve(query_bundle)

        logger.debug(
            f"向量检索返回 {len(vector_results)} 结果, "
            f"BM25检索返回 {len(bm25_results)} 结果"
        )

        fused = self._rrf_fusion(vector_results, bm25_results)
        logger.debug(f"RRF融合后返回 {len(fused)} 结果")

        return fused

关于这个混合检索器的实现,有几个工程上的要点:

BM25 的懒初始化(Lazy Initialization):BM25 索引需要遍历所有节点来构建倒排索引,对于几千个节点来说这个过程需要几秒钟。如果我们把它放在引擎启动时同步执行,会导致服务启动变慢。所以采用懒初始化策略——只在第一次查询请求到达时才构建 BM25 索引,后续请求直接复用。代价是第一个用户会多等几秒,但可以通过预热(warm-up)机制来解决——服务启动后自动发一个 dummy query 来触发初始化。

jieba 中文分词的使用:BM25 的效果高度依赖分词质量。对于英文来说空格分词就够了,但中文没有天然的词边界。jieba.cut_for_search() 是搜索引擎模式的分词,它会生成较细粒度的词元(比如"知识库问答"会被切分为"知识"/"库"/"问答"三个词),这对提高召回率有帮助。如果你对分词精度有更高要求,可以考虑用 pkuseg 或 LAC 这些更先进的分词器。

RRF 参数 k=60 的选择:这个值来自原始 RRF 论文的推荐范围(通常在 20-100 之间,60 是经验最优值)。k 越小,排名靠前的结果权重越大(更激进);k 越大,不同排名之间的差异越小(更保守)。在我们的场景中 60 效果不错——既不会让某一个检索方法完全主导结果,也不会过度稀释高质量结果的贡献。

引用构建器:CitationBuilder

引用溯源是企业知识库区别于通用 AI 助手的核心差异化功能之一:

python
import re
from dataclasses import dataclass
from typing import Optional
from llama_index.core.schema import NodeWithScore


@dataclass
class Citation:
    """单条引用"""
    source_title: str
    source_type: str  # pdf/docx/md/database
    section: Optional[str] = None
    snippet: Optional[str] = None
    relevance_score: float = 0.0
    url: Optional[str] = None  # 原文链接(如果有)


class CitationBuilder:
    """引用构建器——从回答和检索节点中提取格式化的引用"""

    def __init__(self, max_citations: int = 5, min_snippet_length: int = 50, max_snippet_length: int = 150):
        self.max_citations = max_citations
        self.min_snippet_length = min_snippet_length
        self.max_snippet_length = max_snippet_length

    def build(
        self,
        answer: str,
        nodes: list[NodeWithScore],
    ) -> list[dict]:
        """
        从回答文本和检索节点中构建引用列表

        核心思路:
        1. 提取回答中的关键陈述句
        2. 将每个陈述映射到最相关的源节点
        3. 格式化为标准引用格式
        """
        if not nodes:
            return []

        citations = []
        used_node_ids = set()

        # 方法一:按节点相关性取 top-N
        sorted_nodes = sorted(nodes, key=lambda x: x.score, reverse=True)

        for node in sorted_nodes[: self.max_citations]:
            if node.node.node_id in used_node_ids:
                continue

            citation = self._format_citation(node)
            citations.append(citation)
            used_node_ids.add(node.node.node_id)

        return citations

    def _format_citation(self, node: NodeWithScore) -> dict:
        """将单个节点格式化为引用条目"""
        meta = node.node.metadata

        title = (
            meta.get("title")
            or meta.get("file_name")
            or meta.get("db_title")
            or "未知来源"
        )

        source_type = self._detect_source_type(meta)

        section = meta.get("header_hierarchy") or meta.get("page_label") or None

        snippet = self._extract_snippet(node.node.text)

        return {
            "id": f"cite-{len(citations) + 1}" if 'citations' in dir() else f"cite-{id(node) % 1000}",
            "title": title,
            "source_type": source_type,
            "section": section,
            "snippet": snippet,
            "relevance_score": round(node.score, 3),
            "doc_id": meta.get("doc_id", "") or meta.get("source_doc_id", ""),
        }

    def _detect_source_type(self, metadata: dict) -> str:
        """根据元数据推断文档类型"""
        file_name = metadata.get("file_name", "")
        if file_name.endswith(".pdf"):
            return "pdf"
        elif file_name.endswith(".docx"):
            return "word"
        elif file_name.endswith(".md") or file_name.endswith(".markdown"):
            return "markdown"
        elif "db_" in str(metadata.keys()):
            return "database"
        elif file_name.endswith(".html"):
            return "webpage"
        else:
            return "unknown"

    def _extract_snippet(self, text: str) -> str:
        """提取用于展示的原文片段"""
        if len(text) <= self.max_snippet_length:
            return text.strip()

        # 尝试在第一句话处截断
        sentences = re.split(r'(?<=[。!?.!?\n])', text)
        snippet = ""
        for sent in sentences:
            if len(snippet) + len(sent) > self.max_snippet_length:
                break
            snippet += sent

        if len(snippet) < self.min_snippet_length:
            snippet = text[: self.max_snippet_length]

        if len(text) > len(snippet):
            snippet += "..."

        return snippet.strip()

引用构建看起来简单,但在实际项目中有很多细节需要注意:

snippet 截断策略:上面的 _extract_snippet 方法优先按句子边界截断,而不是简单地按字符数硬切。这是因为按句子截断能保持语义完整性——用户看到的引用片段是一段有意义的话,而不是半截话。正则表达式 (?<=[。!?.!?\n]) 使用了 lookbehind 来匹配中文和英文的句末标点,确保断句位置自然。

去重处理:同一个文档可能被切成多个 chunk,检索时可能有多个来自同一文档的节点都进入了 top-k。如果不做去重,引用列表可能会出现"《员工手册》第 3 章"出现三次的情况。used_node_ids 集合用来追踪已经引用过的节点,避免重复。更进一步的做法是按文档聚合——同一文档只保留得分最高的那条引用。

查询分类器:QueryClassifier

python
from enum import Enum
from dataclasses import dataclass
from typing import Optional
import re


class QuestionType(Enum):
    FACTUAL = "factual"           # 事实性:"公司成立于哪一年?"
    PROCEDURAL = "procedural"     # 流程性:"如何申请报销?"
    COMPARATIVE = "comparative"   # 对比性:"Pro 版和企业版的区别?"
    TROUBLESHOOTING = "troubleshooting"  # 排查类:"API 返回 403 怎么办?"
    DEFINITIONAL = "definitional" # 定义性:"什么是 RBAC?"
    OPINION = "opinion"           # 观点性:"你觉得哪个方案更好?"
    AMBIGUOUS = "ambiguous"       # 歧义性:"我们的产品怎么样?"


@dataclass
class ClassificationResult:
    question_type: QuestionType
    confidence: float
    clarification_needed: bool
    suggested_clarification: Optional[str] = None
    keywords_extracted: list[str] = None


class QueryClassifier:
    """基于规则 + 启发的查询分类器"""

    PROCEDURAL_PATTERNS = [
        r"如何.*?[办做申请操作流程步骤]",
        r".*?怎么.*?",
        r".*?如何.*?",
        r".*?怎样.*?",
        r".*?步骤",
        r".*?流程",
        r".*?指南",
        r".*?教程",
        r".*?方法",
    ]

    COMPARATIVE_PATTERNS = [
        r".*?.*?的区别",
        r".*?.*?对比",
        r".*?vs\.?",
        r".*?比较.*?",
        r".*?差异",
        r".*?哪个好",
        r".*?选哪个",
        r".*?有什么不同",
    ]

    TROUBLESHOOTING_PATTERNS = [
        r".*?报错",
        r".*?错误",
        r".*?失败",
        r".*?异常",
        r".*?不行",
        r".*?解决.*?问题",
        r".*?修复",
        r".*?bug",
        r"\d{3}\s*(错误|error)",
        r".*?40[0-9]\b",
        r".*?50[0-9]\b",
    ]

    FACTUAL_PATTERNS = [
        r".*?是什么",
        r".*?是多少",
        r".*?有没有",
        r".*?是否",
        r".*?哪些",
        r".*?",
        r".*?哪里",
        r".*?什么时候",
        r".*?[个天月年]",
        r".*?[是否是不是].*?[吗呢]",
    ]

    OPINION_PATTERNS = [
        r".*?觉得.*?",
        r".*?推荐.*?",
        r".*?应该.*?",
        r".*?值得.*?",
        r".*?评价.*?",
        r".*?看法",
        r".*?建议.*?",
    ]

    def classify(self, question: str) -> ClassificationResult:
        """
        分类用户问题

        Returns:
            ClassificationResult 包含问题类型、置信度和是否需要澄清
        """
        question = question.strip()

        if len(question) < 3:
            return ClassificationResult(
                question_type=QuestionType.AMBIGUOUS,
                confidence=0.3,
                clarification_needed=True,
                suggested_clarification="您的问题太简短了,能否详细描述一下您想了解什么?",
            )

        scores = {
            QuestionType.PROCEDURAL: self._match_score(
                question, self.PROCEDURAL_PATTERNS
            ),
            QuestionType.COMPARATIVE: self._match_score(
                question, self.COMPARATIVE_PATTERNS
            ),
            QuestionType.TROUBLESHOOTING: self._match_score(
                question, self.TROUBLESHOOTING_PATTERNS
            ),
            QuestionType.FACTUAL: self._match_score(
                question, self.FACTUAL_PATTERNS
            ),
            QuestionType.OPINION: self._match_score(
                question, self.OPINION_PATTERNS
            ),
        }

        best_type = max(scores, key=scores.get)
        best_score = scores[best_type]

        if best_score < 0.3:
            best_type = QuestionType.AMBIGUOUS
            best_score = 0.2

        clarification = self._check_need_clarification(question, best_type)

        keywords = self._extract_keywords(question)

        return ClassificationResult(
            question_type=best_type,
            confidence=round(best_score, 3),
            clarification_needed=clarification["needed"],
            suggested_clarification=clarification["suggestion"],
            keywords_extracted=keywords,
        )

    def _match_score(self, text: str, patterns: list[str]) -> float:
        """计算文本与一组正则模式的匹配分数"""
        matches = 0
        for pattern in patterns:
            if re.search(pattern, text, re.IGNORECASE):
                matches += 1
        return matches / len(patterns) if patterns else 0

    def _check_need_clarification(
        self, question: str, qtype: QuestionType
    ) -> dict:
        """检查问题是否过于模糊需要澄清"""
        ambiguous_indicators = [
            len(question) < 8,
            "什么" in question and len(question) < 15,
            "怎么样" in question and len(question) < 15,
            qtype == QuestionType.AMBIGUOUS,
        ]

        needed = any(ambiguous_indicators)
        suggestion = None

        if needed:
            if "产品" in question or "功能" in question:
                suggestion = (
                    "您想了解哪个产品的信息?或者您可以描述一下"
                    "具体的业务场景,我会为您提供更精准的回答。"
                )
            elif "政策" in question or "规定" in question:
                suggestion = (
                    "请问您想了解哪方面的制度规范?"
                    "比如人事、财务、IT 还是其他领域?"
                )
            else:
                suggestion = (
                    "您的问题比较宽泛,能否提供更多细节?"
                    "例如涉及的具体产品、部门或场景。"
                )

        return {"needed": needed, "suggestion": suggestion}

    def _extract_keywords(self, question: str) -> list[str]:
        """提取问题中的关键词(用于日志分析和后续优化)"""
        import jieba
        import jieba.posseg as pseg

        words = pseg.cut(question)
        keywords = []
        stop_words = {
            "的", "了", "是", "在", "我", "有", "和", "就",
            "不", "人", "都", "一", "一个", "上", "也", "很",
            "到", "说", "要", "去", "你", "会", "着", "没有",
            "什么", "怎么", "如何", "哪", "哪些", "这个", "那个",
            "吗", "呢", "吧", "啊", "呀", "哦", "嗯",
        }

        for word, flag in words:
            if len(word) >= 2 and word not in stop_words:
                if flag.startswith(("n", "v", "nr", "nz")):
                    keywords.append(word)

        return keywords[:10]

这个查询分类器采用的是规则匹配而非 LLM 调用的方式。原因是:分类操作会在每次查询时执行,如果用 LLM 来分类会增加一次 API 调用和约 200-500ms 的额外延迟。对于企业内部的知识库问答,问题的类型分布相对固定且可预测(事实性问题最多约占 40%,流程性约 25%,其他类型合计约 35%),规则方法的准确率可以达到 85% 以上,足以满足需求。

当然,如果你的预算充足且对分类准确率有更高要求(>95%),可以用 LLM 替换规则方法。LlamaIndex 本身也提供了 LLMPredictor 可以方便地封装分类逻辑。一个折中方案是:规则方法作为快速路径(覆盖 85% 的常见情况),对于规则无法确定的问题(confidence < 0.5)再 fallback 到 LLM 分类。

总结

基于 MIT 许可发布