forked from langroid/langroid
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsql_chat.py
212 lines (172 loc) · 6.6 KB
/
sql_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""
Example showing how to chat with a SQL database.
Note if you are using this with a postgres db, you will need to:
(a) Install PostgreSQL dev libraries for your platform, e.g.
- `sudo apt-get install libpq-dev` on Ubuntu,
- `brew install postgresql` on Mac, etc.
(b) langroid with the postgres extra, e.g. `pip install langroid[postgres]`
or `poetry add langroid[postgres]` or `poetry install -E postgres`
or `uv pip install langroid[postgres]` or `uv add langroid[postgres]`.
If this gives you an error, try `pip install psycopg2-binary` in your virtualenv.
"""
import typer
from rich import print
from rich.prompt import Prompt
from typing import Dict, Any
import json
import os
from langroid.exceptions import LangroidImportError
try:
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine import Engine
except ImportError as e:
raise LangroidImportError(extra="sql", error=str(e))
from prettytable import PrettyTable
try:
from .utils import get_database_uri, fix_uri
except ImportError:
from utils import get_database_uri, fix_uri
from langroid.agent.task import Task
from langroid.agent.special.sql.sql_chat_agent import (
SQLChatAgentConfig,
SQLChatAgent,
)
from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
from langroid.utils.configuration import set_global, Settings
from langroid.utils.constants import SEND_TO
import logging
logger = logging.getLogger(__name__)
app = typer.Typer()
def create_descriptions_file(filepath: str, engine: Engine) -> None:
"""
Create an empty descriptions JSON file for SQLAlchemy tables.
This function inspects the database, generates a template for table and
column descriptions, and writes that template to a new JSON file.
Args:
filepath: The path to the file where the descriptions should be written.
engine: The SQLAlchemy Engine connected to the database to describe.
Raises:
FileExistsError: If the file at `filepath` already exists.
Returns:
None
"""
if os.path.exists(filepath):
raise FileExistsError(f"File {filepath} already exists.")
inspector = inspect(engine)
descriptions: Dict[str, Dict[str, Any]] = {}
for table_name in inspector.get_table_names():
descriptions[table_name] = {
"description": "",
"columns": {col["name"]: "" for col in inspector.get_columns(table_name)},
}
with open(filepath, "w") as f:
json.dump(descriptions, f, indent=4)
def load_context_descriptions(engine: Engine) -> dict:
"""
Ask the user for a path to a JSON file and load context descriptions from it.
Returns:
dict: The context descriptions, or an empty dictionary if the user decides to skip this step.
"""
while True:
filepath = Prompt.ask(
"[blue]Enter the path to your context descriptions file. \n"
"('n' to create a NEW file, 's' to SKIP, or Hit enter to use DEFAULT) ",
default="examples/data-qa/sql-chat/demo.json",
)
if filepath.strip() == "s":
return {}
if filepath.strip() == "n":
filepath = Prompt.ask(
"[blue]To create a new context description file, enter the path",
default="examples/data-qa/sql-chat/description.json",
)
print(f"[blue]Creating new context description file at {filepath}...")
create_descriptions_file(filepath, engine)
print(
f"[blue] Please fill in the descriptions in {filepath}, "
f"then try again."
)
# Try to load the file
if not os.path.exists(filepath):
print(f"[red]The file '{filepath}' does not exist. Please try again.")
continue
try:
with open(filepath, "r") as file:
return json.load(file)
except json.JSONDecodeError:
print(
f"[red]The file '{filepath}' is not a valid JSON file. Please try again."
)
@app.command()
def main(
debug: bool = typer.Option(False, "--debug", "-d", help="debug mode"),
no_stream: bool = typer.Option(False, "--nostream", "-ns", help="no streaming"),
nocache: bool = typer.Option(False, "--nocache", "-nc", help="don't use cache"),
tools: bool = typer.Option(
False, "--tools", "-t", help="use langroid tools instead of function-calling"
),
cache_type: str = typer.Option(
"redis", "--cachetype", "-ct", help="redis or momento"
),
schema_tools: bool = typer.Option(
False, "--schema_tools", "-st", help="use schema tools"
),
) -> None:
set_global(
Settings(
debug=debug,
cache=not nocache,
stream=not no_stream,
cache_type=cache_type,
)
)
print("[blue]Welcome to the SQL database chatbot!\n")
database_uri = Prompt.ask(
"""
[blue]Enter the URI for your SQL database
(type 'i' for interactive, or hit enter for default)
""",
default="sqlite:///examples/data-qa/sql-chat/demo.db",
)
if database_uri == "i":
database_uri = get_database_uri()
database_uri = fix_uri(database_uri)
logger.warning(f"Using database URI: {database_uri}")
# Create engine and inspector
engine = create_engine(database_uri)
inspector = inspect(engine)
context_descriptions = load_context_descriptions(engine)
# Get table names
table_names = inspector.get_table_names()
for table_name in table_names:
print(f"[blue]Table: {table_name}")
# Create a new table for the columns
table = PrettyTable()
table.field_names = ["Column Name", "Type"]
# Get the columns for the table
columns = inspector.get_columns(table_name)
for column in columns:
table.add_row([column["name"], column["type"]])
print(table)
agent_config = SQLChatAgentConfig(
name="sql",
database_uri=database_uri,
use_tools=tools,
use_functions_api=not tools,
show_stats=False,
chat_mode=True,
use_helper=True,
context_descriptions=context_descriptions, # Add context descriptions to the config
use_schema_tools=schema_tools,
addressing_prefix=SEND_TO,
llm=OpenAIGPTConfig(
chat_model=OpenAIChatModel.GPT4,
),
)
agent = SQLChatAgent(agent_config)
# Set interactive = False, but we user gets chance to respond
# when explicitly addressed by LLM
task = Task(agent, interactive=False)
task.run()
if __name__ == "__main__":
app()