Skip to content

Commit b689cb1

Browse files
JklubienskiJklubienski
authored andcommitted
Implement TIGER TIL regression task
1 parent 986760e commit b689cb1

File tree

15 files changed

+516
-8
lines changed

15 files changed

+516
-8
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
---
2+
trainer:
3+
class_path: eva.Trainer
4+
init_args:
5+
n_runs: &N_RUNS ${oc.env:N_RUNS, 20}
6+
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/tiger_til}
7+
max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 100}
8+
checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best}
9+
callbacks:
10+
- class_path: eva.callbacks.ConfigurationLogger
11+
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
12+
init_args:
13+
refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1}
14+
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
15+
init_args:
16+
logging_interval: epoch
17+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
18+
init_args:
19+
filename: best
20+
save_last: ${oc.env:SAVE_LAST, false}
21+
save_top_k: 1
22+
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MeanAbsoluteError}
23+
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
24+
- class_path: lightning.pytorch.callbacks.EarlyStopping
25+
init_args:
26+
min_delta: 0
27+
patience: ${oc.env:PATIENCE, 20}
28+
monitor: *MONITOR_METRIC
29+
mode: *MONITOR_METRIC_MODE
30+
- class_path: eva.callbacks.ClassificationEmbeddingsWriter
31+
init_args:
32+
output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings/${oc.env:MODEL_NAME, dino_vits16}/tiger_til}
33+
dataloader_idx_map:
34+
0: train
35+
1: val
36+
2: test
37+
metadata_keys: ["wsi_id"]
38+
backbone:
39+
class_path: eva.vision.models.ModelFromRegistry
40+
init_args:
41+
model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino}
42+
model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null}
43+
overwrite: false
44+
logger:
45+
- class_path: lightning.pytorch.loggers.TensorBoardLogger
46+
init_args:
47+
save_dir: *OUTPUT_ROOT
48+
name: ""
49+
model:
50+
class_path: eva.HeadModule
51+
init_args:
52+
head:
53+
class_path: eva.vision.models.networks.ABMIL
54+
init_args:
55+
input_size: ${oc.env:IN_FEATURES, 384}
56+
output_size: &NUM_CLASSES 1
57+
# task: regression
58+
criterion: torch.nn.MSELoss
59+
optimizer:
60+
class_path: torch.optim.AdamW
61+
init_args:
62+
lr: ${oc.env:LR_VALUE, 0.001}
63+
betas: [0.9, 0.999]
64+
metrics:
65+
common:
66+
- class_path: eva.core.metrics.AverageLoss
67+
- class_path: eva.core.metrics.RegressionMetrics
68+
init_args:
69+
prefix: null
70+
postfix: null
71+
data:
72+
class_path: eva.DataModule
73+
init_args:
74+
datasets:
75+
train:
76+
class_path: eva.datasets.MultiEmbeddingsRegressionDataset
77+
init_args: &DATASET_ARGS
78+
root: *DATASET_EMBEDDINGS_ROOT
79+
manifest_file: manifest.csv
80+
split: train
81+
embeddings_transforms:
82+
class_path: eva.core.data.transforms.Pad2DTensor
83+
init_args:
84+
pad_size: &N_PATCHES ${oc.env:N_PATCHES, 200}
85+
target_transforms:
86+
class_path: eva.core.data.transforms.dtype.SqueezeTensor
87+
val:
88+
class_path: eva.datasets.MultiEmbeddingsRegressionDataset
89+
init_args:
90+
<<: *DATASET_ARGS
91+
split: val
92+
test:
93+
class_path: eva.datasets.MultiEmbeddingsRegressionDataset
94+
init_args:
95+
<<: *DATASET_ARGS
96+
split: test
97+
predict:
98+
- class_path: eva.vision.datasets.TIGERTILScore
99+
init_args: &PREDICT_DATASET_ARGS
100+
root: ${oc.env:DATA_ROOT, ./data/training/wsitils}
101+
sampler:
102+
class_path: eva.vision.data.wsi.patching.samplers.ForegroundGridSampler
103+
init_args:
104+
max_samples: *N_PATCHES
105+
width: 224
106+
height: 224
107+
target_mpp: 0.5
108+
split: train
109+
coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv
110+
image_transforms:
111+
class_path: eva.vision.data.transforms.common.ResizeAndCrop
112+
init_args:
113+
size: ${oc.env:RESIZE_DIM, 224}
114+
mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]}
115+
std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]}
116+
- class_path: eva.vision.datasets.TIGERTILScore
117+
init_args:
118+
<<: *PREDICT_DATASET_ARGS
119+
split: val
120+
- class_path: eva.vision.datasets.TIGERTILScore
121+
init_args:
122+
<<: *PREDICT_DATASET_ARGS
123+
split: test
124+
dataloaders:
125+
train:
126+
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 32}
127+
num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4}
128+
shuffle: true
129+
val:
130+
batch_size: *BATCH_SIZE
131+
num_workers: *N_DATA_WORKERS
132+
test:
133+
batch_size: *BATCH_SIZE
134+
num_workers: *N_DATA_WORKERS
135+
predict:
136+
batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64}
137+
num_workers: *N_DATA_WORKERS

