python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > Python RAG检索增强

Python实现RAG检索增强生成的完整教学

作者:程序猿の搬砖日记

检索增强生成(Retrieval-Augmented Generation, RAG)是当下最热门的 AI 应用架构之一,本文用 Python 从零实现一个完整的 RAG 系统,让大模型基于你的私有知识精准回答,希望对大家有所帮助

一、为什么需要 RAG

大模型的两个致命限制:

限制具体表现后果
知识截止训练数据有时间窗口不知道最新发生的事
私有数据盲区只见过公开数据不懂你公司的业务

解决思路有两种:

方案A:微调(Fine-tuning)方案B:RAG
重新训练模型不改模型,外 挂知识库
成本高、周期长成本低、即插即用
知识"烧进"权重知识实时检索
更新知识需要重新训练随时增删文档

90% 的场景,RAG 是更好的选择。

二、RAG 核心原理:三步走

RAG 的工作流程 可以用一句话概括:先检索,再生成。

┌──────────┐     ┌──────────┐     ┌──────────┐
│  第一步   │     │  第二步   │     │  第三步   │
│  文档处理  │────▶│  向量检索  │────▶│  增强生成  │
│          │     │          │     │          │
│ 文档切分  │     │ 问题向量化 │     │ 拼接上下文 │
│ 向量化存储 │     │ 相似度匹配  │     │ 大模型生成  │
└──────────┘     └──────────┘     └──────────┘

详细流程:

三、环境准备

pip install openai chromadb sentence-transformers numpy

说明:chromadb 是轻量级向量数据库,sentence-transformers 用于本地文本向量化(不依赖 API)。

四、实战 1:最小 RAG 系统 ——50 行代码搞定

先看一个最简版,理解核心逻辑:

# mini_rag.py
import numpy as np
from openai import OpenAI
client = OpenAI()
# ========== 知识库 ==========
documents = [
    "公司的年假政策:入职满1年有5天年假,满3年有10天年假,满5年有15天年假。",
    "报销流程:填写报销单 → 部门主管审批 → 财务审核 → 打款,一般3-5个工作日完成。",
    "上班时间为周一至周五 9:00-18:00,午休时间 12:00-13:30。",
    "远程办公政策:每周可申请最多2天远程办公,需提前一天在OA系统申请。",
    "试用期薪资为正式薪资的80%,试用期一般为3个月,表现优秀可提前转正。",
    "公司提供五险一金,公积金缴纳比例为12%,个人和公司各承担一半。",
    "年终奖发放规则:入职满一年的员工可获得1-3个月薪资的年终奖。",
    "加班政策:工作日加班按1.5倍计算,周末加班按2倍计算,法定节假日按3倍计算。",
]
def simple_embed(texts: list[str]) -> list[list[float]]:
    """使用 OpenAI API 生成文本向量"""
    response = client.embeddings.create(
        model="text-embedding-3-small",
        input=texts
    )
    return [item.embedding for item in response.data]
def cosine_similarity(a, b):
    """计算余弦相似度"""
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def rag_query(question: str, top_k: int = 2) -> str:
    """最简 RAG 查询"""
    # 1. 检索:找最相关的文档
    doc_vectors = simple_embed(documents)
    question_vector = simple_embed([question])[0]
    # 计算相似度并排序
    scores = [cosine_similarity(question_vector, dv) for dv in doc_vectors]
    top_indices = np.argsort(scores)[-top_k:][::-1]
    # 2. 拼接上下文
    context = "\n".join([f"[文档{i+1}] {documents[idx]}" for i, idx in enumerate(top_indices)])
    # 3. 生成回答
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": f"根据以下参考文档回答用户问题。如果文档中没有相关信息,请说明。\n\n参考文档:\n{context}"},
            {"role": "user", "content": question}
        ]
    )
    return response.choices[0].message.content
