RAG实施 RAGImplementation

这项技能涉及使用 LangChain 框架实施检索增强型生成(RAG),包括文档处理、嵌入生成、向量存储检索、检索策略、提示构建、响应生成、上下文窗口管理、系统评估和性能优化等关键环节,旨在提高基于知识的问答系统的准确性和效率。

RAG应用 0 次安装 0 次浏览 更新于 3/5/2026

RAG 实施

概览

使用 LangChain 实施检索增强型生成(RAG)的全面指南。

前提条件

  • 了解向量数据库和嵌入
  • 熟悉 LangChain 框架
  • 了解文档处理和分块
  • 具有与 LLM 集成的经验
  • 理解语义搜索概念
  • 基本了解检索策略

核心概念

  • RAG 架构:检索增强型生成结合知识库检索与 LLM 生成
  • 文档分块:将文档分割成较小的部分,以便于高效处理和检索
  • 嵌入:将文本转换为向量表示,用于语义相似性搜索
  • 向量存储:优化存储和检索向量嵌入的数据库(FAISS、Chroma、Pinecone、Weaviate)
  • 语义搜索:基于含义而非关键词匹配来查找文档
  • 混合搜索:结合语义搜索和关键词搜索(BM25)以获得更好的结果
  • 重新排名:通过更精确的模型重新评分,改进检索结果
  • 上下文窗口管理:在令牌限制内管理传递给 LLM 的上下文量
  • 多查询检索:从单个输入生成多个查询,以提高检索覆盖率
  • 自查询:从查询中提取元数据过滤器,以优化搜索结果
  • 父文档检索:检索完整的父文档而非片段,以获得更好的上下文
  • RAG 评估:衡量检索和生成质量的精确度、召回率和 F1 分数
  • 缓存:存储频繁询问的查询和响应以提高性能
  • 批量处理:同时处理多个查询以提高效率

1. RAG 架构概览

┌─────────────┐
│   文档     │
└──────┬──────┘
       │
       ▼
┌─────────────┐
│   分块     │
└──────┬──────┘
       │
       ▼
┌─────────────┐
│ 嵌入     │
└──────┬──────┘
       │
       ▼
┌─────────────┐
│ 向量存储   │
└──────┬──────┘
       │
       ▼
┌─────────────┐
│  检索     │
└──────┬──────┘
       │
       ▼
┌─────────────┐
│  提示     │
└──────┬──────┘
       │
       ▼
┌─────────────┐
│    LLM     │
└─────────────┘

2. 文档处理

2.1 分块策略

from langchain.text_splitter import (
    RecursiveCharacterTextSplitter,
    CharacterTextSplitter,
    TokenTextSplitter,
    SemanticChunker
)
from langchain.docstore.document import 文档

