Skip to content

Commit

Permalink
fix: create embeddings not batching requests correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
aaryanpunia authored and skeptrunedev committed Jul 24, 2024
1 parent f62a23a commit 89ca2a4
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 84 deletions.
10 changes: 5 additions & 5 deletions server/src/bin/ingestion-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use trieve_server::operators::dataset_operator::get_dataset_by_id_query;
use trieve_server::operators::event_operator::create_event_query;
use trieve_server::operators::group_operator::get_groups_from_group_ids_query;
use trieve_server::operators::model_operator::{
create_embedding, create_embeddings, get_bm25_embeddings, get_sparse_vectors,
get_bm25_embeddings, get_dense_vector, get_dense_vectors, get_sparse_vectors,
};
use trieve_server::operators::parse_operator::{
average_embeddings, coarse_doc_chunker, convert_html_to_text,
Expand Down Expand Up @@ -557,7 +557,7 @@ pub async fn bulk_upload_chunks(

let embedding_vectors = match dataset_config.SEMANTIC_ENABLED {
true => {
let vectors = match create_embeddings(
let vectors = match get_dense_vectors(
content_and_boosts
.iter()
.map(|(content, _, distance_boost)| (content.clone(), distance_boost.clone()))
Expand Down Expand Up @@ -845,7 +845,7 @@ async fn upload_chunk(
true => {
let chunks = coarse_doc_chunker(content.clone(), None, false, 20);

let embeddings = create_embeddings(
let embeddings = get_dense_vectors(
chunks
.iter()
.map(|chunk| (chunk.clone(), payload.chunk.distance_phrase.clone()))
Expand All @@ -859,7 +859,7 @@ async fn upload_chunk(
average_embeddings(embeddings)?
}
false => {
let embedding_vectors = create_embeddings(
let embedding_vectors = get_dense_vectors(
vec![(content.clone(), payload.chunk.distance_phrase.clone())],
"doc",
dataset_config.clone(),
Expand Down Expand Up @@ -1062,7 +1062,7 @@ async fn update_chunk(

let embedding_vector = match dataset_config.SEMANTIC_ENABLED {
true => {
let embedding = create_embedding(
let embedding = get_dense_vector(
content.to_string(),
payload.distance_phrase,
"doc",
Expand Down
243 changes: 170 additions & 73 deletions server/src/operators/model_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub struct EmbeddingParameters {
}

#[tracing::instrument]
pub async fn create_embedding(
pub async fn get_dense_vector(
message: String,
distance_phrase: Option<DistancePhrase>,
embed_type: &str,
Expand All @@ -35,11 +35,11 @@ pub async fn create_embedding(
let parent_span = sentry::configure_scope(|scope| scope.get_span());
let transaction: sentry::TransactionOrSpan = match &parent_span {
Some(parent) => parent
.start_child("create_embedding", "Create semantic dense embedding")
.start_child("get_dense_vector", "Create semantic dense embedding")
.into(),
None => {
let ctx = sentry::TransactionContext::new(
"create_embedding",
"get_dense_vector",
"Create semantic dense embedding",
);
sentry::start_transaction(ctx).into()
Expand Down Expand Up @@ -266,20 +266,20 @@ pub async fn get_sparse_vector(
}

#[tracing::instrument]
pub async fn create_embeddings(
content_and_boosts: Vec<(String, Option<DistancePhrase>)>,
pub async fn get_dense_vectors(
content_and_distances: Vec<(String, Option<DistancePhrase>)>,
embed_type: &str,
dataset_config: DatasetConfiguration,
reqwest_client: reqwest::Client,
) -> Result<Vec<Vec<f32>>, ServiceError> {
let parent_span = sentry::configure_scope(|scope| scope.get_span());
let transaction: sentry::TransactionOrSpan = match &parent_span {
Some(parent) => parent
.start_child("create_embedding", "Create semantic dense embedding")
.start_child("get_dense_vector", "Create semantic dense embedding")
.into(),
None => {
let ctx = sentry::TransactionContext::new(
"create_embedding",
"get_dense_vector",
"Create semantic dense embedding",
);
sentry::start_transaction(ctx).into()
Expand Down Expand Up @@ -323,31 +323,31 @@ pub async fn create_embeddings(
embedding_api_key.to_string()
};

let thirty_message_groups = content_and_boosts.chunks(30);
let (contents, distance_phrases): (Vec<_>, Vec<_>) =
content_and_distances.clone().into_iter().unzip();
let thirty_content_groups = contents.chunks(30);

let vec_futures: Vec<_> = thirty_message_groups
let filtered_distances_with_index = distance_phrases
.clone()
.iter()
.enumerate()
.map(|(i, combined_messages)| {
let messages = combined_messages
.iter()
.map(|(x, _)| x)
.cloned()
.collect::<Vec<String>>();

let boost_phrase_and_index = combined_messages
.iter()
.enumerate()
.filter_map(|(i, (_, y))| y.clone().map(|phrase| (i, phrase)))
.collect::<Vec<(usize, DistancePhrase)>>();
.filter_map(|(index, distance_phrase)| {
distance_phrase
.clone()
.map(|distance_phrase| (index, distance_phrase))
})
.collect::<Vec<(usize, DistancePhrase)>>();
let thirty_filterted_distances_with_indices = filtered_distances_with_index.chunks(30);

let boost_phrases = combined_messages
let vec_distance_futures: Vec<_> = thirty_filterted_distances_with_indices
.map(|thirty_distances| {
let distance_phrases = thirty_distances
.iter()
.filter_map(|(_, y)| y.clone().map(|x| x.phrase.clone()))
.map(|(_, x)| x.phrase.clone())
.collect::<Vec<String>>();

let clipped_messages = messages
let clipped_messages = distance_phrases
.iter()
.chain(boost_phrases.iter())
.map(|message| {
if message.len() > 5000 {
message.chars().take(12000).collect()
Expand Down Expand Up @@ -406,65 +406,162 @@ pub async fn create_embeddings(
)
})?;

let mut vectors: Vec<Vec<f32>> = embeddings
.data
.into_iter()
.map(|x| match x.embedding {
EmbeddingOutput::Float(v) => v.iter().map(|x| *x as f32).collect(),
EmbeddingOutput::Base64(_) => {
log::error!("Embedding server responded with Base64 and that is not currently supported for embeddings");
vec![]
let vectors_and_boosts: Vec<(Vec<f32>, &(usize, DistancePhrase))> = embeddings
.data
.into_iter()
.map(|x| match x.embedding {
EmbeddingOutput::Float(v) => v.iter().map(|x| *x as f32).collect(),
EmbeddingOutput::Base64(_) => {
log::error!("Embedding server responded with Base64 and that is not currently supported for embeddings");
vec![]
}
})
.zip(thirty_distances)
.collect();

if vectors_and_boosts.iter().any(|x| x.0.is_empty()) {
return Err(ServiceError::InternalServerError(
"Embedding server responded with Base64 and that is not currently supported for embeddings".to_owned(),
));
}
})
.collect();

if vectors.iter().any(|x| x.is_empty()) {
return Err(ServiceError::InternalServerError(
"Embedding server responded with Base64 and that is not currently supported for embeddings".to_owned(),
));
}
Ok(vectors_and_boosts)
};

if !boost_phrase_and_index.is_empty() {
let boost_vectors = vectors
.split_off(messages.len()).to_vec();
vectors_resp
})
.collect();

let mut vectors_sorted = vectors.clone();
for ((og_index, phrase), boost_vector) in boost_phrase_and_index.iter().zip(boost_vectors) {
vectors_sorted[*og_index] = vectors_sorted[*og_index]
.iter()
.zip(boost_vector)
.map(|(vector_elem, boost_vec_elem)| vector_elem + phrase.distance_factor * boost_vec_elem)
.collect();
}
let vec_content_futures: Vec<_> = thirty_content_groups
.map(|messages| {
let clipped_messages = messages
.iter()
.map(|message| {
if message.len() > 5000 {
message.chars().take(12000).collect()
} else {
message.clone()
}
})
.collect::<Vec<String>>();

return Ok((i, vectors_sorted));
}
let input = match embed_type {
"doc" => EmbeddingInput::StringArray(clipped_messages),
"query" => EmbeddingInput::String(
format!(
"{}{}",
dataset_config.EMBEDDING_QUERY_PREFIX, &clipped_messages[0]
)
.to_string(),
),
_ => EmbeddingInput::StringArray(clipped_messages),
};

Ok((i, vectors))
let parameters = EmbeddingParameters {
model: dataset_config.EMBEDDING_MODEL_NAME.to_string(),
input,
truncate: true,
};

vectors_resp
})
let cur_client = reqwest_client.clone();
let url = embedding_base_url.clone();

let embedding_api_key = embedding_api_key.clone();

let vectors_resp = async move {
let embeddings_resp = cur_client
.post(&format!("{}/embeddings?api-version=2023-05-15", url))
.header("Authorization", &format!("Bearer {}", &embedding_api_key.clone()))
.header("api-key", &embedding_api_key.clone())
.header("Content-Type", "application/json")
.json(&parameters)
.send()
.await
.map_err(|_| {
ServiceError::BadRequest("Failed to send message to embedding server".to_string())
})?
.text()
.await
.map_err(|_| {
ServiceError::BadRequest("Failed to get text from embeddings".to_string())
})?;

let embeddings: EmbeddingResponse = format_response(embeddings_resp.clone())
.map_err(move |_e| {
log::error!("Failed to format response from embeddings server {:?}", embeddings_resp);
ServiceError::InternalServerError(
format!("Failed to format response from embeddings server {:?}", embeddings_resp)
)
})?;

let vectors: Vec<Vec<f32>> = embeddings
.data
.into_iter()
.map(|x| match x.embedding {
EmbeddingOutput::Float(v) => v.iter().map(|x| *x as f32).collect(),
EmbeddingOutput::Base64(_) => {
log::error!("Embedding server responded with Base64 and that is not currently supported for embeddings");
vec![]
}
})
.collect();

if vectors.iter().any(|x| x.is_empty()) {
return Err(ServiceError::InternalServerError(
"Embedding server responded with Base64 and that is not currently supported for embeddings".to_owned(),
));
}
Ok(vectors)
};

vectors_resp

})
.collect();

let all_chunk_vectors: Vec<(usize, Vec<Vec<f32>>)> = futures::future::join_all(vec_futures)
let mut content_vectors: Vec<_> = futures::future::join_all(vec_content_futures)
.await
.into_iter()
.collect::<Result<Vec<(usize, Vec<Vec<f32>>)>, ServiceError>>()?;
.collect::<Result<Vec<_>, ServiceError>>()?
.into_iter()
.flatten()
.collect();

let mut vectors_sorted = vec![];
for index in 0..all_chunk_vectors.len() {
let (_, vectors_i) = all_chunk_vectors.iter().find(|(i, _)| *i == index).ok_or(
ServiceError::InternalServerError(
"Failed to get index i (this should never happen)".to_string(),
),
)?;
let distance_vectors: Vec<_> = futures::future::join_all(vec_distance_futures)
.await
.into_iter()
.collect::<Result<Vec<_>, ServiceError>>()?
.into_iter()
.flatten()
.collect();

vectors_sorted.extend(vectors_i.clone());
if !distance_vectors.is_empty() {
content_vectors = content_vectors
.into_iter()
.enumerate()
.map(|(i, message)| {
let distance_vector = distance_vectors
.iter()
.find(|(_, (og_index, _))| *og_index == i);
match distance_vector {
Some((distance_vec, (_, distance_phrase))) => {
let distance_factor = distance_phrase.distance_factor;
message
.iter()
.zip(distance_vec)
.map(|(vec_elem, distance_elem)| {
vec_elem + distance_factor * distance_elem
})
.collect()
}
None => message,
}
})
.collect();
}

transaction.finish();
Ok(vectors_sorted)
Ok(content_vectors)
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -492,31 +589,31 @@ pub struct CustomSparseEmbedData {

#[tracing::instrument]
pub async fn get_sparse_vectors(
messages: Vec<(String, Option<BoostPhrase>)>,
content_and_boosts: Vec<(String, Option<BoostPhrase>)>,
embed_type: &str,
reqwest_client: reqwest::Client,
) -> Result<Vec<Vec<(u32, f32)>>, ServiceError> {
if messages.is_empty() {
if content_and_boosts.is_empty() {
return Err(ServiceError::BadRequest(
"No messages to encode".to_string(),
));
}

let contents = messages
let contents = content_and_boosts
.clone()
.into_iter()
.map(|(x, _)| x)
.collect::<Vec<String>>();
let thirty_content_groups = contents.chunks(30);

let filtered_boosts_with_index = messages
let filtered_boosts_with_index = content_and_boosts
.into_iter()
.enumerate()
.filter_map(|(i, (_, y))| y.map(|boost_phrase| (i, boost_phrase)))
.collect::<Vec<(usize, BoostPhrase)>>();
let filtered_boosts_with_index_groups = filtered_boosts_with_index.chunks(30);
let thirty_filtered_boosts_with_indices = filtered_boosts_with_index.chunks(30);

let vec_boost_futures: Vec<_> = filtered_boosts_with_index_groups
let vec_boost_futures: Vec<_> = thirty_filtered_boosts_with_indices
.enumerate()
.map(|(i, thirty_boosts)| {
let cur_client = reqwest_client.clone();
Expand Down
Loading

0 comments on commit 89ca2a4

Please sign in to comment.