在本教程中,我们将使用LangGraph框架实现一个Corrective RAG(自校正检索增强生成)系统。该系统能够对检索到的文档进行自校正,并根据文档的相关性动态补充Web搜索,以生成更准确的回答。以下是详细的实现步骤。
首先,安装所需的Python包并设置API密钥。
python%%capture --no-stderr
%pip install -U langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python
import getpass
import os
def _set_env(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"{var}: ")
_set_env("OPENAI_API_KEY")
_set_env("TAVILY_API_KEY")
我们加载并索引3篇博客文章,以便后续检索。
pythonfrom langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
# 添加到向量数据库
vectorstore = Chroma.from_documents(
documents=doc_splits,
collection_name="rag-chroma",
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()
我们定义一个评分器,用于评估检索到的文档是否与问题相关。
pythonfrom langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
# 数据模型
class GradeDocuments(BaseModel):
"""用于评估检索到的文档是否相关的二元评分。"""
binary_score: str = Field(description="文档是否与问题相关,'yes' 或 'no'")
# LLM 配置
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)
# 提示模板
system = """你是一个评估检索到的文档与用户问题相关性的评分器。\n
如果文档包含与问题相关的关键词或语义,则评分为相关。\n
给出一个二元评分 'yes' 或 'no' 来表示文档是否与问题相关。"""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "检索到的文档: \n\n {document} \n\n 用户问题: {question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
我们定义一个生成器,用于根据检索到的文档生成回答。
pythonfrom langchain import hub
from langchain_core.output_parsers import StrOutputParser
# 提示模板
prompt = hub.pull("rlm/rag-prompt")
# LLM 配置
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
# 后处理函数
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# 生成链
rag_chain = prompt | llm | StrOutputParser()
我们定义一个重写器,用于优化问题以便进行Web搜索。
python# LLM 配置
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
# 提示模板
system = """你是一个问题重写器,将输入问题转换为更适合Web搜索的版本。\n
查看输入并尝试推理其潜在的语义意图/含义。"""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "初始问题: \n\n {question} \n 生成一个改进的问题。"),
]
)
question_rewriter = re_write_prompt | llm | StrOutputParser()
我们使用Tavily Search进行Web搜索。
pythonfrom langchain_community.tools.tavily_search import TavilySearchResults
web_search_tool = TavilySearchResults(k=3)
我们定义图的状态和各个节点。
pythonfrom typing import List, TypedDict
from langchain.schema import Document
class GraphState(TypedDict):
"""表示图的状态。"""
question: str
generation: str
web_search: str
documents: List[Document]
def retrieve(state):
"""检索文档。"""
print("---RETRIEVE---")
question = state["question"]
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
def generate(state):
"""生成回答。"""
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def grade_documents(state):
"""评估文档的相关性。"""
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]
filtered_docs = []
web_search = "No"
for d in documents:
score = retrieval_grader.invoke({"question": question, "document": d.page_content})
if score.binary_score == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
web_search = "Yes"
return {"documents": filtered_docs, "question": question, "web_search": web_search}
def transform_query(state):
"""重写问题。"""
print("---TRANSFORM QUERY---")
question = state["question"]
better_question = question_rewriter.invoke({"question": question})
return {"documents": state["documents"], "question": better_question}
def web_search(state):
"""进行Web搜索。"""
print("---WEB SEARCH---")
question = state["question"]
docs = web_search_tool.invoke({"query": question})
web_results = Document(page_content="\n".join([d["content"] for d in docs]))
return {"documents": [web_results], "question": question}
def decide_to_generate(state):
"""决定是否生成回答或重写问题。"""
print("---ASSESS GRADED DOCUMENTS---")
if state["web_search"] == "Yes":
print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---")
return "transform_query"
else:
print("---DECISION: GENERATE---")
return "generate"
我们编译图并运行它来生成回答。
pythonfrom langgraph.graph import END, StateGraph, START
workflow = StateGraph(GraphState)
# 定义节点
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search_node", web_search)
# 构建图
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "web_search_node")
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)
# 编译
app = workflow.compile()
# 运行
from pprint import pprint
inputs = {"question": "What are the types of agent memory?"}
for output in app.stream(inputs):
for key, value in output.items():
pprint(f"Node '{key}':")
pprint("\n---\n")
# 最终生成
pprint(value["generation"])
通过本教程,你可以掌握如何使用LangGraph实现一个Corrective RAG系统,并根据文档的相关性动态补充Web搜索以生成更准确的回答。希望这对你的学习和开发有所帮助!
py%%capture --no-stderr
%pip install -U langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python
import getpass
import os
def _set_env(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"{var}: ")
_set_env("OPENAI_API_KEY")
_set_env("TAVILY_API_KEY")
from typing import Dict, TypedDict
from langchain import hub
from langchain.retrievers import TavilySearchAPIRetriever
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langgraph.graph import END, Graph
# 创建索引
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
docs = loader.load()
vectorstore = Chroma.from_documents(docs, embedding=OpenAIEmbeddings())
retriever = vectorstore.as_retriever()
# 定义状态
class GraphState(TypedDict):
question: str
documents: list[Document]
generation: str
# 检索节点
def retrieve(state: GraphState):
print("---RETRIEVE---")
question = state["question"]
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
# 文档评分节点
def grade_documents(state: GraphState):
print("---GRADE DOCUMENTS---")
question = state["question"]
documents = state["documents"]
# 评分逻辑
filtered_docs = []
for d in documents:
# 假设评分逻辑为:文档内容包含问题关键词
if any(word in d.page_content.lower() for word in question.lower().split()):
filtered_docs.append(d)
return {"documents": filtered_docs, "question": question}
# 决定是否生成或补充Web搜索
def decide_to_generate(state: GraphState):
print("---DECIDE TO GENERATE---")
if len(state["documents"]) >= 1:
return "generate"
else:
return "transform_query"
# 转换查询节点
def transform_query(state: GraphState):
print("---TRANSFORM QUERY---")
question = state["question"]
# 假设将问题转换为更具体的查询
transformed_query = f"{question} (with additional context from web search)"
return {"question": transformed_query}
# Web搜索节点
def web_search_node(state: GraphState):
print("---WEB SEARCH---")
question = state["question"]
retriever = TavilySearchAPIRetriever(k=3)
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
# 生成节点
def generate(state: GraphState):
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
# 生成逻辑
prompt = ChatPromptTemplate.from_template("Answer the question based on the context:\n\n{context}\n\nQuestion: {question}")
model = ChatOpenAI(model="gpt-3.5-turbo")
chain = prompt | model | StrOutputParser()
context = "\n".join([d.page_content for d in documents])
generation = chain.invoke({"context": context, "question": question})
return {"generation": generation}
# 构建图
workflow = Graph()
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search_node", web_search_node)
workflow.add_node("generate", generate)
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "web_search_node")
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)
# 编译
app = workflow.compile()
# 运行
from pprint import pprint
inputs = {"question": "What are the types of agent memory?"}
for output in app.stream(inputs):
for key, value in output.items():
pprint(f"Node '{key}':")
pprint("\n---\n")
# 最终生成
pprint(value["generation"])
```完整代码
本文作者:yowayimono
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!