Skip to content

Commit 23a7ff6

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add clone method to datasets (#2625)
Summary: This makes it far easier to obtain slices of different kinds of datasets (Supervised, MultiTask, Contextual), which will be helpful for things like doing LOOCV MBM in Ax. Reviewed By: saitcakmak Differential Revision: D65616941
1 parent 3c2ce15 commit 23a7ff6

File tree

3 files changed

+344
-57
lines changed

3 files changed

+344
-57
lines changed

botorch/utils/containers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88

99
from __future__ import annotations
1010

11+
import dataclasses
12+
1113
from abc import ABC, abstractmethod
1214
from dataclasses import dataclass, fields
1315
from typing import Any
1416

17+
import torch
18+
1519
from torch import device as Device, dtype as Dtype, LongTensor, Size, Tensor
1620

1721

@@ -102,6 +106,9 @@ def _validate(self) -> None:
102106
f"`event shape` {self.event_shape}."
103107
)
104108

109+
def clone(self) -> DenseContainer:
110+
return dataclasses.replace(self)
111+
105112

106113
@dataclass(eq=False)
107114
class SliceContainer(BotorchContainer):
@@ -149,3 +156,10 @@ def _validate(self) -> None:
149156
f"Shapes of `values` {values.shape} and `indices` "
150157
f"{indices.shape} incompatible with `event_shape` {event_shape}."
151158
)
159+
160+
def clone(self) -> SliceContainer:
161+
return type(self)(
162+
values=self.values.clone(),
163+
indices=self.indices.clone(),
164+
event_shape=torch.Size(self.event_shape),
165+
)

botorch/utils/datasets.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
from __future__ import annotations
1010

11+
import copy
12+
13+
import warnings
1114
from typing import Any
1215

1316
import torch
@@ -70,6 +73,7 @@ def __init__(
7073
self._Yvar = Yvar
7174
self.feature_names = feature_names
7275
self.outcome_names = outcome_names
76+
self.validate_init = validate_init
7377
if validate_init:
7478
self._validate()
7579

@@ -147,6 +151,50 @@ def __eq__(self, other: Any) -> bool:
147151
and self.outcome_names == other.outcome_names
148152
)
149153

154+
def clone(
155+
self, deepcopy: bool = False, mask: Tensor | None = None
156+
) -> SupervisedDataset:
157+
"""Return a copy of the dataset.
158+
159+
Args:
160+
deepcopy: If True, perform a deep copy. Otherwise, use the same tensors/lists.
161+
mask: A `n`-dim boolean mask indicating which rows to keep. This is used along the -2 dimension.
162+
163+
Returns:
164+
The new dataset.
165+
"""
166+
new_X = self._X
167+
new_Y = self._Y
168+
new_Yvar = self._Yvar
169+
feature_names = self.feature_names
170+
outcome_names = self.outcome_names
171+
if mask is not None:
172+
if any(isinstance(x, BotorchContainer) for x in [new_X, new_Y, new_Yvar]):
173+
raise NotImplementedError(
174+
"Masking is not supported for BotorchContainers."
175+
)
176+
new_X = new_X[..., mask, :]
177+
new_Y = new_Y[..., mask, :]
178+
if new_Yvar is not None:
179+
new_Yvar = new_Yvar[..., mask, :]
180+
if deepcopy:
181+
new_X = new_X.clone()
182+
new_Y = new_Y.clone()
183+
new_Yvar = new_Yvar.clone() if new_Yvar is not None else None
184+
feature_names = copy.copy(self.feature_names)
185+
outcome_names = copy.copy(self.outcome_names)
186+
kwargs = {}
187+
if new_Yvar is not None:
188+
kwargs = {"Yvar": new_Yvar}
189+
return type(self)(
190+
X=new_X,
191+
Y=new_Y,
192+
feature_names=feature_names,
193+
outcome_names=outcome_names,
194+
validate_init=self.validate_init,
195+
**kwargs,
196+
)
197+
150198

151199
class RankingDataset(SupervisedDataset):
152200
r"""A SupervisedDataset whose labelled pairs `(x, y)` consist of m-ary combinations
@@ -339,7 +387,7 @@ def from_joint_dataset(
339387
outcome_names=[outcome_name],
340388
)
341389
datasets.append(new_dataset)
342-
# Return the new
390+
# Return the new dataset
343391
return cls(
344392
datasets=datasets,
345393
target_outcome_name=outcome_names_per_task.get(
@@ -466,6 +514,35 @@ def __eq__(self, other: Any) -> bool:
466514
and self.task_feature_index == other.task_feature_index
467515
)
468516

517+
def clone(
518+
self, deepcopy: bool = False, mask: Tensor | None = None
519+
) -> MultiTaskDataset:
520+
"""Return a copy of the dataset.
521+
522+
Args:
523+
deepcopy: If True, perform a deep copy. Otherwise, use the same tensors/lists/datasets.
524+
mask: A `n`-dim boolean mask indicating which rows to keep from the target dataset. This is used along the -2 dimension.
525+
526+
Returns:
527+
The new dataset.
528+
"""
529+
datasets = list(self.datasets.values())
530+
if mask is not None or deepcopy:
531+
new_datasets = []
532+
for outcome, ds in self.datasets.items():
533+
new_datasets.append(
534+
ds.clone(
535+
deepcopy=deepcopy,
536+
mask=mask if outcome == self.target_outcome_name else None,
537+
)
538+
)
539+
datasets = new_datasets
540+
return MultiTaskDataset(
541+
datasets=datasets,
542+
target_outcome_name=self.target_outcome_name,
543+
task_feature_index=self.task_feature_index,
544+
)
545+
469546

470547
class ContextualDataset(SupervisedDataset):
471548
"""This is a contextual dataset that is constructed from either a single
@@ -627,3 +704,31 @@ def _validate_decompositions(self) -> None:
627704
raise InputDataError(
628705
f"{outcome} is missing in metric_decomposition."
629706
)
707+
708+
def clone(
709+
self, deepcopy: bool = False, mask: Tensor | None = None
710+
) -> ContextualDataset:
711+
"""Return a copy of the dataset.
712+
713+
Args:
714+
deepcopy: If True, perform a deep copy. Otherwise, use the same tensors/lists/datasets.
715+
mask: A `n`-dim boolean mask indicating which rows to keep. This is used along the -2
716+
dimension. `n` here corresponds to the number of rows in an individual dataset.
717+
718+
Returns:
719+
The new dataset.
720+
"""
721+
datasets = list(self.datasets.values())
722+
if mask is not None or deepcopy:
723+
datasets = [ds.clone(deepcopy=deepcopy, mask=mask) for ds in datasets]
724+
if deepcopy:
725+
parameter_decomposition = copy.deepcopy(self.parameter_decomposition)
726+
metric_decomposition = copy.deepcopy(self.metric_decomposition)
727+
else:
728+
parameter_decomposition = self.parameter_decomposition
729+
metric_decomposition = self.metric_decomposition
730+
return ContextualDataset(
731+
datasets=datasets,
732+
parameter_decomposition=parameter_decomposition,
733+
metric_decomposition=metric_decomposition,
734+
)

0 commit comments

Comments
 (0)