Skip to content

Commit 89ca2a4

Browse files
aaryanpuniaskeptrunedev
authored andcommitted
fix: create embeddings not batching requests correctly
1 parent f62a23a commit 89ca2a4

File tree

3 files changed

+181
-84
lines changed

3 files changed

+181
-84
lines changed

server/src/bin/ingestion-worker.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use trieve_server::operators::dataset_operator::get_dataset_by_id_query;
2626
use trieve_server::operators::event_operator::create_event_query;
2727
use trieve_server::operators::group_operator::get_groups_from_group_ids_query;
2828
use trieve_server::operators::model_operator::{
29-
create_embedding, create_embeddings, get_bm25_embeddings, get_sparse_vectors,
29+
get_bm25_embeddings, get_dense_vector, get_dense_vectors, get_sparse_vectors,
3030
};
3131
use trieve_server::operators::parse_operator::{
3232
average_embeddings, coarse_doc_chunker, convert_html_to_text,
@@ -557,7 +557,7 @@ pub async fn bulk_upload_chunks(
557557

558558
let embedding_vectors = match dataset_config.SEMANTIC_ENABLED {
559559
true => {
560-
let vectors = match create_embeddings(
560+
let vectors = match get_dense_vectors(
561561
content_and_boosts
562562
.iter()
563563
.map(|(content, _, distance_boost)| (content.clone(), distance_boost.clone()))
@@ -845,7 +845,7 @@ async fn upload_chunk(
845845
true => {
846846
let chunks = coarse_doc_chunker(content.clone(), None, false, 20);
847847

848-
let embeddings = create_embeddings(
848+
let embeddings = get_dense_vectors(
849849
chunks
850850
.iter()
851851
.map(|chunk| (chunk.clone(), payload.chunk.distance_phrase.clone()))
@@ -859,7 +859,7 @@ async fn upload_chunk(
859859
average_embeddings(embeddings)?
860860
}
861861
false => {
862-
let embedding_vectors = create_embeddings(
862+
let embedding_vectors = get_dense_vectors(
863863
vec![(content.clone(), payload.chunk.distance_phrase.clone())],
864864
"doc",
865865
dataset_config.clone(),
@@ -1062,7 +1062,7 @@ async fn update_chunk(
10621062

10631063
let embedding_vector = match dataset_config.SEMANTIC_ENABLED {
10641064
true => {
1065-
let embedding = create_embedding(
1065+
let embedding = get_dense_vector(
10661066
content.to_string(),
10671067
payload.distance_phrase,
10681068
"doc",

server/src/operators/model_operator.rs

Lines changed: 170 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ pub struct EmbeddingParameters {
2626
}
2727

2828
#[tracing::instrument]
29-
pub async fn create_embedding(
29+
pub async fn get_dense_vector(
3030
message: String,
3131
distance_phrase: Option<DistancePhrase>,
3232
embed_type: &str,
@@ -35,11 +35,11 @@ pub async fn create_embedding(
3535
let parent_span = sentry::configure_scope(|scope| scope.get_span());
3636
let transaction: sentry::TransactionOrSpan = match &parent_span {
3737
Some(parent) => parent
38-
.start_child("create_embedding", "Create semantic dense embedding")
38+
.start_child("get_dense_vector", "Create semantic dense embedding")
3939
.into(),
4040
None => {
4141
let ctx = sentry::TransactionContext::new(
42-
"create_embedding",
42+
"get_dense_vector",
4343
"Create semantic dense embedding",
4444
);
4545
sentry::start_transaction(ctx).into()
@@ -266,20 +266,20 @@ pub async fn get_sparse_vector(
266266
}
267267

268268
#[tracing::instrument]
269-
pub async fn create_embeddings(
270-
content_and_boosts: Vec<(String, Option<DistancePhrase>)>,
269+
pub async fn get_dense_vectors(
270+
content_and_distances: Vec<(String, Option<DistancePhrase>)>,
271271
embed_type: &str,
272272
dataset_config: DatasetConfiguration,
273273
reqwest_client: reqwest::Client,
274274
) -> Result<Vec<Vec<f32>>, ServiceError> {
275275
let parent_span = sentry::configure_scope(|scope| scope.get_span());
276276
let transaction: sentry::TransactionOrSpan = match &parent_span {
277277
Some(parent) => parent
278-
.start_child("create_embedding", "Create semantic dense embedding")
278+
.start_child("get_dense_vector", "Create semantic dense embedding")
279279
.into(),
280280
None => {
281281
let ctx = sentry::TransactionContext::new(
282-
"create_embedding",
282+
"get_dense_vector",
283283
"Create semantic dense embedding",
284284
);
285285
sentry::start_transaction(ctx).into()
@@ -323,31 +323,31 @@ pub async fn create_embeddings(
323323
embedding_api_key.to_string()
324324
};
325325

326-
let thirty_message_groups = content_and_boosts.chunks(30);
326+
let (contents, distance_phrases): (Vec<_>, Vec<_>) =
327+
content_and_distances.clone().into_iter().unzip();
328+
let thirty_content_groups = contents.chunks(30);
327329

328-
let vec_futures: Vec<_> = thirty_message_groups
330+
let filtered_distances_with_index = distance_phrases
331+
.clone()
332+
.iter()
329333
.enumerate()
330-
.map(|(i, combined_messages)| {
331-
let messages = combined_messages
332-
.iter()
333-
.map(|(x, _)| x)
334-
.cloned()
335-
.collect::<Vec<String>>();
336-
337-
let boost_phrase_and_index = combined_messages
338-
.iter()
339-
.enumerate()
340-
.filter_map(|(i, (_, y))| y.clone().map(|phrase| (i, phrase)))
341-
.collect::<Vec<(usize, DistancePhrase)>>();
334+
.filter_map(|(index, distance_phrase)| {
335+
distance_phrase
336+
.clone()
337+
.map(|distance_phrase| (index, distance_phrase))
338+
})
339+
.collect::<Vec<(usize, DistancePhrase)>>();
340+
let thirty_filterted_distances_with_indices = filtered_distances_with_index.chunks(30);
342341

343-
let boost_phrases = combined_messages
342+
let vec_distance_futures: Vec<_> = thirty_filterted_distances_with_indices
343+
.map(|thirty_distances| {
344+
let distance_phrases = thirty_distances
344345
.iter()
345-
.filter_map(|(_, y)| y.clone().map(|x| x.phrase.clone()))
346+
.map(|(_, x)| x.phrase.clone())
346347
.collect::<Vec<String>>();
347348

348-
let clipped_messages = messages
349+
let clipped_messages = distance_phrases
349350
.iter()
350-
.chain(boost_phrases.iter())
351351
.map(|message| {
352352
if message.len() > 5000 {
353353
message.chars().take(12000).collect()
@@ -406,65 +406,162 @@ pub async fn create_embeddings(
406406
)
407407
})?;
408408

409-
let mut vectors: Vec<Vec<f32>> = embeddings
410-
.data
411-
.into_iter()
412-
.map(|x| match x.embedding {
413-
EmbeddingOutput::Float(v) => v.iter().map(|x| *x as f32).collect(),
414-
EmbeddingOutput::Base64(_) => {
415-
log::error!("Embedding server responded with Base64 and that is not currently supported for embeddings");
416-
vec![]
409+
let vectors_and_boosts: Vec<(Vec<f32>, &(usize, DistancePhrase))> = embeddings
410+
.data
411+
.into_iter()
412+
.map(|x| match x.embedding {
413+
EmbeddingOutput::Float(v) => v.iter().map(|x| *x as f32).collect(),
414+
EmbeddingOutput::Base64(_) => {
415+
log::error!("Embedding server responded with Base64 and that is not currently supported for embeddings");
416+
vec![]
417+
}
418+
})
419+
.zip(thirty_distances)
420+
.collect();
421+
422+
if vectors_and_boosts.iter().any(|x| x.0.is_empty()) {
423+
return Err(ServiceError::InternalServerError(
424+
"Embedding server responded with Base64 and that is not currently supported for embeddings".to_owned(),
425+
));
417426
}
418-
})
419-
.collect();
420427

421-
if vectors.iter().any(|x| x.is_empty()) {
422-
return Err(ServiceError::InternalServerError(
423-
"Embedding server responded with Base64 and that is not currently supported for embeddings".to_owned(),
424-
));
425-
}
428+
Ok(vectors_and_boosts)
429+
};
426430

427-
if !boost_phrase_and_index.is_empty() {
428-
let boost_vectors = vectors
429-
.split_off(messages.len()).to_vec();
431+
vectors_resp
432+
})
433+
.collect();
430434

431-
let mut vectors_sorted = vectors.clone();
432-
for ((og_index, phrase), boost_vector) in boost_phrase_and_index.iter().zip(boost_vectors) {
433-
vectors_sorted[*og_index] = vectors_sorted[*og_index]
434-
.iter()
435-
.zip(boost_vector)
436-
.map(|(vector_elem, boost_vec_elem)| vector_elem + phrase.distance_factor * boost_vec_elem)
437-
.collect();
438-
}
435+
let vec_content_futures: Vec<_> = thirty_content_groups
436+
.map(|messages| {
437+
let clipped_messages = messages
438+
.iter()
439+
.map(|message| {
440+
if message.len() > 5000 {
441+
message.chars().take(12000).collect()
442+
} else {
443+
message.clone()
444+
}
445+
})
446+
.collect::<Vec<String>>();
439447

440-
return Ok((i, vectors_sorted));
441-
}
448+
let input = match embed_type {
449+
"doc" => EmbeddingInput::StringArray(clipped_messages),
450+
"query" => EmbeddingInput::String(
451+
format!(
452+
"{}{}",
453+
dataset_config.EMBEDDING_QUERY_PREFIX, &clipped_messages[0]
454+
)
455+
.to_string(),
456+
),
457+
_ => EmbeddingInput::StringArray(clipped_messages),
458+
};
442459

443-
Ok((i, vectors))
460+
let parameters = EmbeddingParameters {
461+
model: dataset_config.EMBEDDING_MODEL_NAME.to_string(),
462+
input,
463+
truncate: true,
444464
};
445465

446-
vectors_resp
447-
})
466+
let cur_client = reqwest_client.clone();
467+
let url = embedding_base_url.clone();
468+
469+
let embedding_api_key = embedding_api_key.clone();
470+
471+
let vectors_resp = async move {
472+
let embeddings_resp = cur_client
473+
.post(&format!("{}/embeddings?api-version=2023-05-15", url))
474+
.header("Authorization", &format!("Bearer {}", &embedding_api_key.clone()))
475+
.header("api-key", &embedding_api_key.clone())
476+
.header("Content-Type", "application/json")
477+
.json(&parameters)
478+
.send()
479+
.await
480+
.map_err(|_| {
481+
ServiceError::BadRequest("Failed to send message to embedding server".to_string())
482+
})?
483+
.text()
484+
.await
485+
.map_err(|_| {
486+
ServiceError::BadRequest("Failed to get text from embeddings".to_string())
487+
})?;
488+
489+
let embeddings: EmbeddingResponse = format_response(embeddings_resp.clone())
490+
.map_err(move |_e| {
491+
log::error!("Failed to format response from embeddings server {:?}", embeddings_resp);
492+
ServiceError::InternalServerError(
493+
format!("Failed to format response from embeddings server {:?}", embeddings_resp)
494+
)
495+
})?;
496+
497+
let vectors: Vec<Vec<f32>> = embeddings
498+
.data
499+
.into_iter()
500+
.map(|x| match x.embedding {
501+
EmbeddingOutput::Float(v) => v.iter().map(|x| *x as f32).collect(),
502+
EmbeddingOutput::Base64(_) => {
503+
log::error!("Embedding server responded with Base64 and that is not currently supported for embeddings");
504+
vec![]
505+
}
506+
})
507+
.collect();
508+
509+
if vectors.iter().any(|x| x.is_empty()) {
510+
return Err(ServiceError::InternalServerError(
511+
"Embedding server responded with Base64 and that is not currently supported for embeddings".to_owned(),
512+
));
513+
}
514+
Ok(vectors)
515+
};
516+
517+
vectors_resp
518+
519+
})
448520
.collect();
449521

450-
let all_chunk_vectors: Vec<(usize, Vec<Vec<f32>>)> = futures::future::join_all(vec_futures)
522+
let mut content_vectors: Vec<_> = futures::future::join_all(vec_content_futures)
451523
.await
452524
.into_iter()
453-
.collect::<Result<Vec<(usize, Vec<Vec<f32>>)>, ServiceError>>()?;
525+
.collect::<Result<Vec<_>, ServiceError>>()?
526+
.into_iter()
527+
.flatten()
528+
.collect();
454529

455-
let mut vectors_sorted = vec![];
456-
for index in 0..all_chunk_vectors.len() {
457-
let (_, vectors_i) = all_chunk_vectors.iter().find(|(i, _)| *i == index).ok_or(
458-
ServiceError::InternalServerError(
459-
"Failed to get index i (this should never happen)".to_string(),
460-
),
461-
)?;
530+
let distance_vectors: Vec<_> = futures::future::join_all(vec_distance_futures)
531+
.await
532+
.into_iter()
533+
.collect::<Result<Vec<_>, ServiceError>>()?
534+
.into_iter()
535+
.flatten()
536+
.collect();
462537

463-
vectors_sorted.extend(vectors_i.clone());
538+
if !distance_vectors.is_empty() {
539+
content_vectors = content_vectors
540+
.into_iter()
541+
.enumerate()
542+
.map(|(i, message)| {
543+
let distance_vector = distance_vectors
544+
.iter()
545+
.find(|(_, (og_index, _))| *og_index == i);
546+
match distance_vector {
547+
Some((distance_vec, (_, distance_phrase))) => {
548+
let distance_factor = distance_phrase.distance_factor;
549+
message
550+
.iter()
551+
.zip(distance_vec)
552+
.map(|(vec_elem, distance_elem)| {
553+
vec_elem + distance_factor * distance_elem
554+
})
555+
.collect()
556+
}
557+
None => message,
558+
}
559+
})
560+
.collect();
464561
}
465562

466563
transaction.finish();
467-
Ok(vectors_sorted)
564+
Ok(content_vectors)
468565
}
469566

470567
#[derive(Debug, Serialize, Deserialize)]
@@ -492,31 +589,31 @@ pub struct CustomSparseEmbedData {
492589

493590
#[tracing::instrument]
494591
pub async fn get_sparse_vectors(
495-
messages: Vec<(String, Option<BoostPhrase>)>,
592+
content_and_boosts: Vec<(String, Option<BoostPhrase>)>,
496593
embed_type: &str,
497594
reqwest_client: reqwest::Client,
498595
) -> Result<Vec<Vec<(u32, f32)>>, ServiceError> {
499-
if messages.is_empty() {
596+
if content_and_boosts.is_empty() {
500597
return Err(ServiceError::BadRequest(
501598
"No messages to encode".to_string(),
502599
));
503600
}
504601

505-
let contents = messages
602+
let contents = content_and_boosts
506603
.clone()
507604
.into_iter()
508605
.map(|(x, _)| x)
509606
.collect::<Vec<String>>();
510607
let thirty_content_groups = contents.chunks(30);
511608

512-
let filtered_boosts_with_index = messages
609+
let filtered_boosts_with_index = content_and_boosts
513610
.into_iter()
514611
.enumerate()
515612
.filter_map(|(i, (_, y))| y.map(|boost_phrase| (i, boost_phrase)))
516613
.collect::<Vec<(usize, BoostPhrase)>>();
517-
let filtered_boosts_with_index_groups = filtered_boosts_with_index.chunks(30);
614+
let thirty_filtered_boosts_with_indices = filtered_boosts_with_index.chunks(30);
518615

519-
let vec_boost_futures: Vec<_> = filtered_boosts_with_index_groups
616+
let vec_boost_futures: Vec<_> = thirty_filtered_boosts_with_indices
520617
.enumerate()
521618
.map(|(i, thirty_boosts)| {
522619
let cur_client = reqwest_client.clone();

0 commit comments

Comments
 (0)