diff --git a/server/src/handlers/chunk_handler.rs b/server/src/handlers/chunk_handler.rs index b588bb460..1830e3b9d 100644 --- a/server/src/handlers/chunk_handler.rs +++ b/server/src/handlers/chunk_handler.rs @@ -932,7 +932,7 @@ pub struct ChunkFilter { "score_threshold": 0.5 }))] pub struct SearchChunksReqPayload { - /// Can be either "semantic", "fulltext", or "hybrid". If specified as "hybrid", it will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. + /// Can be either "semantic", "fulltext", "hybrid, or "bm25". If specified as "hybrid", it will pull in one page of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page of the nearest cosine distant vectors. "fulltext" will pull in one page of full-text results based on SPLADE. "bm25" will get one page of results scored using BM25 with the terms OR'd together. pub search_type: SearchMethod, /// Query is the search query. This can be any string. The query will be used to create an embedding vector and/or SPLADE vector which will be used to find the result set. You can either provide one query, or multiple with weights. Multi-query only works with Semantic Search. pub query: QueryTypes, @@ -951,7 +951,7 @@ pub struct SearchChunksReqPayload { /// Set score_threshold to a float to filter out chunks with a score below the threshold for cosine distance metric /// For Manhattan Distance, Euclidean Distance, and Dot Product, it will filter out scores above the threshold distance /// This threshold applies before weight and bias modifications. If not specified, this defaults to no threshold - /// A threshold of 0 will default to no threashold + /// A threshold of 0 will default to no threshold pub score_threshold: Option, /// Set slim_chunks to true to avoid returning the content and chunk_html of the chunks. This is useful for when you want to reduce amount of data over the wire for latency improvement (typically 10-50ms). Default is false. pub slim_chunks: Option, diff --git a/server/src/handlers/message_handler.rs b/server/src/handlers/message_handler.rs index a13625805..37b1acc73 100644 --- a/server/src/handlers/message_handler.rs +++ b/server/src/handlers/message_handler.rs @@ -1,15 +1,16 @@ use super::{ auth_handler::{AdminOnly, LoggedUser}, - chunk_handler::{ChunkFilter, ParsedQuery, SearchChunksReqPayload}, + chunk_handler::{ChunkFilter, ParsedQuery, ParsedQueryTypes, SearchChunksReqPayload}, }; use crate::{ data::models::{ - self, ChunkMetadataTypes, DatasetAndOrgWithSubAndPlan, DatasetConfiguration, - HighlightOptions, LLMOptions, Pool, SearchMethod, + self, ChunkMetadata, DatasetAndOrgWithSubAndPlan, DatasetConfiguration, HighlightOptions, + LLMOptions, Pool, SearchMethod, }, errors::ServiceError, get_env, operators::{ + chunk_operator::{get_chunk_metadatas_from_point_ids, get_random_chunk_metadatas_query}, clickhouse_operator::EventQueue, message_operator::{ create_topic_message_query, delete_message_query, get_message_by_sort_for_topic_query, @@ -17,7 +18,8 @@ use crate::{ }, organization_operator::get_message_org_count, parse_operator::convert_html_to_text, - search_operator::search_hybrid_chunks, + qdrant_operator::scroll_dataset_points, + search_operator::{assemble_qdrant_filter, search_chunks_query, search_hybrid_chunks}, }, }; use actix_web::{web, HttpResponse}; @@ -599,7 +601,11 @@ pub async fn regenerate_message( #[derive(Deserialize, Serialize, Debug, ToSchema)] pub struct SuggestedQueriesReqPayload { /// The query to base the generated suggested queries off of using RAG. A hybrid search for 10 chunks from your dataset using this query will be performed and the context of the chunks will be used to generate the suggested queries. - pub query: String, + pub query: Option, + /// Can be either "semantic", "fulltext", "hybrid, or "bm25". If specified as "hybrid", it will pull in one page of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page of the nearest cosine distant vectors. "fulltext" will pull in one page of full-text results based on SPLADE. "bm25" will get one page of results scored using BM25 with the terms OR'd together. + pub search_type: Option, + /// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata. + pub filters: Option, } #[derive(Deserialize, Serialize, Debug, ToSchema)] @@ -635,6 +641,7 @@ pub async fn get_suggested_queries( pool: web::Data, _required_user: LoggedUser, ) -> Result { + let dataset_id = dataset_org_plan_sub.dataset.id; let dataset_config = DatasetConfiguration::from_json(dataset_org_plan_sub.dataset.clone().server_configuration); @@ -659,36 +666,87 @@ pub async fn get_suggested_queries( .into() }; - let chunk_metadatas = search_hybrid_chunks( - SearchChunksReqPayload { - search_type: SearchMethod::Hybrid, - query: models::QueryTypes::Single(data.query.clone()), - page_size: Some(10), - ..Default::default() - }, - ParsedQuery { - query: data.query.clone(), - quote_words: None, - negated_words: None, - }, - pool, - dataset_org_plan_sub.dataset.clone(), - &dataset_config, - &mut Timer::new(), - ) - .await - .map_err(|err| ServiceError::BadRequest(err.to_string()))? - .score_chunks; + let search_type = data.search_type.clone().unwrap_or(SearchMethod::Hybrid); + let filters = data.filters.clone(); + + let chunk_metadatas = match data.query.clone() { + Some(query) => { + let search_req_payload = SearchChunksReqPayload { + search_type: search_type.clone(), + query: models::QueryTypes::Single(query.clone()), + page_size: Some(10), + filters, + ..Default::default() + }; + let parsed_query = ParsedQuery { + query, + quote_words: None, + negated_words: None, + }; + match search_type { + SearchMethod::Hybrid => search_hybrid_chunks( + search_req_payload, + parsed_query, + pool, + dataset_org_plan_sub.dataset.clone(), + &dataset_config, + &mut Timer::new(), + ) + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?, + _ => search_chunks_query( + search_req_payload, + ParsedQueryTypes::Single(parsed_query), + pool, + dataset_org_plan_sub.dataset.clone(), + &dataset_config, + &mut Timer::new(), + ) + .await + .map_err(|err| ServiceError::BadRequest(err.to_string()))?, + } + .score_chunks + .into_iter() + .filter_map(|chunk| chunk.metadata.clone().first().cloned()) + .map(ChunkMetadata::from) + .collect::>() + } + None => { + let random_chunk = get_random_chunk_metadatas_query(dataset_id, 1, pool.clone()) + .await? + .clone() + .first() + .cloned(); + match random_chunk { + Some(chunk) => { + let filter = + assemble_qdrant_filter(filters, None, None, dataset_id, pool.clone()) + .await?; + + let qdrant_point_ids = scroll_dataset_points( + 10, + Some(chunk.qdrant_point_id), + None, + dataset_config, + filter, + ) + .await?; + + get_chunk_metadatas_from_point_ids(qdrant_point_ids.clone(), pool) + .await? + .into_iter() + .map(ChunkMetadata::from) + .collect() + } + None => vec![], + } + } + }; let rag_content = chunk_metadatas .iter() .enumerate() .map(|(idx, chunk)| { - let chunk = match chunk.metadata.first().unwrap() { - ChunkMetadataTypes::Metadata(chunk_metadata) => chunk_metadata, - _ => unreachable!("The operator should never return slim chunks for this"), - }; - format!( "Doc {}: {}", idx + 1, @@ -776,11 +834,11 @@ pub async fn get_suggested_queries( .chat() .create(parameters.clone()) .await - .expect("No OpenAI Completion for topic"); + .expect("No LLM Completion for topic"); queries = match &query .choices .first() - .expect("No response for OpenAI completion") + .expect("No response for LLM completion") .message .content { @@ -795,10 +853,6 @@ pub async fn get_suggested_queries( let mut engine: SimSearch = SimSearch::new(); chunk_metadatas.iter().for_each(|chunk| { - let chunk = match chunk.metadata.first().unwrap() { - ChunkMetadataTypes::Metadata(chunk_metadata) => chunk_metadata, - _ => unreachable!("The operator should never return slim chunks for this"), - }; let content = convert_html_to_text(&chunk.chunk_html.clone().unwrap_or_default()); engine.insert(content.clone(), &content); diff --git a/server/src/lib.rs b/server/src/lib.rs index 7f882a6a2..766295553 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -914,8 +914,7 @@ pub fn main() -> std::io::Result<()> { .service( web::resource("/search") .route(web::post().to(handlers::group_handler::search_within_group)) - .wrap(Compress::default()) -, + .wrap(Compress::default()), ) .service( web::resource("/group_oriented_search").route( @@ -926,8 +925,8 @@ pub fn main() -> std::io::Result<()> { .service( web::resource("/recommend").route( web::post().to(handlers::group_handler::get_recommended_groups), - ) .wrap(Compress::default()) -, + ) + .wrap(Compress::default()), ) .service( web::resource("/chunk/{chunk_group_id}") diff --git a/server/src/operators/chunk_operator.rs b/server/src/operators/chunk_operator.rs index 3be0ec802..ecdc8aa3d 100644 --- a/server/src/operators/chunk_operator.rs +++ b/server/src/operators/chunk_operator.rs @@ -338,6 +338,76 @@ pub async fn get_content_chunk_from_point_ids_query( Ok(content_chunks) } +#[tracing::instrument(skip(pool))] +pub async fn get_random_chunk_metadatas_query( + dataset_id: uuid::Uuid, + limit: u64, + pool: web::Data, +) -> Result, ServiceError> { + use crate::data::schema::chunk_metadata::dsl as chunk_metadata_columns; + let mut random_uuid = uuid::Uuid::new_v4(); + + let mut conn = pool + .get() + .await + .expect("Failed to get connection from pool"); + + let get_lowest_id_future = chunk_metadata_columns::chunk_metadata + .filter(chunk_metadata_columns::dataset_id.eq(dataset_id)) + .select(chunk_metadata_columns::id) + .order_by(chunk_metadata_columns::id.desc()) + .first::(&mut conn); + let get_highest_ids_future = chunk_metadata_columns::chunk_metadata + .filter(chunk_metadata_columns::dataset_id.eq(dataset_id)) + .select(chunk_metadata_columns::id) + .order_by(chunk_metadata_columns::id.desc()) + .limit(limit.try_into().unwrap_or(10)) + .load::(&mut conn); + let (lowest_id, highest_ids) = futures::join!(get_lowest_id_future, get_highest_ids_future); + let lowest_id: uuid::Uuid = lowest_id.map_err(|err| { + ServiceError::BadRequest(format!( + "Failed to load chunk with the lowest id in the dataset for random range: {:?}", + err + )) + })?; + let highest_ids: Vec = highest_ids.map_err(|err| { + ServiceError::BadRequest(format!( + "Failed to load chunks with the highest id in the dataset for random range: {:?}", + err + )) + })?; + let highest_id = match highest_ids.get(0) { + Some(id) => *id, + None => { + return Err(ServiceError::NotFound( + "Chunk with the highest id in the dataset not found for random range".to_string(), + )) + } + }; + while (random_uuid < lowest_id) || (random_uuid > highest_id) { + random_uuid = uuid::Uuid::new_v4(); + } + + let chunk_metadatas: Vec = chunk_metadata_columns::chunk_metadata + .filter(chunk_metadata_columns::dataset_id.eq(dataset_id)) + .filter(chunk_metadata_columns::id.gt(random_uuid)) + .order_by(chunk_metadata_columns::id.desc()) + .limit(limit.try_into().unwrap_or(10)) + .select(ChunkMetadataTable::as_select()) + .load::(&mut conn) + .await + .map_err(|_| ServiceError::BadRequest("Failed to load metadata".to_string()))?; + + let chunk_metadatas = chunk_metadatas + .into_iter() + .map(|chunk_metadata_table| { + ChunkMetadata::from_table_and_tag_set(chunk_metadata_table, vec![]) + }) + .collect(); + + Ok(chunk_metadatas) +} + #[tracing::instrument(skip(pool))] pub async fn get_metadata_from_id_query( chunk_id: uuid::Uuid, diff --git a/server/src/operators/message_operator.rs b/server/src/operators/message_operator.rs index 4cf862475..af0b13df2 100644 --- a/server/src/operators/message_operator.rs +++ b/server/src/operators/message_operator.rs @@ -330,12 +330,12 @@ pub async fn stream_response( .chat() .create(gen_inference_parameters) .await - .expect("No OpenAI Completion for chunk search"); + .expect("No LLM Completion for chunk search"); query = match &search_query_from_message_to_query_prompt .choices .get(0) - .expect("No response for OpenAI completion") + .expect("No response for LLM completion") .message .content { @@ -762,13 +762,13 @@ pub async fn get_topic_string( .chat() .create(parameters) .await - .map_err(|_| ServiceError::BadRequest("No OpenAI Completion for topic".to_string()))?; + .map_err(|_| ServiceError::BadRequest("No LLM Completion for topic".to_string()))?; let topic = match &query .choices .get(0) .ok_or(ServiceError::BadRequest( - "No response for OpenAI completion".to_string(), + "No response for LLM completion".to_string(), ))? .message .content