src/eva/core/data/datasets/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,19 @@
66
MultiEmbeddingsClassificationDataset,
77
)
88
from eva.core.data.datasets.dataset import TorchDataset
9+
from eva.core.data.datasets.regression import (
10+
EmbeddingsRegressionDataset,
11+
MultiEmbeddingsRegressionDataset,
12+
)
913
from eva.core.data.datasets.typings import DataSample
1014

1115
__all__ = [
1216
"Dataset",
1317
"MapDataset",
1418
"EmbeddingsClassificationDataset",
1519
"MultiEmbeddingsClassificationDataset",
20+
"EmbeddingsRegressionDataset",
21+
"MultiEmbeddingsRegressionDataset",
1622
"TorchDataset",
1723
"DataSample",
1824
]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Embedding regression datasets API."""
2+
3+
from eva.core.data.datasets.regression.embeddings import EmbeddingsRegressionDataset
4+
from eva.core.data.datasets.regression.multi_embeddings import MultiEmbeddingsRegressionDataset
5+
6+
__all__ = ["EmbeddingsRegressionDataset", "MultiEmbeddingsRegressionDataset"]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Embeddings regression dataset."""
2+
3+
import os
4+
5+
import torch
6+
from typing_extensions import override
7+
8+
from eva.core.data.datasets import embeddings as embeddings_base
9+
10+
11+
class EmbeddingsRegressionDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
12+
"""Embeddings dataset class for regression tasks.
13+
14+
NOTE: This barely changes from the EmbeddingsClassificationDataset
15+
but they have been kept apart for abstraction
16+
17+
"""
18+
19+
@override
20+
def load_embeddings(self, index: int) -> torch.Tensor:
21+
filename = self.filename(index)
22+
embeddings_path = os.path.join(self._root, filename)
23+
tensor = torch.load(embeddings_path, map_location="cpu")
24+
if isinstance(tensor, list):
25+
if len(tensor) > 1:
26+
raise ValueError(
27+
f"Expected a single tensor in the .pt file, but found {len(tensor)}."
28+
)
29+
tensor = tensor[0]
30+
return tensor.squeeze(0)
31+
32+
@override
33+
def load_target(self, index: int) -> torch.Tensor:
34+
target = self._data.at[index, self._column_mapping["target"]]
35+
return torch.tensor(float(target), dtype=torch.float32)
36+
37+
@override
38+
def __len__(self) -> int:
39+
return len(self._data)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Dataset class for where a sample corresponds to multiple embeddings (regression)."""
2+
3+
import os
4+
from typing import Callable, Dict, List, Literal
5+
6+
import torch
7+
from typing_extensions import override
8+
9+
from eva.core.data.datasets import embeddings as embeddings_base
10+
11+
12+
class MultiEmbeddingsRegressionDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
13+
"""Dataset class for regression with multiple embeddings per sample."""
14+
15+
def __init__(
16+
self,
17+
root: str,
18+
manifest_file: str,
19+
split: Literal["train", "val", "test"],
20+
column_mapping: Dict[str, str] = embeddings_base.default_column_mapping,
21+
embeddings_transforms: Callable | None = None,
22+
target_transforms: Callable | None = None,
23+
):
24+
"""Initialize dataset.
25+
26+
Expects a manifest file listing the paths of `.pt` files containing tensor embeddings.
27+
28+
The manifest must have a `column_mapping["multi_id"]` column that contains the
29+
unique identifier group of embeddings. For oncology datasets, this would be usually
30+
the slide id. Each row in the manifest file points to a .pt file that can contain
31+
one or multiple embeddings (either as a list or stacked tensors). There can also be
32+
multiple rows for the same `multi_id`, in which case the embeddings from the different
33+
.pt files corresponding to that same `multi_id` will be stacked along the first dimension.
34+
35+
Args:
36+
root: Root directory of the dataset.
37+
manifest_file: The path to the manifest file, which is relative to
38+
the `root` argument.
39+
split: The dataset split to use. The `split` column of the manifest
40+
file will be splitted based on this value.
41+
column_mapping: Defines the map between the variables and the manifest
42+
columns. It will overwrite the `default_column_mapping` with
43+
the provided values, so that `column_mapping` can contain only the
44+
values which are altered or missing.
45+
embeddings_transforms: A function/transform that transforms the embedding.
46+
target_transforms: A function/transform that transforms the target.
47+
"""
48+
super().__init__(
49+
manifest_file=manifest_file,
50+
root=root,
51+
split=split,
52+
column_mapping=column_mapping,
53+
embeddings_transforms=embeddings_transforms,
54+
target_transforms=target_transforms,
55+
)
56+
self._multi_ids: List[int]
57+
58+
@override
59+
def setup(self):
60+
super().setup()
61+
self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique())
62+
63+
@override
64+
def load_embeddings(self, index: int) -> torch.Tensor:
65+
"""Loads and stacks all embedding corresponding to the `index`'th multi_id."""
66+
# Get all embeddings for the given index (multi_id)
67+
multi_id = self._multi_ids[index]
68+
embedding_paths = self._data.loc[
69+
self._data[self._column_mapping["multi_id"]] == multi_id,
70+
self._column_mapping["path"],
71+
].to_list()
72+
73+
embeddings = []
74+
for path in embedding_paths:
75+
embedding = torch.load(os.path.join(self._root, path), map_location="cpu")
76+
if isinstance(embedding, list):
77+
embedding = torch.stack(embedding, dim=0)
78+
embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding)
79+
embeddings = torch.cat(embeddings, dim=0)
80+
81+
if embeddings.ndim != 2:
82+
raise ValueError(
83+
f"Expected 2D tensor, got \
84+
{embeddings.ndim} for {multi_id}."
85+
)
86+
87+
return embeddings
88+
89+
@override
90+
def load_target(self, index: int) -> torch.Tensor:
91+
"""Returns the target corresponding to the `index`'th multi_id.
92+
93+
This method assumes that all the embeddings corresponding to the same `multi_id`
94+
have the same target. If this is not the case, it will raise an error.
95+
"""
96+
multi_id = self._multi_ids[index]
97+
targets = self._data.loc[
98+
self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"]
99+
]
100+
101+
if not targets.nunique() == 1:
102+
raise ValueError(f"Multiple targets found for {multi_id}.")
103+
104+
return torch.tensor(targets.iloc[0], dtype=torch.float32)
105+
106+
@override
107+
def __len__(self) -> int:
108+
return len(self._multi_ids)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Core data transforms."""
22

3-
from eva.core.data.transforms.dtype import ArrayToFloatTensor, ArrayToTensor
3+
from eva.core.data.transforms.dtype import ArrayToFloatTensor, ArrayToTensor, SqueezeTensor
44
from eva.core.data.transforms.padding import Pad2DTensor
55
from eva.core.data.transforms.sampling import SampleFromAxis
66

7-
__all__ = ["ArrayToFloatTensor", "ArrayToTensor", "Pad2DTensor", "SampleFromAxis"]
7+
__all__ = ["ArrayToFloatTensor", "ArrayToTensor", "Pad2DTensor", "SampleFromAxis", "SqueezeTensor"]
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Type casting related transforms."""
22

