From de0a2f4f6bbb2fca059bca01bb5a36db2bdf6e51 Mon Sep 17 00:00:00 2001 From: skeptrune Date: Tue, 12 Nov 2024 16:52:04 -0800 Subject: [PATCH] feature: add batch create dataset route --- server/src/data/models.rs | 8 ++ server/src/handlers/chunk_handler.rs | 2 +- server/src/handlers/dataset_handler.rs | 94 +++++++++++++++++++++--- server/src/lib.rs | 13 +++- server/src/operators/chunk_operator.rs | 1 + server/src/operators/dataset_operator.rs | 52 +++++++++++++ 6 files changed, 157 insertions(+), 13 deletions(-) diff --git a/server/src/data/models.rs b/server/src/data/models.rs index 02d6ad405..c7430f9d4 100644 --- a/server/src/data/models.rs +++ b/server/src/data/models.rs @@ -1892,13 +1892,21 @@ pub struct DatasetEventCount { }))] #[diesel(table_name = datasets)] pub struct Dataset { + /// Unique identifier of the dataset, auto-generated uuid created by Trieve pub id: uuid::Uuid, + /// Name of the dataset pub name: String, + /// Timestamp of the creation of the dataset pub created_at: chrono::NaiveDateTime, + /// Timestamp of the last update of the dataset pub updated_at: chrono::NaiveDateTime, + /// Unique identifier of the organization that owns the dataset pub organization_id: uuid::Uuid, + /// Configuration of the dataset for RAG, embeddings, BM25, etc. pub server_configuration: serde_json::Value, + /// Tracking ID of the dataset, can be any string, determined by the user. Tracking ID's are unique identifiers for datasets within an organization. They are designed to match the unique identifier of the dataset in the user's system. pub tracking_id: Option, + /// Flag to indicate if the dataset has been deleted. Deletes are handled async after the flag is set so as to avoid expensive search index compaction. pub deleted: i32, } diff --git a/server/src/handlers/chunk_handler.rs b/server/src/handlers/chunk_handler.rs index 50bc6cee7..6a7c495b0 100644 --- a/server/src/handlers/chunk_handler.rs +++ b/server/src/handlers/chunk_handler.rs @@ -107,7 +107,7 @@ pub struct ChunkReqPayload { pub metadata: Option, /// Tracking_id is a string which can be used to identify a chunk. This is useful for when you are coordinating with an external system and want to use the tracking_id to identify the chunk. pub tracking_id: Option, - /// Upsert when a chunk with the same tracking_id exists. By default this is false, and the request will fail if a chunk with the same tracking_id exists. If this is true, the chunk will be updated if a chunk with the same tracking_id exists. + /// Upsert when a chunk with the same tracking_id exists. By default this is false, and chunks will be ignored if another with the same tracking_id exists. If this is true, the chunk will be updated if a chunk with the same tracking_id exists. pub upsert_by_tracking_id: Option, /// Group ids are the Trieve generated ids of the groups that the chunk should be placed into. This is useful for when you want to create a chunk and add it to a group or multiple groups in one request. Groups with these Trieve generated ids must be created first, it cannot be arbitrarily created through this route. pub group_ids: Option>, diff --git a/server/src/handlers/dataset_handler.rs b/server/src/handlers/dataset_handler.rs index 4ccb91054..e66cd9d2c 100644 --- a/server/src/handlers/dataset_handler.rs +++ b/server/src/handlers/dataset_handler.rs @@ -13,9 +13,9 @@ use crate::{ validate_crawl_options, }, dataset_operator::{ - clear_dataset_by_dataset_id_query, create_dataset_query, get_dataset_by_id_query, - get_dataset_usage_query, get_datasets_by_organization_id, get_tags_in_dataset_query, - soft_delete_dataset_by_id_query, update_dataset_query, + clear_dataset_by_dataset_id_query, create_dataset_query, create_datasets_query, + get_dataset_by_id_query, get_dataset_usage_query, get_datasets_by_organization_id, + get_tags_in_dataset_query, soft_delete_dataset_by_id_query, update_dataset_query, }, dittofeed_operator::{ send_ditto_event, DittoDatasetCreated, DittoTrackProperties, DittoTrackRequest, @@ -99,7 +99,7 @@ impl FromRequest for OrganizationWithSubAndPlan { "MAX_LIMIT": 10000 } }))] -pub struct CreateDatasetRequest { +pub struct CreateDatasetReqPayload { /// Name of the dataset. pub dataset_name: String, /// Optional tracking ID for the dataset. Can be used to track the dataset in external systems. Must be unique within the organization. Strongly recommended to not use a valid uuid value as that will not work with the TR-Dataset header. @@ -112,13 +112,13 @@ pub struct CreateDatasetRequest { /// Create Dataset /// -/// Auth'ed user must be an owner of the organization to create a dataset. +/// Dataset will be created in the org specified via the TR-Organization header. Auth'ed user must be an owner of the organization to create a dataset. #[utoipa::path( post, path = "/dataset", context_path = "/api", tag = "Dataset", - request_body(content = CreateDatasetRequest, description = "JSON request payload to create a new dataset", content_type = "application/json"), + request_body(content = CreateDatasetReqPayload, description = "JSON request payload to create a new dataset", content_type = "application/json"), responses( (status = 200, description = "Dataset created successfully", body = Dataset), (status = 400, description = "Service error relating to creating the dataset", body = ErrorResponseBody), @@ -132,7 +132,7 @@ pub struct CreateDatasetRequest { )] #[tracing::instrument(skip(pool))] pub async fn create_dataset( - data: web::Json, + data: web::Json, pool: web::Data, redis_pool: web::Data, org_with_sub_and_plan: OrganizationWithSubAndPlan, @@ -234,7 +234,7 @@ pub async fn create_dataset( "MAX_LIMIT": 10000 } }))] -pub struct UpdateDatasetRequest { +pub struct UpdateDatasetReqPayload { /// The id of the dataset you want to update. pub dataset_id: Option, /// The tracking ID of the dataset you want to update. @@ -257,7 +257,7 @@ pub struct UpdateDatasetRequest { path = "/dataset", context_path = "/api", tag = "Dataset", - request_body(content = UpdateDatasetRequest, description = "JSON request payload to update a dataset", content_type = "application/json"), + request_body(content = UpdateDatasetReqPayload, description = "JSON request payload to update a dataset", content_type = "application/json"), responses( (status = 200, description = "Dataset updated successfully", body = Dataset), (status = 400, description = "Service error relating to updating the dataset", body = ErrorResponseBody), @@ -272,7 +272,7 @@ pub struct UpdateDatasetRequest { )] #[tracing::instrument(skip(pool))] pub async fn update_dataset( - data: web::Json, + data: web::Json, pool: web::Data, redis_pool: web::Data, user: OwnerOnly, @@ -700,13 +700,17 @@ pub struct GetAllTagsReqPayload { #[derive(Serialize, Deserialize, Debug, ToSchema, Queryable)] pub struct TagsWithCount { + /// Content of the tag pub tag: String, + /// Number of chunks in the dataset with that tag pub count: i64, } #[derive(Serialize, Deserialize, Debug, ToSchema)] pub struct GetAllTagsResponse { + /// List of tags with the number of chunks in the dataset with that tag. pub tags: Vec, + /// Total number of unique tags in the dataset. pub total: i64, } @@ -751,3 +755,73 @@ pub async fn get_all_tags( total: items.1, })) } + +#[derive(Serialize, Deserialize, Debug, ToSchema)] +pub struct CreateBatchDataset { + /// Name of the dataset. + pub dataset_name: String, + /// Optional tracking ID for the dataset. Can be used to track the dataset in external systems. Must be unique within the organization. Strongly recommended to not use a valid uuid value as that will not work with the TR-Dataset header. + pub tracking_id: Option, + /// The configuration of the dataset. See the example request payload for the potential keys which can be set. It is possible to break your dataset's functionality by erroneously setting this field. We recommend setting through creating a dataset at dashboard.trieve.ai and managing it's settings there. + pub server_configuration: Option, +} + +#[derive(Serialize, Deserialize, Debug, ToSchema)] +pub struct CreateDatasetBatchReqPayload { + /// List of datasets to create + pub datasets: Vec, + /// Upsert when a dataset with one of the specified tracking_ids already exists. By default this is false and specified datasets with a tracking_id that already exists in the org will not be ignored. If true, the existing dataset will be updated with the new dataset's details. + pub upsert: Option, +} + +/// Datasets +#[derive(Serialize, Deserialize, Debug, ToSchema)] +pub struct Datasets(Vec); + +/// Batch Create Datasets +/// +/// Datasets will be created in the org specified via the TR-Organization header. Auth'ed user must be an owner of the organization to create datasets. If a tracking_id is ignored due to it already existing on the org, the response will not contain a dataset with that tracking_id and it can be assumed that a dataset with the missing tracking_id already exists. +#[utoipa::path( + post, + path = "/dataset/batch_create_datasets", + context_path = "/api", + tag = "Dataset", + request_body(content = CreateDatasetBatchReqPayload, description = "JSON request payload to bulk create datasets", content_type = "application/json"), + responses( + (status = 200, description = "Page of tags requested with all tags and the number of chunks in the dataset with that tag plus the total number of unique tags for the whole datset", body = Datasets), + (status = 400, description = "Service error relating to finding items by tag", body = ErrorResponseBody), + ), + params( + ("TR-Organization" = uuid::Uuid, Header, description = "The organization id to use for the request"), + ), + security( + ("ApiKey" = ["owner"]), + ) +)] +pub async fn batch_create_datasets( + data: web::Json, + _user: OwnerOnly, + pool: web::Data, + org_with_sub_and_plan: OrganizationWithSubAndPlan, +) -> Result { + let datasets = data + .datasets + .iter() + .map(|d| { + Dataset::from_details( + d.dataset_name.clone(), + org_with_sub_and_plan.organization.id, + d.tracking_id.clone(), + d.server_configuration + .clone() + .map(|c| c.into()) + .unwrap_or_default(), + ) + }) + .collect::>(); + + let created_or_upserted_datasets = + create_datasets_query(datasets, data.upsert, pool.clone()).await?; + + Ok(HttpResponse::Ok().json(created_or_upserted_datasets)) +} diff --git a/server/src/lib.rs b/server/src/lib.rs index f61cd38d2..f01e084ea 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -221,6 +221,7 @@ impl Modify for SecurityAddon { handlers::organization_handler::get_organization_users, handlers::organization_handler::update_all_org_dataset_configs, handlers::dataset_handler::create_dataset, + handlers::dataset_handler::batch_create_datasets, handlers::dataset_handler::update_dataset, handlers::dataset_handler::delete_dataset, handlers::dataset_handler::delete_dataset_by_tracking_id, @@ -302,6 +303,7 @@ impl Modify for SecurityAddon { handlers::dataset_handler::GetAllTagsReqPayload, handlers::dataset_handler::GetAllTagsResponse, handlers::dataset_handler::GetCrawlOptionsResponse, + handlers::dataset_handler::Datasets, handlers::group_handler::RecommendGroupsReqPayload, handlers::group_handler::RecommendGroupsResponse, handlers::group_handler::SearchWithinGroupReqPayload, @@ -344,8 +346,10 @@ impl Modify for SecurityAddon { operators::search_operator::SearchOverGroupsResults, operators::search_operator::SearchOverGroupsResponseBody, operators::search_operator::SearchOverGroupsResponseTypes, - handlers::dataset_handler::CreateDatasetRequest, - handlers::dataset_handler::UpdateDatasetRequest, + handlers::dataset_handler::CreateDatasetReqPayload, + handlers::dataset_handler::CreateBatchDataset, + handlers::dataset_handler::CreateDatasetBatchReqPayload, + handlers::dataset_handler::UpdateDatasetReqPayload, handlers::dataset_handler::GetDatasetsPagination, data::models::DatasetConfigurationDTO, handlers::chunk_handler::CrawlOpenAPIOptions, @@ -805,6 +809,11 @@ pub fn main() -> std::io::Result<()> { ) .route(web::put().to(handlers::dataset_handler::update_dataset)) ) + .service( + web::resource("/batch_create_datasets").route( + web::post().to(handlers::dataset_handler::batch_create_datasets), + ) + ) .service( web::resource("/organization/{organization_id}").route( web::get().to( diff --git a/server/src/operators/chunk_operator.rs b/server/src/operators/chunk_operator.rs index 79a9a97b8..3bc0779c9 100644 --- a/server/src/operators/chunk_operator.rs +++ b/server/src/operators/chunk_operator.rs @@ -1364,6 +1364,7 @@ pub async fn delete_chunk_metadata_query( ) -> Result<(), ServiceError> { use crate::data::schema::chunk_group_bookmarks::dsl as chunk_group_bookmarks_columns; use crate::data::schema::chunk_metadata::dsl as chunk_metadata_columns; + let mut conn = pool.get().await.map_err(|_e| { ServiceError::InternalServerError("Failed to get postgres connection".to_string()) })?; diff --git a/server/src/operators/dataset_operator.rs b/server/src/operators/dataset_operator.rs index f1925fc60..f8aaaf607 100644 --- a/server/src/operators/dataset_operator.rs +++ b/server/src/operators/dataset_operator.rs @@ -18,6 +18,7 @@ use clickhouse::Row; use diesel::dsl::count; use diesel::prelude::*; use diesel::result::{DatabaseErrorKind, Error as DBError}; +use diesel::upsert::excluded; use diesel_async::RunQueryDsl; use itertools::Itertools; use serde::{Deserialize, Serialize}; @@ -54,6 +55,57 @@ pub async fn create_dataset_query( Ok(new_dataset) } +#[tracing::instrument(skip(pool))] +pub async fn create_datasets_query( + datasets: Vec, + upsert: Option, + pool: web::Data, +) -> Result, ServiceError> { + use crate::data::schema::datasets::dsl as datasets_columns; + + let mut conn = pool + .get() + .await + .map_err(|_| ServiceError::BadRequest("Could not get database connection".to_string()))?; + + let created_or_upserted_datasets: Vec = if upsert.unwrap_or(false) { + diesel::insert_into(datasets_columns::datasets) + .values(&datasets) + .on_conflict(( + datasets_columns::tracking_id, + datasets_columns::organization_id, + )) + .do_update() + .set(( + datasets_columns::name.eq(excluded(datasets_columns::name)), + datasets_columns::server_configuration + .eq(excluded(datasets_columns::server_configuration)), + )) + .get_results::(&mut conn) + .await + .map_err(|err| { + log::error!("Could not create dataset batch: {}", err); + ServiceError::BadRequest( + "Could not create dataset batch due to pg error".to_string(), + ) + })? + } else { + diesel::insert_into(datasets_columns::datasets) + .values(&datasets) + .on_conflict_do_nothing() + .get_results::(&mut conn) + .await + .map_err(|err| { + log::error!("Could not create dataset batch: {}", err); + ServiceError::BadRequest( + "Could not create dataset batch due to pg error".to_string(), + ) + })? + }; + + Ok(created_or_upserted_datasets) +} + #[tracing::instrument(skip(pool))] pub async fn get_dataset_by_id_query( id: UnifiedId,