Skip to content

Commit

Permalink
Update text2sql.py according to linter
Browse files Browse the repository at this point in the history
  • Loading branch information
SichengStevenLi authored Aug 15, 2024
1 parent 1098331 commit b081d16
Showing 1 changed file with 9 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@ def define_sql_database(engine, city_stats_table):

def main(args):
engine, city_stats_table = create_database_schema()

sql_database = define_sql_database(engine, city_stats_table)

model_id = args.embedding_model_path
device_map = args.device


embed_model = IpexLLMEmbedding(model_id, device=device_map)

llm = IpexLLM.from_model_id(
Expand All @@ -64,7 +63,7 @@ def main(args):
max_new_tokens=args.n_predict,
generate_kwargs={"temperature": 0.7, "do_sample": False},
model_kwargs={},
device_map=device_map
device_map=device_map,
)


Expand All @@ -74,7 +73,7 @@ def main(args):
tables=["city_stats"],
llm=llm,
embed_model=embed_model,
return_raw=True
return_raw=True,
)

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever, llm=llm)
Expand All @@ -90,35 +89,31 @@ def main(args):
"--model-path",
type=str,
required=True,
help="the path to transformers model"
help="the path to transformers model",
)
parser.add_argument(
"--device",
"-d",
type=str,
default="cpu",
choices=["cpu", "xpu"],
help="The device (Intel CPU or Intel GPU) the LLM model runs on",
help="The device (Intel CPU or Intel GPU) the LLM model runs on"
)
parser.add_argument(
"-q",
"--question",
type=str,
default="Which city has the highest population?",
help="question you want to ask."
help="question you want to ask.",
)
parser.add_argument(
"-e",
"--embedding-model-path",
default="BAAI/bge-small-en",
help="the path to embedding model path"
help="the path to embedding model path",
)
parser.add_argument(
"-n",
"--n-predict",
type=int,
default=32,
help="max number of predict tokens"
"-n", "--n-predict", type=int, default=32, help="max number of predict tokens"
)
args = parser.parse_args()

Expand Down

0 comments on commit b081d16

Please sign in to comment.