编辑
2025-02-11
后端
00
请注意,本文编写于 87 天前,最后修改于 87 天前,其中某些信息可能已经过时。

目录

教程:使用LangGraph实现Agentic RAG(检索增强生成)
1. 环境设置
2. 创建检索器
3. 创建检索工具
4. 定义Agent状态
5. 定义节点和边
5.1 评估文档相关性
5.2 调用Agent
5.3 重写问题
5.4 生成回答
6. 构建图
7. 运行图
完整代码

教程:使用LangGraph实现Agentic RAG(检索增强生成)

在本教程中,我们将使用LangGraph框架实现一个Agentic RAG(检索增强生成)系统。该系统能够根据用户的问题决定是否从索引中检索信息,并生成相应的回答。以下是详细的实现步骤。


1. 环境设置

首先,安装所需的Python包并设置API密钥。

python
%%capture --no-stderr %pip install -U --quiet langchain-community tiktoken langchain-openai langchainhub chromadb langchain langgraph langchain-text-splitters beautifulsoup4 import getpass import os def _set_env(key: str): if key not in os.environ: os.environ[key] = getpass.getpass(f"{key}:") _set_env("OPENAI_API_KEY")

2. 创建检索器

我们首先加载3篇博客文章,并将其索引到向量数据库中。

python
from langchain_community.document_loaders import WebBaseLoader from langchain_community.vectorstores import Chroma from langchain_openai import OpenAIEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter urls = [ "https://lilianweng.github.io/posts/2023-06-23-agent/", "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/", "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/", ] docs = [WebBaseLoader(url).load() for url in urls] docs_list = [item for sublist in docs for item in sublist] text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=100, chunk_overlap=50 ) doc_splits = text_splitter.split_documents(docs_list) # 添加到向量数据库 vectorstore = Chroma.from_documents( documents=doc_splits, collection_name="rag-chroma", embedding=OpenAIEmbeddings(), ) retriever = vectorstore.as_retriever()

3. 创建检索工具

我们将检索器封装为一个工具,供Agent调用。

python
from langchain.tools.retriever import create_retriever_tool retriever_tool = create_retriever_tool( retriever, "retrieve_blog_posts", "Search and return information about Lilian Weng blog posts on LLM agents, prompt engineering, and adversarial attacks on LLMs.", ) tools = [retriever_tool]

4. 定义Agent状态

我们定义一个状态对象,用于在图中传递消息。

python
from typing import Annotated, Sequence from typing_extensions import TypedDict from langchain_core.messages import BaseMessage from langgraph.graph.message import add_messages class AgentState(TypedDict): messages: Annotated[Sequence[BaseMessage], add_messages]

5. 定义节点和边

我们定义图中的节点和边,包括Agent、检索、重写问题和生成回答等功能。

5.1 评估文档相关性

python
from typing import Literal from pydantic import BaseModel, Field def grade_documents(state) -> Literal["generate", "rewrite"]: class grade(BaseModel): binary_score: str = Field(description="Relevance score 'yes' or 'no'") model = ChatOpenAI(temperature=0, model="gpt-4o", streaming=True) llm_with_tool = model.with_structured_output(grade) prompt = PromptTemplate( template="""You are a grader assessing relevance of a retrieved document to a user question. \n Here is the retrieved document: \n\n {context} \n\n Here is the user question: {question} \n If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""", input_variables=["context", "question"], ) chain = prompt | llm_with_tool messages = state["messages"] last_message = messages[-1] question = messages[0].content docs = last_message.content scored_result = chain.invoke({"question": question, "context": docs}) score = scored_result.binary_score if score == "yes": return "generate" else: return "rewrite"

5.2 调用Agent

python
def agent(state): messages = state["messages"] model = ChatOpenAI(temperature=0, streaming=True, model="gpt-4-turbo") model = model.bind_tools(tools) response = model.invoke(messages) return {"messages": [response]}

5.3 重写问题

python
def rewrite(state): messages = state["messages"] question = messages[0].content msg = [ HumanMessage( content=f""" \n Look at the input and try to reason about the underlying semantic intent / meaning. \n Here is the initial question: \n ------- \n {question} \n ------- \n Formulate an improved question: """, ) ] model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True) response = model.invoke(msg) return {"messages": [response]}

5.4 生成回答

python
def generate(state): messages = state["messages"] question = messages[0].content last_message = messages[-1] docs = last_message.content prompt = hub.pull("rlm/rag-prompt") llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, streaming=True) def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) rag_chain = prompt | llm | StrOutputParser() response = rag_chain.invoke({"context": docs, "question": question}) return {"messages": [response]}

6. 构建图

我们将节点和边组合成一个完整的图。

python
from langgraph.graph import END, StateGraph, START from langgraph.prebuilt import ToolNode workflow = StateGraph(AgentState) workflow.add_node("agent", agent) retrieve = ToolNode([retriever_tool]) workflow.add_node("retrieve", retrieve) workflow.add_node("rewrite", rewrite) workflow.add_node("generate", generate) workflow.add_edge(START, "agent") workflow.add_conditional_edges( "agent", tools_condition, { "tools": "retrieve", END: END, }, ) workflow.add_conditional_edges( "retrieve", grade_documents, ) workflow.add_edge("generate", END) workflow.add_edge("rewrite", "agent") graph = workflow.compile()

