|
1 | | -"""Dataset class for where a sample corresponds to multiple embeddings.""" |
2 | | - |
3 | | -import os |
4 | | -from typing import Callable, Dict, List, Literal |
| 1 | +"""Dataset class for where a classification task sample corresponds to multiple embeddings.""" |
5 | 2 |
|
6 | 3 | import numpy as np |
7 | | -import torch |
8 | | -from typing_extensions import override |
9 | 4 |
|
10 | | -from eva.core.data.datasets import embeddings as embeddings_base |
| 5 | +from eva.core.data.datasets.multi_embeddings import MultiEmbeddingsDataset |
11 | 6 |
|
12 | 7 |
|
13 | | -class MultiEmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]): |
| 8 | +class MultiEmbeddingsClassificationDataset(MultiEmbeddingsDataset): |
14 | 9 | """Dataset class for where a sample corresponds to multiple embeddings. |
15 | 10 |
|
16 | | - Example use case: Slide level dataset where each slide has multiple patch embeddings. |
| 11 | + Specialised for classification data with an int target type. |
17 | 12 | """ |
18 | 13 |
|
19 | | - def __init__( |
20 | | - self, |
21 | | - root: str, |
22 | | - manifest_file: str, |
23 | | - split: Literal["train", "val", "test"], |
24 | | - column_mapping: Dict[str, str] = embeddings_base.default_column_mapping, |
25 | | - embeddings_transforms: Callable | None = None, |
26 | | - target_transforms: Callable | None = None, |
27 | | - ): |
28 | | - """Initialize dataset. |
29 | | -
|
30 | | - Expects a manifest file listing the paths of `.pt` files containing tensor embeddings. |
31 | | -
|
32 | | - The manifest must have a `column_mapping["multi_id"]` column that contains the |
33 | | - unique identifier group of embeddings. For oncology datasets, this would be usually |
34 | | - the slide id. Each row in the manifest file points to a .pt file that can contain |
35 | | - one or multiple embeddings (either as a list or stacked tensors). There can also be |
36 | | - multiple rows for the same `multi_id`, in which case the embeddings from the different |
37 | | - .pt files corresponding to that same `multi_id` will be stacked along the first dimension. |
38 | | -
|
39 | | - Args: |
40 | | - root: Root directory of the dataset. |
41 | | - manifest_file: The path to the manifest file, which is relative to |
42 | | - the `root` argument. |
43 | | - split: The dataset split to use. The `split` column of the manifest |
44 | | - file will be splitted based on this value. |
45 | | - column_mapping: Defines the map between the variables and the manifest |
46 | | - columns. It will overwrite the `default_column_mapping` with |
47 | | - the provided values, so that `column_mapping` can contain only the |
48 | | - values which are altered or missing. |
49 | | - embeddings_transforms: A function/transform that transforms the embedding. |
50 | | - target_transforms: A function/transform that transforms the target. |
51 | | - """ |
52 | | - super().__init__( |
53 | | - manifest_file=manifest_file, |
54 | | - root=root, |
55 | | - split=split, |
56 | | - column_mapping=column_mapping, |
57 | | - embeddings_transforms=embeddings_transforms, |
58 | | - target_transforms=target_transforms, |
59 | | - ) |
60 | | - |
61 | | - self._multi_ids: List[int] |
62 | | - |
63 | | - @override |
64 | | - def setup(self): |
65 | | - super().setup() |
66 | | - self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique()) |
67 | | - |
68 | | - @override |
69 | | - def load_embeddings(self, index: int) -> torch.Tensor: |
70 | | - """Loads and stacks all embedding corresponding to the `index`'th multi_id.""" |
71 | | - # Get all embeddings for the given index (multi_id) |
72 | | - multi_id = self._multi_ids[index] |
73 | | - embedding_paths = self._data.loc[ |
74 | | - self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"] |
75 | | - ].to_list() |
76 | | - |
77 | | - # Load embeddings and stack them accross the first dimension |
78 | | - embeddings = [] |
79 | | - for path in embedding_paths: |
80 | | - embedding = torch.load(os.path.join(self._root, path), map_location="cpu") |
81 | | - if isinstance(embedding, list): |
82 | | - embedding = torch.stack(embedding, dim=0) |
83 | | - embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding) |
84 | | - embeddings = torch.cat(embeddings, dim=0) |
85 | | - |
86 | | - if not embeddings.ndim == 2: |
87 | | - raise ValueError(f"Expected 2D tensor, got {embeddings.ndim} for {multi_id}.") |
88 | | - |
89 | | - return embeddings |
90 | | - |
91 | | - @override |
92 | | - def load_target(self, index: int) -> np.ndarray: |
93 | | - """Returns the target corresponding to the `index`'th multi_id. |
94 | | -
|
95 | | - This method assumes that all the embeddings corresponding to the same `multi_id` |
96 | | - have the same target. If this is not the case, it will raise an error. |
97 | | - """ |
98 | | - multi_id = self._multi_ids[index] |
99 | | - targets = self._data.loc[ |
100 | | - self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"] |
101 | | - ] |
102 | | - |
103 | | - if not targets.nunique() == 1: |
104 | | - raise ValueError(f"Multiple targets found for {multi_id}.") |
105 | | - |
106 | | - return np.asarray(targets.iloc[0], dtype=np.int64) |
107 | | - |
108 | | - @override |
109 | | - def __len__(self) -> int: |
110 | | - return len(self._multi_ids) |
| 14 | + def __init__(self, *args, **kwargs): |
| 15 | + """Initialize dataset with the correct return type.""" |
| 16 | + super().__init__(*args, target_type=np.int64, **kwargs) |
0 commit comments