class 文档分块器:
    """将文档分割成块。"""

    def __init__(self, chunk_size=1000, chunk_overlap=200):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

    def character_split(self, text: str) -> list[文档]:
        """按字符数分割。"""
        splitter = CharacterTextSplitter(
            chunk_size=self.chunk_size,
            chunk_overlap=self.chunk_overlap,
            separator="

"
        )
        return splitter.create_documents([text])

    def recursive_split(self, text: str) -> list[文档]:
        """按分隔符递归分割。"""
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.chunk_size,
            chunk_overlap=self.chunk_overlap,
            length_function=len,
            separators=["

", "
", ". ", " ", ""]
        )
        return splitter.create_documents([text])

    def token_split(self, text: str, tokenizer: str = "gpt-4") -> list[文档]:
        """按令牌数分割。"""
        splitter = TokenTextSplitter(
            chunk_size=self.chunk_size,
            chunk_overlap=self.chunk_overlap,
            encoding_name=tokenizer
        )
        return splitter.create_documents([text])

    def semantic_split(self, text: str) -> list[文档]:
        """按语义边界分割。"""
        splitter = SemanticChunker(
            text,
            embeddings=OpenAIEmbeddings(),
            breakpoint_threshold_type="percentile"
        )
        return splitter.create_documents([text])

    def custom_split(self, text: str, pattern: str = "


") -> list[文档]:
        """按自定义模式分割。"""
        chunks = text.split(pattern)
        return [文档(page_content=chunk) for chunk in chunks]

# 使用方法
chunker = 文档分块器(chunk_size=1000, chunk_overlap=200)

text = "长文档文本..."

chunks = chunker.recursive_split(text)
print(f"创建了 {len(chunks)} 块")

2.2 元数据提取

from typing import Dict, Any
import re
from datetime import datetime

class 元数据提取器:
    """从文档中提取元数据。"""

    @staticmethod
    def extract_basic(text: str) -> Dict[str, Any]:
        """提取基本元数据。"""
        return {
            "char_count": len(text),
            "word_count": len(text.split()),
            "sentence_count": len(re.split(r'[.!?]+', text)),
            "paragraph_count": len(text.split('

')),
            "extracted_at": datetime.now().isoformat()
        }

    @staticmethod
    def extract_urls(text: str) -> list[str]:
        """从文本中提取 URL。"""
        url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
        return re.findall(url_pattern, text)

    @staticmethod
    def extract_emails(text: str) -> list[str]:
        """从文本中提取电子邮件。"""
        email_pattern = r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
        return re.findall(email_pattern, text)

    @staticmethod
    def extract_dates(text: str) -> list[str]:
        """从文本中提取日期。"""
        date_patterns = [
            r'\d{4}-\d{2}-\d{2}',  # YYYY-MM-DD
            r'\d{2}/\d{2}/\d{4}',  # MM/DD/YYYY
            r'\d{1,2} [A-Za-z]+ \d{4}',  # DD Month YYYY
        ]
        dates = []
        for pattern in date_patterns:
            dates.extend(re.findall(pattern, text))
        return dates

    @staticmethod
    def extract_numbers(text: str) -> list[float]:
        """从文本中提取数字。"""
        number_pattern = r'-?\d+\.?\d*'
        numbers = re.findall(number_pattern, text)
        return [float(num) for num in numbers]

    @staticmethod
    def extract_entities(text: str) -> Dict[str, Any]:
        """使用 NLP 提取命名实体。"""
        try:
            import spacy
            nlp = spacy.load("en_core_web_sm")
            doc = nlp(text)

            entities = {
                "persons": [],
                "organizations": [],
                "locations": [],
                "dates": []
            }

            for ent in doc.ents:
                if ent.label_ == "PERSON":
                    entities["persons"].append(ent.text)
                elif ent.label_ == "ORG":
                    entities["organizations"].append(ent.text)
                elif ent.label_ == "GPE":
                    entities["locations"].append(ent.text)
                elif ent.label_ == "DATE":
                    entities["dates"].append(ent.text)

            return entities
        except ImportError:
            return {}

    def extract_all(self, text: str, source: str = "unknown") -> Dict[str, Any]:
        """提取所有元数据。"""
        return {
            "source": source,
            "basic": self.extract_basic(text),
            "urls": self.extract_urls(text),
            "emails": self.extract_emails(text),
            "dates": self.extract_dates(text),
            "numbers": self.extract_numbers(text),
            "entities": self.extract_entities(text)
        }

# 使用方法
extractor = 元数据提取器()

metadata = extractor.extract_all(document_text, source="document.pdf")
print(metadata)

3. 嵌入生成

from langchain.embeddings import (
    OpenAIEmbeddings,
    HuggingFaceEmbeddings,
    CohereEmbeddings,
    SentenceTransformerEmbeddings
)
from typing import List

class 嵌入生成器:
    """为文档生成嵌入。"""

    def __init__(self, model: str = "openai"):
        self.model = model
        self.embedding_model = self._get_embedding_model()

    def _get_embedding_model(self):
        """获取嵌入模型。"""
        if self.model == "openai":
            return OpenAIEmbeddings()
        elif self.model == "huggingface":
            return HuggingFaceEmbeddings(
                model_name="sentence-transformers/all-MiniLM-L6-v2"
            )
        elif self.model == "cohere":
            return CohereEmbeddings()
        elif self.model == "sentence-transformers":
            return SentenceTransformerEmbeddings(
                model_name="all-MiniLM-L6-v2"
            )
        else:
            raise ValueError(f"未知模型: {self.model}")

    def embed_documents(self, documents: List[str]) -> List[List[float]]:
        """为文档生成嵌入。"""
        return self.embedding_model.embed_documents(documents)

    def embed_query(self, query: str) -> List[float]:
        """为查询生成嵌入。"""
        return self.embedding_model.embed_query(query)

    def embed_batch(self, texts: List[str], batch_size: int = 32) -> List[List[float]]:
        """批量生成嵌入。"""
        embeddings = []

        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            batch_embeddings = self.embed_documents(batch)
            embeddings.extend(batch_embeddings)

        return embeddings

# 使用方法
generator = 嵌入生成器(model="openai")

# 嵌入文档
documents = ["文档 1 文本", "文档 2 文本", "文档 3 文本"]
document_embeddings = generator.embed_documents(documents)

# 嵌入查询
query = "法国的首都是哪里?"
query_embedding = generator.embed_query(query)

4. 向量存储和检索

4.1 向量存储设置

from langchain.vectorstores import (
    FAISS,
    Chroma,
    Pinecone,
    Weaviate
)
from langchain.docstore.document import 文档
from typing import List, Optional

class 向量存储管理器:
    """管理向量存储和检索。"""

    def __init__(self, vector_store_type: str = "faiss"):
        self.vector_store_type = vector_store_type
        self.vector_store = None

    def create_faiss_store(
        self,
        documents: List[文档],
        embeddings: List[List[float]]
    ):
        """创建 FAISS 向量存储。"""
        self.vector_store = FAISS.from_documents(documents, embeddings)
        return self.vector_store

    def create_chroma_store(
        self,
        documents: List[文档],
        embeddings: List[List[float]],
        collection_name: str = "documents"
    ):
        """创建 Chroma 向量存储。"""
        self.vector_store = Chroma.from_documents(
            documents=documents,
            embedding=embeddings,
            collection_name=collection_name
        )
        return self.vector_store

    def create_pinecone_store(
        self,
        documents: List[文档],
        embeddings: List[List[float]],
        index_name: str,
        environment: str = "us-west1-gcp"
    ):
        """创建 Pinecone 向量存储。"""
        import pinecone

        # 初始化 Pinecone
        pinecone.init(api_key="your-api-key", environment=environment)

        # 创建索引
        if index_name not in pinecone.list_indexes():
            pinecone.create_index(
                name=index_name,
                dimension=len(embeddings[0]),
                metric="cosine"
            )

        self.vector_store = Pinecone.from_documents(
            documents=documents,
            embedding=embeddings,
            index_name=index_name
        )
        return self.vector_store

    def similarity_search(
        self,
        query: str,
        k: int = 4,
        filter: Optional[dict] = None
    ) -> List[文档]:
        """执行相似性搜索。"""
        return self.vector_store.similarity_search(query, k=k, filter=filter)

    def similarity_search_with_score(
        self,
        query: str,
        k: int = 4,
        filter: Optional[dict] = None
    ) -> List[tuple]:
        """执行带分数的相似性搜索。"""
        return self.vector_store.similarity_search_with_score(
            query, k=k, filter=filter
        )

    def max_marginal_relevance_search(
        self,
        query: str,
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5
    ) -> List[文档]:
        """MMR 搜索以获得多样化结果。"""
        return self.vector_store.max_marginal_relevance_search(
            query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult
        )

# 使用方法
manager = 向量存储管理器(vector_store_type="faiss")

# 创建向量存储
documents = [
    文档(page_content="巴黎是法国的首都。"),
    文档(page_content="伦敦是英国的首都。"),
    文档(page_content="柏林是德国的首都。")
]
embeddings = generator.embed_documents([doc.page_content for doc in documents])

vector_store = manager.create_faiss_store(documents, embeddings)

# 搜索
results = manager.similarity_search("法国首都", k=2)
for doc in results:
    print(doc.page_content)

5. 检索策略

5.1 语义搜索

from langchain.retrievers import VectorStoreRetriever

class 语义检索器:
    """语义搜索检索器。"""

    def __init__(self, vector_store, search_kwargs: dict = None):
        self.vector_store = vector_store
        self.search_kwargs = search_kwargs or {"k": 4}

    def get_relevant_documents(self, query: str) -> List[文档]:
        """使用语义搜索获取相关文档。"""
        return self.vector_store.similarity_search(
            query,
            **self.search_kwargs
        )

    async def aget_relevant_documents(self, query: str) -> List[文档]:
        """异步获取相关文档。"""
        return await self.vector_store.asimilarity_search(
            query,
            **self.search_kwargs
        )

# 使用方法
retriever = 语义检索器(vector_store, search_kwargs={"k": 3})
results = retriever.get_relevant_documents("法国首都")

5.2 混合搜索

from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever as CommunityBM25Retriever

class 混合检索器:
    """结合语义和关键词搜索的混合搜索。"""

    def __init__(self, vector_store, documents: List[文档]):
        self.vector_store = vector_store
        self.documents = documents

        # 创建语义检索器
        self.semantic_retriever = VectorStoreRetriever(
            vector_store=vector_store,
            search_kwargs={"k": 5}
        )

        # 创建关键词检索器
        self.keyword_retriever = BM25Retriever.from_documents(documents)

        # 创建集成检索器
        self.ensemble_retriever = EnsembleRetriever(
            retrievers=[self.semantic_retriever, self.keyword_retriever],
            weights=[0.7, 0.3]  # 70% 语义,30% 关键词
        )

    def get_relevant_documents(self, query: str) -> List[文档]:
        """使用混合搜索获取相关文档。"""
        return self.ensemble_retriever.get_relevant_documents(query)

# 使用方法
hybrid_retriever = 混合检索器(vector_store, documents)
results = hybrid_retriever.get_relevant_documents("法国首都")

5.3 重新排名

from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CohereRerank
from typing import List, Callable

class 重新排名检索器:
    """具有重新排名功能的检索器。"""

    def __init__(self, base_retriever, top_n: int = 10):
        self.base_retriever = base_retriever
        self.top_n = top_n

    def cross_encoder_rerank(
        self,
        query: str,
        documents: List[文档],
        cross_encoder
    ) -> List[文档]:
        """使用交叉编码器重新排名文档。"""
        # 为每个文档打分
        scores = []
        for doc in documents:
            score = cross_encoder.predict(
                query,
                doc.page_content
            )
            scores.append(score)

        # 按分数排序
        sorted_docs = sorted(
            zip(documents, scores),
            key=lambda x: x[1],
            reverse=True
        )

        # 返回前 N
        return [doc for doc, score in sorted_docs[:self.top_n]]

    def get_relevant_documents(self, query: str) -> List[文档]:
        """获取重新排名的相关文档。"""
        # 获取初始结果
        initial_docs = self.base_retriever.get_relevant_documents(query)

        # 重新排名(示例使用 Cohere)
        try:
            from langchain_community.llms import Cohere
            llm = Cohere(cohere_api_key="your-api-key")
            compressor = CohereRerank(top_n=self.top_n)

            retriever = ContextualCompressionRetriever(
                base_compressor=compressor,
                base_retriever=self.base_retriever
            )

            return retriever.get_relevant_documents(query)
        except ImportError:
            return initial_docs

# 使用方法
reranker = 重新排名检索器(vector_store, top_n=3)
results = reranker.get_relevant_documents("法国首都")

6. 提示构建

from langchain.prompts import PromptTemplate, ChatPromptTemplate
from typing import List, Dict

class RAG提示构建器:
    """构建特定于 RAG 的提示。"""

    def __init__(self):
        self.templates = {}

    def basic_rag_prompt(self) -> PromptTemplate:
        """基本 RAG 提示模板。"""
        template = """
使用以下上下文片段回答末尾的问题。
如果你不知道答案,就说你不知道,不要试图编造答案。

上下文:
{context}

问题:{question}

答案:
"""
        return PromptTemplate(
            template=template,
            input_variables=["context", "question"]
        )

    def conversational_rag_prompt(self) -> ChatPromptTemplate:
        """会话式 RAG 提示模板。"""
        template = """
你是一个乐于助人的助手。使用以下上下文片段回答末尾的问题。
如果你不知道答案,就说你不知道,不要试图编造答案。

上下文:
{context}

聊天历史:
{chat_history}

问题:{question}

答案:
"""
        return ChatPromptTemplate.from_template(template)

    def multi_query_rag_prompt(self) -> PromptTemplate:
        """多查询 RAG 提示模板。"""
        template = """
使用以下上下文片段回答末尾的问题。
如果你不知道答案,就说你不知道,不要试图编造答案。

上下文:
{context}

问题:
{questions}

答案:
"""
        return PromptTemplate(
            template=template,
            input_variables=["context", "questions"]
        )

    def build_prompt(
        self,
        context: str,
        query: str,
        prompt_type: str = "basic"
    ) -> str:
        """用上下文和查询构建提示。"""
        if prompt_type == "basic":
            prompt = self.basic_rag_prompt()
        elif prompt_type == "conversational":
            prompt = self.conversational_rag_prompt()
        elif prompt_type == "multi_query":
            prompt = self.multi_query_rag_prompt()
        else:
            raise ValueError(f"未知提示类型: {prompt_type}")

        return prompt.format(context=context, question=query)

    def build_prompt_with_sources(
        self,
        context: List[Dict],
        query: str
    ) -> str:
        """用源引用构建提示。"""
        context_str = "

".join([
            f"来源 {i+1}: {doc['page_content']}"
            for i, doc in enumerate(context)
        ])

        template = """
使用以下上下文片段回答末尾的问题。在你的答案中包含来源编号。

上下文:
{context}

问题:{question}

答案:
"""
        prompt = PromptTemplate(
            template=template,
            input_variables=["context", "question"]
        )

        return prompt.format(context=context_str, question=query)

# 使用方法
prompt_builder = RAG提示构建器()

# 基本 RAG 提示
prompt = prompt_builder.build_prompt(
    context="巴黎是法国的首都。",
    query="法国的首都是哪里?",
    prompt_type="basic"
)

print(prompt)

7. 响应生成

from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from typing import List, Dict

class RAG生成器:
    """使用 RAG 生成响应。"""

    def __init__(self, llm, retriever):
        self.llm = llm
        self.retriever = retriever

    def create_basic_chain(self):
        """创建基本 RAG 链。"""
        chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="stuff",
            retriever=self.retriever,
            return_source_documents=True
        )
        return chain

    def create_conversational_chain(self):
        """创建会话式 RAG 链。"""
        chain = ConversationalRetrievalChain.from_llm(
            llm=self.llm,
            retriever=self.retriever,
            return_source_documents=True
        )
        return chain

    def create_map_reduce_chain(self):
        """创建 map-reduce RAG 链。"""
        chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="map_reduce",
            retriever=self.retriever,
            return_source_documents=True
        )
        return chain

    def create_refine_chain(self):
        """创建细化 RAG 链。"""
        chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="refine",
            retriever=self.retriever,
            return_source_documents=True
        )
        return chain

    def query(self, query: str, chain_type: str = "basic") -> Dict:
        """查询 RAG 系统。"""
        if chain_type == "basic":
            chain = self.create_basic_chain()
        elif chain_type == "conversational":
            chain = self.create_conversational_chain()
        elif chain_type == "map_reduce":
            chain = self.create_map_reduce_chain()
        elif chain_type == "refine":
            chain = self.create_refine_chain()
        else:
            raise ValueError(f"未知链类型: {chain_type}")

        result = chain({"query": query})

        return {
            "answer": result["result"],
            "source_documents": result.get("source_documents", [])
        }

    def query_with_sources(self, query: str) -> Dict:
        """查询并引用源。"""
        chain = self.create_basic_chain()
        result = chain({"query": query})

        # 格式化源
        sources = []
        for doc in result.get("source_documents", []):
            sources.append({
                "content": doc.page_content,
                "metadata": doc.metadata
            })

        return {
            "answer": result["result"],
            "sources": sources
        }

