Skip to content

Commit fc336a3

Browse files
JklubienskiJklubienski
authored andcommitted
Refactor for clarity and address suggested changes
1 parent b689cb1 commit fc336a3

File tree

13 files changed

+292
-364
lines changed

13 files changed

+292
-364
lines changed

configs/vision/pathology/offline/regression/tiger_til_score.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ trainer:
1919
filename: best
2020
save_last: ${oc.env:SAVE_LAST, false}
2121
save_top_k: 1
22-
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MeanAbsoluteError}
22+
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MAE}
2323
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
2424
- class_path: lightning.pytorch.callbacks.EarlyStopping
2525
init_args:
@@ -40,7 +40,7 @@ trainer:
4040
init_args:
4141
model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino}
4242
model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null}
43-
overwrite: false
43+
overwrite: true
4444
logger:
4545
- class_path: lightning.pytorch.loggers.TensorBoardLogger
4646
init_args:
@@ -53,8 +53,6 @@ model:
5353
class_path: eva.vision.models.networks.ABMIL
5454
init_args:
5555
input_size: ${oc.env:IN_FEATURES, 384}
56-
output_size: &NUM_CLASSES 1
57-
# task: regression
5856
criterion: torch.nn.MSELoss
5957
optimizer:
6058
class_path: torch.optim.AdamW
@@ -83,7 +81,10 @@ data:
8381
init_args:
8482
pad_size: &N_PATCHES ${oc.env:N_PATCHES, 200}
8583
target_transforms:
86-
class_path: eva.core.data.transforms.dtype.SqueezeTensor
84+
class_path: eva.vision.data.transforms.common.Squeeze
85+
init_args:
86+
dim: -1
87+
8788
val:
8889
class_path: eva.datasets.MultiEmbeddingsRegressionDataset
8990
init_args:
Lines changed: 7 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,16 @@
1-
"""Dataset class for where a sample corresponds to multiple embeddings."""
2-
3-
import os
4-
from typing import Callable, Dict, List, Literal
1+
"""Dataset class for where a classification task sample corresponds to multiple embeddings."""
52

63
import numpy as np
7-
import torch
8-
from typing_extensions import override
94

10-
from eva.core.data.datasets import embeddings as embeddings_base
5+
from eva.core.data.datasets.multi_embeddings import MultiEmbeddingsDataset
116

127

13-
class MultiEmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
8+
class MultiEmbeddingsClassificationDataset(MultiEmbeddingsDataset):
149
"""Dataset class for where a sample corresponds to multiple embeddings.
1510
16-
Example use case: Slide level dataset where each slide has multiple patch embeddings.
11+
Specialised for classification data with an int target type.
1712
"""
1813

