Skip to content

Commit 1e01201

Browse files
authored
draft prepared statement support (#232)
1 parent 372c649 commit 1e01201

File tree

4 files changed

+303
-29
lines changed

4 files changed

+303
-29
lines changed

proxy/src/embeddings.rs

Lines changed: 93 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ pub struct EmbedCall {
1616
pub full_match: String,
1717
pub start_pos: usize,
1818
pub end_pos: usize,
19+
pub is_prepared: bool,
20+
pub query_param_index: Option<usize>,
21+
pub project_param_index: Option<usize>,
1922
}
2023

2124
/// Embedding provider that uses the jobmap cache to determine the appropriate embedding provider
@@ -64,20 +67,21 @@ impl JobMapEmbeddingProvider {
6467
pub fn parse_embed_calls(sql: &str) -> Result<Vec<EmbedCall>> {
6568
let mut calls = Vec::new();
6669

67-
// Regex to match vectorize.embed('query', 'project_name')
68-
let re = Regex::new(
70+
// matches vectorize.embed('query', 'project_name') string literals only
71+
let string_re = Regex::new(
6972
r"(?i)vectorize\.embed\s*\(\s*'([^']*(?:''[^']*)*)'\s*,\s*'([^']*(?:''[^']*)*)'\s*\)",
7073
)?;
7174

72-
let param_re = Regex::new(
73-
r"(?i)vectorize\.embed\s*\(\s*'([^']*(?:''[^']*)*)'\s*,\s*'([^']*(?:''[^']*)*)'\s*\)",
74-
)?;
75-
for mat in re.find_iter(sql) {
75+
// matches vectorize.embed($1, $2) prepared statement parameters
76+
let param_re = Regex::new(r"(?i)vectorize\.embed\s*\(\s*\$(\d+)\s*,\s*\$(\d+)\s*\)")?;
77+
78+
// Parse string literal calls
79+
for mat in string_re.find_iter(sql) {
7680
let full_match = mat.as_str().to_string();
7781
let start_pos = mat.start();
7882
let end_pos = mat.end();
7983

80-
if let Some(captures) = param_re.captures(&full_match) {
84+
if let Some(captures) = string_re.captures(&full_match) {
8185
let query = captures.get(1).unwrap().as_str().replace("''", "'");
8286
let project_name = captures.get(2).unwrap().as_str().replace("''", "'");
8387

@@ -87,14 +91,67 @@ pub fn parse_embed_calls(sql: &str) -> Result<Vec<EmbedCall>> {
8791
full_match,
8892
start_pos,
8993
end_pos,
94+
is_prepared: false,
95+
query_param_index: None,
96+
project_param_index: None,
97+
});
98+
}
99+
}
100+
101+
// parse prepared statement parameter calls
102+
for mat in param_re.find_iter(sql) {
103+
let full_match = mat.as_str().to_string();
104+
let start_pos = mat.start();
105+
let end_pos = mat.end();
106+
107+
if let Some(captures) = param_re.captures(&full_match) {
108+
// convert 1-based indices to 0-based (e.g. bind parameters from $1 -> 0)
109+
let query_param_index = captures.get(1).unwrap().as_str().parse::<usize>()? - 1;
110+
let project_param_index = captures.get(2).unwrap().as_str().parse::<usize>()? - 1;
111+
112+
calls.push(EmbedCall {
113+
query: String::new(), // filled from bind parameters
114+
project_name: String::new(), // filled from bind parameters
115+
full_match,
116+
start_pos,
117+
end_pos,
118+
is_prepared: true,
119+
query_param_index: Some(query_param_index),
120+
project_param_index: Some(project_param_index),
90121
});
91122
}
92123
}
93124

94125
Ok(calls)
95126
}
96127

97-
/// Rewrites SQL query by replacing vectorize.embed() calls with actual embeddings
128+
/// resolves prepared statement parameters in embed calls
129+
pub fn resolve_prepared_embed_calls(
130+
mut embed_calls: Vec<EmbedCall>,
131+
parameters: &[String],
132+
) -> Result<Vec<EmbedCall>, VectorizeError> {
133+
for call in &mut embed_calls {
134+
if call.is_prepared {
135+
if let (Some(query_idx), Some(project_idx)) =
136+
(call.query_param_index, call.project_param_index)
137+
{
138+
if query_idx >= parameters.len() || project_idx >= parameters.len() {
139+
return Err(VectorizeError::EmbeddingGenerationFailed(format!(
140+
"Parameter index out of bounds: query_idx={}, project_idx={}, params_len={}",
141+
query_idx,
142+
project_idx,
143+
parameters.len()
144+
)));
145+
}
146+
call.query = parameters[query_idx].clone();
147+
call.project_name = parameters[project_idx].clone();
148+
}
149+
}
150+
}
151+
Ok(embed_calls)
152+
}
153+
154+
/// rewrites SQL query by replacing vectorize.embed() calls with actual embeddings
98155
pub async fn rewrite_query_with_embeddings(
99156
sql: &str,
100157
provider: &JobMapEmbeddingProvider,
@@ -109,7 +166,7 @@ pub async fn rewrite_query_with_embeddings(
109166

110167
let mut rewritten = sql.to_string();
111168

112-
// Process calls in reverse order to maintain correct positions
169+
// process calls in reverse order to maintain correct positions
113170
for call in embed_calls.iter().rev() {
114171
let embeddings = provider
115172
.generate_embeddings(&call.query, &call.project_name)
@@ -138,6 +195,7 @@ mod tests {
138195
assert_eq!(calls.len(), 1);
139196
assert_eq!(calls[0].query, "hello world");
140197
assert_eq!(calls[0].project_name, "my_project");
198+
assert!(!calls[0].is_prepared);
141199
}
142200

143201
#[test]
@@ -149,8 +207,34 @@ mod tests {
149207
assert_eq!(calls.len(), 2);
150208
assert_eq!(calls[0].query, "query1");
151209
assert_eq!(calls[0].project_name, "project1");
210+
assert!(!calls[0].is_prepared);
152211
assert_eq!(calls[1].query, "query2");
153212
assert_eq!(calls[1].project_name, "project2");
213+
assert!(!calls[1].is_prepared);
214+
}
215+
216+
#[test]
217+
fn test_parse_prepared_embed_calls() {
218+
let sql = "SELECT vectorize.embed($1, $2)";
219+
let calls = parse_embed_calls(sql).unwrap();
220+
221+
assert_eq!(calls.len(), 1);
222+
assert!(calls[0].is_prepared);
223+
assert_eq!(calls[0].query_param_index, Some(0));
224+
assert_eq!(calls[0].project_param_index, Some(1));
225+
}
226+
227+
#[test]
228+
fn test_resolve_prepared_embed_calls() {
229+
let sql = "SELECT vectorize.embed($1, $2)";
230+
let mut calls = parse_embed_calls(sql).unwrap();
231+
let parameters = vec!["hello world".to_string(), "my_project".to_string()];
232+
233+
calls = resolve_prepared_embed_calls(calls, &parameters).unwrap();
234+
235+
assert_eq!(calls.len(), 1);
236+
assert_eq!(calls[0].query, "hello world");
237+
assert_eq!(calls[0].project_name, "my_project");
154238
}
155239

156240
#[test]

0 commit comments

Comments
 (0)