# 使用方法
llm = ChatOpenAI(model_name="gpt-4")
retriever = 语义检索器(vector_store)

generator = RAG生成器(llm, retriever)

# 查询
result = generator.query("法国的首都是哪里?")
print(f"答案: {result['answer']}")
print(f"来源: {len(result['source_documents'])}")

8. 上下文窗口管理

from typing import List
from langchain.docstore.document import 文档

class 上下文窗口管理器:
    """管理 RAG 的上下文窗口。"""

    def __init__(
        self,
        max_context_length: int = 4000,
        context_overlap: int = 200
    ):
        self.max_context_length = max_context_length
        self.context_overlap = context_overlap

    def fit_documents(
        self,
        documents: List[文档],
        query: str
    ) -> List[文档]:
        """将文档适配到上下文窗口。"""
        # 获取文档长度
        doc_lengths = [len(doc.page_content) for doc in documents]

        # 选择适合的文档
        selected_docs = []
        total_length = 0

        for doc, length in zip(documents, doc_lengths):
            if total_length + length > self.max_context_length:
                break
            selected_docs.append(doc)
            total_length += length

        return selected_docs

    def sliding_window(
        self,
        documents: List[文档],
        window_size: int = 3
    ) -> List[List[文档]]:
        """创建文档的滑动窗口。"""
        windows = []

        for i in range(0, len(documents), window_size):
            window = documents[i:i + window_size]
            windows.append(window)

        return windows

    def recursive_context_selection(
        self,
        documents: List[文档],
        query: str,
        max_depth: int = 3
    ) -> List[文档]:
        """递归选择上下文。"""
        # 根据查询对文档进行评分
        scored_docs = self._score_documents(documents, query)

        # 选择顶部文档
        top_docs = scored_docs[:max_depth]

        return [doc for doc, score in top_docs]

    def _score_documents(
        self,
        documents: List[文档],
        query: str
    ) -> List[tuple]:
        """根据相关性对文档打分。"""
        # 简单的基于关键词重叠的评分
        query_words = set(query.lower().split())
        scores = []

        for doc in documents:
            doc_words = set(doc.page_content.lower().split())
            overlap = len(query_words & doc_words)
            scores.append((doc, overlap))

        return sorted(scores, key=lambda x: x[1], reverse=True)

