编辑
2025-02-11
后端
00
请注意,本文编写于 87 天前,最后修改于 87 天前,其中某些信息可能已经过时。

目录

SQL 数据库交互智能体教程
1. 环境准备与依赖安装
2. 配置数据库
3. 定义工具
4. 查询检查
5. 构建智能体工作流
6. 评估智能体
总结

SQL 数据库交互智能体教程

在本教程中,我们将逐步构建一个能够回答有关 SQL 数据库问题的智能体。该智能体的主要功能包括:

  1. 从数据库中获取可用表
  2. 确定哪些表与问题相关
  3. 获取相关表的 DDL(数据定义语言)
  4. 根据问题和 DDL 信息生成查询
  5. 使用 LLM 检查查询中的常见错误
  6. 执行查询并返回结果
  7. 修正数据库引擎返回的错误,直到查询成功
  8. 根据结果生成响应

以下是实现该智能体的详细步骤。


1. 环境准备与依赖安装

首先,安装所需的库:

bash
pip install -U langgraph langchain_openai langchain_community

设置 OpenAI API 密钥:

python
import getpass import os def _set_env(key: str): if key not in os.environ: os.environ[key] = getpass.getpass(f"{key}:") _set_env("OPENAI_API_KEY")

2. 配置数据库

我们将使用 SQLite 数据库,并加载 chinook 示例数据库。该数据库代表一个数字媒体商店。

python
import requests url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db" response = requests.get(url) if response.status_code == 200: with open("Chinook.db", "wb") as file: file.write(response.content) print("文件已下载并保存为 Chinook.db") else: print(f"文件下载失败,状态码:{response.status_code}")

使用 langchain_community 中的 SQLDatabase 包装器与数据库交互:

python
from langchain_community.utilities import SQLDatabase db = SQLDatabase.from_uri("sqlite:///Chinook.db") print(db.dialect) print(db.get_usable_table_names()) db.run("SELECT * FROM Artist LIMIT 10;")

3. 定义工具

我们将定义以下工具供智能体使用:

  1. list_tables_tool:从数据库中获取可用表
  2. get_schema_tool:获取表的 DDL
  3. db_query_tool:执行查询并返回结果,如果查询失败则返回错误信息
python
from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_openai import ChatOpenAI toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(model="gpt-4o")) tools = toolkit.get_tools() list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables") get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema") print(list_tables_tool.invoke("")) print(get_schema_tool.invoke("Artist"))

手动定义 db_query_tool

python
from langchain_core.tools import tool @tool def db_query_tool(query: str) -> str: """ 对数据库执行 SQL 查询并返回结果。 如果查询不正确,则返回错误信息。 """ result = db.run_no_throw(query) if not result: return "错误:查询失败。请重写查询并重试。" return result print(db_query_tool.invoke("SELECT * FROM Artist LIMIT 10;"))

4. 查询检查

使用 LLM 检查查询中的常见错误:

python
from langchain_core.prompts import ChatPromptTemplate query_check_system = """你是一位 SQL 专家,注重细节。 请检查以下 SQLite 查询中的常见错误,包括: - 在 NULL 值上使用 NOT IN - 应使用 UNION ALL 时使用了 UNION - 在独占范围中使用 BETWEEN - 谓词中的数据类型不匹配 - 正确引用标识符 - 函数的参数数量是否正确 - 转换为正确的数据类型 - 连接时使用正确的列 如果存在上述错误,请重写查询。如果没有错误,则直接返回原始查询。 """ query_check_prompt = ChatPromptTemplate.from_messages( [("system", query_check_system), ("placeholder", "{messages}")] ) query_check = query_check_prompt | ChatOpenAI(model="gpt-4o", temperature=0).bind_tools( [db_query_tool], tool_choice="required" ) query_check.invoke({"messages": [("user", "SELECT * FROM Artist LIMIT 10;")]})

5. 构建智能体工作流

定义智能体的工作流,包括以下节点:

  1. retrieve:检索相关表
  2. grade_documents:评估表的相关性
  3. generate:生成查询
  4. transform_query:转换查询
python
from langgraph.prebuilt import ToolNode from langchain_core.messages import HumanMessage def retrieve(state): question = state["question"] tables = list_tables_tool.invoke(question) return {"tables": tables} def grade_documents(state): tables = state["tables"] schema = get_schema_tool.invoke(tables[0]) return {"schema": schema} def generate(state): schema = state["schema"] question = state["question"] query = llm.invoke([HumanMessage(content=f"根据以下 DDL 和问题生成 SQL 查询:{schema}\n问题:{question}")]) return {"query": query} def transform_query(state): query = state["query"] rewritten_query = query_check.invoke({"messages": [("user", query)]}) return {"query": rewritten_query} # 构建工作流 workflow = Graph() workflow.add_node("retrieve", retrieve) workflow.add_node("grade_documents", grade_documents) workflow.add_node("generate", generate) workflow.add_node("transform_query", transform_query) workflow.add_edge("retrieve", "grade_documents") workflow.add_edge("grade_documents", "generate") workflow.add_edge("generate", "transform_query") workflow.add_edge("transform_query", "retrieve") app = workflow.compile() # 执行工作流 question = "2009 年销售额最高的销售代理是谁?" for output in app.stream({"question": question}): for key, value in output.items(): print(f"节点:{key}, 输出:{value}")

6. 评估智能体

使用 LangSmith 对智能体进行评估,包括响应、工具调用和轨迹的评估。

python
from langsmith.evaluation import evaluate def predict_sql_agent_answer(example: dict): msg = {"messages": ("user", example["input"])} messages = app.invoke(msg) json_str = messages["messages"][-1].tool_calls[0]["args"] response = json_str["final_answer"] return {"response": response} def answer_evaluator(run, example) -> dict: input_question = example.inputs["input"] reference = example.outputs["output"] prediction = run.outputs["response"] llm = ChatOpenAI(model="gpt-4-turbo", temperature=0) answer_grader = grade_prompt_answer_accuracy | llm score = answer_grader.invoke( { "question": input_question, "correct_answer": reference, "student_answer": prediction, } ) score = score["Score"] return {"key": "answer_v_reference_score", "score": score} try: experiment_results = evaluate( predict_sql_agent_answer, data="SQL Agent Response", evaluators=[answer_evaluator], num_repetitions=3, experiment_prefix="sql-agent-multi-step-response-v-reference", metadata={"version": "Chinook, gpt-4o multi-step-agent"}, ) except: print("请设置 LangSmith")

完整代码

py

总结

本教程实现了一个能够与 SQL 数据库交互的智能体,结合了检索、生成、检查和执行查询的功能。你可以根据实际需求调整工作流和工具,以优化智能体的性能。

本文作者:yowayimono

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!