前言
前几篇我们搭建了 AI 应用的基础骨架、Prompt 工程体系和多轮对话管理。这一篇进入 RAG(检索增强生成)——让 AI 拥有自己的知识库,能基于私有文档回答问题。
市面上讲 RAG 的文章很多,但要么太理论(只讲概念),要么太碎片(只讲某个环节)。这篇的目标是:从零实现一个完整的、可运行的 RAG 系统。
最终效果:上传 PDF/文档,自动切分、向量化、存入向量库,用户提问时检索相关片段并生成回答。
1. RAG 整体架构
用户问题
│
▼
┌──────────────┐ ┌────────────────┐
│ Query 改写 │───▶│ Embedding 模型 │
│ 改写/扩展 │ │ 向量化问题 │
└──────────────┘ └───────┬────────┘
│
▼
┌──────────────────────┐
│ 向量数据库 │
│ (Qdrant/Milvus) │
│ → Top-K 检索 │
│ → 混合检索(可选) │
└───────┬──────────────┘
│
▼
┌──────────────────────┐
│ Re-rank │
│ 重排序 Top-K 结果 │
└───────┬──────────────┘
│
▼
┌──────────────────────┐
│ LLM + Prompt │
│ 基于资料生成回答 │
└───────┬──────────────┘
│
▼
最终答案
2. 文档解析与切分
2.1 文档解析
先支持最常见的 PDF 和 TXT 格式:
import os
from pathlib import Path
def parse_document(file_path: str) -> str:
"""解析文档,返回纯文本内容。"""
ext = Path(file_path).suffix.lower()
if ext == ".txt":
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
elif ext == ".pdf":
return parse_pdf(file_path)
elif ext == ".md":
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
else:
raise ValueError(f"Unsupported format: {ext}")
def parse_pdf(file_path: str) -> str:
"""解析 PDF 文件。使用 PyMuPDF(fitz)实现。"""
try:
import fitz # PyMuPDF
except ImportError:
raise ImportError("Please install PyMuPDF: pip install pymupdf")
doc = fitz.open(file_path)
text = []
for page in doc:
text.append(page.get_text())
doc.close()
return "\n\n".join(text)
2.2 文本切分(Chunking)
切分是 RAG 效果的基础。三种常用策略:
策略 1:递归字符切分(推荐)
from typing import List
class RecursiveCharacterTextSplitter:
"""递归字符文本切分器。
先按段落切,段落太长再按句子切,句子太长再按固定长度切。
"""
def __init__(
self,
chunk_size: int = 512,
chunk_overlap: int = 128,
separators: List[str] = None,
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.separators = separators or ["\n\n", "\n", "。", "。", ",", " ", ""]
def split_text(self, text: str) -> List[str]:
"""递归切分文本。"""
return self._split(text, self.separators, 0)
def _split(self, text: str, separators: List[str], level: int) -> List[str]:
if level >= len(separators):
# 没有更多分隔符了,按固定长度切
return self._split_fixed_length(text)
separator = separators[level]
if not separator:
return self._split_fixed_length(text)
chunks = []
# 用当前分隔符切分
for segment in text.split(separator):
segment = segment.strip()
if not segment:
continue
if len(segment) List[str]:
"""固定长度切分(按字符数)。"""
chunks = []
start = 0
while start List[str]:
"""合并过小的 chunks。"""
merged = []
buffer = ""
for c in chunks:
if len(buffer) + len(c) List[str]:
"""按 Markdown 标题或空行切分。"""
import re
# 先按标题切
sections = re.split(r"\n#{1,3}\s+", text)
result = []
for s in sections:
s = s.strip()
if not s:
continue
if len(s) > 1500:
# 太长的再按空行切
paras = s.split("\n\n")
for p in paras:
p = p.strip()
if p:
result.append(p)
else:
result.append(s)
return result
2.3 生成元数据
每个 chunk 需要附带元数据,方便溯源和过滤:
def create_chunks_with_metadata(
text: str,
source: str,
chunk_size: int = 512,
chunk_overlap: int = 128,
) -> List[dict]:
"""切分文本并生成带元数据的 chunks。"""
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
texts = splitter.split_text(text)
chunks = []
for i, content in enumerate(texts):
chunks.append({
"id": f"{source}#chunk{i}",
"text": content,
"metadata": {
"source": source,
"chunk_index": i,
"total_chunks": len(texts),
},
})
return chunks
3. Embedding 与向量存储
3.1 Embedding 封装
import numpy as np
from typing import List
class EmbeddingModel:
"""Embedding 模型封装,支持多种后端。"""
def __init__(self, model_type: str = "local", model_name: str = None):
self.model_type = model_type
self.model_name = model_name
self._model = None
if model_type == "local":
self._init_local()
elif model_type == "api":
self._init_api()
def _init_local(self):
"""使用本地 Sentence-Transformers 模型。"""
try:
from sentence_transformers import SentenceTransformer
name = self.model_name or "BAAI/bge-small-zh-v1.5"
self._model = SentenceTransformer(name)
self.dimension = self._model.get_sentence_embedding_dimension()
print(f"Loaded local model: {name} (dim={self.dimension})")
except ImportError:
raise ImportError("Install: pip install sentence-transformers")
def _init_api(self):
"""使用 API 版 Embedding 模型。"""
from openai import OpenAI
import os
self._client = OpenAI(
api_key=os.getenv("LLM_API_KEY"),
base_url=os.getenv("LLM_BASE_URL"),
)
self.dimension = 1024 # bge-m3 dimension
def encode(self, texts: List[str]) -> List[List[float]]:
"""将文本列表编码为向量。"""
if self.model_type == "local":
emb = self._model.encode(texts, normalize_embeddings=True)
return emb.tolist()
else:
response = self._client.embeddings.create(
model=self.model_name or "BAAI/bge-m3",
input=texts,
)
return [d.embedding for d in response.data]
def encode_query(self, text: str) -> List[float]:
"""编码查询(单条)。"""
if self.model_type == "local":
emb = self._model.encode([text], normalize_embeddings=True)
return emb[0].tolist()
else:
return self.encode([text])[0]
Embedding 选型建议:
| 模型 | 维度 | 适用场景 | 推荐度 |
|---|---|---|---|
| BAAI/bge-small-zh-v1.5 | 512 | 中文轻量 | ⭐⭐⭐⭐⭐ |
| BAAI/bge-m3 | 1024 | 多语言、混合检索 | ⭐⭐⭐⭐⭐ |
| text-embedding-3-small | 1536 | 英文通用 | ⭐⭐⭐⭐ |
| m3e-base | 768 | 中文通用 | ⭐⭐⭐⭐ |
3.2 向量存储(内存版,适合原型)
import numpy as np
from typing import List, Dict, Optional
import json
class VectorStore:
"""内存向量存储(原型用,生产环境用 Qdrant/Milvus)。"""
def __init__(self, dimension: int = 512):
self.dimension = dimension
self.vectors = [] # List[np.array]
self.documents = [] # List[dict]
self.index = None # FAISS index
def add_documents(self, documents: List[dict]):
"""添加文档到向量库。"""
for doc in documents:
self.documents.append({
"id": doc["id"],
"text": doc["text"],
"metadata": doc.get("metadata", {}),
})
def add_vectors(self, vectors: List[List[float]], documents: List[dict]):
"""添加向量和文档。"""
for vec, doc in zip(vectors, documents):
self.vectors.append(np.array(vec, dtype=np.float32))
self.documents.append({
"id": doc["id"],
"text": doc["text"],
"metadata": doc.get("metadata", {}),
})
def build_index(self):
"""构建 FAISS 索引加速检索。"""
try:
import faiss
except ImportError:
raise ImportError("Install faiss: pip install faiss-cpu")
if not self.vectors:
return
matrix = np.array(self.vectors).astype(np.float32)
self.index = faiss.IndexFlatIP(self.dimension) # 内积 = 余弦相似度
faiss.normalize_L2(matrix)
self.index.add(matrix)
print(f"Index built: {self.index.ntotal} vectors")
def search(
self,
query_vector: List[float],
k: int = 5,
score_threshold: float = 0.0,
filter_expr: Optional[dict] = None,
) -> List[dict]:
"""检索 Top-K 相似文档。"""
query = np.array([query_vector]).astype(np.float32)
faiss.normalize_L2(query)
if self.index is None:
self.build_index()
if self.index is None:
return []
scores, indices = self.index.search(query, k * 2)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx = len(self.documents):
continue
if score = k:
break
return results
4. Query 改写与检索增强
4.1 Query 改写
用户的原始问题往往不适合直接检索。改写可以大幅提升检索效果:
class QueryRewriter:
"""Query 改写引擎。"""
def __init__(self, llm_func):
self.llm_func = llm_func # chat_sync or similar
self._prompt = """你是一个检索专家。用户的问题可能不够清晰,
你需要将其改写成更适合向量检索的形式。
要求:
- 提取核心关键词和实体
- 补充同义词或相关术语
- 保持简洁,不超过 50 个字
- 如果问题已经清晰,直接返回原问题
用户问题:{question}
改写后的检索查询:"""
def rewrite(self, question: str) -> str:
"""改写问题。"""
reply = self.llm_func([
{"role": "system", "content": self._prompt.format(question=question)}
])
return reply.strip()
def expand(self, question: str) -> list:
"""生成多个检索查询(HyDE 方法)。"""
prompt = """用户提问:{question}
请生成 3 个不同的检索查询,每个一行。
这些查询应该从不同角度覆盖用户的问题。
"""
reply = self.llm_func([
{"role": "system", "content": prompt.format(question=question)}
])
queries = [q.strip() for q in reply.split("\n") if q.strip()]
return queries[:3]
4.2 混合检索
class HybridRetriever:
"""混合检索:向量检索 + 关键词检索(BM25)。"""
def __init__(self, vector_store: VectorStore, alpha: float = 0.5):
self.vector_store = vector_store
self.alpha = alpha # 向量检索权重,1-alpha 为 BM25 权重
self._bm25 = None
def _build_bm25(self, texts: List[str]):
"""构建 BM25 索引。"""
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError("Install: pip install rank-bm25")
tokenized = [t.split() for t in texts]
self._bm25 = BM25Okapi(tokenized)
def search(
self,
query_vector: List[float],
query_text: str,
k: int = 5,
) -> List[dict]:
"""混合检索。"""
# 向量检索
vec_results = self.vector_store.search(query_vector, k=k * 2)
# BM25 检索
if self._bm25:
tokenized_query = query_text.split()
bm25_scores = self._bm25.get_scores(tokenized_query)
docs = self.vector_store.documents
bm25_results = sorted(
[
(i, bm25_scores[i])
for i in range(len(bm25_scores))
],
key=lambda x: -x[1],
)[:k]
# RRF 合并
rrf_scores = {}
for rank, r in enumerate(vec_results):
doc_id = r.get("id", str(rank))
rrf_scores[doc_id] = rrf_scores.get(doc_id, 0) + 1 / (60 + rank + 1)
if self._bm25:
for rank, (idx, score) in enumerate(bm25_results):
doc_id = docs[idx].get("id", str(idx))
rrf_scores[doc_id] = rrf_scores.get(doc_id, 0) + 1 / (60 + rank + 1)
sorted_results = sorted(rrf_scores.items(), key=lambda x: -x[1])
top_ids = [doc_id for doc_id, _ in sorted_results[:k]]
# 返回结果
doc_map = {d.get("id", str(i)): d for i, d in enumerate(self.vector_store.documents)}
return [doc_map[doc_id] for doc_id in top_ids if doc_id in doc_map]
5. Re-ranking
检索到的 Top-K 结果需要用 Cross-Encoder 重排序,能大幅提升最终回答质量:
class ReRanker:
"""用 Cross-Encoder 重排序检索结果。"""
def __init__(self, model_name: str = "BAAI/bge-reranker-v2-m3"):
try:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model.eval()
except Exception as e:
print(f"Re-ranker not available: {e}")
self.model = None
def rerank(self, query: str, documents: List[dict], top_k: int = 5) -> List[dict]:
"""重排序并返回 Top-K。"""
if not self.model or not documents:
return documents[:top_k]
pairs = [[query, doc["text"][:512]] for doc in documents]
inputs = self.tokenizer(
pairs, padding=True, truncation=True,
max_length=512, return_tensors="pt",
)
import torch
with torch.no_grad():
scores = self.model(**inputs).logits.squeeze(-1).tolist()
if isinstance(scores, float):
scores = [scores]
for doc, score in zip(documents, scores):
doc["rerank_score"] = score
documents.sort(key=lambda x: x.get("rerank_score", 0), reverse=True)
return documents[:top_k]
为什么需要 Re-rank?
| 阶段 | 模型 | 速度 | 精度 | 适用 |
|---|---|---|---|---|
| 检索 | Bi-Encoder(bge-m3) | 快(毫秒级) | 一般 | 初筛 Top-50 |
| 重排序 | Cross-Encoder(bge-reranker) | 慢(百毫秒级) | 高 | 精排 Top-5 |
Bi-Encoder 把文本独立编码为向量,速度快但精度有限。Cross-Encoder 把 query 和 doc 一起送入模型计算相似度,精度高但慢。所以标准做法是:先用 Bi-Encoder 检索 Top-50,再用 Cross-Encoder 重排取 Top-5。
6. 完整 RAG Pipeline
class RAGPipeline:
"""完整的 RAG Pipeline。"""
def __init__(
self,
embedding_model: EmbeddingModel,
llm_func,
vector_store: VectorStore = None,
):
self.embedder = embedding_model
self.llm = llm_func
self.vector_store = vector_store or VectorStore(embedding_model.dimension)
self.rewriter = QueryRewriter(llm_func)
self.reranker = ReRanker()
self._rag_prompt = """你是一个基于知识库的 AI 助手。
请根据以下检索到的资料来回答用户的问题。
如果资料不足以回答问题,请直接说"资料中没有相关信息",不要编造。
## 检索资料
{context}
## 用户问题
{question}
请基于以上资料回答,并在回答末尾注明引用的资料来源。"""
def ingest(self, file_path: str):
"""导入文档到知识库。"""
# 1. 解析
text = parse_document(file_path)
source = os.path.basename(file_path)
# 2. 切分
chunks = create_chunks_with_metadata(text, source)
# 3. 向量化
texts = [c["text"] for c in chunks]
vectors = self.embedder.encode(texts)
# 4. 存入向量库
self.vector_store.add_vectors(vectors, chunks)
print(f"Ingested {len(chunks)} chunks from {source}")
def query(self, question: str, k: int = 5) -> dict:
"""回答用户问题。"""
# 1. Query 改写
rewritten = self.rewriter.rewrite(question)
print(f" Rewritten: {rewritten}")
# 2. 向量化问题
query_vector = self.embedder.encode_query(rewritten)
# 3. 检索
results = self.vector_store.search(
query_vector, k=k * 3,
)
# 4. Re-rank
results = self.reranker.rerank(rewritten, results, top_k=k)
# 5. 构建 Prompt
context = "\n\n---\n\n".join([
f"[来源 {r.get('metadata', {}).get('source', 'unknown')}] {r['text']}"
for r in results
])
prompt = self._rag_prompt.format(context=context, question=question)
# 6. 生成回答
answer = self.llm([
{"role": "system", "content": prompt}
])
return {
"answer": answer,
"sources": [
{
"source": r.get("metadata", {}).get("source", "unknown"),
"text": r["text"][:200],
"score": r.get("score", 0),
}
for r in results
],
}
7. 集成到之前的 AI 应用
在系列第一篇的 ai-chat-app 中集成 RAG:
# rag_app.py - 在 AI Chat 中添加 RAG 知识库功能
from flask import Flask, request, jsonify
from rag_pipeline import RAGPipeline
from llm import chat_sync
app = Flask(__name__)
rag = RAGPipeline(
embedding_model=EmbeddingModel(model_type="local", model_name="BAAI/bge-small-zh-v1.5"),
llm_func=chat_sync,
)
UPLOAD_FOLDER = "knowledge_base"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
@app.route("/api/rag/upload", methods=["POST"])
def upload_document():
"""上传文档到知识库。"""
if "file" not in request.files:
return jsonify({"error": "No file"}), 400
file = request.files["file"]
file_path = os.path.join(UPLOAD_FOLDER, file.filename)
file.save(file_path)
try:
rag.ingest(file_path)
return jsonify({"status": "ok", "message": f"Imported {file.filename}"})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/rag/query", methods=["POST"])
def rag_query():
"""基于知识库回答问题。"""
data = request.json
question = data.get("question", "")
if not question:
return jsonify({"error": "Question required"}), 400
result = rag.query(question)
return jsonify({
"answer": result["answer"],
"sources": result["sources"],
"mode": "rag",
})
@app.route("/api/rag/status", methods=["GET"])
def get_status():
"""查看知识库状态。"""
doc_count = len(rag.vector_store.documents)
return jsonify({
"document_count": doc_count,
"status": "ready" if doc_count > 0 else "empty",
})
8. 效果调优清单
当 RAG 效果不好时,按这个清单排查:
回答不准确?
├─ 检索不到相关内容?
│ ├─ Chunk 太大/太小?→ 尝试 chunk_size=256/512/1024
│ ├─ Embedding 模型不适合中文?→ 用 bge-small-zh/bge-m3
│ ├─ Query 太模糊?→ 加 Query Rewriter
│ └─ 需要混合检索?→ 加 BM25
├─ 检索到了但答案不对?
│ ├─ Re-rank 了吗?→ 加 Cross-Encoder 重排序
│ ├─ Prompt 不够好?→ 明确要求"基于资料回答"
│ └─ 上下文太长?→ 限制使用 Top-3 而不是 Top-5
└─ 答案在胡说?
├─ 模型幻觉?→ Prompt 加"不知道就说不知道"
└─ 资料质量差?→ 检查源文档质量
9. 生产环境注意事项
- 向量数据库选型:原型用内存/FAISS,生产用 Qdrant(单机)或 Milvus(分布式)
- Embedding 缓存:相同的文档内容不要重复计算向量,用 hash 做缓存
- 增量更新:支持文档追加而不是每次全量重建索引
- 权限隔离:多租户场景下,每个用户的文档向量用 metadata 中的 user_id 字段隔离
- 监控指标:
rag_metrics = {
"retrieval_latency_ms": 150, # 检索耗时
"retrieval_recall@5": 0.85, # 前5命中率
"generation_latency_ms": 1200, # 生成耗时
"total_chunks": 15000, # 知识库总量
}
总结
RAG 系统的核心链路:
PDF/TXT → 切分 Chunk → Embedding → 向量库
用户问题 → Query 改写 → 向量检索 → Re-rank → LLM 回答
每个环节都有优化空间,但先把全链路跑通,再根据实际效果逐步调优,是最高效的方式。
本文是 《AI 应用开发实战》系列 的第 4 篇。
系列目录:
1. ✅ 从零搭建你的第一个 AI 应用
2. ✅ Prompt 工程实战
3. ✅ 多轮对话进阶
4. ✅ 从零实现 RAG 系统 ← 你在这里
5. 📝 AI Agent——工具调用与自主决策本文由 Zyentor(智元界) 原创发布