Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions breadbox/breadbox/api/dataset_uploads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.encoders import jsonable_encoder
from breadbox.compute.dataset_uploads_tasks import run_dataset_upload
from ..schemas.custom_http_exception import UserError

from ..schemas.dataset import DatasetParams, AddDatasetResponse
from .dependencies import get_user
Expand Down Expand Up @@ -66,6 +67,11 @@ def add_dataset_uploads(
"""
utils.check_celery()

if not dataset.is_transient and dataset.expiry_in_seconds is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally prefer adding these types of validations within the pydantic model itself since it feels a bit cleaner and will catch malformed request body faster. In addition, this prevents us from needing to do asserts or similar checks later on which I notice you do later on in breadbox/breadbox/compute/dataset_uploads_tasks.py:87

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm less a fan of validations within the pydanic model ... but only for reasons that I'd say more "habit" then anything.

I agree, this is a validation which would make perfect sense to include in the pydantic model. I'll fix.

raise UserError(
"Dataset was not marked as 'transient' but expiry_in_seconds is set."
)

# Converts a data type (like a Pydantic model) to something compatible with JSON, in this case a dict. Although Celery uses a JSON serializer to serialize arguments to tasks by default, pydantic models are too complex for their default serializer. Pydantic models have a built-in .dict() method but it turns out it doesn't convert enums to strings which celery can't JSON serialize, so I opted to use fastapi's jsonable_encoder() which appears to successfully json serialize enums
dataset_json = jsonable_encoder(dataset)
result = run_dataset_upload.delay(dataset_json, user) # pyright: ignore
Expand Down
1 change: 1 addition & 0 deletions breadbox/breadbox/compute/analysis_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ def create_cell_line_group(
taiga_id=None,
dataset_metadata=None,
dataset_md5=None,
expiry=None,
)
dataset_service.add_matrix_dataset(
db,
Expand Down
1 change: 1 addition & 0 deletions breadbox/breadbox/compute/dataset_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def upload_dataset(
allowed_values=valid_fields.valid_allowed_values,
dataset_metadata=dataset_metadata,
dataset_md5=None,
expiry=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that we are manually setting this field multiple places even though it is already default to None. Is there a reason we need to explicitly set this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's being set to None in the legacy code paths. I've only updated the dataset-v2 endpoint's code path to be aware of expiry.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see what you're saying. On the pydantic model expiry is specified as:

expiry: Annotated[Optional[datetime], Field(default=None,)]

I'd love to leave out the expiry=None here, but pyright doesn't seem to be aware of the the defaulting and complained that I wasn't providing expiry.

)

added_dataset = dataset_service.add_matrix_dataset(
Expand Down
9 changes: 9 additions & 0 deletions breadbox/breadbox/compute/dataset_uploads_tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
from uuid import UUID, uuid4
from typing import Any, List, Optional, Union, Literal, Dict

Expand Down Expand Up @@ -81,6 +82,12 @@ def dataset_upload(
)

dataset_id = str(uuid4())
expiry = None
if dataset_params.expiry_in_seconds is not None:
assert dataset_params.is_transient
expiry = dataset_crud.get_current_datetime() + timedelta(
seconds=dataset_params.expiry_in_seconds
)

unknown_ids = []

Expand Down Expand Up @@ -132,6 +139,7 @@ def dataset_upload(
sample_type_name=dataset_params.sample_type,
data_type=dataset_params.data_type,
is_transient=dataset_params.is_transient,
expiry=expiry,
group_id=str(dataset_params.group_id),
value_type=dataset_params.value_type,
priority=dataset_params.priority,
Expand Down Expand Up @@ -186,6 +194,7 @@ def dataset_upload(
index_type_name=dataset_params.index_type,
data_type=dataset_params.data_type,
is_transient=dataset_params.is_transient,
expiry=expiry,
group_id=str(dataset_params.group_id),
priority=dataset_params.priority,
taiga_id=dataset_params.taiga_id,
Expand Down
30 changes: 30 additions & 0 deletions breadbox/breadbox/crud/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, List, Type, Union, Tuple, Set
from uuid import UUID, uuid4
import warnings
Expand Down Expand Up @@ -677,6 +678,35 @@ def update_dataset(
return dataset


def get_current_datetime():
# this method only exists to allow us to mock it in tests. Since `datetime` is a built-in we're not able to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks adding this informational comment!

# mutate it.
return datetime.now()


def find_expired_datasets(db: SessionWithUser, max_age: timedelta) -> List[Dataset]:
"""
Finds transient datasets which can be deleted (because they've "expired")
Two ways a transient dataset can be expired:
1. the `expiry` field can be explictly set, and that time is in the past
2. the upload_date is before now - `max_age`.
"""

now = get_current_datetime()
min_upload_date = now - max_age

expired_datasets = (
db.query(Dataset)
.filter(
Dataset.is_transient == True,
or_(Dataset.expiry < now, Dataset.upload_date < min_upload_date,),
)
.all()
)

return expired_datasets


def delete_dataset(
db: SessionWithUser, user: str, dataset: Dataset, filestore_location: str
):
Expand Down
6 changes: 6 additions & 0 deletions breadbox/breadbox/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,13 @@ class Dataset(Base, UUIDMixin, GroupMixin):
data_type: Mapped[str] = mapped_column(
String, ForeignKey(DataType.data_type), nullable=False
)

is_transient: Mapped[bool] = mapped_column(Boolean, nullable=False)
# only meaningful for datasets where is_transient==True. Indicates when a transient dataset
# should be deleted.
expiry: Mapped[Optional[DateTime]] = mapped_column(
DateTime(timezone=True), nullable=True
)

priority: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
taiga_id: Mapped[Optional[str]] = mapped_column(String, nullable=True)
Expand Down
9 changes: 8 additions & 1 deletion breadbox/breadbox/schemas/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from breadbox.schemas.custom_http_exception import UserError
from .group import Group
import enum

from datetime import datetime


# NOTE: Using multivalue Literals seems to be creating errors in pydantic models and fastapi request params.
Expand Down Expand Up @@ -106,6 +106,12 @@ class SharedDatasetParams(BaseModel):
description="Transient datasets can be deleted - should only be set to true for non-public short-term-use datasets like custom analysis results.",
),
] = False
expiry_in_seconds: Annotated[
Optional[int],
Field(
description="The number of seconds before this dataset is expired (only applies to transient datasets)"
),
] = None
dataset_metadata: Annotated[
Optional[Dict[str, Any]],
Body(
Expand Down Expand Up @@ -286,6 +292,7 @@ class SharedDatasetFields(BaseModel):
priority: Annotated[Optional[int], Field(default=None, gt=0,)]
taiga_id: Annotated[Optional[str], Field(default=None,)]
is_transient: Annotated[bool, Field(default=False,)]
expiry: Annotated[Optional[datetime], Field(default=None,)]
dataset_metadata: Annotated[
Optional[Dict[str, Any]], Field()
] # NOTE: Same as Dict[str, Any] = Field(None,)
Expand Down
11 changes: 10 additions & 1 deletion breadbox/breadbox/service/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
get_transient_group,
)

from ..crud.dataset import add_tabular_dimensions, add_matrix_dataset_dimensions
from ..crud.dataset import (
add_tabular_dimensions,
add_matrix_dataset_dimensions,
get_current_datetime,
)
from ..crud.dimension_types import (
set_properties_to_index,
add_metadata_dimensions,
Expand Down Expand Up @@ -348,6 +352,7 @@ def add_tabular_dataset(
index_type_name=dataset_in.index_type_name,
data_type=dataset_in.data_type,
is_transient=dataset_in.is_transient,
expiry=dataset_in.expiry,
group_id=group.id,
priority=dataset_in.priority,
taiga_id=dataset_in.taiga_id,
Expand All @@ -356,6 +361,7 @@ def add_tabular_dataset(
short_name=short_name,
version=version,
description=description,
upload_date=get_current_datetime(),
)
db.add(dataset)
db.flush()
Expand Down Expand Up @@ -415,6 +421,7 @@ def add_matrix_dataset(
sample_type_name=dataset_in.sample_type_name,
data_type=dataset_in.data_type,
is_transient=dataset_in.is_transient,
expiry=dataset_in.expiry,
group_id=group.id,
value_type=dataset_in.value_type,
priority=dataset_in.priority,
Expand All @@ -425,6 +432,7 @@ def add_matrix_dataset(
short_name=short_name,
description=description,
version=version,
upload_date=get_current_datetime(),
)
db.add(dataset)
db.flush()
Expand Down Expand Up @@ -494,6 +502,7 @@ def add_dimension_type(
priority=None,
dataset_metadata=None,
dataset_md5=None, # This may change!
expiry=None,
)

check_id_mapping_is_valid(db, reference_column_mappings)
Expand Down
24 changes: 24 additions & 0 deletions breadbox/commands.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import re
from typing import List, Optional
from datetime import timedelta
from breadbox.crud.dataset import find_expired_datasets, delete_dataset

import click
import subprocess
import json
Expand Down Expand Up @@ -30,6 +33,27 @@ def cli():
pass


@cli.command()
@click.option("--dryrun", is_flag=True, default=False)
@click.option("--maxdays", default=60, type=int)
def delete_expired_datasets(maxdays, dryrun):
db = _get_db_connection()
settings = get_settings()
expired_datasets = find_expired_datasets(db, timedelta(days=maxdays))

print(f"Found {len(expired_datasets)} expired datasets")

with transaction(db):
for dataset in expired_datasets:
dataset_summary = f"{dataset.id} (upload_date={dataset.upload_date}, expiry={dataset.expiry})"
if dryrun:
print(f"dryrun: Would have deleted {dataset_summary}")
else:
print(f"Deleting {dataset_summary}")
delete_dataset(db, db.user, dataset, settings.filestore_location)
print("Done")


@cli.command()
@click.argument("user_email")
@click.argument("group_name")
Expand Down
110 changes: 101 additions & 9 deletions breadbox/tests/api/test_dataset_uploads.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
import io
from datetime import datetime

from fastapi.testclient import TestClient

from breadbox.db.session import SessionWithUser, SessionLocalWithUser
from breadbox.schemas.dataset import AddDatasetResponse
from breadbox.compute import dataset_uploads_tasks
from breadbox.celery_task import utils
from breadbox.models.dataset import TabularDataset, TabularCell, TabularColumn
from breadbox.models.dataset import TabularDataset, TabularCell, TabularColumn, Dataset
from sqlalchemy import and_
from datetime import timedelta

from typing import Dict
from ..utils import assert_status_ok
import pytest
import numpy as np
from ..utils import upload_and_get_file_ids
import json
import pandas as pd
from breadbox.models.dataset import AnnotationType
from fastapi.testclient import TestClient
from breadbox.schemas.dataset import ColumnMetadata
from breadbox.crud.access_control import PUBLIC_GROUP_ID, TRANSIENT_GROUP_ID
from tests import factories
from ..utils import assert_status_not_ok, assert_status_ok
from breadbox.crud import dataset as dataset_crud
from breadbox.service import dataset as dataset_service


class TestPost:
Expand Down Expand Up @@ -1125,13 +1136,94 @@ def test_add_tabular_dataset_with_invalid_list_str_vals(
assert tabular_dataset.status_code == 400


import json
import pandas as pd
from breadbox.models.dataset import AnnotationType
from fastapi.testclient import TestClient
from breadbox.schemas.dataset import ColumnMetadata
from breadbox.crud.access_control import PUBLIC_GROUP_ID
from tests import factories
def test_dataset_with_expiry(
client: TestClient, minimal_db: SessionWithUser, mock_celery, settings, monkeypatch
):
user = settings.admin_users[0]
headers = {"X-Forwarded-User": user}
one_day_in_seconds = 60 * 60 * 24

file = factories.continuous_matrix_csv_file()

factories.feature_type(minimal_db, minimal_db.user, "feature_name")
factories.sample_type(minimal_db, minimal_db.user, "sample_name")

file_ids, expected_md5 = upload_and_get_file_ids(client, file)

def override_time(m, mock_now):
m.setattr(dataset_crud, "get_current_datetime", lambda: mock_now)
m.setattr(dataset_service, "get_current_datetime", lambda: mock_now)

with monkeypatch.context() as m:
override_time(m, datetime(year=2025, month=1, day=1))
response = client.post(
"/dataset-v2/",
json={
"format": "matrix",
"name": "a dataset",
"units": "a unit",
"feature_type": "feature_name",
"sample_type": "sample_name",
"data_type": "User upload",
"file_ids": file_ids,
"dataset_md5": expected_md5,
"is_transient": False,
"expiry_in_seconds": one_day_in_seconds,
"group_id": TRANSIENT_GROUP_ID,
"value_type": "continuous",
"allowed_values": None,
},
headers=headers,
)

# verify we can't specify expiry on a non-transient dataset
assert_status_not_ok(response)

# try again with transient = True
response = client.post(
"/dataset-v2/",
json={
"format": "matrix",
"name": "a dataset",
"units": "a unit",
"feature_type": "feature_name",
"sample_type": "sample_name",
"data_type": "User upload",
"file_ids": file_ids,
"dataset_md5": expected_md5,
"is_transient": True,
"expiry_in_seconds": one_day_in_seconds,
"group_id": TRANSIENT_GROUP_ID,
"value_type": "continuous",
"allowed_values": None,
},
headers=headers,
)

# now this should work and we should see the expiry having been set
assert_status_ok(response)
dataset_id = response.json()["result"]["datasetId"]
dataset = minimal_db.query(Dataset).filter(Dataset.id == dataset_id).one()
assert dataset.expiry == datetime(year=2025, month=1, day=2)

# now let's try a few dates to make sure we recognize when this is expired
# first, make sure that nothing is found before the expiration date
with monkeypatch.context() as m:
override_time(m, datetime(year=2025, month=1, day=1, hour=1))
expired = dataset_crud.find_expired_datasets(minimal_db, timedelta(days=1000))
assert len(expired) == 0

# now make sure that we honor max_age when looking for expired data
with monkeypatch.context() as m:
override_time(m, datetime(year=2025, month=1, day=1, hour=1))
expired = dataset_crud.find_expired_datasets(minimal_db, timedelta(minutes=1))
assert len(expired) == 1

# okay, now make sure that we honor the expiration
with monkeypatch.context() as m:
override_time(m, datetime(year=2025, month=1, day=2, hour=1))
expired = dataset_crud.find_expired_datasets(minimal_db, timedelta(days=1000))
assert len(expired) == 1


def test_end_to_end_with_mismatched_metadata(
Expand Down
1 change: 1 addition & 0 deletions breadbox/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def tabular_dataset(
taiga_id=taiga_id,
dataset_metadata=dataset_metadata,
dataset_md5=None,
expiry=None,
)

assert columns_metadata is not None
Expand Down
Loading