前言

前几篇我们搭建了 AI 应用的基础骨架、Prompt 工程体系和多轮对话管理。这一篇进入 RAG(检索增强生成)——让 AI 拥有自己的知识库,能基于私有文档回答问题。

市面上讲 RAG 的文章很多,但要么太理论(只讲概念),要么太碎片(只讲某个环节)。这篇的目标是:从零实现一个完整的、可运行的 RAG 系统

最终效果:上传 PDF/文档,自动切分、向量化、存入向量库,用户提问时检索相关片段并生成回答。

1. RAG 整体架构

用户问题
    │
    ▼
┌──────────────┐    ┌────────────────┐
│  Query 改写   │───▶│  Embedding 模型 │
│  改写/扩展    │    │  向量化问题     │
└──────────────┘    └───────┬────────┘
                            │
                            ▼
                 ┌──────────────────────┐
                 │  向量数据库           │
                 │  (Qdrant/Milvus)     │
                 │  → Top-K 检索        │
                 │  → 混合检索(可选)    │
                 └───────┬──────────────┘
                            │
                            ▼
                 ┌──────────────────────┐
                 │  Re-rank             │
                 │  重排序 Top-K 结果   │
                 └───────┬──────────────┘
                            │
                            ▼
                 ┌──────────────────────┐
                 │  LLM + Prompt        │
                 │  基于资料生成回答     │
                 └───────┬──────────────┘
                            │
                            ▼
                      最终答案

2. 文档解析与切分

2.1 文档解析

先支持最常见的 PDF 和 TXT 格式:

import os
from pathlib import Path


def parse_document(file_path: str) -> str:
    """解析文档,返回纯文本内容。"""
    ext = Path(file_path).suffix.lower()

    if ext == ".txt":
        with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
            return f.read()

    elif ext == ".pdf":
        return parse_pdf(file_path)

    elif ext == ".md":
        with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
            return f.read()

    else:
        raise ValueError(f"Unsupported format: {ext}")


def parse_pdf(file_path: str) -> str:
    """解析 PDF 文件。使用 PyMuPDF(fitz)实现。"""
    try:
        import fitz  # PyMuPDF
    except ImportError:
        raise ImportError("Please install PyMuPDF: pip install pymupdf")

    doc = fitz.open(file_path)
    text = []
    for page in doc:
        text.append(page.get_text())
    doc.close()
    return "\n\n".join(text)

2.2 文本切分(Chunking)

切分是 RAG 效果的基础。三种常用策略:

策略 1:递归字符切分(推荐)

from typing import List


class RecursiveCharacterTextSplitter:
    """递归字符文本切分器。

    先按段落切,段落太长再按句子切,句子太长再按固定长度切。
    """

    def __init__(
        self,
        chunk_size: int = 512,
        chunk_overlap: int = 128,
        separators: List[str] = None,
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.separators = separators or ["\n\n", "\n", "。", "。", ",", " ", ""]

    def split_text(self, text: str) -> List[str]:
        """递归切分文本。"""
        return self._split(text, self.separators, 0)

    def _split(self, text: str, separators: List[str], level: int) -> List[str]:
        if level >= len(separators):
            # 没有更多分隔符了,按固定长度切
            return self._split_fixed_length(text)

        separator = separators[level]
        if not separator:
            return self._split_fixed_length(text)

        chunks = []
        # 用当前分隔符切分
        for segment in text.split(separator):
            segment = segment.strip()
            if not segment:
                continue

            if len(segment)  List[str]:
        """固定长度切分(按字符数)。"""
        chunks = []
        start = 0
        while start  List[str]:
        """合并过小的 chunks。"""
        merged = []
        buffer = ""
        for c in chunks:
            if len(buffer) + len(c)  List[str]:
    """按 Markdown 标题或空行切分。"""
    import re
    # 先按标题切
    sections = re.split(r"\n#{1,3}\s+", text)
    result = []
    for s in sections:
        s = s.strip()
        if not s:
            continue
        if len(s) > 1500:
            # 太长的再按空行切
            paras = s.split("\n\n")
            for p in paras:
                p = p.strip()
                if p:
                    result.append(p)
        else:
            result.append(s)
    return result

2.3 生成元数据

每个 chunk 需要附带元数据,方便溯源和过滤:

def create_chunks_with_metadata(
    text: str,
    source: str,
    chunk_size: int = 512,
    chunk_overlap: int = 128,
) -> List[dict]:
    """切分文本并生成带元数据的 chunks。"""
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
    )
    texts = splitter.split_text(text)

    chunks = []
    for i, content in enumerate(texts):
        chunks.append({
            "id": f"{source}#chunk{i}",
            "text": content,
            "metadata": {
                "source": source,
                "chunk_index": i,
                "total_chunks": len(texts),
            },
        })
    return chunks

