You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/docs/docs/use_cases/sql/agents.ipynb

827 lines
32 KiB
Plaintext

{
"cells": [
{
"cell_type": "raw",
"metadata": {},
"source": [
"---\n",
"sidebar_position: 1\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Agents\n",
"\n",
"LangChain has a SQL Agent which provides a more flexible way of interacting with SQL Databases than a chain. The main advantages of using the SQL Agent are:\n",
"\n",
"- It can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table).\n",
"- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.\n",
"- It can query the database as many times as needed to answer the user question.\n",
"- It will save tokens by only retrieving the schema from relevant tables.\n",
"\n",
"To initialize the agent we'll use the [create_sql_agent](https://api.python.langchain.com/en/latest/agent_toolkits/langchain_community.agent_toolkits.sql.base.create_sql_agent.html) constructor. This agent uses the `SQLDatabaseToolkit` which contains tools to: \n",
"\n",
"* Create and execute queries\n",
"* Check query syntax\n",
"* Retrieve table descriptions\n",
"* ... and more\n",
"\n",
"## Setup\n",
"\n",
"First, get required packages and set environment variables:"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet langchain langchain-community langchain-openai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n",
"\n",
"# Uncomment the below to use LangSmith. Not required.\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n",
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n",
"\n",
"* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n",
"* Run `sqlite3 Chinook.db`\n",
"* Run `.read Chinook_Sqlite.sql`\n",
"* Test `SELECT * FROM Artist LIMIT 10;`\n",
"\n",
"Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sqlite\n",
"['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n"
]
},
{
"data": {
"text/plain": [
"\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\""
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_community.utilities import SQLDatabase\n",
"\n",
"db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n",
"print(db.dialect)\n",
"print(db.get_usable_table_names())\n",
"db.run(\"SELECT * FROM Artist LIMIT 10;\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Agent\n",
"\n",
"We'll use an OpenAI chat model and an `\"openai-tools\"` agent, which will use OpenAI's function-calling API to drive the agent's tool selection and invocations."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, the agent will first choose which tables are relevant and then add the schema for those tables and a few sample rows to the prompt."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.agent_toolkits import create_sql_agent\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
"agent_executor = create_sql_agent(llm, db=db, agent_type=\"openai-tools\", verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_list_tables` with `{}`\n",
"\n",
"\n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_schema` with `Invoice,Customer`\n",
"\n",
"\n",
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"Customer\" (\n",
"\t\"CustomerId\" INTEGER NOT NULL, \n",
"\t\"FirstName\" NVARCHAR(40) NOT NULL, \n",
"\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
"\t\"Company\" NVARCHAR(80), \n",
"\t\"Address\" NVARCHAR(70), \n",
"\t\"City\" NVARCHAR(40), \n",
"\t\"State\" NVARCHAR(40), \n",
"\t\"Country\" NVARCHAR(40), \n",
"\t\"PostalCode\" NVARCHAR(10), \n",
"\t\"Phone\" NVARCHAR(24), \n",
"\t\"Fax\" NVARCHAR(24), \n",
"\t\"Email\" NVARCHAR(60) NOT NULL, \n",
"\t\"SupportRepId\" INTEGER, \n",
"\tPRIMARY KEY (\"CustomerId\"), \n",
"\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Customer table:\n",
"CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
"1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
"2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
"3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"Invoice\" (\n",
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
"\t\"CustomerId\" INTEGER NOT NULL, \n",
"\t\"InvoiceDate\" DATETIME NOT NULL, \n",
"\t\"BillingAddress\" NVARCHAR(70), \n",
"\t\"BillingCity\" NVARCHAR(40), \n",
"\t\"BillingState\" NVARCHAR(40), \n",
"\t\"BillingCountry\" NVARCHAR(40), \n",
"\t\"BillingPostalCode\" NVARCHAR(10), \n",
"\t\"Total\" NUMERIC(10, 2) NOT NULL, \n",
"\tPRIMARY KEY (\"InvoiceId\"), \n",
"\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Invoice table:\n",
"InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
"1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
"2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
"3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
"*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC LIMIT 10;`\n",
"responded: To list the total sales per country, I can query the \"Invoice\" and \"Customer\" tables. I will join these tables on the \"CustomerId\" column and group the results by the \"BillingCountry\" column. Then, I will calculate the sum of the \"Total\" column to get the total sales per country. Finally, I will order the results in descending order of the total sales.\n",
"\n",
"Here is the SQL query:\n",
"\n",
"```sql\n",
"SELECT c.Country, SUM(i.Total) AS TotalSales\n",
"FROM Invoice i\n",
"JOIN Customer c ON i.CustomerId = c.CustomerId\n",
"GROUP BY c.Country\n",
"ORDER BY TotalSales DESC\n",
"LIMIT 10;\n",
"```\n",
"\n",
"Now, I will execute this query to get the total sales per country.\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)]\u001b[0m\u001b[32;1m\u001b[1;3mThe total sales per country are as follows:\n",
"\n",
"1. USA: $523.06\n",
"2. Canada: $303.96\n",
"3. France: $195.10\n",
"4. Brazil: $190.10\n",
"5. Germany: $156.48\n",
"6. United Kingdom: $112.86\n",
"7. Czech Republic: $90.24\n",
"8. Portugal: $77.24\n",
"9. India: $75.26\n",
"10. Chile: $46.62\n",
"\n",
"To answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': \"List the total sales per country. Which country's customers spent the most?\",\n",
" 'output': 'The total sales per country are as follows:\\n\\n1. USA: $523.06\\n2. Canada: $303.96\\n3. France: $195.10\\n4. Brazil: $190.10\\n5. Germany: $156.48\\n6. United Kingdom: $112.86\\n7. Czech Republic: $90.24\\n8. Portugal: $77.24\\n9. India: $75.26\\n10. Chile: $46.62\\n\\nTo answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.'}"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.invoke(\n",
" \"List the total sales per country. Which country's customers spent the most?\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_list_tables` with `{}`\n",
"\n",
"\n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_schema` with `PlaylistTrack`\n",
"\n",
"\n",
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from PlaylistTrack table:\n",
"PlaylistId\tTrackId\n",
"1\t3402\n",
"1\t3389\n",
"1\t3390\n",
"*/\u001b[0m\u001b[32;1m\u001b[1;3mThe `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \n",
"\n",
"Here is the schema of the `PlaylistTrack` table:\n",
"\n",
"```\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
")\n",
"```\n",
"\n",
"The `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\n",
"\n",
"Here are three sample rows from the `PlaylistTrack` table:\n",
"\n",
"```\n",
"PlaylistId TrackId\n",
"1 3402\n",
"1 3389\n",
"1 3390\n",
"```\n",
"\n",
"Please let me know if there is anything else I can help with.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': 'Describe the playlisttrack table',\n",
" 'output': 'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \\n\\nHere is the schema of the `PlaylistTrack` table:\\n\\n```\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n```\\n\\nThe `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\\n\\nHere are three sample rows from the `PlaylistTrack` table:\\n\\n```\\nPlaylistId TrackId\\n1 3402\\n1 3389\\n1 3390\\n```\\n\\nPlease let me know if there is anything else I can help with.'}"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.invoke(\"Describe the playlisttrack table\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using a dynamic few-shot prompt\n",
"\n",
"To optimize agent performance, we can provide a custom prompt with domain-specific knowledge. In this case we'll create a few shot prompt with an example selector, that will dynamically build the few shot prompt based on the user input. This will help the model make better queries by inserting relevant queries in the prompt that the model can use as reference.\n",
"\n",
"First we need some user input \\<\\> SQL query examples:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"examples = [\n",
" {\"input\": \"List all artists.\", \"query\": \"SELECT * FROM Artist;\"},\n",
" {\n",
" \"input\": \"Find all albums for the artist 'AC/DC'.\",\n",
" \"query\": \"SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');\",\n",
" },\n",
" {\n",
" \"input\": \"List all tracks in the 'Rock' genre.\",\n",
" \"query\": \"SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\",\n",
" },\n",
" {\n",
" \"input\": \"Find the total duration of all tracks.\",\n",
" \"query\": \"SELECT SUM(Milliseconds) FROM Track;\",\n",
" },\n",
" {\n",
" \"input\": \"List all customers from Canada.\",\n",
" \"query\": \"SELECT * FROM Customer WHERE Country = 'Canada';\",\n",
" },\n",
" {\n",
" \"input\": \"How many tracks are there in the album with ID 5?\",\n",
" \"query\": \"SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\",\n",
" },\n",
" {\n",
" \"input\": \"Find the total number of invoices.\",\n",
" \"query\": \"SELECT COUNT(*) FROM Invoice;\",\n",
" },\n",
" {\n",
" \"input\": \"List all tracks that are longer than 5 minutes.\",\n",
" \"query\": \"SELECT * FROM Track WHERE Milliseconds > 300000;\",\n",
" },\n",
" {\n",
" \"input\": \"Who are the top 5 customers by total purchase?\",\n",
" \"query\": \"SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;\",\n",
" },\n",
" {\n",
" \"input\": \"Which albums are from the year 2000?\",\n",
" \"query\": \"SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\",\n",
" },\n",
" {\n",
" \"input\": \"How many employees are there\",\n",
" \"query\": 'SELECT COUNT(*) FROM \"Employee\"',\n",
" },\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can create an example selector. This will take the actual user input and select some number of examples to add to our few-shot prompt. We'll use a SemanticSimilarityExampleSelector, which will perform a semantic search using the embeddings and vector store we configure to find the examples most similar to our input:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.vectorstores import FAISS\n",
"from langchain_core.example_selectors import SemanticSimilarityExampleSelector\n",
"from langchain_openai import OpenAIEmbeddings\n",
"\n",
"example_selector = SemanticSimilarityExampleSelector.from_examples(\n",
" examples,\n",
" OpenAIEmbeddings(),\n",
" FAISS,\n",
" k=5,\n",
" input_keys=[\"input\"],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can create our FewShotPromptTemplate, which takes our example selector, an example prompt for formatting each example, and a string prefix and suffix to put before and after our formatted examples:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import (\n",
" ChatPromptTemplate,\n",
" FewShotPromptTemplate,\n",
" MessagesPlaceholder,\n",
" PromptTemplate,\n",
" SystemMessagePromptTemplate,\n",
")\n",
"\n",
"system_prefix = \"\"\"You are an agent designed to interact with a SQL database.\n",
"Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
"Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n",
"You can order the results by a relevant column to return the most interesting examples in the database.\n",
"Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
"You have access to tools for interacting with the database.\n",
"Only use the given tools. Only use the information returned by the tools to construct your final answer.\n",
"You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
"\n",
"DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
"\n",
"If the question does not seem related to the database, just return \"I don't know\" as the answer.\n",
"\n",
"Here are some examples of user inputs and their corresponding SQL queries:\"\"\"\n",
"\n",
"few_shot_prompt = FewShotPromptTemplate(\n",
" example_selector=example_selector,\n",
" example_prompt=PromptTemplate.from_template(\n",
" \"User input: {input}\\nSQL query: {query}\"\n",
" ),\n",
" input_variables=[\"input\", \"dialect\", \"top_k\"],\n",
" prefix=system_prefix,\n",
" suffix=\"\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since our underlying agent is an [OpenAI tools agent](/docs/modules/agents/agent_types/openai_tools), which uses OpenAI function calling, our full prompt should be a chat prompt with a human message template and an agent_scratchpad `MessagesPlaceholder`. The few-shot prompt will be used for our system message:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"full_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" SystemMessagePromptTemplate(prompt=few_shot_prompt),\n",
" (\"human\", \"{input}\"),\n",
" MessagesPlaceholder(\"agent_scratchpad\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"System: You are an agent designed to interact with a SQL database.\n",
"Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n",
"Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\n",
"You can order the results by a relevant column to return the most interesting examples in the database.\n",
"Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
"You have access to tools for interacting with the database.\n",
"Only use the given tools. Only use the information returned by the tools to construct your final answer.\n",
"You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
"\n",
"DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
"\n",
"If the question does not seem related to the database, just return \"I don't know\" as the answer.\n",
"\n",
"Here are some examples of user inputs and their corresponding SQL queries:\n",
"\n",
"User input: List all artists.\n",
"SQL query: SELECT * FROM Artist;\n",
"\n",
"User input: How many employees are there\n",
"SQL query: SELECT COUNT(*) FROM \"Employee\"\n",
"\n",
"User input: How many tracks are there in the album with ID 5?\n",
"SQL query: SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\n",
"\n",
"User input: List all tracks in the 'Rock' genre.\n",
"SQL query: SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\n",
"\n",
"User input: Which albums are from the year 2000?\n",
"SQL query: SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\n",
"Human: How many arists are there\n"
]
}
],
"source": [
"# Example formatted prompt\n",
"prompt_val = full_prompt.invoke(\n",
" {\n",
" \"input\": \"How many arists are there\",\n",
" \"top_k\": 5,\n",
" \"dialect\": \"SQLite\",\n",
" \"agent_scratchpad\": [],\n",
" }\n",
")\n",
"print(prompt_val.to_string())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now we can create our agent with our custom prompt:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"agent = create_sql_agent(\n",
" llm=llm,\n",
" db=db,\n",
" prompt=full_prompt,\n",
" verbose=True,\n",
" agent_type=\"openai-tools\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's try it out:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) FROM Artist'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m[(275,)]\u001b[0m\u001b[32;1m\u001b[1;3mThere are 275 artists in the database.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': 'How many artists are there?',\n",
" 'output': 'There are 275 artists in the database.'}"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.invoke({\"input\": \"How many artists are there?\"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dealing with high-cardinality columns\n",
"\n",
"In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly. \n",
"\n",
"We can achieve this by creating a vector store with all the distinct proper nouns that exist in the database. We can then have the agent query that vector store each time the user includes a proper noun in their question, to find the correct spelling for that word. In this way, the agent can make sure it understands which entity the user is referring to before building the target query.\n",
"\n",
"First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Os Cães Ladram Mas A Caravana Não Pára',\n",
" 'War',\n",
" 'Mais Do Mesmo',\n",
" \"Up An' Atom\",\n",
" 'Riot Act']"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import ast\n",
"import re\n",
"\n",
"\n",
"def query_as_list(db, query):\n",
" res = db.run(query)\n",
" res = [el for sub in ast.literal_eval(res) for el in sub if el]\n",
" res = [re.sub(r\"\\b\\d+\\b\", \"\", string).strip() for string in res]\n",
" return list(set(res))\n",
"\n",
"\n",
"artists = query_as_list(db, \"SELECT Name FROM Artist\")\n",
"albums = query_as_list(db, \"SELECT Title FROM Album\")\n",
"albums[:5]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can proceed with creating the custom **retriever tool** and the final agent:"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents.agent_toolkits import create_retriever_tool\n",
"\n",
"vector_db = FAISS.from_texts(artists + albums, OpenAIEmbeddings())\n",
"retriever = vector_db.as_retriever(search_kwargs={\"k\": 5})\n",
"description = \"\"\"Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \\\n",
"valid proper nouns. Use the noun most similar to the search.\"\"\"\n",
"retriever_tool = create_retriever_tool(\n",
" retriever,\n",
" name=\"search_proper_nouns\",\n",
" description=description,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"system = \"\"\"You are an agent designed to interact with a SQL database.\n",
"Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
"Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n",
"You can order the results by a relevant column to return the most interesting examples in the database.\n",
"Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
"You have access to tools for interacting with the database.\n",
"Only use the given tools. Only use the information returned by the tools to construct your final answer.\n",
"You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
"\n",
"DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
"\n",
"If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the \"search_proper_nouns\" tool! \n",
"\n",
"You have access to the following tables: {table_names}\n",
"\n",
"If the question does not seem related to the database, just return \"I don't know\" as the answer.\"\"\"\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [(\"system\", system), (\"human\", \"{input}\"), MessagesPlaceholder(\"agent_scratchpad\")]\n",
")\n",
"agent = create_sql_agent(\n",
" llm=llm,\n",
" db=db,\n",
" extra_tools=[retriever_tool],\n",
" prompt=prompt,\n",
" agent_type=\"openai-tools\",\n",
" verbose=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `search_proper_nouns` with `{'query': 'alis in chain'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mAlice In Chains\n",
"\n",
"Aisha Duo\n",
"\n",
"Xis\n",
"\n",
"Da Lama Ao Caos\n",
"\n",
"A-Sides\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m[(1,)]\u001b[0m\u001b[32;1m\u001b[1;3mAlice In Chains has 1 album.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': 'How many albums does alis in chain have?',\n",
" 'output': 'Alice In Chains has 1 album.'}"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.invoke({\"input\": \"How many albums does alis in chain have?\"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, the agent used the `search_proper_nouns` tool in order to check how to correctly query the database for this specific artist."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Next steps\n",
"\n",
"Under the hood, `create_sql_agent` is just passing in SQL tools to more generic agent constructors. To learn more about the built-in generic agent types as well as how to build custom agents, head to the [Agents Modules](/docs/modules/agents/).\n",
"\n",
"The built-in `AgentExecutor` runs a simple Agent action -> Tool call -> Agent action... loop. To build more complex agent runtimes, head to the [LangGraph section](/docs/langgraph)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pampa-labs",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}