diff --git a/configs/vision/pathology/offline/regression/tiger_til_score.yaml b/configs/vision/pathology/offline/regression/tiger_til_score.yaml new file mode 100644 index 000000000..eb8430e33 --- /dev/null +++ b/configs/vision/pathology/offline/regression/tiger_til_score.yaml @@ -0,0 +1,136 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 20} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/tiger_til} + max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 100} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} + callbacks: + - class_path: eva.callbacks.ConfigurationLogger + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: ${oc.env:SAVE_LAST, false} + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MAE} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, min} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: ${oc.env:PATIENCE, 20} + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + - class_path: eva.callbacks.ClassificationEmbeddingsWriter + init_args: + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings/${oc.env:MODEL_NAME, dino_vits16}/tiger_til} + dataloader_idx_map: + 0: train + 1: val + 2: test + metadata_keys: ["wsi_id"] + backbone: + class_path: eva.vision.models.ModelFromRegistry + init_args: + model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} + model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} + overwrite: false + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.HeadModule + init_args: + head: + class_path: eva.vision.models.networks.ABMIL + init_args: + input_size: ${oc.env:IN_FEATURES, 384} + criterion: torch.nn.MSELoss + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: ${oc.env:LR_VALUE, 0.001} + betas: [0.9, 0.999] + metrics: + common: + - class_path: eva.core.metrics.AverageLoss + - class_path: eva.core.metrics.RegressionMetrics + init_args: + prefix: null + postfix: null +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.datasets.MultiEmbeddingsRegressionDataset + init_args: &DATASET_ARGS + root: *DATASET_EMBEDDINGS_ROOT + manifest_file: manifest.csv + split: train + embeddings_transforms: + class_path: eva.core.data.transforms.Pad2DTensor + init_args: + pad_size: &N_PATCHES ${oc.env:N_PATCHES, 200} + target_transforms: + class_path: eva.vision.data.transforms.common.Squeeze + init_args: + dim: -1 + val: + class_path: eva.datasets.MultiEmbeddingsRegressionDataset + init_args: + <<: *DATASET_ARGS + split: val + test: + class_path: eva.datasets.MultiEmbeddingsRegressionDataset + init_args: + <<: *DATASET_ARGS + split: test + predict: + - class_path: eva.vision.datasets.TIGERTILScore + init_args: &PREDICT_DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data/training/wsitils} + sampler: + class_path: eva.vision.data.wsi.patching.samplers.ForegroundGridSampler + init_args: + max_samples: *N_PATCHES + width: 224 + height: 224 + split: train + coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv + image_transforms: + class_path: eva.vision.data.transforms.common.ResizeAndCrop + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} + std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: eva.vision.datasets.TIGERTILScore + init_args: + <<: *PREDICT_DATASET_ARGS + split: val + - class_path: eva.vision.datasets.TIGERTILScore + init_args: + <<: *PREDICT_DATASET_ARGS + split: test + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 32} + num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} + shuffle: true + val: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS + test: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} + num_workers: *N_DATA_WORKERS diff --git a/src/eva/core/data/datasets/__init__.py b/src/eva/core/data/datasets/__init__.py index c5e366827..6da04d3b4 100644 --- a/src/eva/core/data/datasets/__init__.py +++ b/src/eva/core/data/datasets/__init__.py @@ -6,6 +6,10 @@ MultiEmbeddingsClassificationDataset, ) from eva.core.data.datasets.dataset import TorchDataset +from eva.core.data.datasets.regression import ( + EmbeddingsRegressionDataset, + MultiEmbeddingsRegressionDataset, +) from eva.core.data.datasets.typings import DataSample __all__ = [ @@ -13,6 +17,8 @@ "MapDataset", "EmbeddingsClassificationDataset", "MultiEmbeddingsClassificationDataset", + "EmbeddingsRegressionDataset", + "MultiEmbeddingsRegressionDataset", "TorchDataset", "DataSample", ] diff --git a/src/eva/core/data/datasets/classification/multi_embeddings.py b/src/eva/core/data/datasets/classification/multi_embeddings.py index 399d5eab9..ba8a1e223 100644 --- a/src/eva/core/data/datasets/classification/multi_embeddings.py +++ b/src/eva/core/data/datasets/classification/multi_embeddings.py @@ -1,110 +1,16 @@ -"""Dataset class for where a sample corresponds to multiple embeddings.""" - -import os -from typing import Callable, Dict, List, Literal +"""Dataset class for where a classification task sample corresponds to multiple embeddings.""" import numpy as np -import torch -from typing_extensions import override -from eva.core.data.datasets import embeddings as embeddings_base +from eva.core.data.datasets.multi_embeddings import MultiEmbeddingsDataset -class MultiEmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]): +class MultiEmbeddingsClassificationDataset(MultiEmbeddingsDataset): """Dataset class for where a sample corresponds to multiple embeddings. - Example use case: Slide level dataset where each slide has multiple patch embeddings. + Specialised for classification data with an int target type. """ - def __init__( - self, - root: str, - manifest_file: str, - split: Literal["train", "val", "test"], - column_mapping: Dict[str, str] = embeddings_base.default_column_mapping, - embeddings_transforms: Callable | None = None, - target_transforms: Callable | None = None, - ): - """Initialize dataset. - - Expects a manifest file listing the paths of `.pt` files containing tensor embeddings. - - The manifest must have a `column_mapping["multi_id"]` column that contains the - unique identifier group of embeddings. For oncology datasets, this would be usually - the slide id. Each row in the manifest file points to a .pt file that can contain - one or multiple embeddings (either as a list or stacked tensors). There can also be - multiple rows for the same `multi_id`, in which case the embeddings from the different - .pt files corresponding to that same `multi_id` will be stacked along the first dimension. - - Args: - root: Root directory of the dataset. - manifest_file: The path to the manifest file, which is relative to - the `root` argument. - split: The dataset split to use. The `split` column of the manifest - file will be splitted based on this value. - column_mapping: Defines the map between the variables and the manifest - columns. It will overwrite the `default_column_mapping` with - the provided values, so that `column_mapping` can contain only the - values which are altered or missing. - embeddings_transforms: A function/transform that transforms the embedding. - target_transforms: A function/transform that transforms the target. - """ - super().__init__( - manifest_file=manifest_file, - root=root, - split=split, - column_mapping=column_mapping, - embeddings_transforms=embeddings_transforms, - target_transforms=target_transforms, - ) - - self._multi_ids: List[int] - - @override - def setup(self): - super().setup() - self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique()) - - @override - def load_embeddings(self, index: int) -> torch.Tensor: - """Loads and stacks all embedding corresponding to the `index`'th multi_id.""" - # Get all embeddings for the given index (multi_id) - multi_id = self._multi_ids[index] - embedding_paths = self._data.loc[ - self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"] - ].to_list() - - # Load embeddings and stack them accross the first dimension - embeddings = [] - for path in embedding_paths: - embedding = torch.load(os.path.join(self._root, path), map_location="cpu") - if isinstance(embedding, list): - embedding = torch.stack(embedding, dim=0) - embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding) - embeddings = torch.cat(embeddings, dim=0) - - if not embeddings.ndim == 2: - raise ValueError(f"Expected 2D tensor, got {embeddings.ndim} for {multi_id}.") - - return embeddings - - @override - def load_target(self, index: int) -> np.ndarray: - """Returns the target corresponding to the `index`'th multi_id. - - This method assumes that all the embeddings corresponding to the same `multi_id` - have the same target. If this is not the case, it will raise an error. - """ - multi_id = self._multi_ids[index] - targets = self._data.loc[ - self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"] - ] - - if not targets.nunique() == 1: - raise ValueError(f"Multiple targets found for {multi_id}.") - - return np.asarray(targets.iloc[0], dtype=np.int64) - - @override - def __len__(self) -> int: - return len(self._multi_ids) + def __init__(self, *args, **kwargs): + """Initialize dataset with the correct return type.""" + super().__init__(*args, target_type=np.int64, **kwargs) diff --git a/src/eva/core/data/datasets/multi_embeddings.py b/src/eva/core/data/datasets/multi_embeddings.py new file mode 100644 index 000000000..93b68d8ef --- /dev/null +++ b/src/eva/core/data/datasets/multi_embeddings.py @@ -0,0 +1,114 @@ +"""Dataset class for where a sample corresponds to multiple embeddings.""" + +import os +from typing import Any, Callable, Dict, List, Literal + +import numpy as np +import numpy.typing as npt +import torch +from typing_extensions import override + +from eva.core.data.datasets import embeddings as base + + +class MultiEmbeddingsDataset(base.EmbeddingsDataset[torch.Tensor]): + """Dataset class for where a sample corresponds to multiple embeddings. + + Example use case: Slide level dataset where each slide has multiple patch embeddings. + """ + + def __init__( + self, + root: str, + manifest_file: str, + split: Literal["train", "val", "test"], + column_mapping: Dict[str, str] = base.default_column_mapping, + embeddings_transforms: Callable | None = None, + target_transforms: Callable | None = None, + target_type: type[np.generic] = np.int64, + ): + """Initialize dataset. + + Expects a manifest file listing the paths of `.pt` files containing tensor embeddings. + + The manifest must have a `column_mapping["multi_id"]` column that contains the + unique identifier group of embeddings. For oncology datasets, this would be usually + the slide id. Each row in the manifest file points to a .pt file that can contain + one or multiple embeddings (either as a list or stacked tensors). There can also be + multiple rows for the same `multi_id`, in which case the embeddings from the different + .pt files corresponding to that same `multi_id` will be stacked along the first dimension. + + Args: + root: Root directory of the dataset. + manifest_file: The path to the manifest file, which is relative to + the `root` argument. + split: The dataset split to use. The `split` column of the manifest + file will be splitted based on this value. + column_mapping: Defines the map between the variables and the manifest + columns. It will overwrite the `default_column_mapping` with + the provided values, so that `column_mapping` can contain only the + values which are altered or missing. + embeddings_transforms: A function/transform that transforms the embedding. + target_transforms: A function/transform that transforms the target. + target_type: Desired type of the target data + """ + super().__init__( + manifest_file=manifest_file, + root=root, + split=split, + column_mapping=column_mapping, + embeddings_transforms=embeddings_transforms, + target_transforms=target_transforms, + ) + + self._multi_ids: List[int] + self._target_type = target_type + + @override + def setup(self): + super().setup() + self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique()) + + @override + def load_embeddings(self, index: int) -> torch.Tensor: + """Loads and stacks all embedding corresponding to the `index`'th multi_id.""" + # Get all embeddings for the given index (multi_id) + multi_id = self._multi_ids[index] + embedding_paths = self._data.loc[ + self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"] + ].to_list() + + # Load embeddings and stack them accross the first dimension + embeddings = [] + for path in embedding_paths: + embedding = torch.load(os.path.join(self._root, path), map_location="cpu") + if isinstance(embedding, list): + embedding = torch.stack(embedding, dim=0) + embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding) + embeddings = torch.cat(embeddings, dim=0) + + if not embeddings.ndim == 2: + raise ValueError(f"Expected 2D tensor, got {embeddings.ndim} for {multi_id}.") + + return embeddings + + @override + def load_target(self, index: int) -> npt.NDArray[Any]: + """Returns the target corresponding to the `index`'th multi_id. + + This method assumes that all the embeddings corresponding to the same `multi_id` + have the same target. If this is not the case, it will raise an error. + """ + multi_id = self._multi_ids[index] + targets = self._data.loc[ + self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"] + ] + + if not targets.nunique() == 1: + raise ValueError(f"Multiple targets found for {multi_id}.") + + return np.asarray(targets.iloc[0], dtype=self._target_type) + + @override + def __len__(self) -> int: + return len(self._multi_ids) diff --git a/src/eva/core/data/datasets/regression/__init__.py b/src/eva/core/data/datasets/regression/__init__.py new file mode 100644 index 000000000..2c000653f --- /dev/null +++ b/src/eva/core/data/datasets/regression/__init__.py @@ -0,0 +1,6 @@ +"""Embedding regression datasets API.""" + +from eva.core.data.datasets.regression.embeddings import EmbeddingsRegressionDataset +from eva.core.data.datasets.regression.multi_embeddings import MultiEmbeddingsRegressionDataset + +__all__ = ["EmbeddingsRegressionDataset", "MultiEmbeddingsRegressionDataset"] diff --git a/src/eva/core/data/datasets/regression/embeddings.py b/src/eva/core/data/datasets/regression/embeddings.py new file mode 100644 index 000000000..50dd01671 --- /dev/null +++ b/src/eva/core/data/datasets/regression/embeddings.py @@ -0,0 +1,15 @@ +"""Embeddings regression dataset.""" + +import torch +from typing_extensions import override + +from eva.core.data.datasets.classification import EmbeddingsClassificationDataset + + +class EmbeddingsRegressionDataset(EmbeddingsClassificationDataset): + """Embeddings dataset class for regression tasks.""" + + @override + def load_target(self, index: int) -> torch.Tensor: + target = self._data.at[index, self._column_mapping["target"]] + return torch.tensor(float(target), dtype=torch.float32) diff --git a/src/eva/core/data/datasets/regression/multi_embeddings.py b/src/eva/core/data/datasets/regression/multi_embeddings.py new file mode 100644 index 000000000..d3db9cee5 --- /dev/null +++ b/src/eva/core/data/datasets/regression/multi_embeddings.py @@ -0,0 +1,16 @@ +"""Dataset class for where a regression task sample corresponds to multiple embeddings.""" + +import numpy as np + +from eva.core.data.datasets.multi_embeddings import MultiEmbeddingsDataset + + +class MultiEmbeddingsRegressionDataset(MultiEmbeddingsDataset): + """Dataset class for where a sample corresponds to multiple embeddings. + + Specialised for regression data with a float target type. + """ + + def __init__(self, *args, **kwargs): + """Initialize dataset with the correct return type.""" + super().__init__(*args, target_type=np.float32, **kwargs) diff --git a/src/eva/core/metrics/__init__.py b/src/eva/core/metrics/__init__.py index aed8c33ea..32b5af0c4 100644 --- a/src/eva/core/metrics/__init__.py +++ b/src/eva/core/metrics/__init__.py @@ -2,7 +2,11 @@ from eva.core.metrics.average_loss import AverageLoss from eva.core.metrics.binary_balanced_accuracy import BinaryBalancedAccuracy -from eva.core.metrics.defaults import BinaryClassificationMetrics, MulticlassClassificationMetrics +from eva.core.metrics.defaults import ( + BinaryClassificationMetrics, + MulticlassClassificationMetrics, + RegressionMetrics, +) from eva.core.metrics.structs import Metric, MetricCollection, MetricModule, MetricsSchema __all__ = [ @@ -10,6 +14,7 @@ "BinaryBalancedAccuracy", "BinaryClassificationMetrics", "MulticlassClassificationMetrics", + "RegressionMetrics", "Metric", "MetricCollection", "MetricModule", diff --git a/src/eva/core/metrics/defaults/__init__.py b/src/eva/core/metrics/defaults/__init__.py index be65d7579..3a9bb789c 100644 --- a/src/eva/core/metrics/defaults/__init__.py +++ b/src/eva/core/metrics/defaults/__init__.py @@ -4,8 +4,6 @@ BinaryClassificationMetrics, MulticlassClassificationMetrics, ) +from eva.core.metrics.defaults.regression import RegressionMetrics -__all__ = [ - "MulticlassClassificationMetrics", - "BinaryClassificationMetrics", -] +__all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics", "RegressionMetrics"] diff --git a/src/eva/core/metrics/defaults/regression/__init__.py b/src/eva/core/metrics/defaults/regression/__init__.py new file mode 100644 index 000000000..083c20122 --- /dev/null +++ b/src/eva/core/metrics/defaults/regression/__init__.py @@ -0,0 +1,5 @@ +"""Default regression metric collections API.""" + +from eva.core.metrics.defaults.regression.regression_metrics import RegressionMetrics + +__all__ = ["RegressionMetrics"] diff --git a/src/eva/core/metrics/defaults/regression/regression_metrics.py b/src/eva/core/metrics/defaults/regression/regression_metrics.py new file mode 100644 index 000000000..3dddf2305 --- /dev/null +++ b/src/eva/core/metrics/defaults/regression/regression_metrics.py @@ -0,0 +1,37 @@ +"""Default metric collection for regression tasks.""" + +from torchmetrics import MeanAbsoluteError, MeanSquaredError, R2Score + +from eva.core.metrics import structs + + +class RegressionMetrics(structs.MetricCollection): + """Default metrics for regression tasks. + + Supports: + Mean Absolute Error + Root Mean Squared Error + R^2 score + """ + + def __init__( + self, + prefix: str | None = None, + postfix: str | None = None, + ) -> None: + """Initialises regression metrics. + + Args: + prefix: A string to prepend to metric names. + postfix: A string to append after metric names. + """ + super().__init__( + metrics={ + "MAE": MeanAbsoluteError(), + "RMSE": MeanSquaredError(squared=False), + "R2": R2Score(), + }, + prefix=prefix, + postfix=postfix, + compute_groups=[["MAE", "RMSE", "R2"]], + ) diff --git a/src/eva/vision/data/datasets/__init__.py b/src/eva/vision/data/datasets/__init__.py index 95ed8d847..81da744f7 100644 --- a/src/eva/vision/data/datasets/__init__.py +++ b/src/eva/vision/data/datasets/__init__.py @@ -14,6 +14,7 @@ UniToPatho, WsiClassificationDataset, ) +from eva.vision.data.datasets.regression import TIGERTILScore from eva.vision.data.datasets.segmentation import ( BCSS, BTCV, @@ -49,4 +50,5 @@ "VisionDataset", "MultiWsiDataset", "WsiDataset", + "TIGERTILScore", ] diff --git a/src/eva/vision/data/datasets/regression/__init__.py b/src/eva/vision/data/datasets/regression/__init__.py new file mode 100644 index 000000000..6c5f51cdf --- /dev/null +++ b/src/eva/vision/data/datasets/regression/__init__.py @@ -0,0 +1,5 @@ +"""Regression datasets API.""" + +from eva.vision.data.datasets.regression.tiger_til_score import TIGERTILScore + +__all__ = ["TIGERTILScore"] diff --git a/src/eva/vision/data/datasets/regression/tiger_til_score.py b/src/eva/vision/data/datasets/regression/tiger_til_score.py new file mode 100644 index 000000000..0f62c4d37 --- /dev/null +++ b/src/eva/vision/data/datasets/regression/tiger_til_score.py @@ -0,0 +1,50 @@ +"""Tiger dataset class for regression targets.""" + +import functools +import os +from pathlib import Path +from typing import Dict + +import pandas as pd +import torch +from typing_extensions import override + +from eva.vision.data.datasets import tiger + + +class TIGERTILScore(tiger.TIGERBase): + """Dataset class for regression tasks using the TIGERTILS partition of the TIGER dataset. + + Predicts TIL scores, i.e. the proportion of the cell infiltrated by TILs. + """ + + @functools.cached_property + def annotations(self) -> Dict[str, float]: + """Loads per-slide regression targets from a CSV file. + + Expected CSV format: + image-id,tils-score + 103S,0.70 + ... + """ + targets_csv_path = os.path.join(self._root, "tiger-til-scores-wsitils.csv") + + if not os.path.isfile(targets_csv_path): + raise FileNotFoundError(f"Targets CSV file not found at: {targets_csv_path}") + + df = pd.read_csv(targets_csv_path) + if not {"image-id", "tils-score"} <= set(df.columns): + raise ValueError("targets_csv must contain 'image-id' and 'tils-score' columns.") + + return {str(row["image-id"]): float(row["tils-score"]) for _, row in df.iterrows()} + + @override + def load_target(self, index: int) -> torch.Tensor: + metadata = self.load_metadata(index=index) + slide_idx = metadata["slide_idx"] + file_path = self._file_paths[slide_idx] + slide_name = Path(file_path).stem + + target_value = self.annotations[slide_name] + tensor = torch.tensor([target_value], dtype=torch.float32) + return tensor diff --git a/src/eva/vision/data/datasets/tiger.py b/src/eva/vision/data/datasets/tiger.py new file mode 100644 index 000000000..12d34cfed --- /dev/null +++ b/src/eva/vision/data/datasets/tiger.py @@ -0,0 +1,129 @@ +"""Abstract base class for TIGER datasets spanning different task types.""" + +import abc +import glob +import os +from typing import Any, Callable, Dict, List, Literal, Tuple + +import numpy as np +import torch +from torchvision import tv_tensors +from torchvision.transforms.v2 import functional +from typing_extensions import override + +from eva.vision.data.datasets import _validators, vision, wsi +from eva.vision.data.wsi.patching import samplers + + +class TIGERBase( + wsi.MultiWsiDataset, + vision.VisionDataset[tv_tensors.Image, torch.Tensor], + abc.ABC, +): + """Abstract base class for TIGER datasets spanning different task types.""" + + _train_split_ratio: float = 0.7 + _val_split_ratio: float = 0.15 + + _target_mpp: float = 0.5 + '''Target microns per pixel (mpp) for patches''' + + def __init__( + self, + root: str, + sampler: samplers.Sampler, + split: Literal["train", "val", "test"] | None = None, + width: int = 224, + height: int = 224, + backend: str = "openslide", + image_transforms: Callable | None = None, + coords_path: str | None = None, + seed: int = 42, + ) -> None: + """Initializes the dataset. + + Args: + root: Root directory of the dataset. + sampler: The sampler to use for sampling patch coordinates. + split: Dataset split to use. If `None`, the entire dataset is used. + width: Patch width in pixels. + height: Patch height in pixels. + backend: WSI reading backend. + image_transforms: Transforms to apply to patches. + coords_path: Optional path to save patch coordinates. + seed: Random seed. + """ + self._root = root + self._split = split + self._width = width + self._height = height + self._seed = seed + + wsi.MultiWsiDataset.__init__( + self, + root=root, + file_paths=self._load_file_paths(split), + width=width, + height=height, + sampler=sampler, + target_mpp=self._target_mpp, + backend=backend, + image_transforms=image_transforms, + coords_path=coords_path, + ) + + @override + def prepare_data(self) -> None: + _validators.check_dataset_exists(self._root, False) + + @override + def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]: + return vision.VisionDataset.__getitem__(self, index) + + @override + def load_data(self, index: int) -> tv_tensors.Image: + image_array = wsi.MultiWsiDataset.__getitem__(self, index) + return functional.to_image(image_array) + + @override + def load_metadata(self, index: int) -> Dict[str, Any]: + return wsi.MultiWsiDataset.load_metadata(self, index) + + @abc.abstractmethod + def annotations(self) -> Dict[str, Any]: + """Annotates target data.""" + raise NotImplementedError + + @abc.abstractmethod + def load_target(self, index: int): + """Task-specific target loading.""" + raise NotImplementedError + + def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]: + """Loads the file paths of WSIs from wsibulk/images. + + Splits are assigned 70% train, 15% val, 15% test by filename sorting. + """ + image_dir = os.path.join(self._root, "images") + all_paths = sorted(glob.glob(os.path.join(image_dir, "*.tif"))) + + if not all_paths: + raise FileNotFoundError(f"No .tif files found in {image_dir}") + + rng = np.random.default_rng(self._seed) # nosec B311 + rng.shuffle(all_paths) + + n_total = len(all_paths) + n_train = int(n_total * self._train_split_ratio) + n_val = int(n_total * self._val_split_ratio) + + if split == "train": + selected_paths = all_paths[:n_train] + elif split == "val": + selected_paths = all_paths[n_train : n_train + n_val] + elif split == "test": + selected_paths = all_paths[n_train + n_val :] + elif split is None: + selected_paths = all_paths + + return [os.path.relpath(path, self._root) for path in selected_paths] diff --git a/src/eva/vision/data/datasets/wsi.py b/src/eva/vision/data/datasets/wsi.py index 4c1c789a3..8e31d5644 100644 --- a/src/eva/vision/data/datasets/wsi.py +++ b/src/eva/vision/data/datasets/wsi.py @@ -179,7 +179,11 @@ def load_metadata(self, index: int) -> Dict[str, Any]: """Loads the metadata for the patch at the specified index.""" dataset_index, sample_index = self._get_dataset_idx(index), self._get_sample_idx(index) patch_metadata = self.datasets[dataset_index].load_metadata(sample_index) - return {"wsi_id": self.filename(index).split(".")[0]} | patch_metadata + return { + "wsi_id": self.filename(index).split(".")[0], + "slide_idx": dataset_index, + "patch_idx": sample_index, + } | patch_metadata def _load_datasets(self) -> list[WsiDataset]: logger.info(f"Initializing dataset with {len(self._file_paths)} WSIs ...") diff --git a/src/eva/vision/models/networks/abmil.py b/src/eva/vision/models/networks/abmil.py index bb2ca4820..03553bef1 100644 --- a/src/eva/vision/models/networks/abmil.py +++ b/src/eva/vision/models/networks/abmil.py @@ -34,8 +34,8 @@ class ABMIL(torch.nn.Module): def __init__( self, input_size: int, - output_size: int, projected_input_size: int | None, + output_size: int = 1, hidden_size_attention: int = 128, hidden_sizes_mlp: tuple = (128, 64), use_bias: bool = True,