3. Embedding 与向量存储

3.1 Embedding 封装

import numpy as np
from typing import List


class EmbeddingModel:
    """Embedding 模型封装,支持多种后端。"""

    def __init__(self, model_type: str = "local", model_name: str = None):
        self.model_type = model_type
        self.model_name = model_name
        self._model = None

        if model_type == "local":
            self._init_local()
        elif model_type == "api":
            self._init_api()

    def _init_local(self):
        """使用本地 Sentence-Transformers 模型。"""
        try:
            from sentence_transformers import SentenceTransformer
            name = self.model_name or "BAAI/bge-small-zh-v1.5"
            self._model = SentenceTransformer(name)
            self.dimension = self._model.get_sentence_embedding_dimension()
            print(f"Loaded local model: {name} (dim={self.dimension})")
        except ImportError:
            raise ImportError("Install: pip install sentence-transformers")

    def _init_api(self):
        """使用 API 版 Embedding 模型。"""
        from openai import OpenAI
        import os
        self._client = OpenAI(
            api_key=os.getenv("LLM_API_KEY"),
            base_url=os.getenv("LLM_BASE_URL"),
        )
        self.dimension = 1024  # bge-m3 dimension

    def encode(self, texts: List[str]) -> List[List[float]]:
        """将文本列表编码为向量。"""
        if self.model_type == "local":
            emb = self._model.encode(texts, normalize_embeddings=True)
            return emb.tolist()
        else:
            response = self._client.embeddings.create(
                model=self.model_name or "BAAI/bge-m3",
                input=texts,
            )
            return [d.embedding for d in response.data]

    def encode_query(self, text: str) -> List[float]:
        """编码查询(单条)。"""
        if self.model_type == "local":
            emb = self._model.encode([text], normalize_embeddings=True)
            return emb[0].tolist()
        else:
            return self.encode([text])[0]

Embedding 选型建议

模型 维度 适用场景 推荐度
BAAI/bge-small-zh-v1.5 512 中文轻量 ⭐⭐⭐⭐⭐
BAAI/bge-m3 1024 多语言、混合检索 ⭐⭐⭐⭐⭐
text-embedding-3-small 1536 英文通用 ⭐⭐⭐⭐
m3e-base 768 中文通用 ⭐⭐⭐⭐

3.2 向量存储(内存版,适合原型)

import numpy as np
from typing import List, Dict, Optional
import json


class VectorStore:
    """内存向量存储(原型用,生产环境用 Qdrant/Milvus)。"""

    def __init__(self, dimension: int = 512):
        self.dimension = dimension
        self.vectors = []       # List[np.array]
        self.documents = []     # List[dict]
        self.index = None       # FAISS index

    def add_documents(self, documents: List[dict]):
        """添加文档到向量库。"""
        for doc in documents:
            self.documents.append({
                "id": doc["id"],
                "text": doc["text"],
                "metadata": doc.get("metadata", {}),
            })

    def add_vectors(self, vectors: List[List[float]], documents: List[dict]):
        """添加向量和文档。"""
        for vec, doc in zip(vectors, documents):
            self.vectors.append(np.array(vec, dtype=np.float32))
            self.documents.append({
                "id": doc["id"],
                "text": doc["text"],
                "metadata": doc.get("metadata", {}),
            })

    def build_index(self):
        """构建 FAISS 索引加速检索。"""
        try:
            import faiss
        except ImportError:
            raise ImportError("Install faiss: pip install faiss-cpu")

        if not self.vectors:
            return

        matrix = np.array(self.vectors).astype(np.float32)
        self.index = faiss.IndexFlatIP(self.dimension)  # 内积 = 余弦相似度
        faiss.normalize_L2(matrix)
        self.index.add(matrix)
        print(f"Index built: {self.index.ntotal} vectors")

    def search(
        self,
        query_vector: List[float],
        k: int = 5,
        score_threshold: float = 0.0,
        filter_expr: Optional[dict] = None,
    ) -> List[dict]:
        """检索 Top-K 相似文档。"""
        query = np.array([query_vector]).astype(np.float32)
        faiss.normalize_L2(query)

        if self.index is None:
            self.build_index()

        if self.index is None:
            return []

        scores, indices = self.index.search(query, k * 2)

        results = []
        for score, idx in zip(scores[0], indices[0]):
            if idx = len(self.documents):
                continue
            if score = k:
                break

        return results