19-
def __init__(
20-
self,
21-
root: str,
22-
manifest_file: str,
23-
split: Literal["train", "val", "test"],
24-
column_mapping: Dict[str, str] = embeddings_base.default_column_mapping,
25-
embeddings_transforms: Callable | None = None,
26-
target_transforms: Callable | None = None,
27-
):
28-
"""Initialize dataset.
29-
30-
Expects a manifest file listing the paths of `.pt` files containing tensor embeddings.
31-
32-
The manifest must have a `column_mapping["multi_id"]` column that contains the
33-
unique identifier group of embeddings. For oncology datasets, this would be usually
34-
the slide id. Each row in the manifest file points to a .pt file that can contain
35-
one or multiple embeddings (either as a list or stacked tensors). There can also be
36-
multiple rows for the same `multi_id`, in which case the embeddings from the different
37-
.pt files corresponding to that same `multi_id` will be stacked along the first dimension.
38-
39-
Args:
40-
root: Root directory of the dataset.
41-
manifest_file: The path to the manifest file, which is relative to
42-
the `root` argument.
43-
split: The dataset split to use. The `split` column of the manifest
44-
file will be splitted based on this value.
45-
column_mapping: Defines the map between the variables and the manifest
46-
columns. It will overwrite the `default_column_mapping` with
47-
the provided values, so that `column_mapping` can contain only the
48-
values which are altered or missing.
49-
embeddings_transforms: A function/transform that transforms the embedding.
50-
target_transforms: A function/transform that transforms the target.
51-
"""
52-
super().__init__(
53-
manifest_file=manifest_file,
54-
root=root,
55-
split=split,
56-
column_mapping=column_mapping,
57-
embeddings_transforms=embeddings_transforms,
58-
target_transforms=target_transforms,
59-
)
60-
61-
self._multi_ids: List[int]
62-
63-
@override
64-
def setup(self):
65-
super().setup()
66-
self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique())
67-
68-
@override
69-
def load_embeddings(self, index: int) -> torch.Tensor:
70-
"""Loads and stacks all embedding corresponding to the `index`'th multi_id."""
71-
# Get all embeddings for the given index (multi_id)
72-
multi_id = self._multi_ids[index]
73-
embedding_paths = self._data.loc[
74-
self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"]
75-
].to_list()
76-
77-
# Load embeddings and stack them accross the first dimension
78-
embeddings = []
79-
for path in embedding_paths:
80-
embedding = torch.load(os.path.join(self._root, path), map_location="cpu")
81-
if isinstance(embedding, list):
82-
embedding = torch.stack(embedding, dim=0)
83-
embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding)
84-
embeddings = torch.cat(embeddings, dim=0)
85-
86-
if not embeddings.ndim == 2:
87-
raise ValueError(f"Expected 2D tensor, got {embeddings.ndim} for {multi_id}.")
88-
89-
return embeddings
90-
91-
@override
92-
def load_target(self, index: int) -> np.ndarray:
93-
"""Returns the target corresponding to the `index`'th multi_id.
94-
95-
This method assumes that all the embeddings corresponding to the same `multi_id`
96-
have the same target. If this is not the case, it will raise an error.
97-
"""
98-
multi_id = self._multi_ids[index]
99-
targets = self._data.loc[
100-
self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"]
101-
]
102-
103-
if not targets.nunique() == 1:
104-
raise ValueError(f"Multiple targets found for {multi_id}.")
105-
106-
return np.asarray(targets.iloc[0], dtype=np.int64)
107-
108-
@override
109-
def __len__(self) -> int:
110-
return len(self._multi_ids)
14+
def __init__(self, *args, **kwargs):
15+
"""Initialize dataset with the correct return type."""
16+
super().__init__(*args, target_type=np.int64, **kwargs)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""Dataset class for where a sample corresponds to multiple embeddings."""
2+
3+
import os
4+
from typing import Any, Callable, Dict, List, Literal
5+
6+
import numpy as np
7+
import numpy.typing as npt
8+
import torch
9+
from typing_extensions import override
10+
11+
from eva.core.data.datasets import embeddings as embeddings_base
12+
13+
14+
class MultiEmbeddingsDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
15+
"""Dataset class for where a sample corresponds to multiple embeddings.
16+
17+
Example use case: Slide level dataset where each slide has multiple patch embeddings.
18+
"""
19+
20+
def __init__(
21+
self,
22+
root: str,
23+
manifest_file: str,
24+
split: Literal["train", "val", "test"],
25+
column_mapping: Dict[str, str] = embeddings_base.default_column_mapping,
26+
embeddings_transforms: Callable | None = None,
27+
target_transforms: Callable | None = None,
28+
target_type: type[np.generic] = np.int64,
29+
):
30+
"""Initialize dataset.
31+
32+
Expects a manifest file listing the paths of `.pt` files containing tensor embeddings.
33+
34+
The manifest must have a `column_mapping["multi_id"]` column that contains the
35+
unique identifier group of embeddings. For oncology datasets, this would be usually
36+
the slide id. Each row in the manifest file points to a .pt file that can contain
37+
one or multiple embeddings (either as a list or stacked tensors). There can also be
38+
multiple rows for the same `multi_id`, in which case the embeddings from the different
39+
.pt files corresponding to that same `multi_id` will be stacked along the first dimension.
40+
41+
Args:
42+
root: Root directory of the dataset.
43+
manifest_file: The path to the manifest file, which is relative to
44+
the `root` argument.
45+
split: The dataset split to use. The `split` column of the manifest
46+
file will be splitted based on this value.
47+
column_mapping: Defines the map between the variables and the manifest
48+
columns. It will overwrite the `default_column_mapping` with
49+
the provided values, so that `column_mapping` can contain only the
50+
values which are altered or missing.
51+
embeddings_transforms: A function/transform that transforms the embedding.
52+
target_transforms: A function/transform that transforms the target.
53+
target_type: Desired type of the target data
54+
"""
55+
super().__init__(
56+
manifest_file=manifest_file,
57+
root=root,
58+
split=split,
59+
column_mapping=column_mapping,
60+
embeddings_transforms=embeddings_transforms,
61+
target_transforms=target_transforms,
62+
)
63+
64+
self._multi_ids: List[int]
65+
self._target_type = target_type
66+
67+
@override
68+
def setup(self):
69+
super().setup()
70+
self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique())
71+
72+
@override
73+
def load_embeddings(self, index: int) -> torch.Tensor:
74+
"""Loads and stacks all embedding corresponding to the `index`'th multi_id."""
75+
# Get all embeddings for the given index (multi_id)
76+
multi_id = self._multi_ids[index]
77+
embedding_paths = self._data.loc[
78+
self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"]
79+
].to_list()
80+
81+
# Load embeddings and stack them accross the first dimension
82+
embeddings = []
83+
for path in embedding_paths:
84+
embedding = torch.load(os.path.join(self._root, path), map_location="cpu")
85+
if isinstance(embedding, list):
86+
embedding = torch.stack(embedding, dim=0)
87+
embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding)
88+
embeddings = torch.cat(embeddings, dim=0)
89+
90+
if not embeddings.ndim == 2:
91+
raise ValueError(f"Expected 2D tensor, got {embeddings.ndim} for {multi_id}.")
92+
93+
return embeddings
94+
95+
@override
96+
def load_target(self, index: int) -> npt.NDArray[Any]:
97+
"""Returns the target corresponding to the `index`'th multi_id.
98+
99+
This method assumes that all the embeddings corresponding to the same `multi_id`
100+
have the same target. If this is not the case, it will raise an error.
101+
"""
102+
multi_id = self._multi_ids[index]
103+
targets = self._data.loc[
104+
self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"]
105+
]
106+
107+
if not targets.nunique() == 1:
108+
raise ValueError(f"Multiple targets found for {multi_id}.")
109+
110+
return np.asarray(targets.iloc[0], dtype=self._target_type)
111+
112+
@override
113+
def __len__(self) -> int:
114+
return len(self._multi_ids)
Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,15 @@
11
"""Embeddings regression dataset."""
22