# 使用方法
context_manager = 上下文窗口管理器(
    max_context_length=4000,
    context_overlap=200
)

# 适配上下文文档
selected_docs = context_manager.fit_documents(documents, "法国首都")
print(f"为上下文选择了 {len(selected_docs)} 个文档")

9. RAG 评估

from typing import List, Dict
from langchain.evaluation import load_evaluator, EvaluatorType
import numpy as np

class RAG评估器:
    """评估 RAG 系统性能。"""

    def __init__(self, rag_chain, retriever):
        self.rag_chain = rag_chain
        self.retriever = retriever

    def evaluate_retrieval(
        self,
        questions: List[str],
        ground_truth: List[List[str]]
    ) -> Dict[str, float]:
        """评估检索性能。"""
        results = {
            "precision": [],
            "recall": [],
            "f1": []
        }

        for question, truth in zip(questions, ground_truth):
            # 检索文档
            retrieved = self.retriever.get_relevant_documents(question)
            retrieved_texts = [doc.page_content for doc in retrieved]

            # 计算指标
            true_positives = len(set(truth) & set(retrieved_texts))
            precision = true_positives / len(retrieved_texts) if retrieved_texts else 0
            recall = true_positives / len(truth) if truth else 0
            f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

            results["precision"].append(precision)
            results["recall"].append(recall)
            results["f1"].append(f1)

        return {
            "avg_precision": np.mean(results["precision"]),
            "avg_recall": np.mean(results["recall"]),
            "avg_f1": np.mean(results["f1"])
        }

    def evaluate_generation(
        self,
        questions: List[str],
        answers: List[str],
        ground_truth: List[str]
    ) -> Dict[str, float]:
        """评估生成性能。"""
        # 使用 LangChain 评估器
        evaluator = load_evaluator("exact_match")

        results = []
        for question, answer, truth in zip(questions, answers, ground_truth):
            result = evaluator.evaluate_strings(
                prediction=answer,
                reference=truth
            )
            results.append(result)

        return {
            "exact_match_rate": np.mean(results),
            "total_questions": len(results)
        }

    def evaluate_rag(
        self,
        questions: List[str],
        ground_truth_answers: List[str],
        ground_truth_contexts: List[List[str]]
    ) -> Dict[str, float]:
        """评估完整的 RAG 流程。"""
        # 评估检索
        retrieval_metrics = self.evaluate_retrieval(
            questions,
            ground_truth_contexts
        )

        # 评估生成
        generated_answers = []
        for question in questions:
            result = self.rag_chain({"query": question})
            generated_answers.append(result["result"])

        generation_metrics = self.evaluate_generation(
            questions,
            generated_answers,
            ground_truth_answers
        )

        return {
            "retrieval": retrieval_metrics,
            "generation": generation_metrics
        }

