From 9790fc7dcffe04d26e388caafa3afe9103543e69 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Thu, 14 Nov 2024 13:34:55 -0800 Subject: [PATCH] add clone method to datasets Summary: This makes it far easier to obtain slices of different kinds of datasets (Supervised, MultiTask, Contextual), which will be helpful for things like doing LOOCV MBM in Ax. Differential Revision: D65616941 --- botorch/utils/containers.py | 14 ++ botorch/utils/datasets.py | 106 +++++++++++++- test/utils/test_datasets.py | 280 ++++++++++++++++++++++++++++-------- 3 files changed, 343 insertions(+), 57 deletions(-) diff --git a/botorch/utils/containers.py b/botorch/utils/containers.py index f4e4c01e80..8cd2aabe76 100644 --- a/botorch/utils/containers.py +++ b/botorch/utils/containers.py @@ -8,10 +8,14 @@ from __future__ import annotations +import dataclasses + from abc import ABC, abstractmethod from dataclasses import dataclass, fields from typing import Any +import torch + from torch import device as Device, dtype as Dtype, LongTensor, Size, Tensor @@ -102,6 +106,9 @@ def _validate(self) -> None: f"`event shape` {self.event_shape}." ) + def clone(self) -> DenseContainer: + return dataclasses.replace(self) + @dataclass(eq=False) class SliceContainer(BotorchContainer): @@ -149,3 +156,10 @@ def _validate(self) -> None: f"Shapes of `values` {values.shape} and `indices` " f"{indices.shape} incompatible with `event_shape` {event_shape}." ) + + def clone(self) -> SliceContainer: + return type(self)( + values=self.values.clone(), + indices=self.indices.clone(), + event_shape=torch.Size(self.event_shape), + ) diff --git a/botorch/utils/datasets.py b/botorch/utils/datasets.py index f11f5c80e7..65f87c2aa4 100644 --- a/botorch/utils/datasets.py +++ b/botorch/utils/datasets.py @@ -8,6 +8,8 @@ from __future__ import annotations +import copy + import warnings from typing import Any @@ -71,6 +73,7 @@ def __init__( self._Yvar = Yvar self.feature_names = feature_names self.outcome_names = outcome_names + self.validate_init = validate_init if validate_init: self._validate() @@ -148,6 +151,50 @@ def __eq__(self, other: Any) -> bool: and self.outcome_names == other.outcome_names ) + def clone( + self, deepcopy: bool = False, mask: Tensor | None = None + ) -> SupervisedDataset: + """Return a copy of the dataset. + + Args: + deepcopy: If True, perform a deep copy. Otherwise, use the same tensors/lists. + mask: A `n`-dim boolean mask indicating which rows to keep. This is used along the -2 dimension. + + Returns: + The new dataset. + """ + new_X = self._X + new_Y = self._Y + new_Yvar = self._Yvar + feature_names = self.feature_names + outcome_names = self.outcome_names + if mask is not None: + if any(isinstance(x, BotorchContainer) for x in [new_X, new_Y, new_Yvar]): + raise NotImplementedError( + "Masking is not supported for BotorchContainers." + ) + new_X = new_X[..., mask, :] + new_Y = new_Y[..., mask, :] + if new_Yvar is not None: + new_Yvar = new_Yvar[..., mask, :] + if deepcopy: + new_X = new_X.clone() + new_Y = new_Y.clone() + new_Yvar = new_Yvar.clone() if new_Yvar is not None else None + feature_names = copy.copy(self.feature_names) + outcome_names = copy.copy(self.outcome_names) + kwargs = {} + if new_Yvar is not None: + kwargs = {"Yvar": new_Yvar} + return type(self)( + X=new_X, + Y=new_Y, + feature_names=feature_names, + outcome_names=outcome_names, + validate_init=self.validate_init, + **kwargs, + ) + class FixedNoiseDataset(SupervisedDataset): r"""A SupervisedDataset with an additional field `Yvar` that stipulates @@ -373,7 +420,7 @@ def from_joint_dataset( outcome_names=[outcome_name], ) datasets.append(new_dataset) - # Return the new + # Return the new dataset return cls( datasets=datasets, target_outcome_name=outcome_names_per_task.get( @@ -500,6 +547,35 @@ def __eq__(self, other: Any) -> bool: and self.task_feature_index == other.task_feature_index ) + def clone( + self, deepcopy: bool = False, mask: Tensor | None = None + ) -> MultiTaskDataset: + """Return a copy of the dataset. + + Args: + deepcopy: If True, perform a deep copy. Otherwise, use the same tensors/lists/datasets. + mask: A `n`-dim boolean mask indicating which rows to keep from the target dataset. This is used along the -2 dimension. + + Returns: + The new dataset. + """ + datasets = list(self.datasets.values()) + if mask is not None or deepcopy: + new_datasets = [] + for outcome, ds in self.datasets.items(): + new_datasets.append( + ds.clone( + deepcopy=deepcopy, + mask=mask if outcome == self.target_outcome_name else None, + ) + ) + datasets = new_datasets + return MultiTaskDataset( + datasets=datasets, + target_outcome_name=self.target_outcome_name, + task_feature_index=self.task_feature_index, + ) + class ContextualDataset(SupervisedDataset): """This is a contextual dataset that is constructed from either a single @@ -661,3 +737,31 @@ def _validate_decompositions(self) -> None: raise InputDataError( f"{outcome} is missing in metric_decomposition." ) + + def clone( + self, deepcopy: bool = False, mask: Tensor | None = None + ) -> ContextualDataset: + """Return a copy of the dataset. + + Args: + deepcopy: If True, perform a deep copy. Otherwise, use the same tensors/lists/datasets. + mask: A `n`-dim boolean mask indicating which rows to keep. This is used along the -2 + dimension. `n` here corresponds to the number of rows in an individual dataset. + + Returns: + The new dataset. + """ + datasets = list(self.datasets.values()) + if mask is not None or deepcopy: + datasets = [ds.clone(deepcopy=deepcopy, mask=mask) for ds in datasets] + if deepcopy: + parameter_decomposition = copy.deepcopy(self.parameter_decomposition) + metric_decomposition = copy.deepcopy(self.metric_decomposition) + else: + parameter_decomposition = self.parameter_decomposition + metric_decomposition = self.metric_decomposition + return ContextualDataset( + datasets=datasets, + parameter_decomposition=parameter_decomposition, + metric_decomposition=metric_decomposition, + ) diff --git a/test/utils/test_datasets.py b/test/utils/test_datasets.py index 22d8c24a50..3ec0aff2d4 100644 --- a/test/utils/test_datasets.py +++ b/test/utils/test_datasets.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. +from itertools import product + import torch from botorch.exceptions.errors import InputDataError, UnsupportedError from botorch.utils.containers import DenseContainer, SliceContainer @@ -40,6 +42,60 @@ def make_dataset( ) +def make_contextual_dataset( + has_yvar: bool = False, contextual_outcome: bool = False +) -> tuple[ContextualDataset, list[SupervisedDataset]]: + num_contexts = 3 + feature_names = [f"x_c{i}" for i in range(num_contexts)] + parameter_decomposition = { + "context_2": ["x_c2"], + "context_1": ["x_c1"], + "context_0": ["x_c0"], + } + context_buckets = list(parameter_decomposition.keys()) + if contextual_outcome: + context_outcome_list = [f"y:context_{i}" for i in range(num_contexts)] + metric_decomposition = {f"{c}": [f"y:{c}"] for c in context_buckets} + + dataset_list2 = [ + make_dataset( + d=1 * num_contexts, + has_yvar=has_yvar, + feature_names=feature_names, + outcome_names=[context_outcome_list[0]], + ) + ] + for mname in context_outcome_list[1:]: + dataset_list2.append( + SupervisedDataset( + X=dataset_list2[0].X, + Y=rand(dataset_list2[0].Y.size()), + Yvar=rand(dataset_list2[0].Yvar.size()) if has_yvar else None, + feature_names=feature_names, + outcome_names=[mname], + ) + ) + context_dt = ContextualDataset( + datasets=dataset_list2, + parameter_decomposition=parameter_decomposition, + metric_decomposition=metric_decomposition, + ) + return context_dt, dataset_list2 + dataset_list1 = [ + make_dataset( + d=num_contexts, + has_yvar=has_yvar, + feature_names=feature_names, + outcome_names=["y"], + ) + ] + context_dt = ContextualDataset( + datasets=dataset_list1, + parameter_decomposition=parameter_decomposition, + ) + return context_dt, dataset_list1 + + class TestDatasets(BotorchTestCase): def test_supervised(self): # Generate some data @@ -122,6 +178,70 @@ def test_supervised(self): self.assertNotEqual(dataset, dataset2) self.assertNotEqual(dataset2, dataset) + def test_clone(self, supervised: bool = True) -> None: + has_yvar_options = [False] + if supervised: + has_yvar_options.append(True) + for has_yvar in has_yvar_options: + if supervised: + dataset = make_dataset(has_yvar=has_yvar) + else: + X_val = rand(16, 2) + X_idx = stack([randperm(len(X_val))[:3] for _ in range(1)]) + X = SliceContainer( + X_val, X_idx, event_shape=Size([3 * X_val.shape[-1]]) + ) + dataset = RankingDataset( + X=X, + Y=tensor([[0, 1, 1]]), + feature_names=["x1", "x2"], + outcome_names=["ranking indices"], + ) + + for use_deepcopy in [False, True]: + dataset2 = dataset.clone(deepcopy=use_deepcopy) + self.assertEqual(dataset, dataset2) + self.assertTrue(torch.equal(dataset.X, dataset2.X)) + self.assertTrue(torch.equal(dataset.Y, dataset2.Y)) + if has_yvar: + self.assertTrue(torch.equal(dataset.Yvar, dataset2.Yvar)) + else: + self.assertIsNone(dataset2.Yvar) + self.assertEqual(dataset.feature_names, dataset2.feature_names) + self.assertEqual(dataset.outcome_names, dataset2.outcome_names) + if use_deepcopy: + self.assertIsNot(dataset.X, dataset2.X) + self.assertIsNot(dataset.Y, dataset2.Y) + if has_yvar: + self.assertIsNot(dataset.Yvar, dataset2.Yvar) + self.assertIsNot(dataset.feature_names, dataset2.feature_names) + self.assertIsNot(dataset.outcome_names, dataset2.outcome_names) + else: + self.assertIs(dataset._X, dataset2._X) + self.assertIs(dataset._Y, dataset2._Y) + self.assertIs(dataset._Yvar, dataset2._Yvar) + self.assertIs(dataset.feature_names, dataset2.feature_names) + self.assertIs(dataset.outcome_names, dataset2.outcome_names) + # test with mask + mask = torch.tensor([0, 1, 1], dtype=torch.bool) + if supervised: + dataset2 = dataset.clone(deepcopy=use_deepcopy, mask=mask) + self.assertTrue(torch.equal(dataset.X[1:], dataset2.X)) + self.assertTrue(torch.equal(dataset.Y[1:], dataset2.Y)) + if has_yvar: + self.assertTrue(torch.equal(dataset.Yvar[1:], dataset2.Yvar)) + else: + self.assertIsNone(dataset2.Yvar) + else: + with self.assertRaisesRegex( + NotImplementedError, + "Masking is not supported for BotorchContainers.", + ): + dataset.clone(deepcopy=use_deepcopy, mask=mask) + + def test_clone_ranking(self) -> None: + self.test_clone(supervised=False) + def test_fixedNoise(self): # Generate some data X = rand(3, 2) @@ -365,6 +485,52 @@ def test_multi_task(self): MultiTaskDataset(datasets=[dataset_1, dataset_5], target_outcome_name="z"), ) + def test_clone_multitask(self) -> None: + for has_yvar in [False, True]: + dataset_1 = make_dataset(outcome_names=["y"], has_yvar=has_yvar) + dataset_2 = make_dataset(outcome_names=["z"], has_yvar=has_yvar) + mt_dataset = MultiTaskDataset( + datasets=[dataset_1, dataset_2], + target_outcome_name="z", + ) + for use_deepcopy in [False, True]: + mt_dataset2 = mt_dataset.clone(deepcopy=use_deepcopy) + self.assertEqual(mt_dataset, mt_dataset2) + self.assertTrue(torch.equal(mt_dataset.X, mt_dataset2.X)) + self.assertTrue(torch.equal(mt_dataset.Y, mt_dataset2.Y)) + if has_yvar: + self.assertTrue(torch.equal(mt_dataset.Yvar, mt_dataset2.Yvar)) + else: + self.assertIsNone(mt_dataset2.Yvar) + self.assertEqual(mt_dataset.feature_names, mt_dataset2.feature_names) + self.assertEqual(mt_dataset.outcome_names, mt_dataset2.outcome_names) + if use_deepcopy: + for ds, ds2 in zip( + mt_dataset.datasets.values(), mt_dataset2.datasets.values() + ): + self.assertIsNot(ds, ds2) + else: + for ds, ds2 in zip( + mt_dataset.datasets.values(), mt_dataset2.datasets.values() + ): + self.assertIs(ds, ds2) + # test with mask + mask = torch.tensor([0, 1, 1], dtype=torch.bool) + mt_dataset2 = mt_dataset.clone(deepcopy=use_deepcopy, mask=mask) + # mask should only apply to target dataset. + # All non-target datasets should be included. + full_mask = torch.tensor([1, 1, 1, 0, 1, 1], dtype=torch.bool) + self.assertTrue(torch.equal(mt_dataset.X[full_mask], mt_dataset2.X)) + self.assertTrue(torch.equal(mt_dataset.Y[full_mask], mt_dataset2.Y)) + if has_yvar: + self.assertTrue( + torch.equal(mt_dataset.Yvar[full_mask], mt_dataset2.Yvar) + ) + else: + self.assertIsNone(mt_dataset2.Yvar) + self.assertEqual(mt_dataset.feature_names, mt_dataset2.feature_names) + self.assertEqual(mt_dataset.outcome_names, mt_dataset2.outcome_names) + def test_contextual_datasets(self): num_contexts = 3 feature_names = [f"x_c{i}" for i in range(num_contexts)] @@ -378,17 +544,8 @@ def test_contextual_datasets(self): metric_decomposition = {f"{c}": [f"y:{c}"] for c in context_buckets} # test construction of agg outcome - dataset_list1 = [ - make_dataset( - d=1 * num_contexts, - has_yvar=True, - feature_names=feature_names, - outcome_names=["y"], - ) - ] - context_dt = ContextualDataset( - datasets=dataset_list1, - parameter_decomposition=parameter_decomposition, + context_dt, dataset_list1 = make_contextual_dataset( + has_yvar=True, contextual_outcome=False ) self.assertEqual(len(context_dt.datasets), len(dataset_list1)) self.assertListEqual(context_dt.context_buckets, context_buckets) @@ -400,28 +557,8 @@ def test_contextual_datasets(self): self.assertIs(context_dt.Yvar, dataset_list1[0].Yvar) # test construction of context outcome - dataset_list2 = [ - make_dataset( - d=1 * num_contexts, - has_yvar=True, - feature_names=feature_names, - outcome_names=[context_outcome_list[0]], - ) - ] - for m in context_outcome_list[1:]: - dataset_list2.append( - SupervisedDataset( - X=dataset_list2[0].X, - Y=rand(dataset_list2[0].Y.size()), - Yvar=rand(dataset_list2[0].Yvar.size()), - feature_names=feature_names, - outcome_names=[m], - ) - ) - context_dt = ContextualDataset( - datasets=dataset_list2, - parameter_decomposition=parameter_decomposition, - metric_decomposition=metric_decomposition, + context_dt, dataset_list2 = make_contextual_dataset( + has_yvar=True, contextual_outcome=True ) self.assertEqual(len(context_dt.datasets), len(dataset_list2)) # Ordering should match datasets, not parameter_decomposition @@ -438,30 +575,10 @@ def test_contextual_datasets(self): self.assertIs(context_dt.datasets[dt.outcome_names[0]], dt) # Test handling None Yvar - dataset_list3 = [ - make_dataset( - d=1 * num_contexts, - has_yvar=False, - feature_names=feature_names, - outcome_names=[context_outcome_list[0]], - ) - ] - for m in context_outcome_list[1:]: - dataset_list3.append( - SupervisedDataset( - X=dataset_list3[0].X, - Y=rand(dataset_list3[0].Y.size()), - Yvar=None, - feature_names=feature_names, - outcome_names=[m], - ) - ) - context_dt3 = ContextualDataset( - datasets=dataset_list3, - parameter_decomposition=parameter_decomposition, - metric_decomposition=metric_decomposition, + context_dt, dataset_list3 = make_contextual_dataset( + has_yvar=False, contextual_outcome=True ) - self.assertIsNone(context_dt3.Yvar) + self.assertIsNone(context_dt.Yvar) # test dataset validation wrong_metric_decomposition1 = { @@ -557,3 +674,54 @@ def test_contextual_datasets(self): parameter_decomposition=parameter_decomposition, metric_decomposition=wrong_metric_decomposition, ) + + def test_clone_contextual_dataset(self): + for has_yvar, contextual_outcome in product((False, True), (False, True)): + context_dt, _ = make_contextual_dataset( + has_yvar=has_yvar, contextual_outcome=contextual_outcome + ) + for use_deepcopy in [False, True]: + context_dt2 = context_dt.clone(deepcopy=use_deepcopy) + self.assertEqual(context_dt, context_dt2) + self.assertTrue(torch.equal(context_dt.X, context_dt2.X)) + self.assertTrue(torch.equal(context_dt.Y, context_dt2.Y)) + if has_yvar: + self.assertTrue(torch.equal(context_dt.Yvar, context_dt2.Yvar)) + else: + self.assertIsNone(context_dt.Yvar) + self.assertEqual(context_dt.feature_names, context_dt2.feature_names) + self.assertEqual(context_dt.outcome_names, context_dt2.outcome_names) + if use_deepcopy: + for ds, ds2 in zip( + context_dt.datasets.values(), context_dt2.datasets.values() + ): + self.assertIsNot(ds, ds2) + else: + for ds, ds2 in zip( + context_dt.datasets.values(), context_dt2.datasets.values() + ): + self.assertIs(ds, ds2) + # test with mask + mask = torch.tensor([0, 1, 1], dtype=torch.bool) + context_dt2 = context_dt.clone(deepcopy=use_deepcopy, mask=mask) + self.assertTrue(torch.equal(context_dt.X[mask], context_dt2.X)) + self.assertTrue(torch.equal(context_dt.Y[mask], context_dt2.Y)) + if has_yvar: + self.assertTrue( + torch.equal(context_dt.Yvar[mask], context_dt2.Yvar) + ) + else: + self.assertIsNone(context_dt2.Yvar) + self.assertEqual(context_dt.feature_names, context_dt2.feature_names) + self.assertEqual(context_dt.outcome_names, context_dt2.outcome_names) + self.assertEqual( + context_dt.parameter_decomposition, + context_dt2.parameter_decomposition, + ) + if contextual_outcome: + self.assertEqual( + context_dt.metric_decomposition, + context_dt2.metric_decomposition, + ) + else: + self.assertIsNone(context_dt2.metric_decomposition)