编辑
2025-02-11
后端
00
请注意,本文编写于 87 天前,最后修改于 87 天前,其中某些信息可能已经过时。

目录

使用LangGraph实现Corrective RAG(自校正检索增强生成)
1. 环境设置
2. 创建索引
3. 定义检索评分器
4. 定义生成器
5. 定义问题重写器
6. 定义Web搜索工具
7. 定义图状态和节点
8. 编译并运行图

使用LangGraph实现Corrective RAG(自校正检索增强生成)

在本教程中,我们将使用LangGraph框架实现一个Corrective RAG(自校正检索增强生成)系统。该系统能够对检索到的文档进行自校正,并根据文档的相关性动态补充Web搜索,以生成更准确的回答。以下是详细的实现步骤。


1. 环境设置

首先,安装所需的Python包并设置API密钥。

python
%%capture --no-stderr %pip install -U langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python import getpass import os def _set_env(var: str): if not os.environ.get(var): os.environ[var] = getpass.getpass(f"{var}: ") _set_env("OPENAI_API_KEY") _set_env("TAVILY_API_KEY")

2. 创建索引

我们加载并索引3篇博客文章,以便后续检索。

python
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import WebBaseLoader from langchain_community.vectorstores import Chroma from langchain_openai import OpenAIEmbeddings urls = [ "https://lilianweng.github.io/posts/2023-06-23-agent/", "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/", "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/", ] docs = [WebBaseLoader(url).load() for url in urls] docs_list = [item for sublist in docs for item in sublist] text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=250, chunk_overlap=0 ) doc_splits = text_splitter.split_documents(docs_list) # 添加到向量数据库 vectorstore = Chroma.from_documents( documents=doc_splits, collection_name="rag-chroma", embedding=OpenAIEmbeddings(), ) retriever = vectorstore.as_retriever()

3. 定义检索评分器

我们定义一个评分器,用于评估检索到的文档是否与问题相关。

python
from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field # 数据模型 class GradeDocuments(BaseModel): """用于评估检索到的文档是否相关的二元评分。""" binary_score: str = Field(description="文档是否与问题相关,'yes' 或 'no'") # LLM 配置 llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) structured_llm_grader = llm.with_structured_output(GradeDocuments) # 提示模板 system = """你是一个评估检索到的文档与用户问题相关性的评分器。\n 如果文档包含与问题相关的关键词或语义,则评分为相关。\n 给出一个二元评分 'yes' 或 'no' 来表示文档是否与问题相关。""" grade_prompt = ChatPromptTemplate.from_messages( [ ("system", system), ("human", "检索到的文档: \n\n {document} \n\n 用户问题: {question}"), ] ) retrieval_grader = grade_prompt | structured_llm_grader

4. 定义生成器

我们定义一个生成器,用于根据检索到的文档生成回答。

python
from langchain import hub from langchain_core.output_parsers import StrOutputParser # 提示模板 prompt = hub.pull("rlm/rag-prompt") # LLM 配置 llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) # 后处理函数 def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) # 生成链 rag_chain = prompt | llm | StrOutputParser()

5. 定义问题重写器

我们定义一个重写器,用于优化问题以便进行Web搜索。

python
# LLM 配置 llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) # 提示模板 system = """你是一个问题重写器,将输入问题转换为更适合Web搜索的版本。\n 查看输入并尝试推理其潜在的语义意图/含义。""" re_write_prompt = ChatPromptTemplate.from_messages( [ ("system", system), ("human", "初始问题: \n\n {question} \n 生成一个改进的问题。"), ] ) question_rewriter = re_write_prompt | llm | StrOutputParser()

6. 定义Web搜索工具

我们使用Tavily Search进行Web搜索。

python
from langchain_community.tools.tavily_search import TavilySearchResults web_search_tool = TavilySearchResults(k=3)

7. 定义图状态和节点

我们定义图的状态和各个节点。

python
from typing import List, TypedDict from langchain.schema import Document class GraphState(TypedDict): """表示图的状态。""" question: str generation: str web_search: str documents: List[Document] def retrieve(state): """检索文档。""" print("---RETRIEVE---") question = state["question"] documents = retriever.invoke(question) return {"documents": documents, "question": question} def generate(state): """生成回答。""" print("---GENERATE---") question = state["question"] documents = state["documents"] generation = rag_chain.invoke({"context": documents, "question": question}) return {"documents": documents, "question": question, "generation": generation} def grade_documents(state): """评估文档的相关性。""" print("---CHECK DOCUMENT RELEVANCE TO QUESTION---") question = state["question"] documents = state["documents"] filtered_docs = [] web_search = "No" for d in documents: score = retrieval_grader.invoke({"question": question, "document": d.page_content}) if score.binary_score == "yes": print("---GRADE: DOCUMENT RELEVANT---") filtered_docs.append(d) else: print("---GRADE: DOCUMENT NOT RELEVANT---") web_search = "Yes" return {"documents": filtered_docs, "question": question, "web_search": web_search} def transform_query(state): """重写问题。""" print("---TRANSFORM QUERY---") question = state["question"] better_question = question_rewriter.invoke({"question": question}) return {"documents": state["documents"], "question": better_question} def web_search(state): """进行Web搜索。""" print("---WEB SEARCH---") question = state["question"] docs = web_search_tool.invoke({"query": question}) web_results = Document(page_content="\n".join([d["content"] for d in docs])) return {"documents": [web_results], "question": question} def decide_to_generate(state): """决定是否生成回答或重写问题。""" print("---ASSESS GRADED DOCUMENTS---") if state["web_search"] == "Yes": print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---") return "transform_query" else: print("---DECISION: GENERATE---") return "generate"