4. Query 改写与检索增强

4.1 Query 改写

用户的原始问题往往不适合直接检索。改写可以大幅提升检索效果:

class QueryRewriter:
    """Query 改写引擎。"""

    def __init__(self, llm_func):
        self.llm_func = llm_func  # chat_sync or similar
        self._prompt = """你是一个检索专家。用户的问题可能不够清晰,
你需要将其改写成更适合向量检索的形式。

要求:
- 提取核心关键词和实体
- 补充同义词或相关术语
- 保持简洁,不超过 50 个字
- 如果问题已经清晰,直接返回原问题

用户问题:{question}

改写后的检索查询:"""

    def rewrite(self, question: str) -> str:
        """改写问题。"""
        reply = self.llm_func([
            {"role": "system", "content": self._prompt.format(question=question)}
        ])
        return reply.strip()

    def expand(self, question: str) -> list:
        """生成多个检索查询(HyDE 方法)。"""
        prompt = """用户提问:{question}

请生成 3 个不同的检索查询,每个一行。
这些查询应该从不同角度覆盖用户的问题。
"""
        reply = self.llm_func([
            {"role": "system", "content": prompt.format(question=question)}
        ])
        queries = [q.strip() for q in reply.split("\n") if q.strip()]
        return queries[:3]

4.2 混合检索

class HybridRetriever:
    """混合检索:向量检索 + 关键词检索(BM25)。"""

    def __init__(self, vector_store: VectorStore, alpha: float = 0.5):
        self.vector_store = vector_store
        self.alpha = alpha  # 向量检索权重,1-alpha 为 BM25 权重
        self._bm25 = None

    def _build_bm25(self, texts: List[str]):
        """构建 BM25 索引。"""
        try:
            from rank_bm25 import BM25Okapi
        except ImportError:
            raise ImportError("Install: pip install rank-bm25")

        tokenized = [t.split() for t in texts]
        self._bm25 = BM25Okapi(tokenized)

    def search(
        self,
        query_vector: List[float],
        query_text: str,
        k: int = 5,
    ) -> List[dict]:
        """混合检索。"""
        # 向量检索
        vec_results = self.vector_store.search(query_vector, k=k * 2)

        # BM25 检索
        if self._bm25:
            tokenized_query = query_text.split()
            bm25_scores = self._bm25.get_scores(tokenized_query)
            docs = self.vector_store.documents
            bm25_results = sorted(
                [
                    (i, bm25_scores[i])
                    for i in range(len(bm25_scores))
                ],
                key=lambda x: -x[1],
            )[:k]

        # RRF 合并
        rrf_scores = {}
        for rank, r in enumerate(vec_results):
            doc_id = r.get("id", str(rank))
            rrf_scores[doc_id] = rrf_scores.get(doc_id, 0) + 1 / (60 + rank + 1)

        if self._bm25:
            for rank, (idx, score) in enumerate(bm25_results):
                doc_id = docs[idx].get("id", str(idx))
                rrf_scores[doc_id] = rrf_scores.get(doc_id, 0) + 1 / (60 + rank + 1)

        sorted_results = sorted(rrf_scores.items(), key=lambda x: -x[1])
        top_ids = [doc_id for doc_id, _ in sorted_results[:k]]

        # 返回结果
        doc_map = {d.get("id", str(i)): d for i, d in enumerate(self.vector_store.documents)}
        return [doc_map[doc_id] for doc_id in top_ids if doc_id in doc_map]

