Skip to content

Commit

Permalink
feature: get random suggested queries from the entire dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
skeptrunedev authored and cdxker committed Aug 25, 2024
1 parent 6bdc134 commit 73e204d
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 46 deletions.
4 changes: 2 additions & 2 deletions server/src/handlers/chunk_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ pub struct ChunkFilter {
"score_threshold": 0.5
}))]
pub struct SearchChunksReqPayload {
/// Can be either "semantic", "fulltext", or "hybrid". If specified as "hybrid", it will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE.
/// Can be either "semantic", "fulltext", "hybrid, or "bm25". If specified as "hybrid", it will pull in one page of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page of the nearest cosine distant vectors. "fulltext" will pull in one page of full-text results based on SPLADE. "bm25" will get one page of results scored using BM25 with the terms OR'd together.
pub search_type: SearchMethod,
/// Query is the search query. This can be any string. The query will be used to create an embedding vector and/or SPLADE vector which will be used to find the result set. You can either provide one query, or multiple with weights. Multi-query only works with Semantic Search.
pub query: QueryTypes,
Expand All @@ -951,7 +951,7 @@ pub struct SearchChunksReqPayload {
/// Set score_threshold to a float to filter out chunks with a score below the threshold for cosine distance metric
/// For Manhattan Distance, Euclidean Distance, and Dot Product, it will filter out scores above the threshold distance
/// This threshold applies before weight and bias modifications. If not specified, this defaults to no threshold
/// A threshold of 0 will default to no threashold
/// A threshold of 0 will default to no threshold
pub score_threshold: Option<f32>,
/// Set slim_chunks to true to avoid returning the content and chunk_html of the chunks. This is useful for when you want to reduce amount of data over the wire for latency improvement (typically 10-50ms). Default is false.
pub slim_chunks: Option<bool>,
Expand Down
126 changes: 90 additions & 36 deletions server/src/handlers/message_handler.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
use super::{
auth_handler::{AdminOnly, LoggedUser},
chunk_handler::{ChunkFilter, ParsedQuery, SearchChunksReqPayload},
chunk_handler::{ChunkFilter, ParsedQuery, ParsedQueryTypes, SearchChunksReqPayload},
};
use crate::{
data::models::{
self, ChunkMetadataTypes, DatasetAndOrgWithSubAndPlan, DatasetConfiguration,
HighlightOptions, LLMOptions, Pool, SearchMethod,
self, ChunkMetadata, DatasetAndOrgWithSubAndPlan, DatasetConfiguration, HighlightOptions,
LLMOptions, Pool, SearchMethod,
},
errors::ServiceError,
get_env,
operators::{
chunk_operator::{get_chunk_metadatas_from_point_ids, get_random_chunk_metadatas_query},
clickhouse_operator::EventQueue,
message_operator::{
create_topic_message_query, delete_message_query, get_message_by_sort_for_topic_query,
get_messages_for_topic_query, get_topic_messages, stream_response,
},
organization_operator::get_message_org_count,
parse_operator::convert_html_to_text,
search_operator::search_hybrid_chunks,
qdrant_operator::scroll_dataset_points,
search_operator::{assemble_qdrant_filter, search_chunks_query, search_hybrid_chunks},
},
};
use actix_web::{web, HttpResponse};
Expand Down Expand Up @@ -599,7 +601,11 @@ pub async fn regenerate_message(
#[derive(Deserialize, Serialize, Debug, ToSchema)]
pub struct SuggestedQueriesReqPayload {
/// The query to base the generated suggested queries off of using RAG. A hybrid search for 10 chunks from your dataset using this query will be performed and the context of the chunks will be used to generate the suggested queries.
pub query: String,
pub query: Option<String>,
/// Can be either "semantic", "fulltext", "hybrid, or "bm25". If specified as "hybrid", it will pull in one page of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page of the nearest cosine distant vectors. "fulltext" will pull in one page of full-text results based on SPLADE. "bm25" will get one page of results scored using BM25 with the terms OR'd together.
pub search_type: Option<SearchMethod>,
/// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata.
pub filters: Option<ChunkFilter>,
}

#[derive(Deserialize, Serialize, Debug, ToSchema)]
Expand Down Expand Up @@ -635,6 +641,7 @@ pub async fn get_suggested_queries(
pool: web::Data<Pool>,
_required_user: LoggedUser,
) -> Result<HttpResponse, ServiceError> {
let dataset_id = dataset_org_plan_sub.dataset.id;
let dataset_config =
DatasetConfiguration::from_json(dataset_org_plan_sub.dataset.clone().server_configuration);

Expand All @@ -659,36 +666,87 @@ pub async fn get_suggested_queries(
.into()
};

let chunk_metadatas = search_hybrid_chunks(
SearchChunksReqPayload {
search_type: SearchMethod::Hybrid,
query: models::QueryTypes::Single(data.query.clone()),
page_size: Some(10),
..Default::default()
},
ParsedQuery {
query: data.query.clone(),
quote_words: None,
negated_words: None,
},
pool,
dataset_org_plan_sub.dataset.clone(),
&dataset_config,
&mut Timer::new(),
)
.await
.map_err(|err| ServiceError::BadRequest(err.to_string()))?
.score_chunks;
let search_type = data.search_type.clone().unwrap_or(SearchMethod::Hybrid);
let filters = data.filters.clone();

let chunk_metadatas = match data.query.clone() {
Some(query) => {
let search_req_payload = SearchChunksReqPayload {
search_type: search_type.clone(),
query: models::QueryTypes::Single(query.clone()),
page_size: Some(10),
filters,
..Default::default()
};
let parsed_query = ParsedQuery {
query,
quote_words: None,
negated_words: None,
};
match search_type {
SearchMethod::Hybrid => search_hybrid_chunks(
search_req_payload,
parsed_query,
pool,
dataset_org_plan_sub.dataset.clone(),
&dataset_config,
&mut Timer::new(),
)
.await
.map_err(|err| ServiceError::BadRequest(err.to_string()))?,
_ => search_chunks_query(
search_req_payload,
ParsedQueryTypes::Single(parsed_query),
pool,
dataset_org_plan_sub.dataset.clone(),
&dataset_config,
&mut Timer::new(),
)
.await
.map_err(|err| ServiceError::BadRequest(err.to_string()))?,
}
.score_chunks
.into_iter()
.filter_map(|chunk| chunk.metadata.clone().first().cloned())
.map(ChunkMetadata::from)
.collect::<Vec<ChunkMetadata>>()
}
None => {
let random_chunk = get_random_chunk_metadatas_query(dataset_id, 1, pool.clone())
.await?
.clone()
.first()
.cloned();
match random_chunk {
Some(chunk) => {
let filter =
assemble_qdrant_filter(filters, None, None, dataset_id, pool.clone())
.await?;

let qdrant_point_ids = scroll_dataset_points(
10,
Some(chunk.qdrant_point_id),
None,
dataset_config,
filter,
)
.await?;

get_chunk_metadatas_from_point_ids(qdrant_point_ids.clone(), pool)
.await?
.into_iter()
.map(ChunkMetadata::from)
.collect()
}
None => vec![],
}
}
};

let rag_content = chunk_metadatas
.iter()
.enumerate()
.map(|(idx, chunk)| {
let chunk = match chunk.metadata.first().unwrap() {
ChunkMetadataTypes::Metadata(chunk_metadata) => chunk_metadata,
_ => unreachable!("The operator should never return slim chunks for this"),
};

format!(
"Doc {}: {}",
idx + 1,
Expand Down Expand Up @@ -776,11 +834,11 @@ pub async fn get_suggested_queries(
.chat()
.create(parameters.clone())
.await
.expect("No OpenAI Completion for topic");
.expect("No LLM Completion for topic");
queries = match &query
.choices
.first()
.expect("No response for OpenAI completion")
.expect("No response for LLM completion")
.message
.content
{
Expand All @@ -795,10 +853,6 @@ pub async fn get_suggested_queries(
let mut engine: SimSearch<String> = SimSearch::new();

chunk_metadatas.iter().for_each(|chunk| {
let chunk = match chunk.metadata.first().unwrap() {
ChunkMetadataTypes::Metadata(chunk_metadata) => chunk_metadata,
_ => unreachable!("The operator should never return slim chunks for this"),
};
let content = convert_html_to_text(&chunk.chunk_html.clone().unwrap_or_default());

engine.insert(content.clone(), &content);
Expand Down
7 changes: 3 additions & 4 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -914,8 +914,7 @@ pub fn main() -> std::io::Result<()> {
.service(
web::resource("/search")
.route(web::post().to(handlers::group_handler::search_within_group))
.wrap(Compress::default())
,
.wrap(Compress::default()),
)
.service(
web::resource("/group_oriented_search").route(
Expand All @@ -926,8 +925,8 @@ pub fn main() -> std::io::Result<()> {
.service(
web::resource("/recommend").route(
web::post().to(handlers::group_handler::get_recommended_groups),
) .wrap(Compress::default())
,
)
.wrap(Compress::default()),
)
.service(
web::resource("/chunk/{chunk_group_id}")
Expand Down
70 changes: 70 additions & 0 deletions server/src/operators/chunk_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,76 @@ pub async fn get_content_chunk_from_point_ids_query(
Ok(content_chunks)
}

#[tracing::instrument(skip(pool))]
pub async fn get_random_chunk_metadatas_query(
dataset_id: uuid::Uuid,
limit: u64,
pool: web::Data<Pool>,
) -> Result<Vec<ChunkMetadata>, ServiceError> {
use crate::data::schema::chunk_metadata::dsl as chunk_metadata_columns;
let mut random_uuid = uuid::Uuid::new_v4();

let mut conn = pool
.get()
.await
.expect("Failed to get connection from pool");

let get_lowest_id_future = chunk_metadata_columns::chunk_metadata
.filter(chunk_metadata_columns::dataset_id.eq(dataset_id))
.select(chunk_metadata_columns::id)
.order_by(chunk_metadata_columns::id.desc())
.first::<uuid::Uuid>(&mut conn);
let get_highest_ids_future = chunk_metadata_columns::chunk_metadata
.filter(chunk_metadata_columns::dataset_id.eq(dataset_id))
.select(chunk_metadata_columns::id)
.order_by(chunk_metadata_columns::id.desc())
.limit(limit.try_into().unwrap_or(10))
.load::<uuid::Uuid>(&mut conn);
let (lowest_id, highest_ids) = futures::join!(get_lowest_id_future, get_highest_ids_future);
let lowest_id: uuid::Uuid = lowest_id.map_err(|err| {
ServiceError::BadRequest(format!(
"Failed to load chunk with the lowest id in the dataset for random range: {:?}",
err
))
})?;
let highest_ids: Vec<uuid::Uuid> = highest_ids.map_err(|err| {
ServiceError::BadRequest(format!(
"Failed to load chunks with the highest id in the dataset for random range: {:?}",
err
))
})?;
let highest_id = match highest_ids.get(0) {
Some(id) => *id,
None => {
return Err(ServiceError::NotFound(
"Chunk with the highest id in the dataset not found for random range".to_string(),
))
}
};
while (random_uuid < lowest_id) || (random_uuid > highest_id) {
random_uuid = uuid::Uuid::new_v4();
}

let chunk_metadatas: Vec<ChunkMetadataTable> = chunk_metadata_columns::chunk_metadata
.filter(chunk_metadata_columns::dataset_id.eq(dataset_id))
.filter(chunk_metadata_columns::id.gt(random_uuid))
.order_by(chunk_metadata_columns::id.desc())
.limit(limit.try_into().unwrap_or(10))
.select(ChunkMetadataTable::as_select())
.load::<ChunkMetadataTable>(&mut conn)
.await
.map_err(|_| ServiceError::BadRequest("Failed to load metadata".to_string()))?;

let chunk_metadatas = chunk_metadatas
.into_iter()
.map(|chunk_metadata_table| {
ChunkMetadata::from_table_and_tag_set(chunk_metadata_table, vec![])
})
.collect();

Ok(chunk_metadatas)
}

#[tracing::instrument(skip(pool))]
pub async fn get_metadata_from_id_query(
chunk_id: uuid::Uuid,
Expand Down
8 changes: 4 additions & 4 deletions server/src/operators/message_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,12 @@ pub async fn stream_response(
.chat()
.create(gen_inference_parameters)
.await
.expect("No OpenAI Completion for chunk search");
.expect("No LLM Completion for chunk search");

query = match &search_query_from_message_to_query_prompt
.choices
.get(0)
.expect("No response for OpenAI completion")
.expect("No response for LLM completion")
.message
.content
{
Expand Down Expand Up @@ -762,13 +762,13 @@ pub async fn get_topic_string(
.chat()
.create(parameters)
.await
.map_err(|_| ServiceError::BadRequest("No OpenAI Completion for topic".to_string()))?;
.map_err(|_| ServiceError::BadRequest("No LLM Completion for topic".to_string()))?;

let topic = match &query
.choices
.get(0)
.ok_or(ServiceError::BadRequest(
"No response for OpenAI completion".to_string(),
"No response for LLM completion".to_string(),
))?
.message
.content
Expand Down

0 comments on commit 73e204d

Please sign in to comment.