Skip to content

Commit

Permalink
feature: add batch create dataset route
Browse files Browse the repository at this point in the history
  • Loading branch information
skeptrunedev committed Nov 13, 2024
1 parent f75c87a commit de0a2f4
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 13 deletions.
8 changes: 8 additions & 0 deletions server/src/data/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1892,13 +1892,21 @@ pub struct DatasetEventCount {
}))]
#[diesel(table_name = datasets)]
pub struct Dataset {
/// Unique identifier of the dataset, auto-generated uuid created by Trieve
pub id: uuid::Uuid,
/// Name of the dataset
pub name: String,
/// Timestamp of the creation of the dataset
pub created_at: chrono::NaiveDateTime,
/// Timestamp of the last update of the dataset
pub updated_at: chrono::NaiveDateTime,
/// Unique identifier of the organization that owns the dataset
pub organization_id: uuid::Uuid,
/// Configuration of the dataset for RAG, embeddings, BM25, etc.
pub server_configuration: serde_json::Value,
/// Tracking ID of the dataset, can be any string, determined by the user. Tracking ID's are unique identifiers for datasets within an organization. They are designed to match the unique identifier of the dataset in the user's system.
pub tracking_id: Option<String>,
/// Flag to indicate if the dataset has been deleted. Deletes are handled async after the flag is set so as to avoid expensive search index compaction.
pub deleted: i32,
}