5. Re-ranking

检索到的 Top-K 结果需要用 Cross-Encoder 重排序,能大幅提升最终回答质量:

class ReRanker:
    """用 Cross-Encoder 重排序检索结果。"""

    def __init__(self, model_name: str = "BAAI/bge-reranker-v2-m3"):
        try:
            from transformers import AutoModelForSequenceClassification, AutoTokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
            self.model.eval()
        except Exception as e:
            print(f"Re-ranker not available: {e}")
            self.model = None

    def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[dict]:
        """重排序并返回 Top-K。"""
        if not self.model or not documents:
            return documents[:top_k]

        pairs = [[query, doc["text"][:512]] for doc in documents]
        inputs = self.tokenizer(
            pairs, padding=True, truncation=True,
            max_length=512, return_tensors="pt",
        )

        import torch
        with torch.no_grad():
            scores = self.model(**inputs).logits.squeeze(-1).tolist()

        if isinstance(scores, float):
            scores = [scores]

        for doc, score in zip(documents, scores):
            doc["rerank_score"] = score

        documents.sort(key=lambda x: x.get("rerank_score", 0), reverse=True)
        return documents[:top_k]

为什么需要 Re-rank?

阶段 模型 速度 精度 适用
检索 Bi-Encoder(bge-m3) 快(毫秒级) 一般 初筛 Top-50
重排序 Cross-Encoder(bge-reranker) 慢(百毫秒级) 精排 Top-5

Bi-Encoder 把文本独立编码为向量,速度快但精度有限。Cross-Encoder 把 query 和 doc 一起送入模型计算相似度,精度高但慢。所以标准做法是:先用 Bi-Encoder 检索 Top-50,再用 Cross-Encoder 重排取 Top-5

6. 完整 RAG Pipeline

class RAGPipeline:
    """完整的 RAG Pipeline。"""

    def __init__(
        self,
        embedding_model: EmbeddingModel,
        llm_func,
        vector_store: VectorStore = None,
    ):
        self.embedder = embedding_model
        self.llm = llm_func
        self.vector_store = vector_store or VectorStore(embedding_model.dimension)
        self.rewriter = QueryRewriter(llm_func)
        self.reranker = ReRanker()
        self._rag_prompt = """你是一个基于知识库的 AI 助手。

请根据以下检索到的资料来回答用户的问题。
如果资料不足以回答问题,请直接说"资料中没有相关信息",不要编造。

## 检索资料

{context}

## 用户问题

{question}

请基于以上资料回答,并在回答末尾注明引用的资料来源。"""

    def ingest(self, file_path: str):
        """导入文档到知识库。"""
        # 1. 解析
        text = parse_document(file_path)
        source = os.path.basename(file_path)

        # 2. 切分
        chunks = create_chunks_with_metadata(text, source)

        # 3. 向量化
        texts = [c["text"] for c in chunks]
        vectors = self.embedder.encode(texts)

        # 4. 存入向量库
        self.vector_store.add_vectors(vectors, chunks)
        print(f"Ingested {len(chunks)} chunks from {source}")

    def query(self, question: str, k: int = 5) -> dict:
        """回答用户问题。"""
        # 1. Query 改写
        rewritten = self.rewriter.rewrite(question)
        print(f"  Rewritten: {rewritten}")

        # 2. 向量化问题
        query_vector = self.embedder.encode_query(rewritten)

        # 3. 检索
        results = self.vector_store.search(
            query_vector, k=k * 3,
        )

        # 4. Re-rank
        results = self.reranker.rerank(rewritten, results, top_k=k)

        # 5. 构建 Prompt
        context = "\n\n---\n\n".join([
            f"[来源 {r.get('metadata', {}).get('source', 'unknown')}] {r['text']}"
            for r in results
        ])

        prompt = self._rag_prompt.format(context=context, question=question)

        # 6. 生成回答
        answer = self.llm([
            {"role": "system", "content": prompt}
        ])

        return {
            "answer": answer,
            "sources": [
                {
                    "source": r.get("metadata", {}).get("source", "unknown"),
                    "text": r["text"][:200],
                    "score": r.get("score", 0),
                }
                for r in results
            ],
        }