# 测试
if __name__ == "__main__":
    questions = [
        "我入职半年了,能休多少天年假?",
        "我想报销一笔费用,流程是什么?",
        "加班工资怎么算?",
        "公司年终奖怎么发的?",
    ]
    for q in questions:
        print(f"❓ 问:{q}")
        answer = rag_query(q)
        print(f"💬 答:{answer}")
        print("-" * 60)

运行效果:

❓ 问:我入职半年了,能休多少天年假?
💬 答:根据公司政策,入职满1年才有5天年假。您目前入职半年,还未满1年,
暂时还不能享受年假。建议您在入职满1年后再申请年假。
--------------------------------------------------
❓ 问:加班工资怎么算?
💬 答:工作日加班按1.5倍计算,周末加班按2倍计算,法定节假日按3倍计算。
--------------------------------------------------

可以看到,模型不是在"瞎编",而是严格基于检索到的文档回答。

五、实战 2:生产级 RAG —— 完整知识库问答系统

最小版本够理解原理,但生产环境需要更健壮的实现。

5.1 文档加载与切分

# chunker.py
import re
from dataclasses import dataclass
@dataclass
class Chunk:
    """文档块"""
    content: str
    metadata: dict  # 存储来源、页码等元信息
def split_text(
    text: str,
    chunk_size: int = 300,
    chunk_overlap: int = 50,
    separator: str = "\n"
) -> list[Chunk]:
    """
    智能文本切分
    Args:
        text: 原始文本
        chunk_size: 每个块的最大字符数
        chunk_overlap: 相邻块的重叠字符数
        separator: 切分分隔符
    Returns:
        切分后的文档块列表
    """
    # 按分隔符先切分
    segments = text.split(separator)
    segments = [s.strip() for s in segments if s.strip()]
    chunks = []
    current_chunk = ""
    for segment in segments:
        # 如果单个段落就超长,按句子再切
        if len(segment) > chunk_size:
            sentences = re.split(r'[。!?;\.\!\?;]', segment)
            sentences = [s.strip() for s in sentences if s.strip()]
        else:
            sentences = [segment]
        for sentence in sentences:
            if len(current_chunk) + len(sentence) + 1 > chunk_size:
                if current_chunk:
                    chunks.append(Chunk(
                        content=current_chunk,
                        metadata={"char_count": len(current_chunk)}
                    ))
                # 保留重叠部分
                overlap_text = current_chunk[-chunk_overlap:] if chunk_overlap > 0 else ""
                current_chunk = overlap_text + sentence
            else:
                current_chunk = current_chunk + separator + sentence if current_chunk else sentence
    if current_chunk:
        chunks.append(Chunk(
            content=current_chunk,
            metadata={"char_count": len(current_chunk)}
        ))
    return chunks

5.2 向量数据库

# vector_store.py
import chromadb
from chromadb.config import Settings
class KnowledgeBase:
    """基于 ChromaDB 的知识库"""
    def __init__(self, collection_name: str = "my_knowledge"):
        self.client = chromadb.PersistentClient(path="./chroma_db")
        self.collection = self.client.get_or_create_collection(
            name=collection_name,
            metadata={"hnsw:space": "cosine"}  # 使用余弦相似度
        )
        self.embed_fn = None
    def set_embed_function(self, embed_fn):
        """设置向量化函数"""
        self.embed_fn = embed_fn
    def add_documents(self, chunks: list):
        """添加文档块到知识库"""
        if not self.embed_fn:
            raise ValueError("请先设置向量化函数")
        contents = [chunk.content for chunk in chunks]
        embeddings = self.embed_fn(contents)
        ids = [f"doc_{i}" for i in range(self.collection.count(), self.collection.count() + len(chunks))]
        metadatas = [chunk.metadata for chunk in chunks]
        self.collection.add(
            ids=ids,
            documents=contents,
            embeddings=embeddings,
            metadatas=metadatas
        )
        print(f"✅ 已添加 {len(chunks)} 个文档块,知识库总量:{self.collection.count()}")
    def search(self, query: str, top_k: int = 3) -> list[dict]:
        """检索最相关的文档块"""
        query_embedding = self.embed_fn([query])[0]
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=top_k,
            include=["documents", "metadatas", "distances"]
        )
        return [
            {
                "content": doc,
                "metadata": meta,
                "score": 1 - dist  # 距离转相似度
            }
            for doc, meta, dist in zip(
                results["documents"][0],
                results["metadatas"][0],
                results["distances"][0]
            )
        ]