33
from eva.core.data.transforms.dtype.array import ArrayToFloatTensor, ArrayToTensor
4+
from eva.core.data.transforms.dtype.tensor import SqueezeTensor
45

5-
__all__ = ["ArrayToFloatTensor", "ArrayToTensor"]
6+
__all__ = ["ArrayToFloatTensor", "ArrayToTensor", "SqueezeTensor"]
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Transformations to change the shape of tensors."""
2+
3+
import torch
4+
5+
6+
class SqueezeTensor:
7+
"""Squeezes a [B, 1] tensor to [B]."""
8+
9+
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
10+
"""Call method for the transformation.
11+
12+
Args:
13+
tensor: The input tensor to be squeezed.
14+
"""
15+
return tensor.squeeze(-1)

src/eva/core/metrics/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@
22

33
from eva.core.metrics.average_loss import AverageLoss
44
from eva.core.metrics.binary_balanced_accuracy import BinaryBalancedAccuracy
5-
from eva.core.metrics.defaults import BinaryClassificationMetrics, MulticlassClassificationMetrics
5+
from eva.core.metrics.defaults import (
6+
BinaryClassificationMetrics,
7+
MulticlassClassificationMetrics,
8+
RegressionMetrics,
9+
)
610
from eva.core.metrics.structs import Metric, MetricCollection, MetricModule, MetricsSchema
711

812
__all__ = [
913
"AverageLoss",
1014
"BinaryBalancedAccuracy",
1115
"BinaryClassificationMetrics",
1216
"MulticlassClassificationMetrics",
17+
"RegressionMetrics",
1318
"Metric",
1419
"MetricCollection",
1520
"MetricModule",

src/eva/core/metrics/defaults/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
BinaryClassificationMetrics,
55
MulticlassClassificationMetrics,
66
)
7+
from eva.core.metrics.defaults.regression import RegressionMetrics
78

8-
__all__ = [
9-
"MulticlassClassificationMetrics",
10-
"BinaryClassificationMetrics",
11-
]
9+
__all__ = ["MulticlassClassificationMetrics", "BinaryClassificationMetrics", "RegressionMetrics"]

0 commit comments

Comments
 (0)