|
8 | 8 |
|
9 | 9 | from __future__ import annotations
|
10 | 10 |
|
| 11 | +import copy |
| 12 | + |
| 13 | +import warnings |
11 | 14 | from typing import Any
|
12 | 15 |
|
13 | 16 | import torch
|
@@ -70,6 +73,7 @@ def __init__(
|
70 | 73 | self._Yvar = Yvar
|
71 | 74 | self.feature_names = feature_names
|
72 | 75 | self.outcome_names = outcome_names
|
| 76 | + self.validate_init = validate_init |
73 | 77 | if validate_init:
|
74 | 78 | self._validate()
|
75 | 79 |
|
@@ -147,6 +151,50 @@ def __eq__(self, other: Any) -> bool:
|
147 | 151 | and self.outcome_names == other.outcome_names
|
148 | 152 | )
|
149 | 153 |
|
| 154 | + def clone( |
| 155 | + self, deepcopy: bool = False, mask: Tensor | None = None |
| 156 | + ) -> SupervisedDataset: |
| 157 | + """Return a copy of the dataset. |
| 158 | +
|
| 159 | + Args: |
| 160 | + deepcopy: If True, perform a deep copy. Otherwise, use the same tensors/lists. |
| 161 | + mask: A `n`-dim boolean mask indicating which rows to keep. This is used along the -2 dimension. |
| 162 | +
|
| 163 | + Returns: |
| 164 | + The new dataset. |
| 165 | + """ |
| 166 | + new_X = self._X |
| 167 | + new_Y = self._Y |
| 168 | + new_Yvar = self._Yvar |
| 169 | + feature_names = self.feature_names |
| 170 | + outcome_names = self.outcome_names |
| 171 | + if mask is not None: |
| 172 | + if any(isinstance(x, BotorchContainer) for x in [new_X, new_Y, new_Yvar]): |
| 173 | + raise NotImplementedError( |
| 174 | + "Masking is not supported for BotorchContainers." |
| 175 | + ) |
| 176 | + new_X = new_X[..., mask, :] |
| 177 | + new_Y = new_Y[..., mask, :] |
| 178 | + if new_Yvar is not None: |
| 179 | + new_Yvar = new_Yvar[..., mask, :] |
| 180 | + if deepcopy: |
| 181 | + new_X = new_X.clone() |
| 182 | + new_Y = new_Y.clone() |
| 183 | + new_Yvar = new_Yvar.clone() if new_Yvar is not None else None |
| 184 | + feature_names = copy.copy(self.feature_names) |
| 185 | + outcome_names = copy.copy(self.outcome_names) |
| 186 | + kwargs = {} |
| 187 | + if new_Yvar is not None: |
| 188 | + kwargs = {"Yvar": new_Yvar} |
| 189 | + return type(self)( |
| 190 | + X=new_X, |
| 191 | + Y=new_Y, |
| 192 | + feature_names=feature_names, |
| 193 | + outcome_names=outcome_names, |
| 194 | + validate_init=self.validate_init, |
| 195 | + **kwargs, |
| 196 | + ) |
| 197 | + |
150 | 198 |
|
151 | 199 | class RankingDataset(SupervisedDataset):
|
152 | 200 | r"""A SupervisedDataset whose labelled pairs `(x, y)` consist of m-ary combinations
|
@@ -339,7 +387,7 @@ def from_joint_dataset(
|
339 | 387 | outcome_names=[outcome_name],
|
340 | 388 | )
|
341 | 389 | datasets.append(new_dataset)
|
342 |
| - # Return the new |
| 390 | + # Return the new dataset |
343 | 391 | return cls(
|
344 | 392 | datasets=datasets,
|
345 | 393 | target_outcome_name=outcome_names_per_task.get(
|
@@ -466,6 +514,35 @@ def __eq__(self, other: Any) -> bool:
|
466 | 514 | and self.task_feature_index == other.task_feature_index
|
467 | 515 | )
|
468 | 516 |
|
| 517 | + def clone( |
| 518 | + self, deepcopy: bool = False, mask: Tensor | None = None |
| 519 | + ) -> MultiTaskDataset: |
| 520 | + """Return a copy of the dataset. |
| 521 | +
|
| 522 | + Args: |
| 523 | + deepcopy: If True, perform a deep copy. Otherwise, use the same tensors/lists/datasets. |
| 524 | + mask: A `n`-dim boolean mask indicating which rows to keep from the target dataset. This is used along the -2 dimension. |
| 525 | +
|
| 526 | + Returns: |
| 527 | + The new dataset. |
| 528 | + """ |
| 529 | + datasets = list(self.datasets.values()) |
| 530 | + if mask is not None or deepcopy: |
| 531 | + new_datasets = [] |
| 532 | + for outcome, ds in self.datasets.items(): |
| 533 | + new_datasets.append( |
| 534 | + ds.clone( |
| 535 | + deepcopy=deepcopy, |
| 536 | + mask=mask if outcome == self.target_outcome_name else None, |
| 537 | + ) |
| 538 | + ) |
| 539 | + datasets = new_datasets |
| 540 | + return MultiTaskDataset( |
| 541 | + datasets=datasets, |
| 542 | + target_outcome_name=self.target_outcome_name, |
| 543 | + task_feature_index=self.task_feature_index, |
| 544 | + ) |
| 545 | + |
469 | 546 |
|
470 | 547 | class ContextualDataset(SupervisedDataset):
|
471 | 548 | """This is a contextual dataset that is constructed from either a single
|
@@ -627,3 +704,31 @@ def _validate_decompositions(self) -> None:
|
627 | 704 | raise InputDataError(
|
628 | 705 | f"{outcome} is missing in metric_decomposition."
|
629 | 706 | )
|
| 707 | + |
| 708 | + def clone( |
| 709 | + self, deepcopy: bool = False, mask: Tensor | None = None |
| 710 | + ) -> ContextualDataset: |
| 711 | + """Return a copy of the dataset. |
| 712 | +
|
| 713 | + Args: |
| 714 | + deepcopy: If True, perform a deep copy. Otherwise, use the same tensors/lists/datasets. |
| 715 | + mask: A `n`-dim boolean mask indicating which rows to keep. This is used along the -2 |
| 716 | + dimension. `n` here corresponds to the number of rows in an individual dataset. |
| 717 | +
|
| 718 | + Returns: |
| 719 | + The new dataset. |
| 720 | + """ |
| 721 | + datasets = list(self.datasets.values()) |
| 722 | + if mask is not None or deepcopy: |
| 723 | + datasets = [ds.clone(deepcopy=deepcopy, mask=mask) for ds in datasets] |
| 724 | + if deepcopy: |
| 725 | + parameter_decomposition = copy.deepcopy(self.parameter_decomposition) |
| 726 | + metric_decomposition = copy.deepcopy(self.metric_decomposition) |
| 727 | + else: |
| 728 | + parameter_decomposition = self.parameter_decomposition |
| 729 | + metric_decomposition = self.metric_decomposition |
| 730 | + return ContextualDataset( |
| 731 | + datasets=datasets, |
| 732 | + parameter_decomposition=parameter_decomposition, |
| 733 | + metric_decomposition=metric_decomposition, |
| 734 | + ) |
0 commit comments