5.3 本地向量化模型 (免费、无需 API)

# embedder.py
from sentence_transformers import SentenceTransformer
class LocalEmbedder:
    """使用本地模型生成向量,完全免费"""
    def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"):
        """
        中文文本向量化模型
        Args:
            model_name: 模型名称,首次使用会自动下载
                       推荐:shibing624/text2vec-base-chinese(中文,1024维)
                       备选:BAAI/bge-small-zh-v1.5(中文,512维,更快)
        """
        print(f"正在加载向量化模型:{model_name}...")
        self.model = SentenceTransformer(model_name)
        print("模型加载完成!")
    def embed(self, texts: list[str]) -> list[list[float]]:
        """批量生成文本向量"""
        embeddings = self.model.encode(texts, show_progress_bar=False)
        return embeddings.tolist()

5.4 完整 RAG 问答系统

# rag_system.py
import json
from openai import OpenAI
from chunker import Chunk, split_text
from vector_store import KnowledgeBase
from embedder import LocalEmbedder
client = OpenAI()
class RAGSystem:
    """完整的 RAG 问答系统"""
    def __init__(self):
        # 初始化向量化模型
        self.embedder = LocalEmbedder()
        # 初始化知识库
        self.kb = KnowledgeBase()
        self.kb.set_embed_function(self.embedder.embed)
    def ingest(self, text: str, source: str = "unknown"):
        """将文档导入知识库"""
        chunks = split_text(text, chunk_size=300, chunk_overlap=50)
        # 添加来源信息
        for chunk in chunks:
            chunk.metadata["source"] = source
        self.kb.add_documents(chunks)
    def query(self, question: str, top_k: int = 3, show_context: bool = False) -> str:
        """
        问答主函数
        Args:
            question: 用户问题
            top_k: 检索的文档块数量
            show_context: 是否显示检索到的上下文
        Returns:
            模型回答
        """
        # 1. 检索相关文档
        results = self.kb.search(question, top_k=top_k)
        if not results:
            return "抱歉,知识库中暂无相关文档。"
        # 过滤低相关度结果
        results = [r for r in results if r["score"] > 0.3]
        if show_context:
            print("\n📄 检索到的相关文档:")
            for i, r in enumerate(results):
                print(f"  [{i+1}] (相似度: {r['score']:.3f}) {r['content'][:100]}...")
        # 2. 拼接上下文
        context_parts = []
        for i, r in enumerate(results):
            context_parts.append(f"[参考资料{i+1}](来源:{r['metadata'].get('source', '未知')})\n{r['content']}")
        context = "\n\n".join(context_parts)
        # 3. 调用大模型生成回答
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {
                    "role": "system",
                    "content": f"""你是一个精准的知识库问答助手。请严格根据以下参考资料回答用户问题。
规则:
1. 只根据参考资料中的信息回答,不要编造
2. 如果参考资料中没有相关信息,明确说明
3. 引用信息时标注出处(如"根据参考资料1")
4. 回答要简洁准确
参考资料:
{context}"""
                },
                {"role": "user", "content": question}
            ],
            temperature=0.1  # 低温度,减少"幻觉"
        )
        return response.choices[0].message.content
