From 89ca2a45d510836defaed97605fef2f57db4283c Mon Sep 17 00:00:00 2001 From: aaryanpunia Date: Wed, 24 Jul 2024 11:43:53 -0700 Subject: [PATCH] fix: create embeddings not batching requests correctly --- server/src/bin/ingestion-worker.rs | 10 +- server/src/operators/model_operator.rs | 243 +++++++++++++++++------- server/src/operators/search_operator.rs | 12 +- 3 files changed, 181 insertions(+), 84 deletions(-) diff --git a/server/src/bin/ingestion-worker.rs b/server/src/bin/ingestion-worker.rs index 05c93256c..33262e989 100644 --- a/server/src/bin/ingestion-worker.rs +++ b/server/src/bin/ingestion-worker.rs @@ -26,7 +26,7 @@ use trieve_server::operators::dataset_operator::get_dataset_by_id_query; use trieve_server::operators::event_operator::create_event_query; use trieve_server::operators::group_operator::get_groups_from_group_ids_query; use trieve_server::operators::model_operator::{ - create_embedding, create_embeddings, get_bm25_embeddings, get_sparse_vectors, + get_bm25_embeddings, get_dense_vector, get_dense_vectors, get_sparse_vectors, }; use trieve_server::operators::parse_operator::{ average_embeddings, coarse_doc_chunker, convert_html_to_text, @@ -557,7 +557,7 @@ pub async fn bulk_upload_chunks( let embedding_vectors = match dataset_config.SEMANTIC_ENABLED { true => { - let vectors = match create_embeddings( + let vectors = match get_dense_vectors( content_and_boosts .iter() .map(|(content, _, distance_boost)| (content.clone(), distance_boost.clone())) @@ -845,7 +845,7 @@ async fn upload_chunk( true => { let chunks = coarse_doc_chunker(content.clone(), None, false, 20); - let embeddings = create_embeddings( + let embeddings = get_dense_vectors( chunks .iter() .map(|chunk| (chunk.clone(), payload.chunk.distance_phrase.clone())) @@ -859,7 +859,7 @@ async fn upload_chunk( average_embeddings(embeddings)? } false => { - let embedding_vectors = create_embeddings( + let embedding_vectors = get_dense_vectors( vec![(content.clone(), payload.chunk.distance_phrase.clone())], "doc", dataset_config.clone(), @@ -1062,7 +1062,7 @@ async fn update_chunk( let embedding_vector = match dataset_config.SEMANTIC_ENABLED { true => { - let embedding = create_embedding( + let embedding = get_dense_vector( content.to_string(), payload.distance_phrase, "doc", diff --git a/server/src/operators/model_operator.rs b/server/src/operators/model_operator.rs index 1cb4b52f5..32aa37694 100644 --- a/server/src/operators/model_operator.rs +++ b/server/src/operators/model_operator.rs @@ -26,7 +26,7 @@ pub struct EmbeddingParameters { } #[tracing::instrument] -pub async fn create_embedding( +pub async fn get_dense_vector( message: String, distance_phrase: Option, embed_type: &str, @@ -35,11 +35,11 @@ pub async fn create_embedding( let parent_span = sentry::configure_scope(|scope| scope.get_span()); let transaction: sentry::TransactionOrSpan = match &parent_span { Some(parent) => parent - .start_child("create_embedding", "Create semantic dense embedding") + .start_child("get_dense_vector", "Create semantic dense embedding") .into(), None => { let ctx = sentry::TransactionContext::new( - "create_embedding", + "get_dense_vector", "Create semantic dense embedding", ); sentry::start_transaction(ctx).into() @@ -266,8 +266,8 @@ pub async fn get_sparse_vector( } #[tracing::instrument] -pub async fn create_embeddings( - content_and_boosts: Vec<(String, Option)>, +pub async fn get_dense_vectors( + content_and_distances: Vec<(String, Option)>, embed_type: &str, dataset_config: DatasetConfiguration, reqwest_client: reqwest::Client, @@ -275,11 +275,11 @@ pub async fn create_embeddings( let parent_span = sentry::configure_scope(|scope| scope.get_span()); let transaction: sentry::TransactionOrSpan = match &parent_span { Some(parent) => parent - .start_child("create_embedding", "Create semantic dense embedding") + .start_child("get_dense_vector", "Create semantic dense embedding") .into(), None => { let ctx = sentry::TransactionContext::new( - "create_embedding", + "get_dense_vector", "Create semantic dense embedding", ); sentry::start_transaction(ctx).into() @@ -323,31 +323,31 @@ pub async fn create_embeddings( embedding_api_key.to_string() }; - let thirty_message_groups = content_and_boosts.chunks(30); + let (contents, distance_phrases): (Vec<_>, Vec<_>) = + content_and_distances.clone().into_iter().unzip(); + let thirty_content_groups = contents.chunks(30); - let vec_futures: Vec<_> = thirty_message_groups + let filtered_distances_with_index = distance_phrases + .clone() + .iter() .enumerate() - .map(|(i, combined_messages)| { - let messages = combined_messages - .iter() - .map(|(x, _)| x) - .cloned() - .collect::>(); - - let boost_phrase_and_index = combined_messages - .iter() - .enumerate() - .filter_map(|(i, (_, y))| y.clone().map(|phrase| (i, phrase))) - .collect::>(); + .filter_map(|(index, distance_phrase)| { + distance_phrase + .clone() + .map(|distance_phrase| (index, distance_phrase)) + }) + .collect::>(); + let thirty_filterted_distances_with_indices = filtered_distances_with_index.chunks(30); - let boost_phrases = combined_messages + let vec_distance_futures: Vec<_> = thirty_filterted_distances_with_indices + .map(|thirty_distances| { + let distance_phrases = thirty_distances .iter() - .filter_map(|(_, y)| y.clone().map(|x| x.phrase.clone())) + .map(|(_, x)| x.phrase.clone()) .collect::>(); - let clipped_messages = messages + let clipped_messages = distance_phrases .iter() - .chain(boost_phrases.iter()) .map(|message| { if message.len() > 5000 { message.chars().take(12000).collect() @@ -406,65 +406,162 @@ pub async fn create_embeddings( ) })?; - let mut vectors: Vec> = embeddings - .data - .into_iter() - .map(|x| match x.embedding { - EmbeddingOutput::Float(v) => v.iter().map(|x| *x as f32).collect(), - EmbeddingOutput::Base64(_) => { - log::error!("Embedding server responded with Base64 and that is not currently supported for embeddings"); - vec![] + let vectors_and_boosts: Vec<(Vec, &(usize, DistancePhrase))> = embeddings + .data + .into_iter() + .map(|x| match x.embedding { + EmbeddingOutput::Float(v) => v.iter().map(|x| *x as f32).collect(), + EmbeddingOutput::Base64(_) => { + log::error!("Embedding server responded with Base64 and that is not currently supported for embeddings"); + vec![] + } + }) + .zip(thirty_distances) + .collect(); + + if vectors_and_boosts.iter().any(|x| x.0.is_empty()) { + return Err(ServiceError::InternalServerError( + "Embedding server responded with Base64 and that is not currently supported for embeddings".to_owned(), + )); } - }) - .collect(); - if vectors.iter().any(|x| x.is_empty()) { - return Err(ServiceError::InternalServerError( - "Embedding server responded with Base64 and that is not currently supported for embeddings".to_owned(), - )); - } + Ok(vectors_and_boosts) + }; - if !boost_phrase_and_index.is_empty() { - let boost_vectors = vectors - .split_off(messages.len()).to_vec(); + vectors_resp + }) + .collect(); - let mut vectors_sorted = vectors.clone(); - for ((og_index, phrase), boost_vector) in boost_phrase_and_index.iter().zip(boost_vectors) { - vectors_sorted[*og_index] = vectors_sorted[*og_index] - .iter() - .zip(boost_vector) - .map(|(vector_elem, boost_vec_elem)| vector_elem + phrase.distance_factor * boost_vec_elem) - .collect(); - } + let vec_content_futures: Vec<_> = thirty_content_groups + .map(|messages| { + let clipped_messages = messages + .iter() + .map(|message| { + if message.len() > 5000 { + message.chars().take(12000).collect() + } else { + message.clone() + } + }) + .collect::>(); - return Ok((i, vectors_sorted)); - } + let input = match embed_type { + "doc" => EmbeddingInput::StringArray(clipped_messages), + "query" => EmbeddingInput::String( + format!( + "{}{}", + dataset_config.EMBEDDING_QUERY_PREFIX, &clipped_messages[0] + ) + .to_string(), + ), + _ => EmbeddingInput::StringArray(clipped_messages), + }; - Ok((i, vectors)) + let parameters = EmbeddingParameters { + model: dataset_config.EMBEDDING_MODEL_NAME.to_string(), + input, + truncate: true, }; - vectors_resp - }) + let cur_client = reqwest_client.clone(); + let url = embedding_base_url.clone(); + + let embedding_api_key = embedding_api_key.clone(); + + let vectors_resp = async move { + let embeddings_resp = cur_client + .post(&format!("{}/embeddings?api-version=2023-05-15", url)) + .header("Authorization", &format!("Bearer {}", &embedding_api_key.clone())) + .header("api-key", &embedding_api_key.clone()) + .header("Content-Type", "application/json") + .json(¶meters) + .send() + .await + .map_err(|_| { + ServiceError::BadRequest("Failed to send message to embedding server".to_string()) + })? + .text() + .await + .map_err(|_| { + ServiceError::BadRequest("Failed to get text from embeddings".to_string()) + })?; + + let embeddings: EmbeddingResponse = format_response(embeddings_resp.clone()) + .map_err(move |_e| { + log::error!("Failed to format response from embeddings server {:?}", embeddings_resp); + ServiceError::InternalServerError( + format!("Failed to format response from embeddings server {:?}", embeddings_resp) + ) + })?; + + let vectors: Vec> = embeddings + .data + .into_iter() + .map(|x| match x.embedding { + EmbeddingOutput::Float(v) => v.iter().map(|x| *x as f32).collect(), + EmbeddingOutput::Base64(_) => { + log::error!("Embedding server responded with Base64 and that is not currently supported for embeddings"); + vec![] + } + }) + .collect(); + + if vectors.iter().any(|x| x.is_empty()) { + return Err(ServiceError::InternalServerError( + "Embedding server responded with Base64 and that is not currently supported for embeddings".to_owned(), + )); + } + Ok(vectors) + }; + + vectors_resp + + }) .collect(); - let all_chunk_vectors: Vec<(usize, Vec>)> = futures::future::join_all(vec_futures) + let mut content_vectors: Vec<_> = futures::future::join_all(vec_content_futures) .await .into_iter() - .collect::>)>, ServiceError>>()?; + .collect::, ServiceError>>()? + .into_iter() + .flatten() + .collect(); - let mut vectors_sorted = vec![]; - for index in 0..all_chunk_vectors.len() { - let (_, vectors_i) = all_chunk_vectors.iter().find(|(i, _)| *i == index).ok_or( - ServiceError::InternalServerError( - "Failed to get index i (this should never happen)".to_string(), - ), - )?; + let distance_vectors: Vec<_> = futures::future::join_all(vec_distance_futures) + .await + .into_iter() + .collect::, ServiceError>>()? + .into_iter() + .flatten() + .collect(); - vectors_sorted.extend(vectors_i.clone()); + if !distance_vectors.is_empty() { + content_vectors = content_vectors + .into_iter() + .enumerate() + .map(|(i, message)| { + let distance_vector = distance_vectors + .iter() + .find(|(_, (og_index, _))| *og_index == i); + match distance_vector { + Some((distance_vec, (_, distance_phrase))) => { + let distance_factor = distance_phrase.distance_factor; + message + .iter() + .zip(distance_vec) + .map(|(vec_elem, distance_elem)| { + vec_elem + distance_factor * distance_elem + }) + .collect() + } + None => message, + } + }) + .collect(); } transaction.finish(); - Ok(vectors_sorted) + Ok(content_vectors) } #[derive(Debug, Serialize, Deserialize)] @@ -492,31 +589,31 @@ pub struct CustomSparseEmbedData { #[tracing::instrument] pub async fn get_sparse_vectors( - messages: Vec<(String, Option)>, + content_and_boosts: Vec<(String, Option)>, embed_type: &str, reqwest_client: reqwest::Client, ) -> Result>, ServiceError> { - if messages.is_empty() { + if content_and_boosts.is_empty() { return Err(ServiceError::BadRequest( "No messages to encode".to_string(), )); } - let contents = messages + let contents = content_and_boosts .clone() .into_iter() .map(|(x, _)| x) .collect::>(); let thirty_content_groups = contents.chunks(30); - let filtered_boosts_with_index = messages + let filtered_boosts_with_index = content_and_boosts .into_iter() .enumerate() .filter_map(|(i, (_, y))| y.map(|boost_phrase| (i, boost_phrase))) .collect::>(); - let filtered_boosts_with_index_groups = filtered_boosts_with_index.chunks(30); + let thirty_filtered_boosts_with_indices = filtered_boosts_with_index.chunks(30); - let vec_boost_futures: Vec<_> = filtered_boosts_with_index_groups + let vec_boost_futures: Vec<_> = thirty_filtered_boosts_with_indices .enumerate() .map(|(i, thirty_boosts)| { let cur_client = reqwest_client.clone(); diff --git a/server/src/operators/search_operator.rs b/server/src/operators/search_operator.rs index d1d5c6658..dbcb2a01b 100644 --- a/server/src/operators/search_operator.rs +++ b/server/src/operators/search_operator.rs @@ -7,7 +7,7 @@ use super::group_operator::{ get_group_ids_from_tracking_ids_query, get_groups_from_group_ids_query, }; use super::model_operator::{ - create_embedding, cross_encoder, get_bm25_embeddings, get_sparse_vector, + cross_encoder, get_bm25_embeddings, get_dense_vector, get_sparse_vector, }; use super::qdrant_operator::{ count_qdrant_query, search_over_groups_query, GroupSearchResults, QdrantSearchQuery, VectorType, @@ -1451,7 +1451,7 @@ async fn get_qdrant_vector( )); } let embedding_vector = - create_embedding(data.query.clone(), None, "query", config.clone()).await?; + get_dense_vector(data.query.clone(), None, "query", config.clone()).await?; Ok(VectorType::Dense(embedding_vector)) } SearchMethod::BM25 => { @@ -1615,7 +1615,7 @@ pub async fn search_hybrid_chunks( let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration.clone()); let dense_vector_future = - create_embedding(data.query.clone(), None, "query", dataset_config.clone()); + get_dense_vector(data.query.clone(), None, "query", dataset_config.clone()); let sparse_vector_future = get_sparse_vector(parsed_query.query.clone(), "query"); @@ -1829,7 +1829,7 @@ pub async fn search_hybrid_groups( let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration.clone()); let dense_vector_future = - create_embedding(data.query.clone(), None, "query", dataset_config.clone()); + get_dense_vector(data.query.clone(), None, "query", dataset_config.clone()); let sparse_vector_future = get_sparse_vector(parsed_query.query.clone(), "query"); @@ -1981,7 +1981,7 @@ pub async fn semantic_search_over_groups( timer.add("start to create dense embedding vector"); let embedding_vector = - create_embedding(data.query.clone(), None, "query", dataset_config.clone()).await?; + get_dense_vector(data.query.clone(), None, "query", dataset_config.clone()).await?; timer.add("computed dense embedding"); @@ -2162,7 +2162,7 @@ pub async fn hybrid_search_over_groups( timer.add("start to create dense embedding vector and sparse vector"); let dense_embedding_vectors_future = - create_embedding(data.query.clone(), None, "query", dataset_config.clone()); + get_dense_vector(data.query.clone(), None, "query", dataset_config.clone()); let sparse_embedding_vector_future = get_sparse_vector(data.query.clone(), "query");