Skip to content

Commit 327d3c4

Browse files
Update text2sql.py according to linter
1 parent bb47ef5 commit 327d3c4

File tree

1 file changed

+24
-24
lines changed
  • llama-index-integrations/llms/llama-index-llms-ipex-llm/examples

1 file changed

+24
-24
lines changed

llama-index-integrations/llms/llama-index-llms-ipex-llm/examples/text2sql.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, insert
88
import argparse
99

10+
1011
def create_database_schema():
1112
engine = create_engine("sqlite:///:memory:")
1213
metadata_obj = MetaData()
@@ -23,6 +24,7 @@ def create_database_schema():
2324
metadata_obj.create_all(engine)
2425
return engine, city_stats_table
2526

27+
2628
def define_sql_database(engine, city_stats_table):
2729
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
2830

@@ -43,21 +45,19 @@ def define_sql_database(engine, city_stats_table):
4345

4446
return sql_database
4547

48+
4649
def main(args):
4750
engine, city_stats_table = create_database_schema()
4851

4952
sql_database = define_sql_database(engine, city_stats_table)
5053

51-
model_id=args.embedding_model_path
54+
model_id = args.embedding_model_path
5255
device_map = args.device
5356

5457

55-
embed_model = IpexLLMEmbedding(
56-
model_id,
57-
device=device_map
58-
)
58+
embed_model = IpexLLMEmbedding(model_id, device=device_map)
5959

60-
llm = IpexLLM.from_model_id(
60+
llm = IpexLLM.from_model_id(
6161
model_name=args.model_path,
6262
tokenizer_name=args.model_path,
6363
context_window=512,
@@ -70,11 +70,11 @@ def main(args):
7070

7171
# default retrieval (return_raw=True)
7272
nl_sql_retriever = NLSQLRetriever(
73-
sql_database,
74-
tables=["city_stats"],
75-
llm=llm,
76-
embed_model=embed_model,
77-
return_raw=True
73+
sql_database,
74+
tables=["city_stats"],
75+
llm=llm,
76+
embed_model=embed_model,
77+
return_raw=True
7878
)
7979

8080
query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever, llm=llm)
@@ -84,13 +84,13 @@ def main(args):
8484

8585

8686
if __name__ == "__main__":
87-
parser = argparse.ArgumentParser(description='LlamaIndex IpexLLM Example')
87+
parser = argparse.ArgumentParser(description="LlamaIndex IpexLLM Example")
8888
parser.add_argument(
89-
'-m',
90-
'--model-path',
89+
"-m",
90+
"--model-path",
9191
type=str,
9292
required=True,
93-
help='the path to transformers model'
93+
help="the path to transformers model"
9494
)
9595
parser.add_argument(
9696
"--device",
@@ -101,24 +101,24 @@ def main(args):
101101
help="The device (Intel CPU or Intel GPU) the LLM model runs on",
102102
)
103103
parser.add_argument(
104-
'-q',
105-
'--question',
104+
"-q",
105+
"--question",
106106
type=str,
107-
default='Which city has the highest population?',
108-
help='question you want to ask.'
107+
default="Which city has the highest population?",
108+
help="question you want to ask."
109109
)
110110
parser.add_argument(
111-
'-e',
112-
'--embedding-model-path',
111+
"-e",
112+
"--embedding-model-path",
113113
default="BAAI/bge-small-en",
114114
help="the path to embedding model path"
115115
)
116116
parser.add_argument(
117-
'-n',
118-
'--n-predict',
117+
"-n",
118+
"--n-predict",
119119
type=int,
120120
default=32,
121-
help='max number of predict tokens'
121+
help="max number of predict tokens"
122122
)
123123
args = parser.parse_args()
124124

0 commit comments

Comments
 (0)