@@ -48,13 +48,12 @@ def define_sql_database(engine, city_stats_table):
48
48
49
49
def main (args ):
50
50
engine , city_stats_table = create_database_schema ()
51
-
51
+
52
52
sql_database = define_sql_database (engine , city_stats_table )
53
-
53
+
54
54
model_id = args .embedding_model_path
55
55
device_map = args .device
56
56
57
-
58
57
embed_model = IpexLLMEmbedding (model_id , device = device_map )
59
58
60
59
llm = IpexLLM .from_model_id (
@@ -64,7 +63,7 @@ def main(args):
64
63
max_new_tokens = args .n_predict ,
65
64
generate_kwargs = {"temperature" : 0.7 , "do_sample" : False },
66
65
model_kwargs = {},
67
- device_map = device_map
66
+ device_map = device_map ,
68
67
)
69
68
70
69
@@ -74,7 +73,7 @@ def main(args):
74
73
tables = ["city_stats" ],
75
74
llm = llm ,
76
75
embed_model = embed_model ,
77
- return_raw = True
76
+ return_raw = True ,
78
77
)
79
78
80
79
query_engine = RetrieverQueryEngine .from_args (nl_sql_retriever , llm = llm )
@@ -90,35 +89,31 @@ def main(args):
90
89
"--model-path" ,
91
90
type = str ,
92
91
required = True ,
93
- help = "the path to transformers model"
92
+ help = "the path to transformers model" ,
94
93
)
95
94
parser .add_argument (
96
95
"--device" ,
97
96
"-d" ,
98
97
type = str ,
99
98
default = "cpu" ,
100
99
choices = ["cpu" , "xpu" ],
101
- help = "The device (Intel CPU or Intel GPU) the LLM model runs on" ,
100
+ help = "The device (Intel CPU or Intel GPU) the LLM model runs on"
102
101
)
103
102
parser .add_argument (
104
103
"-q" ,
105
104
"--question" ,
106
105
type = str ,
107
106
default = "Which city has the highest population?" ,
108
- help = "question you want to ask."
107
+ help = "question you want to ask." ,
109
108
)
110
109
parser .add_argument (
111
110
"-e" ,
112
111
"--embedding-model-path" ,
113
112
default = "BAAI/bge-small-en" ,
114
- help = "the path to embedding model path"
113
+ help = "the path to embedding model path" ,
115
114
)
116
115
parser .add_argument (
117
- "-n" ,
118
- "--n-predict" ,
119
- type = int ,
120
- default = 32 ,
121
- help = "max number of predict tokens"
116
+ "-n" , "--n-predict" , type = int , default = 32 , help = "max number of predict tokens"
122
117
)
123
118
args = parser .parse_args ()
124
119
0 commit comments