8. 编译并运行图

我们编译图并运行它来生成回答。

python
from langgraph.graph import END, StateGraph, START workflow = StateGraph(GraphState) # 定义节点 workflow.add_node("retrieve", retrieve) workflow.add_node("grade_documents", grade_documents) workflow.add_node("generate", generate) workflow.add_node("transform_query", transform_query) workflow.add_node("web_search_node", web_search) # 构建图 workflow.add_edge(START, "retrieve") workflow.add_edge("retrieve", "grade_documents") workflow.add_conditional_edges( "grade_documents", decide_to_generate, { "transform_query": "transform_query", "generate": "generate", }, ) workflow.add_edge("transform_query", "web_search_node") workflow.add_edge("web_search_node", "generate") workflow.add_edge("generate", END) # 编译 app = workflow.compile() # 运行 from pprint import pprint inputs = {"question": "What are the types of agent memory?"} for output in app.stream(inputs): for key, value in output.items(): pprint(f"Node '{key}':") pprint("\n---\n") # 最终生成 pprint(value["generation"])

通过本教程,你可以掌握如何使用LangGraph实现一个Corrective RAG系统,并根据文档的相关性动态补充Web搜索以生成更准确的回答。希望这对你的学习和开发有所帮助!

py
%%capture --no-stderr %pip install -U langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python import getpass import os def _set_env(var: str): if not os.environ.get(var): os.environ[var] = getpass.getpass(f"{var}: ") _set_env("OPENAI_API_KEY") _set_env("TAVILY_API_KEY") from typing import Dict, TypedDict from langchain import hub from langchain.retrievers import TavilySearchAPIRetriever from langchain_community.document_loaders import WebBaseLoader from langchain_community.vectorstores import Chroma from langchain_core.documents import Document from langchain_core.messages import BaseMessage from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langgraph.graph import END, Graph # 创建索引 loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/") docs = loader.load() vectorstore = Chroma.from_documents(docs, embedding=OpenAIEmbeddings()) retriever = vectorstore.as_retriever() # 定义状态 class GraphState(TypedDict): question: str documents: list[Document] generation: str # 检索节点 def retrieve(state: GraphState): print("---RETRIEVE---") question = state["question"] documents = retriever.invoke(question) return {"documents": documents, "question": question} # 文档评分节点 def grade_documents(state: GraphState): print("---GRADE DOCUMENTS---") question = state["question"] documents = state["documents"] # 评分逻辑 filtered_docs = [] for d in documents: # 假设评分逻辑为:文档内容包含问题关键词 if any(word in d.page_content.lower() for word in question.lower().split()): filtered_docs.append(d) return {"documents": filtered_docs, "question": question} # 决定是否生成或补充Web搜索 def decide_to_generate(state: GraphState): print("---DECIDE TO GENERATE---") if len(state["documents"]) >= 1: return "generate" else: return "transform_query" # 转换查询节点 def transform_query(state: GraphState): print("---TRANSFORM QUERY---") question = state["question"] # 假设将问题转换为更具体的查询 transformed_query = f"{question} (with additional context from web search)" return {"question": transformed_query} # Web搜索节点 def web_search_node(state: GraphState): print("---WEB SEARCH---") question = state["question"] retriever = TavilySearchAPIRetriever(k=3) documents = retriever.invoke(question) return {"documents": documents, "question": question} # 生成节点 def generate(state: GraphState): print("---GENERATE---") question = state["question"] documents = state["documents"] # 生成逻辑 prompt = ChatPromptTemplate.from_template("Answer the question based on the context:\n\n{context}\n\nQuestion: {question}") model = ChatOpenAI(model="gpt-3.5-turbo") chain = prompt | model | StrOutputParser() context = "\n".join([d.page_content for d in documents]) generation = chain.invoke({"context": context, "question": question}) return {"generation": generation} # 构建图 workflow = Graph() workflow.add_node("retrieve", retrieve) workflow.add_node("grade_documents", grade_documents) workflow.add_node("transform_query", transform_query) workflow.add_node("web_search_node", web_search_node) workflow.add_node("generate", generate) workflow.add_edge(START, "retrieve") workflow.add_edge("retrieve", "grade_documents") workflow.add_conditional_edges( "grade_documents", decide_to_generate, { "transform_query": "transform_query", "generate": "generate", }, ) workflow.add_edge("transform_query", "web_search_node") workflow.add_edge("web_search_node", "generate") workflow.add_edge("generate", END) # 编译 app = workflow.compile() # 运行 from pprint import pprint inputs = {"question": "What are the types of agent memory?"} for output in app.stream(inputs): for key, value in output.items(): pprint(f"Node '{key}':") pprint("\n---\n") # 最终生成 pprint(value["generation"]) ```完整代码

本文作者:yowayimono

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!