7. 运行图

最后,我们运行图并查看输出。

python
import pprint inputs = { "messages": [ ("user", "What does Lilian Weng say about the types of agent memory?"), ] } for output in graph.stream(inputs): for key, value in output.items(): pprint.pprint(f"Output from node '{key}':") pprint.pprint("---") pprint.pprint(value, indent=2, width=80, depth=None) pprint.pprint("\n---\n")

完整代码

以下是完整的代码实现:

python
# 环境设置 %%capture --no-stderr %pip install -U --quiet langchain-community tiktoken langchain-openai langchainhub chromadb langchain langgraph langchain-text-splitters beautifulsoup4 import getpass import os def _set_env(key: str): if key not in os.environ: os.environ[key] = getpass.getpass(f"{key}:") _set_env("OPENAI_API_KEY") # 创建检索器 from langchain_community.document_loaders import WebBaseLoader from langchain_community.vectorstores import Chroma from langchain_openai import OpenAIEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter urls = [ "https://lilianweng.github.io/posts/2023-06-23-agent/", "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/", "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/", ] docs = [WebBaseLoader(url).load() for url in urls] docs_list = [item for sublist in docs for item in sublist] text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=100, chunk_overlap=50 ) doc_splits = text_splitter.split_documents(docs_list) vectorstore = Chroma.from_documents( documents=doc_splits, collection_name="rag-chroma", embedding=OpenAIEmbeddings(), ) retriever = vectorstore.as_retriever() # 创建检索工具 from langchain.tools.retriever import create_retriever_tool retriever_tool = create_retriever_tool( retriever, "retrieve_blog_posts", "Search and return information about Lilian Weng blog posts on LLM agents, prompt engineering, and adversarial attacks on LLMs.", ) tools = [retriever_tool] # 定义Agent状态 from typing import Annotated, Sequence from typing_extensions import TypedDict from langchain_core.messages import BaseMessage from langgraph.graph.message import add_messages class AgentState(TypedDict): messages: Annotated[Sequence[BaseMessage], add_messages] # 定义节点和边 from typing import Literal from pydantic import BaseModel, Field def grade_documents(state) -> Literal["generate", "rewrite"]: class grade(BaseModel): binary_score: str = Field(description="Relevance score 'yes' or 'no'") model = ChatOpenAI(temperature=0, model="gpt-4o", streaming=True) llm_with_tool = model.with_structured_output(grade) prompt = PromptTemplate( template="""You are a grader assessing relevance of a retrieved document to a user question. \n Here is the retrieved document: \n\n {context} \n\n Here is the user question: {question} \n If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""", input_variables=["context", "question"], ) chain = prompt | llm_with_tool messages = state["messages"] last_message = messages[-1] question = messages[0].content docs = last_message.content scored_result = chain.invoke({"question": question, "context": docs}) score = scored_result.binary_score if score == "yes": return "generate" else: return "rewrite" def agent(state): messages = state["messages"] model = ChatOpenAI(temperature=0, streaming=True, model="gpt-4-turbo") model = model.bind_tools(tools) response = model.invoke(messages) return {"messages": [response]} def rewrite(state): messages = state["messages"] question = messages[0].content msg = [ HumanMessage( content=f""" \n Look at the input and try to reason about the underlying semantic intent / meaning. \n Here is the initial question: \n ------- \n {question} \n ------- \n Formulate an improved question: """, ) ] model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True) response = model.invoke(msg) return {"messages": [response]} def generate(state): messages = state["messages"] question = messages[0].content last_message = messages[-1] docs = last_message.content prompt = hub.pull("rlm/rag-prompt") llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0, streaming=True) def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) rag_chain = prompt | llm | StrOutputParser() response = rag_chain.invoke({"context": docs, "question": question}) return {"messages": [response]} # 构建图 from langgraph.graph import END, StateGraph, START from langgraph.prebuilt import ToolNode workflow = StateGraph(AgentState) workflow.add_node("agent", agent) retrieve = ToolNode([retriever_tool]) workflow.add_node("retrieve", retrieve) workflow.add_node("rewrite", rewrite) workflow.add_node("generate", generate) workflow.add_edge(START, "agent") workflow.add_conditional_edges( "agent", tools_condition, { "tools": "retrieve", END: END, }, ) workflow.add_conditional_edges( "retrieve", grade_documents, ) workflow.add_edge("generate", END) workflow.add_edge("rewrite", "agent") graph = workflow.compile() # 运行图 import pprint inputs = { "messages": [ ("user", "What does Lilian Weng say about the types of agent memory?"), ] } for output in graph.stream(inputs): for key, value in output.items(): pprint.pprint(f"Output from node '{key}':") pprint.pprint("---") pprint.pprint(value, indent=2, width=80, depth=None) pprint.pprint("\n---\n")

通过本教程,你可以掌握如何使用LangGraph实现一个Agentic RAG系统,并根据用户问题动态决定是否检索信息并生成回答。希望这对你的学习和开发有所帮助!

本文作者:yowayimono

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!