# ========== 使用示例 ==========
if __name__ == "__main__":
    # 初始化 RAG 系统
    rag = RAGSystem()
    # 导入知识文档
    company_docs = """
    公司员工手册 2026版
    第一章 考勤管理
    上班时间为周一至周五 9:00-18:00,午休时间 12:00-13:30。
    迟到15分钟以内扣50元,超过15分钟按旷工半天处理。
    每月全勤奖200元,需当月无迟到早退记录。
    第二章 休假制度
    年假:入职满1年5天,满3年10天,满5年15天,满10年20天。
    病假:每年带薪病假5天,需提供医院证明。
    事假:需提前申请,每年事假不超过10天,超过部分按旷工处理。
    婚假:法定婚假3天,晚婚增加7天。
    产假:女员工产假158天,男员工陪产假15天。
    第三章 薪酬福利
    薪资结构:基本工资 + 绩效奖金 + 餐补(500元/月) + 交通补贴(300元/月)。
    五险一金:公积金缴纳比例12%,社保按国家标准缴纳。
    年终奖:入职满一年可获1-3个月薪资,根据年度绩效评定。
    调薪:每年4月和10月各有一次调薪窗口。
    第四章 培训发展
    新员工入职培训为期3天,包含公司文化、制度规范、安全教育。
    每季度有一次内部技术分享会。
    每年可申请最高5000元的外部培训费用报销。
    晋升评审每半年一次,分别在1月和7月。
    """.strip()
    rag.ingest(company_docs, source="公司员工手册2026版")
    # 测试问答
    questions = [
        "我刚入职,可以请年假吗?",
        "迟到会怎样?",
        "我想参加外部培训,公司有补贴吗?",
        "公司什么时候调薪?",
    ]
    for q in questions:
        print(f"\n{'='*60}")
        print(f"❓ 问:{q}")
        print(f"{'='*60}")
        answer = rag.query(q, show_context=True)
        print(f"\n💬 答:{answer}")

运行效果:

❓ 问:我想参加外部培训,公司有补贴吗?
📄 检索到的相关文档:
  [1] (相似度: 0.872) 晋升评审每半年一次,分别在1月和7月。每年可申请最高5000元的外部培训费用报...
  [2] (相似度: 0.845) 新员工入职培训为期3天,包含公司文化、制度规范、安全教育。每季度有一次内部技...

💬 答:根据参考资料1,公司每年可申请最高5000元的外部培训费用报销。(来源:公司员工手册2026版)

六、Chunk 切分策略详解

切分策略直接影响检索质量:

策略适用场景优点缺点
固定长度切分通用场景简单高效可能切断语义
按段落切分结构化文档保持语义完整块大小不均匀
按句子切分短文档粒度精细上下文可能不足
递归切分复杂文档自适应实现较复杂
语义切分高质量需求效果最好计算成本高

推荐做法:先用按段落切分,再对超长段落做固定长度切分 + 重叠(本文的实现方式)。

关键参数选择:

# 推荐参数
chunk_size = 300     # 中文场景 200-500 字符
chunk_overlap = 50   # 重叠 10%-20%
top_k = 3            # 检索 3-5 个块

七、向量检索 vs  关键词检索

维度关键词检索(BM25)向量检索(Embedding)混合检索
原理词频匹配语义相似度两者结合
精确匹配
语义理解
专有名词
推荐场景简单搜索问答系统生产环境

生产环境建议使用混合检索:

def hybrid_search(query: str, kb, bm25_results: list, alpha: float = 0.7) -> list:
    """
    混合检索:结合向量检索和关键词检索
    Args:
        alpha: 向量检索的权重(0-1),1.0表示纯向量检索
    """
    vector_results = kb.search(query, top_k=5)
    # 简单的分数融合
    combined = {}
    for r in vector_results:
        key = r["content"]
        combined[key] = combined.get(key, 0) + alpha * r["score"]
    for r in bm25_results:
        key = r["content"]
        combined[key] = combined.get(key, 0) + (1 - alpha) * r["score"]
    # 按综合分数排序
    return sorted(combined.items(), key=lambda x: -x[1])

八、评估 RAG 系统效果

怎么知道你的 RAG 系统好不好?需要从两个维度评估:

8.1 检索质量

