Skip to content

Splitter #719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
00f06d2
first draft
thibaultdvx Apr 16, 2025
58b5ce5
DataLoaderConfig
thibaultdvx Apr 16, 2025
c24dfce
Merge remote-tracking branch 'upstream/clinicadl_v2' into dataloader
thibaultdvx Apr 16, 2025
2d93014
tests
thibaultdvx Apr 17, 2025
60474e1
changes in spit
thibaultdvx Apr 17, 2025
5468b69
sphinx-autodoc-typehints in doc dependencies
thibaultdvx Apr 17, 2025
21889ea
first draft
thibaultdvx Apr 23, 2025
551c1bc
splitter objects
thibaultdvx Apr 23, 2025
8fe5aa2
remove split_utils
thibaultdvx Apr 24, 2025
19b5063
docstrings
thibaultdvx Apr 24, 2025
e1daee9
documentation
thibaultdvx Apr 24, 2025
f5f7864
Merge remote-tracking branch 'upstream/clinicadl_v2' into splitter
thibaultdvx May 19, 2025
0324e3d
add dostring on parameter "dataset"
thibaultdvx May 19, 2025
61b38c1
complete docstring on the argument "dataset"
thibaultdvx May 19, 2025
7033c34
typo in documentation
thibaultdvx May 19, 2025
ca92cd6
first tests
thibaultdvx May 20, 2025
a7840e1
change concat dataframe
thibaultdvx May 21, 2025
02bb399
change paired dataframe
thibaultdvx May 21, 2025
ab39b13
change dataframe for unpaired
thibaultdvx May 21, 2025
2fe4d7a
unittests for make_splits
thibaultdvx Jun 6, 2025
86d2c76
change how to deal with empty datasets
thibaultdvx Jun 11, 2025
344e772
update test data
thibaultdvx Jun 11, 2025
add173e
test
thibaultdvx Jun 11, 2025
9c49665
doc
thibaultdvx Jun 11, 2025
36b6153
try ending multiprocessing tests
thibaultdvx Jun 11, 2025
13d88ca
test multiprocessing ending in test
thibaultdvx Jun 11, 2025
3c23092
Merge branch 'splitter' of https://github.com/thibaultdvx/clinicadl i…
thibaultdvx Jun 11, 2025
aee6086
Revert "test multiprocessing ending in test"
thibaultdvx Jun 11, 2025
e88b43a
try ending multiprocessing in tests
thibaultdvx Jun 11, 2025
ad012cb
fix tests
thibaultdvx Jun 11, 2025
71cf73f
fix tests
thibaultdvx Jun 11, 2025
ee358c1
test if dataloader is the problem
thibaultdvx Jun 11, 2025
3eb397f
test if num_workers is the problem
thibaultdvx Jun 11, 2025
1a9f09b
test if num_workers is the problem
thibaultdvx Jun 11, 2025
25adca4
test conftest
thibaultdvx Jun 11, 2025
8fbc910
try self multiprocessing
thibaultdvx Jun 11, 2025
fc9bd4b
change threshold value for test
thibaultdvx Jun 11, 2025
6e24663
skip some tests for macos
thibaultdvx Jun 11, 2025
52d32e2
minor change
camillebrianceau Jun 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions clinicadl/data/dataloader/config.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
from typing import Iterator, Optional, Union, overload
from typing import Iterator, Optional, overload

from pydantic import NonNegativeInt, PositiveInt, model_validator
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import DataLoader as TorchDataLoaader
from torch.utils.data import DistributedSampler, Sampler, WeightedRandomSampler

from clinicadl.data.datasets import (
CapsDataset,
ConcatDataset,
PairedDataset,
UnpairedDataset,
)
from clinicadl.data.datasets.types import Dataset, SimpleDataset, TupleDataset
from clinicadl.utils.config import ClinicaDLConfig
from clinicadl.utils.seed import pl_worker_init_function

from .batch import SimpleBatch, simple_collate_fn, tuple_collate_fn

SimpleDataset = Union[CapsDataset, ConcatDataset]
TupleDataset = Union[PairedDataset, UnpairedDataset]
Dataset = Union[SimpleDataset, TupleDataset]


