7
7
from sqlalchemy import create_engine , MetaData , Table , Column , String , Integer , insert
8
8
import argparse
9
9
10
+
10
11
def create_database_schema ():
11
12
engine = create_engine ("sqlite:///:memory:" )
12
13
metadata_obj = MetaData ()
@@ -23,6 +24,7 @@ def create_database_schema():
23
24
metadata_obj .create_all (engine )
24
25
return engine , city_stats_table
25
26
27
+
26
28
def define_sql_database (engine , city_stats_table ):
27
29
sql_database = SQLDatabase (engine , include_tables = ["city_stats" ])
28
30
@@ -43,21 +45,19 @@ def define_sql_database(engine, city_stats_table):
43
45
44
46
return sql_database
45
47
48
+
46
49
def main (args ):
47
50
engine , city_stats_table = create_database_schema ()
48
51
49
52
sql_database = define_sql_database (engine , city_stats_table )
50
53
51
- model_id = args .embedding_model_path
54
+ model_id = args .embedding_model_path
52
55
device_map = args .device
53
56
54
57
55
- embed_model = IpexLLMEmbedding (
56
- model_id ,
57
- device = device_map
58
- )
58
+ embed_model = IpexLLMEmbedding (model_id , device = device_map )
59
59
60
- llm = IpexLLM .from_model_id (
60
+ llm = IpexLLM .from_model_id (
61
61
model_name = args .model_path ,
62
62
tokenizer_name = args .model_path ,
63
63
context_window = 512 ,
@@ -70,11 +70,11 @@ def main(args):
70
70
71
71
# default retrieval (return_raw=True)
72
72
nl_sql_retriever = NLSQLRetriever (
73
- sql_database ,
74
- tables = ["city_stats" ],
75
- llm = llm ,
76
- embed_model = embed_model ,
77
- return_raw = True
73
+ sql_database ,
74
+ tables = ["city_stats" ],
75
+ llm = llm ,
76
+ embed_model = embed_model ,
77
+ return_raw = True
78
78
)
79
79
80
80
query_engine = RetrieverQueryEngine .from_args (nl_sql_retriever , llm = llm )
@@ -84,13 +84,13 @@ def main(args):
84
84
85
85
86
86
if __name__ == "__main__" :
87
- parser = argparse .ArgumentParser (description = ' LlamaIndex IpexLLM Example' )
87
+ parser = argparse .ArgumentParser (description = " LlamaIndex IpexLLM Example" )
88
88
parser .add_argument (
89
- '-m' ,
90
- ' --model-path' ,
89
+ "-m" ,
90
+ " --model-path" ,
91
91
type = str ,
92
92
required = True ,
93
- help = ' the path to transformers model'
93
+ help = " the path to transformers model"
94
94
)
95
95
parser .add_argument (
96
96
"--device" ,
@@ -101,24 +101,24 @@ def main(args):
101
101
help = "The device (Intel CPU or Intel GPU) the LLM model runs on" ,
102
102
)
103
103
parser .add_argument (
104
- '-q' ,
105
- ' --question' ,
104
+ "-q" ,
105
+ " --question" ,
106
106
type = str ,
107
- default = ' Which city has the highest population?' ,
108
- help = ' question you want to ask.'
107
+ default = " Which city has the highest population?" ,
108
+ help = " question you want to ask."
109
109
)
110
110
parser .add_argument (
111
- '-e' ,
112
- ' --embedding-model-path' ,
111
+ "-e" ,
112
+ " --embedding-model-path" ,
113
113
default = "BAAI/bge-small-en" ,
114
114
help = "the path to embedding model path"
115
115
)
116
116
parser .add_argument (
117
- '-n' ,
118
- ' --n-predict' ,
117
+ "-n" ,
118
+ " --n-predict" ,
119
119
type = int ,
120
120
default = 32 ,
121
- help = ' max number of predict tokens'
121
+ help = " max number of predict tokens"
122
122
)
123
123
args = parser .parse_args ()
124
124
0 commit comments