# 使用方法
evaluator = RAG评估器(rag_chain, retriever)

# 评估
questions = ["法国的首都是哪里?", "德国的首都是哪里?"]
ground_truth_answers = ["巴黎", "柏林"]
ground_truth_contexts = [
    ["巴黎是法国的首都。"],
    ["柏林是德国的首都。"]
]

metrics = evaluator.evaluate_rag(
    questions,
    ground_truth_answers,
    ground_truth_contexts
)

print(f"检索 F1: {metrics['retrieval']['avg_f1']:.3f}")
print(f"生成精确匹配率: {metrics['generation']['exact_match_rate']:.3f}")

10. 高级模式

10.1 多查询检索

from langchain.prompts import PromptTemplate
from typing import List

class 多查询检索器:
    """生成多个查询以获得更好的检索。"""

    def __init__(self, base_retriever, llm):
        self.base_retriever = base_retriever
        self.llm = llm

    def generate_queries(self, query: str, num_queries: int = 3) -> List[str]:
        """生成多个查询。"""
        prompt = PromptTemplate.from_template("""
为以下问题生成 {num_queries} 个不同的搜索查询。每个查询应在新行上。

问题:{question}

查询:
""")

        response = self.llm.predict(
            prompt.format(question=query, num_queries=num_queries)
        )

        queries = response.strip().split('
')
        return [q.strip() for q in queries if q.strip()]

    def get_relevant_documents(self, query: str) -> List:
        """使用多个查询检索。"""
        # 生成查询
        queries = self.generate_queries(query)

        # 每个查询检索
        all_documents = []
        for q in queries:
            docs = self.base_retriever.get_relevant_documents(q)
            all_documents.extend(docs)

        # 去重
        seen = set()
        unique_docs = []
        for doc in all_documents:
            doc_hash = hash(doc.page_content)
            if doc_hash not in seen:
                seen.add(doc_hash)
                unique_docs.append(doc)

        return unique_docs

# 使用方法
multi_query_retriever = 多查询检索器(vector_store, llm)
results = multi_query_retriever.get_relevant_documents("法国首都")

10.2 自查询

from typing import List

class 自查询检索器:
    """自查询以进行元数据过滤。"""

    def __init__(self, vector_store, llm):
        self.vector_store = vector_store
        self.llm = llm

    def extract_metadata_query(
        self,
        query: str
    ) -> Dict[str, str]:
        """从查询中提取元数据过滤器。"""
        prompt = f"""
分析以下查询并提取任何元数据过滤器。
将结果作为包含 "filters" 键的 JSON 对象返回,其中包含元数据条件。

查询:{query}

示例输出:
{{
    "filters": "year > 2020 AND category = 'technology'"
}}
"""

        response = self.llm.predict(prompt)

        # 解析 JSON 响应
        try:
            import json
            result = json.loads(response)
            return result.get("filters", {})
        except:
            return {}

    def get_relevant_documents(
        self,
        query: str
    ) -> List:
        """使用元数据过滤器检索文档。"""
        # 提取元数据过滤器
        filters = self.extract_metadata_query(query)

        # 带过滤器检索
        if filters:
            return self.vector_store.similarity_search(
                query,
                k=4,
                filter=filters
            )
        else:
            return self.vector_store.similarity_search(query, k=4)

# 使用方法
self_query_retriever = 自查询检索器(vector_store, llm)
results = self_query_retriever.get_relevant_documents("近期技术文档")

10.3 父文档检索

from typing import List, Dict
from langchain.docstore.document import 文档

class 父文档检索器:
    """检索父文档以获得更好的上下文。"""

    def __init__(self, documents: List[文档], child_splitter):
        self.documents = documents
        self.child_splitter = child_splitter

    def create_child_documents(self) -> List[文档]:
        """从父文档创建子文档。"""
        child_docs = []

        for i, parent_doc in enumerate(self.documents):
            # 分割父文档
            child_chunks = self.child_splitter.split_documents([parent_doc])

            # 添加父文档引用
            for child_doc in child_chunks:
                child_doc.metadata["parent_id"] = i
                child_docs.append(child_doc)

        return child_docs

    def get_relevant_documents(
        self,
        query: str,
        k: int = 4
    ) -> List[文档]:
        """检索父文档。"""
        # 获取子文档
        child_docs = self.create_child_documents()

        # 检索相关子文档
        # (假设向量存储已设置)
        relevant_children = self.vector_store.similarity_search(query, k=k)

        # 获取唯一的父文档
        parent_ids = set()
        parent_docs = []

        for child in relevant_children:
            parent_id = child.metadata.get("parent_id")
            if parent_id not in parent_ids:
                parent_ids.add(parent_id)
                parent_docs.append(self.documents[parent_id])

        return parent_docs

# 使用方法
parent_retriever = 父文档检索器(documents, chunker)
results = parent_retriever.get_relevant_documents("法国首都")

11. 生产优化

11.1 缓存

from langchain.cache import (
    InMemoryCache,
    GPTCache,
    RedisCache,
    SQLAlchemyCache
)
from typing import Dict

class RAG缓存:
    """缓存 RAG 响应。"""

    def __init__(self, cache_type: str = "memory"):
        self.cache_type = cache_type
        self.cache = self._get_cache()

    def _get_cache(self):
        """获取缓存实例。"""
        if self.cache_type == "memory":
            return InMemoryCache()
        elif self.cache_type == "gpt":
            return GPTCache()
        elif self.cache_type == "redis":
            import redis
            return RedisCache(redis.Redis())
        elif self.cache_type == "sql":
            return SQLAlchemyCache(
                database_uri="sqlite:///cache.db"
            )
        else:
            raise ValueError(f"未知缓存类型: {self.cache_type}")

    def get(self, prompt: str) -> str:
        """获取缓存响应。"""
        return self.cache.lookup(prompt)

    def set(self, prompt: str, response: str):
        """缓存响应。"""
        self.cache.update(prompt, response)

    def clear(self):
        """清除缓存。"""
        self.cache.clear()

# 使用方法
cache = RAG缓存(cache_type="memory")

# 检查缓存
cached_response = cache.get("法国首都")
if cached_response:
    print(f"缓存: {cached_response}")
else:
    # 生成响应
    response = rag_chain({"query": "法国首都"})
    cache.set("法国首都", response["result"])

11.2 批量处理

from typing import List, Dict
from concurrent.futures import ThreadPoolExecutor, as_completed

class 批量RAG处理器:
    """高效处理多个 RAG 查询。"""

    def __init__(self, rag_chain, max_workers: int = 4):
        self.rag_chain = rag_chain
        self.max_workers = max_workers

    def process_single(self, query: str) -> Dict:
        """处理单个查询。"""
        result = self.rag_chain({"query": query})
        return {
            "query": query,
            "answer": result["result"],
            "sources": result.get("source_documents", [])
        }

    def process_batch(self, queries: List[str]) -> List[Dict]:
        """批量处理查询。"""
        results = []

        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            futures = {
                executor.submit(self.process_single, query): query
                for query in queries
            }

            for future in as_completed(futures):
                try:
                    result = future.result()
                    results.append(result)
                except Exception as e:
                    print(f"处理查询时出错: {e}")

        return results

    def process_batch_async(self, queries: List[str]) -> List[Dict]:
        """异步批量处理。"""
        import asyncio

        async def async_process(query):
            result = await self.rag_chain.acall({"query": query})
            return {
                "query": query,
                "answer": result["result"],
                "sources": result.get("source_documents", [])
            }

        async def process_all():
            tasks = [async_process(q) for q in queries]
            return await asyncio.gather(*tasks)

        return asyncio.run(process_all())

# 使用方法
processor = 批量RAG处理器(rag_chain, max_workers=4)

queries = [
    "法国的首都是哪里?",
    "德国的首都是哪里?",
    "意大利的首都是哪里?"
]

results = processor.process_batch(queries)

12. 常见陷阱

12.1 常见问题

# 1. 分块不佳
# 问题:块太小或太大
# 解决方案:为您的模型使用适当的块大小

# 2. 上下文不足
# 问题:检索到的文档中没有足够的上下文
# 解决方案:增加 k 或使用父文档检索

# 3. 检索不相关
# 问题:检索到的文档不包含答案
# 解决方案:使用多查询检索或重新排名

# 4. 幻觉
# 问题:LLM 生成错误信息
# 解决方案:添加源引用并验证答案

# 5. 性能缓慢
# 问题:检索缓慢
# 解决方案:使用高效的向量存储(FAISS)和缓存

class RAG验证器:
    """验证 RAG 系统常见问题。"""

    @staticmethod
    def validate_chunking(documents: List[文档]) -> Dict[str, bool]:
        """验证文档分块。"""
        issues = {
            "too_small": False,
            "too_large": False,
            "inconsistent": False
        }

        lengths = [len(doc.page_content) for doc in documents]
        avg_length = sum(lengths) / len(lengths)

        if avg_length < 100:
            issues["too_small"] = True
        if avg_length > 2000:
            issues["too_large"] = True

        # 检查一致性
        if lengths:
            std_dev = (sum((x - avg_length) ** 2 for x in lengths) / len(lengths)) ** 0.5
            if std_dev > avg_length * 0.5:
                issues["inconsistent"] = True

        return issues

    @staticmethod
    def validate_retrieval(
        query: str,
        retrieved_docs: List[文档],
        answer: str
    ) -> Dict[str, bool]:
        """验证检索质量。"""
        issues = {
            "no_relevant_docs": False,
            "answer_not_in_docs": False
        }

        if not retrieved_docs:
            issues["no_relevant_docs"] = True
            return issues

        # 检查答案是否在检索到的文档中
        answer_words = set(answer.lower().split())
        doc_words = set()
        for doc in retrieved_docs:
            doc_words.update(doc.page_content.lower().split())

        overlap = len(answer_words & doc_words) / len(answer_words) if answer_words else 0
        if overlap < 0.3:
            issues["answer_not_in_docs"] = True

        return issues

# 使用方法
validator = RAG验证器()

chunking_issues = validator.validate_chunking(documents)
retrieval_issues = validator.validate_retrieval(
    "法国首都",
    retrieved_docs,
    answer
)

print(f"分块问题: {chunking_issues}")
print(f"检索问题: {retrieval_issues}")

相关技能

其他资源