class DataLoader(TorchDataLoader):
class DataLoader(TorchDataLoaader):
"""
Overwrites :py:class:`torch.utils.data.DataLoader` only to add a `set_epoch` method.
"""
Expand Down Expand Up @@ -294,6 +290,8 @@ def get_object(
------
ValueError
If only one of ``dp_degree`` and ``rank`` is not ``None``.
ValueError
If ``rank`` is greater than ``dp_degree``.
ValueError
If the dataset is an :py:class:`~clinicadl.data.datasets.UnpairedDataset`,
and ``sampling_weights`` is not ``None``.
Expand All @@ -304,6 +302,24 @@ def get_object(
If ``sampling_weights`` is not ``None`` and the associated column cannot
be converted to float values.
"""
if (rank is not None and dp_degree is None) or (
dp_degree is not None and rank is None
):
raise ValueError(
"For data parallelism, none of 'dp_degree' and 'rank' can be None. "
f"Got rank={rank} and dp_degree={dp_degree}"
)

if dp_degree is None:
dp_degree = 1
rank = 0

if rank >= dp_degree:
raise ValueError(
"'rank' must be strictly smaller than 'dp_degree'. Got "
f"dp_degree={dp_degree} and rank={rank}"
)

return DataLoader(
dataset=dataset,
sampler=self._generate_sampler(dataset, dp_degree, rank),
Expand All @@ -317,26 +333,15 @@ def get_object(
def _generate_sampler(
self,
dataset: CapsDataset,
dp_degree: Optional[int],
rank: Optional[int],
dp_degree: int,
rank: int,
) -> Sampler:
"""
Returns a WeightedRandomSampler if self.sampling_weights is not None, otherwise a
a DistributedSampler, even when data parallelism is not performed (in this case
the degree of data parallelism is set to 1, so it is equivalent to a simple PyTorch
RandomSampler if self.shuffle is True or no sampler if self.shuffle is False).
"""
if (rank is not None and dp_degree is None) or (
dp_degree is not None and rank is None
):
raise ValueError(
"For data parallelism, none of 'dp_degree' and 'rank' can be None. "
f"Got rank={rank} and dp_degree={dp_degree}"
)
if dp_degree is None:
dp_degree = 1
rank = 0

if self.sampling_weights and rank is not None:
weights = self._get_weights(dataset, self.sampling_weights)
length = len(weights) // dp_degree + int(rank < len(weights) % dp_degree)
Expand Down
61 changes: 19 additions & 42 deletions clinicadl/data/datasets/caps_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,10 @@
)
from clinicadl.transforms.extraction import ExtractionMethod, Sample
from clinicadl.transforms.transforms import Transforms
from clinicadl.tsvtools.utils import (
check_df,
tsv_to_df,
)
from clinicadl.tsvtools.utils import read_data
from clinicadl.utils.exceptions import (
ClinicaDLArgumentError,
ClinicaDLCAPSError,
ClinicaDLTSVError,
)
from clinicadl.utils.typing import DataType, PathType

Expand Down Expand Up @@ -332,7 +328,10 @@ def to_tensors(
self._count_samples()

def read_tensor_conversion(
self, json_name: str, check_transforms: bool = True, load_also: list[str] = []
self,
json_name: str,
check_transforms: bool = True,
load_also: Optional[list[str]] = None,
) -> None:
"""
To read an old tensor conversion. The function will check that
Expand Down Expand Up @@ -365,7 +364,7 @@ def read_tensor_conversion(
.. warning::
**To use carefully**. You must be sure that the transforms match before setting ``check_transforms=False``.

load_also : list[str] (optional, default=[])
load_also : list[str] (optional, default=None)
To load additional information potentially stored in ``.pt`` files. By default, only the image, the label, and masks
mentioned in the argument ``masks`` of the CapsDataset will be loaded.

Expand Down Expand Up @@ -421,28 +420,20 @@ def subset(self, data: DataType) -> CapsDataset:
ClinicaDLTSVError
If the DataFrame associated to ``data`` does not contain the columns ``"participant_id"``
and ``"session_id"``.
ClinicaDLTSVError
If some (participant, session) pairs mentioned in ``data`` are not in the current CapsDataset.
ClinicaDLCAPSError
If no (participant, session) pairs mentioned in ``data`` are in the current CapsDataset
(this would lead to an empty dataset).
"""
new_df = self._check_data_instance(data).set_index([PARTICIPANT_ID, SESSION_ID])

try:
subset_df = (
self.df.set_index([PARTICIPANT_ID, SESSION_ID])
.loc[new_df.index]
.reset_index()
)
except KeyError as exc:
missing_pairs = new_df.index.difference(
self.df.set_index([PARTICIPANT_ID, SESSION_ID]).index
)
new_df = read_data(data, check_protected_names=False).set_index(
[PARTICIPANT_ID, SESSION_ID]
)
df = self.df.set_index([PARTICIPANT_ID, SESSION_ID])
subset_df = df.loc[new_df.index.intersection(df.index)].reset_index()

err_message = (
"Some couples (participant, session) are not in the dataset:\n"
if len(subset_df) == 0:
raise ClinicaDLCAPSError(
"No (participant, session) pairs mentioned in 'data' are in the CapsDataset. This would lead to an empty dataset!"
)
for pair in missing_pairs:
err_message += f" - {pair} \n"
raise ClinicaDLTSVError(err_message) from exc

dataset = deepcopy(self)
dataset.df = subset_df
Expand Down Expand Up @@ -621,7 +612,7 @@ def _check_label(self, label: Optional[str]) -> Optional[Union[Column, Mask]]:
"""
if isinstance(label, str):
if label in self.df.columns:
if isinstance(self.df[label].iloc[0], str):
if not pd.api.types.is_numeric_dtype(self.df[label]):
label_list = self.df[label].unique()
if len(label_list) > 5:
raise ClinicaDLArgumentError(
Expand Down Expand Up @@ -727,24 +718,10 @@ def _get_df_from_input(self, data: Optional[DataType]) -> pd.DataFrame:
f"'data' must be a Pandas DataFrame, a path to a TSV file or None. Got {data}"
)

df = self._check_data_instance(data)
df = read_data(data)

return deepcopy(df)

@staticmethod
def _check_data_instance(data: DataType) -> pd.DataFrame:
"""
Checks the DataFrame passed by the user (either as a DataFrame or
as a path to a TSV). Returns the checked DataFrame.
"""
if isinstance(data, (str, Path)):
path = Path(data)
df = tsv_to_df(path)
elif isinstance(data, pd.DataFrame):
df = check_df(data)

return df # pylint: disable=possibly-used-before-assignment

### for __getitem__ ###
def _get_meta_data(self, idx: int) -> Tuple[str, str, int]:
"""
Expand Down
58 changes: 23 additions & 35 deletions clinicadl/data/datasets/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import pandas as pd
from torch.utils.data import ConcatDataset as TorchConcatDataset

from clinicadl.dictionary.words import N_SAMPLES, PARTICIPANT_ID, SESSION_ID
from clinicadl.dictionary.words import DATASET_ID, PARTICIPANT_ID, SESSION_ID
from clinicadl.transforms.extraction import Sample
from clinicadl.transforms.extraction.slice import Slice
from clinicadl.utils.exceptions import ClinicaDLCAPSError, ClinicaDLTSVError
from clinicadl.utils.exceptions import ClinicaDLCAPSError
from clinicadl.utils.typing import DataType

from .caps_dataset import CapsDataset
Expand Down Expand Up @@ -163,38 +163,28 @@ def subset(self, data: DataType) -> ConcatDataset:
If the DataFrame associated to ``data`` does not contain the columns ``"participant_id"``
and ``"session_id"``.
ClinicaDLCAPSError
If some (participant, session) pairs mentioned in ``data`` are not in any of the CapsDatasets
forming the ConcatDataset.
If no (participant, session) pairs mentioned in ``data`` are at least in one of the underlying datasets.
This would lead to an empty ConcatDataset.
"""
data = CapsDataset._check_data_instance(data).set_index(
[PARTICIPANT_ID, SESSION_ID]
)

in_a_df = {(participant, session): False for participant, session in data.index}
datasets = []
sub_datasets = []
not_empty = False
for dataset in self.datasets:
participants_sessions = dataset.get_participant_session_couples()
participants_sessions = data.index.intersection(participants_sessions)

for participant_session in participants_sessions:
in_a_df[participant_session] = True

sub_data = data.loc[participants_sessions]
try:
datasets.append(dataset.subset(sub_data.reset_index()))
except ClinicaDLTSVError:
sub_datasets.append(dataset.subset(data))
not_empty = True
except ClinicaDLCAPSError: # empty dataset
continue

raise_error = False
err_message = "Some couples (participant, session) are not in any of the datasets forming the ConcatDataset:\n"
for participant_session in in_a_df:
if not in_a_df[participant_session]:
raise_error = True
err_message += f" - {participant_session} \n"
if raise_error:
raise ClinicaDLCAPSError(err_message)
if not not_empty:
raise ClinicaDLCAPSError(
"No (participant, session) pairs mentioned in 'data' are in the ConcatDataset. This would lead to an empty dataset!"
)

return ConcatDataset(datasets, ignore_spacing=True, raise_warnings=False)
return ConcatDataset(
sub_datasets,
ignore_spacing=True,
raise_warnings=False,
)

def describe(self) -> tuple[Dict[str, Any], ...]:
"""
Expand Down Expand Up @@ -357,16 +347,14 @@ def _concat_dfs(self) -> pd.DataFrame:
"""
Concatenates the dataframes from all the datasets.
"""
df = pd.concat(
[
dataset.df[[PARTICIPANT_ID, SESSION_ID, N_SAMPLES]]
for dataset in self.datasets
],
df: pd.DataFrame = pd.concat(
[dataset.df for dataset in self.datasets],
keys=range(len(self.datasets)),
names=["dataset_id"],
names=[DATASET_ID],
)
CapsDataset._map_indices_to_images(df)

return df.reset_index(
drop=False,
level=0,
level=DATASET_ID,
).reset_index(drop=True)
Loading