@@ -16,6 +16,9 @@ pub struct EmbedCall {
16
16
pub full_match : String ,
17
17
pub start_pos : usize ,
18
18
pub end_pos : usize ,
19
+ pub is_prepared : bool ,
20
+ pub query_param_index : Option < usize > ,
21
+ pub project_param_index : Option < usize > ,
19
22
}
20
23
21
24
/// Embedding provider that uses the jobmap cache to determine the appropriate embedding provider
@@ -64,20 +67,21 @@ impl JobMapEmbeddingProvider {
64
67
pub fn parse_embed_calls ( sql : & str ) -> Result < Vec < EmbedCall > > {
65
68
let mut calls = Vec :: new ( ) ;
66
69
67
- // Regex to match vectorize.embed('query', 'project_name')
68
- let re = Regex :: new (
70
+ // matches vectorize.embed('query', 'project_name') string literals only
71
+ let string_re = Regex :: new (
69
72
r"(?i)vectorize\.embed\s*\(\s*'([^']*(?:''[^']*)*)'\s*,\s*'([^']*(?:''[^']*)*)'\s*\)" ,
70
73
) ?;
71
74
72
- let param_re = Regex :: new (
73
- r"(?i)vectorize\.embed\s*\(\s*'([^']*(?:''[^']*)*)'\s*,\s*'([^']*(?:''[^']*)*)'\s*\)" ,
74
- ) ?;
75
- for mat in re. find_iter ( sql) {
75
+ // matches vectorize.embed($1, $2) prepared statement parameters
76
+ let param_re = Regex :: new ( r"(?i)vectorize\.embed\s*\(\s*\$(\d+)\s*,\s*\$(\d+)\s*\)" ) ?;
77
+
78
+ // Parse string literal calls
79
+ for mat in string_re. find_iter ( sql) {
76
80
let full_match = mat. as_str ( ) . to_string ( ) ;
77
81
let start_pos = mat. start ( ) ;
78
82
let end_pos = mat. end ( ) ;
79
83
80
- if let Some ( captures) = param_re . captures ( & full_match) {
84
+ if let Some ( captures) = string_re . captures ( & full_match) {
81
85
let query = captures. get ( 1 ) . unwrap ( ) . as_str ( ) . replace ( "''" , "'" ) ;
82
86
let project_name = captures. get ( 2 ) . unwrap ( ) . as_str ( ) . replace ( "''" , "'" ) ;
83
87
@@ -87,14 +91,67 @@ pub fn parse_embed_calls(sql: &str) -> Result<Vec<EmbedCall>> {
87
91
full_match,
88
92
start_pos,
89
93
end_pos,
94
+ is_prepared : false ,
95
+ query_param_index : None ,
96
+ project_param_index : None ,
97
+ } ) ;
98
+ }
99
+ }
100
+
101
+ // parse prepared statement parameter calls
102
+ for mat in param_re. find_iter ( sql) {
103
+ let full_match = mat. as_str ( ) . to_string ( ) ;
104
+ let start_pos = mat. start ( ) ;
105
+ let end_pos = mat. end ( ) ;
106
+
107
+ if let Some ( captures) = param_re. captures ( & full_match) {
108
+ // convert 1-based indices to 0-based (e.g. bind parameters from $1 -> 0)
109
+ let query_param_index = captures. get ( 1 ) . unwrap ( ) . as_str ( ) . parse :: < usize > ( ) ? - 1 ;
110
+ let project_param_index = captures. get ( 2 ) . unwrap ( ) . as_str ( ) . parse :: < usize > ( ) ? - 1 ;
111
+
112
+ calls. push ( EmbedCall {
113
+ query : String :: new ( ) , // filled from bind parameters
114
+ project_name : String :: new ( ) , // filled from bind parameters
115
+ full_match,
116
+ start_pos,
117
+ end_pos,
118
+ is_prepared : true ,
119
+ query_param_index : Some ( query_param_index) ,
120
+ project_param_index : Some ( project_param_index) ,
90
121
} ) ;
91
122
}
92
123
}
93
124
94
125
Ok ( calls)
95
126
}
96
127
97
- /// Rewrites SQL query by replacing vectorize.embed() calls with actual embeddings
128
+ /// resolves prepared statement parameters in embed calls
129
+ pub fn resolve_prepared_embed_calls (
130
+ mut embed_calls : Vec < EmbedCall > ,
131
+ parameters : & [ String ] ,
132
+ ) -> Result < Vec < EmbedCall > , VectorizeError > {
133
+ for call in & mut embed_calls {
134
+ if call. is_prepared {
135
+ if let ( Some ( query_idx) , Some ( project_idx) ) =
136
+ ( call. query_param_index , call. project_param_index )
137
+ {
138
+ if query_idx >= parameters. len ( ) || project_idx >= parameters. len ( ) {
139
+ return Err ( VectorizeError :: EmbeddingGenerationFailed ( format ! (
140
+ "Parameter index out of bounds: query_idx={}, project_idx={}, params_len={}" ,
141
+ query_idx,
142
+ project_idx,
143
+ parameters. len( )
144
+ ) ) ) ;
145
+ }
146
+ call. query = parameters[ query_idx] . clone ( ) ;
147
+ call. project_name = parameters[ project_idx] . clone ( ) ;
148
+ }
149
+ }
150
+ }
151
+ Ok ( embed_calls)
152
+ }
153
+
154
+ /// rewrites SQL query by replacing vectorize.embed() calls with actual embeddings
98
155
pub async fn rewrite_query_with_embeddings (
99
156
sql : & str ,
100
157
provider : & JobMapEmbeddingProvider ,
@@ -109,7 +166,7 @@ pub async fn rewrite_query_with_embeddings(
109
166
110
167
let mut rewritten = sql. to_string ( ) ;
111
168
112
- // Process calls in reverse order to maintain correct positions
169
+ // process calls in reverse order to maintain correct positions
113
170
for call in embed_calls. iter ( ) . rev ( ) {
114
171
let embeddings = provider
115
172
. generate_embeddings ( & call. query , & call. project_name )
@@ -138,6 +195,7 @@ mod tests {
138
195
assert_eq ! ( calls. len( ) , 1 ) ;
139
196
assert_eq ! ( calls[ 0 ] . query, "hello world" ) ;
140
197
assert_eq ! ( calls[ 0 ] . project_name, "my_project" ) ;
198
+ assert ! ( !calls[ 0 ] . is_prepared) ;
141
199
}
142
200
143
201
#[ test]
@@ -149,8 +207,34 @@ mod tests {
149
207
assert_eq ! ( calls. len( ) , 2 ) ;
150
208
assert_eq ! ( calls[ 0 ] . query, "query1" ) ;
151
209
assert_eq ! ( calls[ 0 ] . project_name, "project1" ) ;
210
+ assert ! ( !calls[ 0 ] . is_prepared) ;
152
211
assert_eq ! ( calls[ 1 ] . query, "query2" ) ;
153
212
assert_eq ! ( calls[ 1 ] . project_name, "project2" ) ;
213
+ assert ! ( !calls[ 1 ] . is_prepared) ;
214
+ }
215
+
216
+ #[ test]
217
+ fn test_parse_prepared_embed_calls ( ) {
218
+ let sql = "SELECT vectorize.embed($1, $2)" ;
219
+ let calls = parse_embed_calls ( sql) . unwrap ( ) ;
220
+
221
+ assert_eq ! ( calls. len( ) , 1 ) ;
222
+ assert ! ( calls[ 0 ] . is_prepared) ;
223
+ assert_eq ! ( calls[ 0 ] . query_param_index, Some ( 0 ) ) ;
224
+ assert_eq ! ( calls[ 0 ] . project_param_index, Some ( 1 ) ) ;
225
+ }
226
+
227
+ #[ test]
228
+ fn test_resolve_prepared_embed_calls ( ) {
229
+ let sql = "SELECT vectorize.embed($1, $2)" ;
230
+ let mut calls = parse_embed_calls ( sql) . unwrap ( ) ;
231
+ let parameters = vec ! [ "hello world" . to_string( ) , "my_project" . to_string( ) ] ;
232
+
233
+ calls = resolve_prepared_embed_calls ( calls, & parameters) . unwrap ( ) ;
234
+
235
+ assert_eq ! ( calls. len( ) , 1 ) ;
236
+ assert_eq ! ( calls[ 0 ] . query, "hello world" ) ;
237
+ assert_eq ! ( calls[ 0 ] . project_name, "my_project" ) ;
154
238
}
155
239
156
240
#[ test]
0 commit comments