Skip to content

Commit

Permalink
add clone method to datasets
Browse files Browse the repository at this point in the history
Summary: This makes it far easier to obtain slices of different kinds of datasets (Supervised, MultiTask, Contextual), which will be helpful for things like doing LOOCV MBM in Ax.

Differential Revision: D65616941
  • Loading branch information
sdaulton authored and facebook-github-bot committed Nov 14, 2024
1 parent 92d73e4 commit 9790fc7
Show file tree
Hide file tree
Showing 3 changed files with 343 additions and 57 deletions.
14 changes: 14 additions & 0 deletions botorch/utils/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@

from __future__ import annotations

import dataclasses

from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any

import torch

from torch import device as Device, dtype as Dtype, LongTensor, Size, Tensor


Expand Down Expand Up @@ -102,6 +106,9 @@ def _validate(self) -> None:
f"`event shape` {self.event_shape}."
)

def clone(self) -> DenseContainer:
return dataclasses.replace(self)

Check warning on line 110 in botorch/utils/containers.py

View check run for this annotation

Codecov / codecov/patch

botorch/utils/containers.py#L110

Added line #L110 was not covered by tests


@dataclass(eq=False)
class SliceContainer(BotorchContainer):
Expand Down Expand Up @@ -149,3 +156,10 @@ def _validate(self) -> None:
f"Shapes of `values` {values.shape} and `indices` "
f"{indices.shape} incompatible with `event_shape` {event_shape}."
)

def clone(self) -> SliceContainer:
return type(self)(
values=self.values.clone(),
indices=self.indices.clone(),
event_shape=torch.Size(self.event_shape),
)
106 changes: 105 additions & 1 deletion botorch/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from __future__ import annotations

import copy

import warnings
from typing import Any

Expand Down Expand Up @@ -71,6 +73,7 @@ def __init__(
self._Yvar = Yvar
self.feature_names = feature_names
self.outcome_names = outcome_names
self.validate_init = validate_init
if validate_init:
self._validate()

Expand Down Expand Up @@ -148,6 +151,50 @@ def __eq__(self, other: Any) -> bool:
and self.outcome_names == other.outcome_names
)

def clone(
self, deepcopy: bool = False, mask: Tensor | None = None
) -> SupervisedDataset:
"""Return a copy of the dataset.
Args:
deepcopy: If True, perform a deep copy. Otherwise, use the same tensors/lists.
mask: A `n`-dim boolean mask indicating which rows to keep. This is used along the -2 dimension.
Returns:
The new dataset.
"""
new_X = self._X
new_Y = self._Y
new_Yvar = self._Yvar
feature_names = self.feature_names
outcome_names = self.outcome_names
if mask is not None:
if any(isinstance(x, BotorchContainer) for x in [new_X, new_Y, new_Yvar]):
raise NotImplementedError(
"Masking is not supported for BotorchContainers."
)
new_X = new_X[..., mask, :]
new_Y = new_Y[..., mask, :]
if new_Yvar is not None:
new_Yvar = new_Yvar[..., mask, :]
if deepcopy:
new_X = new_X.clone()
new_Y = new_Y.clone()
new_Yvar = new_Yvar.clone() if new_Yvar is not None else None
feature_names = copy.copy(self.feature_names)
outcome_names = copy.copy(self.outcome_names)
kwargs = {}
if new_Yvar is not None:
kwargs = {"Yvar": new_Yvar}
return type(self)(
X=new_X,
Y=new_Y,
feature_names=feature_names,
outcome_names=outcome_names,
validate_init=self.validate_init,
**kwargs,
)


class FixedNoiseDataset(SupervisedDataset):
r"""A SupervisedDataset with an additional field `Yvar` that stipulates
Expand Down Expand Up @@ -373,7 +420,7 @@ def from_joint_dataset(
outcome_names=[outcome_name],
)
datasets.append(new_dataset)
# Return the new
# Return the new dataset
return cls(
datasets=datasets,
target_outcome_name=outcome_names_per_task.get(
Expand Down Expand Up @@ -500,6 +547,35 @@ def __eq__(self, other: Any) -> bool:
and self.task_feature_index == other.task_feature_index
)

def clone(
self, deepcopy: bool = False, mask: Tensor | None = None
) -> MultiTaskDataset:
"""Return a copy of the dataset.
Args:
deepcopy: If True, perform a deep copy. Otherwise, use the same tensors/lists/datasets.
mask: A `n`-dim boolean mask indicating which rows to keep from the target dataset. This is used along the -2 dimension.
Returns:
The new dataset.
"""
datasets = list(self.datasets.values())
if mask is not None or deepcopy:
new_datasets = []
for outcome, ds in self.datasets.items():
new_datasets.append(
ds.clone(
deepcopy=deepcopy,
mask=mask if outcome == self.target_outcome_name else None,
)
)
datasets = new_datasets
return MultiTaskDataset(
datasets=datasets,
target_outcome_name=self.target_outcome_name,
task_feature_index=self.task_feature_index,
)


class ContextualDataset(SupervisedDataset):
"""This is a contextual dataset that is constructed from either a single
Expand Down Expand Up @@ -661,3 +737,31 @@ def _validate_decompositions(self) -> None:
raise InputDataError(
f"{outcome} is missing in metric_decomposition."
)

def clone(
self, deepcopy: bool = False, mask: Tensor | None = None
) -> ContextualDataset:
"""Return a copy of the dataset.
Args:
deepcopy: If True, perform a deep copy. Otherwise, use the same tensors/lists/datasets.
mask: A `n`-dim boolean mask indicating which rows to keep. This is used along the -2
dimension. `n` here corresponds to the number of rows in an individual dataset.
Returns:
The new dataset.
"""
datasets = list(self.datasets.values())
if mask is not None or deepcopy:
datasets = [ds.clone(deepcopy=deepcopy, mask=mask) for ds in datasets]
if deepcopy:
parameter_decomposition = copy.deepcopy(self.parameter_decomposition)
metric_decomposition = copy.deepcopy(self.metric_decomposition)
else:
parameter_decomposition = self.parameter_decomposition
metric_decomposition = self.metric_decomposition
return ContextualDataset(
datasets=datasets,
parameter_decomposition=parameter_decomposition,
metric_decomposition=metric_decomposition,
)
Loading

0 comments on commit 9790fc7

Please sign in to comment.