Expand Down
2 changes: 1 addition & 1 deletion server/src/handlers/chunk_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ pub struct ChunkReqPayload {
pub metadata: Option<serde_json::Value>,
/// Tracking_id is a string which can be used to identify a chunk. This is useful for when you are coordinating with an external system and want to use the tracking_id to identify the chunk.
pub tracking_id: Option<String>,
/// Upsert when a chunk with the same tracking_id exists. By default this is false, and the request will fail if a chunk with the same tracking_id exists. If this is true, the chunk will be updated if a chunk with the same tracking_id exists.
/// Upsert when a chunk with the same tracking_id exists. By default this is false, and chunks will be ignored if another with the same tracking_id exists. If this is true, the chunk will be updated if a chunk with the same tracking_id exists.
pub upsert_by_tracking_id: Option<bool>,
/// Group ids are the Trieve generated ids of the groups that the chunk should be placed into. This is useful for when you want to create a chunk and add it to a group or multiple groups in one request. Groups with these Trieve generated ids must be created first, it cannot be arbitrarily created through this route.
pub group_ids: Option<Vec<uuid::Uuid>>,
Expand Down
94 changes: 84 additions & 10 deletions server/src/handlers/dataset_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ use crate::{
validate_crawl_options,
},
dataset_operator::{
clear_dataset_by_dataset_id_query, create_dataset_query, get_dataset_by_id_query,
get_dataset_usage_query, get_datasets_by_organization_id, get_tags_in_dataset_query,
soft_delete_dataset_by_id_query, update_dataset_query,
clear_dataset_by_dataset_id_query, create_dataset_query, create_datasets_query,
get_dataset_by_id_query, get_dataset_usage_query, get_datasets_by_organization_id,
get_tags_in_dataset_query, soft_delete_dataset_by_id_query, update_dataset_query,
},
dittofeed_operator::{
send_ditto_event, DittoDatasetCreated, DittoTrackProperties, DittoTrackRequest,
Expand Down Expand Up @@ -99,7 +99,7 @@ impl FromRequest for OrganizationWithSubAndPlan {
"MAX_LIMIT": 10000
}
}))]
pub struct CreateDatasetRequest {
pub struct CreateDatasetReqPayload {
/// Name of the dataset.
pub dataset_name: String,
/// Optional tracking ID for the dataset. Can be used to track the dataset in external systems. Must be unique within the organization. Strongly recommended to not use a valid uuid value as that will not work with the TR-Dataset header.
Expand All @@ -112,13 +112,13 @@ pub struct CreateDatasetRequest {

/// Create Dataset
///
/// Auth'ed user must be an owner of the organization to create a dataset.
/// Dataset will be created in the org specified via the TR-Organization header. Auth'ed user must be an owner of the organization to create a dataset.
#[utoipa::path(
post,
path = "/dataset",
context_path = "/api",
tag = "Dataset",
request_body(content = CreateDatasetRequest, description = "JSON request payload to create a new dataset", content_type = "application/json"),
request_body(content = CreateDatasetReqPayload, description = "JSON request payload to create a new dataset", content_type = "application/json"),
responses(
(status = 200, description = "Dataset created successfully", body = Dataset),
(status = 400, description = "Service error relating to creating the dataset", body = ErrorResponseBody),
Expand All @@ -132,7 +132,7 @@ pub struct CreateDatasetRequest {
)]
#[tracing::instrument(skip(pool))]
pub async fn create_dataset(
data: web::Json<CreateDatasetRequest>,
data: web::Json<CreateDatasetReqPayload>,
pool: web::Data<Pool>,
redis_pool: web::Data<RedisPool>,
org_with_sub_and_plan: OrganizationWithSubAndPlan,
Expand Down Expand Up @@ -234,7 +234,7 @@ pub async fn create_dataset(
"MAX_LIMIT": 10000
}
}))]
pub struct UpdateDatasetRequest {
pub struct UpdateDatasetReqPayload {
/// The id of the dataset you want to update.
pub dataset_id: Option<uuid::Uuid>,
/// The tracking ID of the dataset you want to update.
Expand All @@ -257,7 +257,7 @@ pub struct UpdateDatasetRequest {
path = "/dataset",
context_path = "/api",
tag = "Dataset",
request_body(content = UpdateDatasetRequest, description = "JSON request payload to update a dataset", content_type = "application/json"),
request_body(content = UpdateDatasetReqPayload, description = "JSON request payload to update a dataset", content_type = "application/json"),
responses(
(status = 200, description = "Dataset updated successfully", body = Dataset),
(status = 400, description = "Service error relating to updating the dataset", body = ErrorResponseBody),
Expand All @@ -272,7 +272,7 @@ pub struct UpdateDatasetRequest {
)]
#[tracing::instrument(skip(pool))]
pub async fn update_dataset(
data: web::Json<UpdateDatasetRequest>,
data: web::Json<UpdateDatasetReqPayload>,
pool: web::Data<Pool>,
redis_pool: web::Data<RedisPool>,
user: OwnerOnly,
Expand Down Expand Up @@ -700,13 +700,17 @@ pub struct GetAllTagsReqPayload {

#[derive(Serialize, Deserialize, Debug, ToSchema, Queryable)]
pub struct TagsWithCount {
/// Content of the tag
pub tag: String,
/// Number of chunks in the dataset with that tag
pub count: i64,
}

#[derive(Serialize, Deserialize, Debug, ToSchema)]
pub struct GetAllTagsResponse {
/// List of tags with the number of chunks in the dataset with that tag.
pub tags: Vec<TagsWithCount>,
/// Total number of unique tags in the dataset.
pub total: i64,
}

Expand Down Expand Up @@ -751,3 +755,73 @@ pub async fn get_all_tags(
total: items.1,
}))
}

#[derive(Serialize, Deserialize, Debug, ToSchema)]
pub struct CreateBatchDataset {
/// Name of the dataset.
pub dataset_name: String,
/// Optional tracking ID for the dataset. Can be used to track the dataset in external systems. Must be unique within the organization. Strongly recommended to not use a valid uuid value as that will not work with the TR-Dataset header.
pub tracking_id: Option<String>,
/// The configuration of the dataset. See the example request payload for the potential keys which can be set. It is possible to break your dataset's functionality by erroneously setting this field. We recommend setting through creating a dataset at dashboard.trieve.ai and managing it's settings there.
pub server_configuration: Option<DatasetConfigurationDTO>,
}

#[derive(Serialize, Deserialize, Debug, ToSchema)]
pub struct CreateDatasetBatchReqPayload {
/// List of datasets to create
pub datasets: Vec<CreateBatchDataset>,
/// Upsert when a dataset with one of the specified tracking_ids already exists. By default this is false and specified datasets with a tracking_id that already exists in the org will not be ignored. If true, the existing dataset will be updated with the new dataset's details.
pub upsert: Option<bool>,
}

/// Datasets
#[derive(Serialize, Deserialize, Debug, ToSchema)]
pub struct Datasets(Vec<Dataset>);

/// Batch Create Datasets
///
/// Datasets will be created in the org specified via the TR-Organization header. Auth'ed user must be an owner of the organization to create datasets. If a tracking_id is ignored due to it already existing on the org, the response will not contain a dataset with that tracking_id and it can be assumed that a dataset with the missing tracking_id already exists.
#[utoipa::path(
post,
path = "/dataset/batch_create_datasets",
context_path = "/api",
tag = "Dataset",
request_body(content = CreateDatasetBatchReqPayload, description = "JSON request payload to bulk create datasets", content_type = "application/json"),
responses(
(status = 200, description = "Page of tags requested with all tags and the number of chunks in the dataset with that tag plus the total number of unique tags for the whole datset", body = Datasets),
(status = 400, description = "Service error relating to finding items by tag", body = ErrorResponseBody),
),
params(
("TR-Organization" = uuid::Uuid, Header, description = "The organization id to use for the request"),
),
security(
("ApiKey" = ["owner"]),
)
)]
pub async fn batch_create_datasets(
data: web::Json<CreateDatasetBatchReqPayload>,
_user: OwnerOnly,
pool: web::Data<Pool>,
org_with_sub_and_plan: OrganizationWithSubAndPlan,
) -> Result<HttpResponse, ServiceError> {
let datasets = data
.datasets
.iter()
.map(|d| {
Dataset::from_details(
d.dataset_name.clone(),
org_with_sub_and_plan.organization.id,
d.tracking_id.clone(),
d.server_configuration
.clone()
.map(|c| c.into())
.unwrap_or_default(),
)
})
.collect::<Vec<_>>();

let created_or_upserted_datasets =
create_datasets_query(datasets, data.upsert, pool.clone()).await?;

Ok(HttpResponse::Ok().json(created_or_upserted_datasets))
}
13 changes: 11 additions & 2 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ impl Modify for SecurityAddon {
handlers::organization_handler::get_organization_users,
handlers::organization_handler::update_all_org_dataset_configs,
handlers::dataset_handler::create_dataset,
handlers::dataset_handler::batch_create_datasets,
handlers::dataset_handler::update_dataset,
handlers::dataset_handler::delete_dataset,
handlers::dataset_handler::delete_dataset_by_tracking_id,
Expand Down Expand Up @@ -302,6 +303,7 @@ impl Modify for SecurityAddon {
handlers::dataset_handler::GetAllTagsReqPayload,
handlers::dataset_handler::GetAllTagsResponse,
handlers::dataset_handler::GetCrawlOptionsResponse,
handlers::dataset_handler::Datasets,
handlers::group_handler::RecommendGroupsReqPayload,
handlers::group_handler::RecommendGroupsResponse,
handlers::group_handler::SearchWithinGroupReqPayload,
Expand Down Expand Up @@ -344,8 +346,10 @@ impl Modify for SecurityAddon {
operators::search_operator::SearchOverGroupsResults,
operators::search_operator::SearchOverGroupsResponseBody,
operators::search_operator::SearchOverGroupsResponseTypes,
handlers::dataset_handler::CreateDatasetRequest,
handlers::dataset_handler::UpdateDatasetRequest,
handlers::dataset_handler::CreateDatasetReqPayload,
handlers::dataset_handler::CreateBatchDataset,
handlers::dataset_handler::CreateDatasetBatchReqPayload,
handlers::dataset_handler::UpdateDatasetReqPayload,
handlers::dataset_handler::GetDatasetsPagination,
data::models::DatasetConfigurationDTO,
handlers::chunk_handler::CrawlOpenAPIOptions,
Expand Down Expand Up @@ -805,6 +809,11 @@ pub fn main() -> std::io::Result<()> {
)
.route(web::put().to(handlers::dataset_handler::update_dataset))
)
.service(
web::resource("/batch_create_datasets").route(
web::post().to(handlers::dataset_handler::batch_create_datasets),
)
)
.service(
web::resource("/organization/{organization_id}").route(
web::get().to(
Expand Down
1 change: 1 addition & 0 deletions server/src/operators/chunk_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,7 @@ pub async fn delete_chunk_metadata_query(
) -> Result<(), ServiceError> {
use crate::data::schema::chunk_group_bookmarks::dsl as chunk_group_bookmarks_columns;
use crate::data::schema::chunk_metadata::dsl as chunk_metadata_columns;

let mut conn = pool.get().await.map_err(|_e| {
ServiceError::InternalServerError("Failed to get postgres connection".to_string())
})?;
Expand Down
52 changes: 52 additions & 0 deletions server/src/operators/dataset_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use clickhouse::Row;
use diesel::dsl::count;
use diesel::prelude::*;
use diesel::result::{DatabaseErrorKind, Error as DBError};
use diesel::upsert::excluded;
use diesel_async::RunQueryDsl;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -54,6 +55,57 @@ pub async fn create_dataset_query(
Ok(new_dataset)
}

#[tracing::instrument(skip(pool))]
pub async fn create_datasets_query(
datasets: Vec<Dataset>,
upsert: Option<bool>,
pool: web::Data<Pool>,
) -> Result<Vec<Dataset>, ServiceError> {
use crate::data::schema::datasets::dsl as datasets_columns;

let mut conn = pool
.get()
.await
.map_err(|_| ServiceError::BadRequest("Could not get database connection".to_string()))?;

let created_or_upserted_datasets: Vec<Dataset> = if upsert.unwrap_or(false) {
diesel::insert_into(datasets_columns::datasets)
.values(&datasets)
.on_conflict((
datasets_columns::tracking_id,
datasets_columns::organization_id,
))
.do_update()
.set((
datasets_columns::name.eq(excluded(datasets_columns::name)),
datasets_columns::server_configuration
.eq(excluded(datasets_columns::server_configuration)),
))
.get_results::<Dataset>(&mut conn)
.await
.map_err(|err| {
log::error!("Could not create dataset batch: {}", err);
ServiceError::BadRequest(
"Could not create dataset batch due to pg error".to_string(),
)
})?
} else {
diesel::insert_into(datasets_columns::datasets)
.values(&datasets)
.on_conflict_do_nothing()
.get_results::<Dataset>(&mut conn)
.await
.map_err(|err| {
log::error!("Could not create dataset batch: {}", err);
ServiceError::BadRequest(
"Could not create dataset batch due to pg error".to_string(),
)
})?
};

Ok(created_or_upserted_datasets)
}

#[tracing::instrument(skip(pool))]
pub async fn get_dataset_by_id_query(
id: UnifiedId,
Expand Down

0 comments on commit de0a2f4

Please sign in to comment.