3-
import os
4-
53
import torch
64
from typing_extensions import override
75

8-
from eva.core.data.datasets import embeddings as embeddings_base
9-
6+
from eva.core.data.datasets.classification import EmbeddingsClassificationDataset
107

11-
class EmbeddingsRegressionDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
12-
"""Embeddings dataset class for regression tasks.
138

14-
NOTE: This barely changes from the EmbeddingsClassificationDataset
15-
but they have been kept apart for abstraction
16-
17-
"""
18-
19-
@override
20-
def load_embeddings(self, index: int) -> torch.Tensor:
21-
filename = self.filename(index)
22-
embeddings_path = os.path.join(self._root, filename)
23-
tensor = torch.load(embeddings_path, map_location="cpu")
24-
if isinstance(tensor, list):
25-
if len(tensor) > 1:
26-
raise ValueError(
27-
f"Expected a single tensor in the .pt file, but found {len(tensor)}."
28-
)
29-
tensor = tensor[0]
30-
return tensor.squeeze(0)
9+
class EmbeddingsRegressionDataset(EmbeddingsClassificationDataset):
10+
"""Embeddings dataset class for regression tasks."""
3111

3212
@override
3313
def load_target(self, index: int) -> torch.Tensor:
3414
target = self._data.at[index, self._column_mapping["target"]]
3515
return torch.tensor(float(target), dtype=torch.float32)
36-
37-
@override
38-
def __len__(self) -> int:
39-
return len(self._data)

0 commit comments

Comments
 (0)