7. 集成到之前的 AI 应用

在系列第一篇的 ai-chat-app 中集成 RAG:

# rag_app.py - 在 AI Chat 中添加 RAG 知识库功能

from flask import Flask, request, jsonify
from rag_pipeline import RAGPipeline
from llm import chat_sync

app = Flask(__name__)
rag = RAGPipeline(
    embedding_model=EmbeddingModel(model_type="local", model_name="BAAI/bge-small-zh-v1.5"),
    llm_func=chat_sync,
)

UPLOAD_FOLDER = "knowledge_base"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)


@app.route("/api/rag/upload", methods=["POST"])
def upload_document():
    """上传文档到知识库。"""
    if "file" not in request.files:
        return jsonify({"error": "No file"}), 400

    file = request.files["file"]
    file_path = os.path.join(UPLOAD_FOLDER, file.filename)
    file.save(file_path)

    try:
        rag.ingest(file_path)
        return jsonify({"status": "ok", "message": f"Imported {file.filename}"})
    except Exception as e:
        return jsonify({"error": str(e)}), 500


@app.route("/api/rag/query", methods=["POST"])
def rag_query():
    """基于知识库回答问题。"""
    data = request.json
    question = data.get("question", "")

    if not question:
        return jsonify({"error": "Question required"}), 400

    result = rag.query(question)

    return jsonify({
        "answer": result["answer"],
        "sources": result["sources"],
        "mode": "rag",
    })


@app.route("/api/rag/status", methods=["GET"])
def get_status():
    """查看知识库状态。"""
    doc_count = len(rag.vector_store.documents)
    return jsonify({
        "document_count": doc_count,
        "status": "ready" if doc_count > 0 else "empty",
    })

8. 效果调优清单

当 RAG 效果不好时,按这个清单排查:

回答不准确?
├─ 检索不到相关内容?
│   ├─ Chunk 太大/太小?→ 尝试 chunk_size=256/512/1024
│   ├─ Embedding 模型不适合中文?→ 用 bge-small-zh/bge-m3
│   ├─ Query 太模糊?→ 加 Query Rewriter
│   └─ 需要混合检索?→ 加 BM25
├─ 检索到了但答案不对?
│   ├─ Re-rank 了吗?→ 加 Cross-Encoder 重排序
│   ├─ Prompt 不够好?→ 明确要求"基于资料回答"
│   └─ 上下文太长?→ 限制使用 Top-3 而不是 Top-5
└─ 答案在胡说?
    ├─ 模型幻觉?→ Prompt 加"不知道就说不知道"
    └─ 资料质量差?→ 检查源文档质量

9. 生产环境注意事项

  1. 向量数据库选型:原型用内存/FAISS,生产用 Qdrant(单机)或 Milvus(分布式)
  2. Embedding 缓存:相同的文档内容不要重复计算向量,用 hash 做缓存
  3. 增量更新:支持文档追加而不是每次全量重建索引
  4. 权限隔离:多租户场景下,每个用户的文档向量用 metadata 中的 user_id 字段隔离
  5. 监控指标
rag_metrics = {
    "retrieval_latency_ms": 150,     # 检索耗时
    "retrieval_recall@5": 0.85,      # 前5命中率
    "generation_latency_ms": 1200,   # 生成耗时
    "total_chunks": 15000,           # 知识库总量
}

总结

RAG 系统的核心链路:

PDF/TXT  →  切分 Chunk  →  Embedding  →  向量库
用户问题  →  Query 改写  →  向量检索   →  Re-rank  →  LLM 回答

每个环节都有优化空间,但先把全链路跑通,再根据实际效果逐步调优,是最高效的方式。


本文是 《AI 应用开发实战》系列 的第 4 篇。
系列目录:
1. ✅ 从零搭建你的第一个 AI 应用
2. ✅ Prompt 工程实战
3. ✅ 多轮对话进阶
4. ✅ 从零实现 RAG 系统 ← 你在这里
5. 📝 AI Agent——工具调用与自主决策

本文由 Zyentor(智元界) 原创发布