def evaluate_retrieval(test_cases: list[dict], rag_system) -> dict:
    """
    评估检索质量
    test_cases 格式:
    [
        {"question": "年假多少天", "expected_keywords": ["5天", "10天", "15天"]},
        ...
    ]
    """
    results = {"total": len(test_cases), "hit": 0, "details": []}
    for case in test_cases:
        search_results = rag_system.kb.search(case["question"], top_k=3)
        retrieved_text = " ".join([r["content"] for r in search_results])
        # 检查关键词是否被检索到
        hit = any(kw in retrieved_text for kw in case["expected_keywords"])
        if hit:
            results["hit"] += 1
        results["details"].append({
            "question": case["question"],
            "hit": hit,
            "expected": case["expected_keywords"],
            "retrieved": retrieved_text[:200]
        })
    results["recall"] = results["hit"] / results["total"]
    return results
# 使用示例
test_data = [
    {"question": "年假多少天", "expected_keywords": ["5天", "10天", "15天"]},
    {"question": "迟到怎么扣钱", "expected_keywords": ["15分钟", "50元"]},
    {"question": "培训报销额度", "expected_keywords": ["5000"]},
]
eval_result = evaluate_retrieval(test_data, rag)
print(f"检索召回率:{eval_result['recall']:.1%}")

8.2 生成质量(人工评估更可靠)

评估维度说明评分标准
准确性回答是否正确1-5分
完整性信息是否全面1-5分
忠实性是否忠于原文1-5分
简洁性是否废话少1-5分

九、常见问题与优化

Q1:检索不到相关内容?

原因:Chunk 太大或太小,或向量化模型不适合中文。

解决方案:

# 1. 调整 chunk_size
chunk_size = 200  # 试试更小的块
# 2. 增加 top_k
top_k = 5  # 检索更多候选
# 3. 换一个更好的向量化模型
# BAAI/bge-large-zh-v1.5(中文最佳)

Q2:回答出现"幻觉"(编造信息)?

解决方案:

# 1. 降低 temperature
temperature = 0.0
# 2. 强化系统提示词
system_prompt = """严格根据参考资料回答。如果参考资料中没有相关信息,
必须回答"根据现有资料,我无法回答这个问题",绝不可编造。"""
# 3. 添加来源引用要求
system_prompt += "\n每个回答必须标注参考资料的编号。"

Q3:多轮对话如何处理?

def multi_turn_rag(messages: list[dict], rag_system) -> str:
    """支持多轮对话的 RAG"""
    # 用最后一条用户消息检索
    last_question = messages[-1]["content"]
    # 如果是追问,结合上下文改写问题
    if len(messages) > 1:
        rewrite_prompt = f"根据对话历史,将用户的最新提问改写为独立问题:\n"
        for msg in messages:
            rewrite_prompt += f"{msg['role']}: {msg['content']}\n"
        rewrite = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": rewrite_prompt}]
        )
        search_query = rewrite.choices[0].message.content
    else:
        search_query = last_question
    return rag_system.query(search_query)

Q4:知识库很大怎么办?

知识库规模推荐方案
< 1000 篇文档ChromaDB(轻量本地)
1000-10万 篇Milvus / Qdrant(专业向量数据库)
> 10万 篇Elasticsearch + 向量检索混合方案

总结

RAG 的核心三步走:

文档切分 → 向量化存储 → 检索 + 生成

组件作用推荐选择
文档切分控制检索粒度按段落 + 固定长度
向量化模型将文本转为向量text2vec-base-chinese / BGE
向量数据库存储和检索向量ChromaDB / Milvus
大语言模型生成最终回答GPT-4o-mini / DeepSeek

RAG 让大模型拥有了你的私有知识,是 AI 落地企业应用的第一步。掌握了它,你就可以构建智能客服、文档问答、知识助手等各种应用。

以上就是Python实现RAG检索增强生成的完整教学的详细内容,更多关于Python RAG检索增强的资料请关注脚本之家其它相关文章!

您可能感兴趣的文章:
阅读全文