Skip to content

Commit

Permalink
Make input & outcome transforms first class attributes of Model
Browse files Browse the repository at this point in the history
Summary: These were previously only set when they were not None, which lead to a lot of `hasattr`, `getattr` usage throughout the codebase to check for them. This diff adds them as attributes to base `Model` class with default values of `None`. With this change, we can now access `model.input/outcome_transform` on all models.

Differential Revision: D66012223
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 15, 2024
1 parent 3c2ce15 commit 9688ef5
Show file tree
Hide file tree
Showing 36 changed files with 117 additions and 164 deletions.
10 changes: 4 additions & 6 deletions botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,9 +1107,8 @@ def _get_noiseless_fantasy_model(

# Set the outcome and input transforms of the fantasy model.
# The transforms should already be in eval mode but just set them to be sure
outcome_transform = getattr(model, "outcome_transform", None)
if outcome_transform is not None:
outcome_transform = deepcopy(outcome_transform).eval()
if model.outcome_transform is not None:
outcome_transform = deepcopy(model.outcome_transform).eval()
fantasy_model.outcome_transform = outcome_transform
# Need to transform the outcome just as in the SingleTaskGP constructor.
# Need to unsqueeze for BoTorch and then squeeze again for GPyTorch.
Expand All @@ -1119,9 +1118,8 @@ def _get_noiseless_fantasy_model(
Y_fantasized.unsqueeze(-1), Yvar.unsqueeze(-1)
)
Y_fantasized = Y_fantasized.squeeze(-1)
input_transform = getattr(model, "input_transform", None)
if input_transform is not None:
fantasy_model.input_transform = deepcopy(input_transform).eval()
if model.input_transform is not None:
fantasy_model.input_transform = deepcopy(model.input_transform).eval()

# update training inputs/targets to be batch mode fantasies
fantasy_model.set_train_data(
Expand Down
2 changes: 1 addition & 1 deletion botorch/acquisition/cached_cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def supports_cache_root(model: Model) -> bool:
) or not isinstance(model, GPyTorchModel):
return False
# Models that return a TransformedPosterior are not supported.
if hasattr(model, "outcome_transform") and (not model.outcome_transform._is_linear):
if model.outcome_transform is not None and not model.outcome_transform._is_linear:
return False
return True

Expand Down
3 changes: 1 addition & 2 deletions botorch/acquisition/multi_objective/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def __init__(
"Multi-Objective MC acquisition functions."
)
if (
hasattr(model, "input_transform")
and isinstance(model.input_transform, InputPerturbation)
isinstance(model.input_transform, InputPerturbation)
and constraints is not None
):
raise UnsupportedError(
Expand Down
20 changes: 5 additions & 15 deletions botorch/acquisition/proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,19 @@ def forward(self, X: Tensor) -> Tensor:
if isinstance(model, BatchedMultiOutputGPyTorchModel) and model.num_outputs > 1:
train_inputs = train_inputs[0]

input_transform = _get_input_transform(model)

last_X = train_inputs[-1].reshape(1, 1, -1)

# if transformed_weighting, transform X to calculate diff
# (proximal weighting in transformed space)
# otherwise,un-transform the last observed point to real space
# (proximal weighting in real space)
if input_transform is not None:
if model.input_transform is not None:
if self.transformed_weighting:
# transformed space weighting
diff = input_transform.transform(X) - last_X
diff = model.input_transform.transform(X) - last_X
else:
# real space weighting
diff = X - input_transform.untransform(last_X)
diff = X - model.input_transform.untransform(last_X)

else:
# no transformation
Expand Down Expand Up @@ -173,7 +171,7 @@ def _validate_model(model: Model, proximal_weights: Tensor) -> None:
# check to make sure that the training inputs and input transformers for each
# model match and are reversible
train_inputs = model.train_inputs[0][0]
input_transform = _get_input_transform(model.models[0])
input_transform = model.models[0].input_transform

for i in range(len(model.train_inputs)):
if not torch.equal(train_inputs, model.train_inputs[i][0]):
Expand All @@ -182,7 +180,7 @@ def _validate_model(model: Model, proximal_weights: Tensor) -> None:
"training inputs"
)

if not input_transform == _get_input_transform(model.models[i]):
if input_transform != model.models[i].input_transform:
raise UnsupportedError(
"Proximal acquisition function does not support non-identical "
"input transforms"
Expand All @@ -207,11 +205,3 @@ def _validate_model(model: Model, proximal_weights: Tensor) -> None:
"`proximal_weights` must be a one dimensional tensor with "
"same feature dimension as model."
)


def _get_input_transform(model: Model) -> InputTransform | None:
"""get input transform if defined"""
try:
return model.input_transform
except AttributeError:
return None
7 changes: 3 additions & 4 deletions botorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def posterior(
dist = self.likelihood(dist)

posterior = GPyTorchPosterior(distribution=dist)
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
posterior = self.outcome_transform.untransform_posterior(posterior)
if posterior_transform is not None:
posterior = posterior_transform(posterior)
Expand Down Expand Up @@ -449,15 +449,14 @@ def __init__(

super().__init__(model=model, likelihood=likelihood, num_outputs=num_outputs)

if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
warnings.warn(
TRANSFORM_WARNING.format(ttype="input"),
UserInputWarning,
stacklevel=3,
)
self.input_transform = input_transform
self.outcome_transform = outcome_transform
self.input_transform = input_transform

# for model fitting utilities
# TODO: make this a flag?
Expand Down
25 changes: 12 additions & 13 deletions botorch/models/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _check_compatibility(models: ModuleList) -> None:
)

# TODO: Add support for outcome transforms.
if any(getattr(m, "outcome_transform", None) is not None for m in models):
if any(m.outcome_transform is not None for m in models):
raise UnsupportedError(
"Conversion of models with outcome transforms is unsupported. "
"To fix this error, explicitly pass `outcome_transform=None`.",
Expand All @@ -111,15 +111,14 @@ def _check_compatibility(models: ModuleList) -> None:
# check that there are no batched input transforms
default_size = torch.Size([])
for m in models:
if hasattr(m, "input_transform"):
if (
m.input_transform is not None
and len(getattr(m.input_transform, "batch_shape", default_size)) != 0
):
raise UnsupportedError("Batched input_transforms are not supported.")
if (
m.input_transform is not None
and len(getattr(m.input_transform, "batch_shape", default_size)) != 0
):
raise UnsupportedError("Batched input_transforms are not supported.")

# check that all models have the same input transforms
if any(hasattr(m, "input_transform") for m in models):
if any(m.input_transform is not None for m in models):
if not all(
m.input_transform.equals(models[0].input_transform) for m in models[1:]
):
Expand Down Expand Up @@ -180,7 +179,7 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch
kwargs["outcome_transform"] = None

# construct the batched GP model
input_transform = getattr(models[0], "input_transform", None)
input_transform = models[0].input_transform
batch_gp = models[0].__class__(input_transform=input_transform, **kwargs)
adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys(
batch_state_dict=batch_gp.state_dict(), input_transform=input_transform
Expand Down Expand Up @@ -286,8 +285,8 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model
raise NotImplementedError(
"Conversion of MixedSingleTaskGP is currently not supported."
)
input_transform = getattr(batch_model, "input_transform", None)
outcome_transform = getattr(batch_model, "outcome_transform", None)
input_transform = batch_model.input_transform
outcome_transform = batch_model.outcome_transform
batch_sd = batch_model.state_dict()

adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys(
Expand Down Expand Up @@ -388,11 +387,11 @@ def batched_multi_output_to_single_output(
raise NotImplementedError(
"Conversion of models with custom likelihoods is currently unsupported."
)
input_transform = getattr(batch_mo_model, "input_transform", None)
input_transform = batch_mo_model.input_transform
batch_sd = batch_mo_model.state_dict()

# TODO: add support for outcome transforms.
if hasattr(batch_mo_model, "outcome_transform"):
if batch_mo_model.outcome_transform is not None:
raise NotImplementedError(
"Converting batched multi-output models with outcome transforms "
"is not currently supported."
Expand Down
2 changes: 1 addition & 1 deletion botorch/models/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def posterior(
# NOTE: The `outcome_transform` `untransform`s the predictions rather than the
# `posterior` (as is done in GP models). This is more general since it works
# even if the transform doesn't support `untransform_posterior`.
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
values, _ = self.outcome_transform.untransform(values)
if output_indices is not None:
values = values[..., output_indices]
Expand Down
6 changes: 2 additions & 4 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,8 @@ def __init__(
train_X=transformed_X, train_Y=train_Y, train_Yvar=train_Yvar
)
self.pyro_model: PyroModel = pyro_model
if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
self.input_transform = input_transform
self.outcome_transform = outcome_transform
self.input_transform = input_transform

def _check_if_fitted(self):
r"""Raise an exception if the model hasn't been fitted."""
Expand Down
6 changes: 2 additions & 4 deletions botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,8 @@ def __init__(
task_rank=self._rank,
)
self.pyro_model: MultitaskSaasPyroModel = pyro_model
if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
self.input_transform = input_transform
self.outcome_transform = outcome_transform
self.input_transform = input_transform

def train(self, mode: bool = True) -> None:
r"""Puts the model in `train` mode."""
Expand Down
8 changes: 3 additions & 5 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,10 @@ def __init__(
}
if train_Yvar is None:
self._subset_batch_dict["likelihood.noise_covar.raw_noise"] = -2
self.covar_module: Module = covar_module
# TODO: Allow subsetting of other covar modules
if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
self.input_transform = input_transform
self.covar_module: Module = covar_module
self.outcome_transform = outcome_transform
self.input_transform = input_transform
self.to(train_X)

@classmethod
Expand Down
16 changes: 7 additions & 9 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def posterior(
else:
mvn = self.likelihood(mvn, X)
posterior = GPyTorchPosterior(distribution=mvn)
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
posterior = self.outcome_transform.untransform_posterior(posterior)
if posterior_transform is not None:
return posterior_transform(posterior)
Expand Down Expand Up @@ -239,7 +239,7 @@ def condition_on_observations(
"""
Yvar = noise

if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
# pass the transformed data to get_fantasy_model below
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)
if not isinstance(self, BatchedMultiOutputGPyTorchModel):
Expand Down Expand Up @@ -464,7 +464,7 @@ def posterior(
mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)

posterior = GPyTorchPosterior(distribution=mvn)
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
posterior = self.outcome_transform.untransform_posterior(posterior)
if posterior_transform is not None:
return posterior_transform(posterior)
Expand Down Expand Up @@ -505,7 +505,7 @@ def condition_on_observations(
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
"""
noise = kwargs.get("noise")
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
# We need to apply transforms before shifting batch indices around.
# `noise` is assumed to already be outcome-transformed.
Y, _ = self.outcome_transform(Y)
Expand Down Expand Up @@ -585,11 +585,9 @@ def subset_output(self, idcs: list[int]) -> BatchedMultiOutputGPyTorchModel:
mod_batch_shape(new_model, mod_name, m if m > 1 else 0)

# subset outcome transform if present
try:
if new_model.outcome_transform is not None:
subset_octf = new_model.outcome_transform.subset_output(idcs=idcs)
new_model.outcome_transform = subset_octf
except AttributeError:
pass

# Subset fixed noise likelihood if present.
if isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
Expand Down Expand Up @@ -681,7 +679,7 @@ def posterior(
# Nonlinear transforms untransform to a `TransformedPosterior`,
# which can't be made into a `GPyTorchPosterior`
returns_untransformed = any(
hasattr(mod, "outcome_transform") and (not mod.outcome_transform._is_linear)
mod.outcome_transform is not None and not mod.outcome_transform._is_linear
for mod in self.models
)
# NOTE: We're not passing in the posterior transform here. We'll apply it later.
Expand Down Expand Up @@ -918,7 +916,7 @@ def posterior(
interleaved=False,
)
posterior = GPyTorchPosterior(distribution=mtmvn)
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
posterior = self.outcome_transform.untransform_posterior(posterior)
if posterior_transform is not None:
return posterior_transform(posterior)
Expand Down
10 changes: 4 additions & 6 deletions botorch/models/higher_order_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,8 @@ def __init__(
dtype=train_Y.dtype,
)

if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
self.input_transform = input_transform
self.outcome_transform = outcome_transform
self.input_transform = input_transform

def _initialize_latents(
self,
Expand Down Expand Up @@ -414,7 +412,7 @@ def condition_on_observations(
conditioned on the new observations `(X, Y)` (and possibly noise
observations passed in via kwargs).
"""
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
# we need to apply transforms before shifting batch indices around
Y, noise = self.outcome_transform(Y=Y, Yvar=noise)
# Do not check shapes when fantasizing as they are not expected to match.
Expand Down Expand Up @@ -539,7 +537,7 @@ def posterior(
output_shape=X.shape[:-1] + self.target_shape,
num_outputs=self._num_outputs,
)
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
posterior = self.outcome_transform.untransform_posterior(posterior)
return posterior

Expand Down
Loading

0 comments on commit 9688ef5

Please sign in to comment.