Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make input & outcome transforms first class attributes of Model #2630

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,9 +1107,8 @@ def _get_noiseless_fantasy_model(

# Set the outcome and input transforms of the fantasy model.
# The transforms should already be in eval mode but just set them to be sure
outcome_transform = getattr(model, "outcome_transform", None)
if outcome_transform is not None:
outcome_transform = deepcopy(outcome_transform).eval()
if model.outcome_transform is not None:
outcome_transform = deepcopy(model.outcome_transform).eval()
fantasy_model.outcome_transform = outcome_transform
# Need to transform the outcome just as in the SingleTaskGP constructor.
# Need to unsqueeze for BoTorch and then squeeze again for GPyTorch.
Expand All @@ -1119,9 +1118,8 @@ def _get_noiseless_fantasy_model(
Y_fantasized.unsqueeze(-1), Yvar.unsqueeze(-1)
)
Y_fantasized = Y_fantasized.squeeze(-1)
input_transform = getattr(model, "input_transform", None)
if input_transform is not None:
fantasy_model.input_transform = deepcopy(input_transform).eval()
if model.input_transform is not None:
fantasy_model.input_transform = deepcopy(model.input_transform).eval()

# update training inputs/targets to be batch mode fantasies
fantasy_model.set_train_data(
Expand Down
2 changes: 1 addition & 1 deletion botorch/acquisition/cached_cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def supports_cache_root(model: Model) -> bool:
) or not isinstance(model, GPyTorchModel):
return False
# Models that return a TransformedPosterior are not supported.
if hasattr(model, "outcome_transform") and (not model.outcome_transform._is_linear):
if model.outcome_transform is not None and not model.outcome_transform._is_linear:
return False
return True

Expand Down
3 changes: 1 addition & 2 deletions botorch/acquisition/multi_objective/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def __init__(
"Multi-Objective MC acquisition functions."
)
if (
hasattr(model, "input_transform")
and isinstance(model.input_transform, InputPerturbation)
isinstance(model.input_transform, InputPerturbation)
and constraints is not None
):
raise UnsupportedError(
Expand Down
20 changes: 5 additions & 15 deletions botorch/acquisition/proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,19 @@ def forward(self, X: Tensor) -> Tensor:
if isinstance(model, BatchedMultiOutputGPyTorchModel) and model.num_outputs > 1:
train_inputs = train_inputs[0]

input_transform = _get_input_transform(model)

last_X = train_inputs[-1].reshape(1, 1, -1)

# if transformed_weighting, transform X to calculate diff
# (proximal weighting in transformed space)
# otherwise,un-transform the last observed point to real space
# (proximal weighting in real space)
if input_transform is not None:
if model.input_transform is not None:
if self.transformed_weighting:
# transformed space weighting
diff = input_transform.transform(X) - last_X
diff = model.input_transform.transform(X) - last_X
else:
# real space weighting
diff = X - input_transform.untransform(last_X)
diff = X - model.input_transform.untransform(last_X)

else:
# no transformation
Expand Down Expand Up @@ -173,7 +171,7 @@ def _validate_model(model: Model, proximal_weights: Tensor) -> None:
# check to make sure that the training inputs and input transformers for each
# model match and are reversible
train_inputs = model.train_inputs[0][0]
input_transform = _get_input_transform(model.models[0])
input_transform = model.models[0].input_transform

for i in range(len(model.train_inputs)):
if not torch.equal(train_inputs, model.train_inputs[i][0]):
Expand All @@ -182,7 +180,7 @@ def _validate_model(model: Model, proximal_weights: Tensor) -> None:
"training inputs"
)

if not input_transform == _get_input_transform(model.models[i]):
if input_transform != model.models[i].input_transform:
raise UnsupportedError(
"Proximal acquisition function does not support non-identical "
"input transforms"
Expand All @@ -207,11 +205,3 @@ def _validate_model(model: Model, proximal_weights: Tensor) -> None:
"`proximal_weights` must be a one dimensional tensor with "
"same feature dimension as model."
)


def _get_input_transform(model: Model) -> InputTransform | None:
"""get input transform if defined"""
try:
return model.input_transform
except AttributeError:
return None
7 changes: 3 additions & 4 deletions botorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def posterior(
dist = self.likelihood(dist)

posterior = GPyTorchPosterior(distribution=dist)
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
posterior = self.outcome_transform.untransform_posterior(posterior)
if posterior_transform is not None:
posterior = posterior_transform(posterior)
Expand Down Expand Up @@ -449,15 +449,14 @@ def __init__(

super().__init__(model=model, likelihood=likelihood, num_outputs=num_outputs)

if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
warnings.warn(
TRANSFORM_WARNING.format(ttype="input"),
UserInputWarning,
stacklevel=3,
)
self.input_transform = input_transform
self.outcome_transform = outcome_transform
self.input_transform = input_transform

# for model fitting utilities
# TODO: make this a flag?
Expand Down
25 changes: 12 additions & 13 deletions botorch/models/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _check_compatibility(models: ModuleList) -> None:
)

# TODO: Add support for outcome transforms.
if any(getattr(m, "outcome_transform", None) is not None for m in models):
if any(m.outcome_transform is not None for m in models):
raise UnsupportedError(
"Conversion of models with outcome transforms is unsupported. "
"To fix this error, explicitly pass `outcome_transform=None`.",
Expand All @@ -111,15 +111,14 @@ def _check_compatibility(models: ModuleList) -> None:
# check that there are no batched input transforms
default_size = torch.Size([])
for m in models:
if hasattr(m, "input_transform"):
if (
m.input_transform is not None
and len(getattr(m.input_transform, "batch_shape", default_size)) != 0
):
raise UnsupportedError("Batched input_transforms are not supported.")
if (
m.input_transform is not None
and len(getattr(m.input_transform, "batch_shape", default_size)) != 0
):
raise UnsupportedError("Batched input_transforms are not supported.")

# check that all models have the same input transforms
if any(hasattr(m, "input_transform") for m in models):
if any(m.input_transform is not None for m in models):
if not all(
m.input_transform.equals(models[0].input_transform) for m in models[1:]
):
Expand Down Expand Up @@ -180,7 +179,7 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch
kwargs["outcome_transform"] = None

# construct the batched GP model
input_transform = getattr(models[0], "input_transform", None)
input_transform = models[0].input_transform
batch_gp = models[0].__class__(input_transform=input_transform, **kwargs)
adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys(
batch_state_dict=batch_gp.state_dict(), input_transform=input_transform
Expand Down Expand Up @@ -286,8 +285,8 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model
raise NotImplementedError(
"Conversion of MixedSingleTaskGP is currently not supported."
)
input_transform = getattr(batch_model, "input_transform", None)
outcome_transform = getattr(batch_model, "outcome_transform", None)
input_transform = batch_model.input_transform
outcome_transform = batch_model.outcome_transform
batch_sd = batch_model.state_dict()

adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys(
Expand Down Expand Up @@ -388,11 +387,11 @@ def batched_multi_output_to_single_output(
raise NotImplementedError(
"Conversion of models with custom likelihoods is currently unsupported."
)
input_transform = getattr(batch_mo_model, "input_transform", None)
input_transform = batch_mo_model.input_transform
batch_sd = batch_mo_model.state_dict()

# TODO: add support for outcome transforms.
if hasattr(batch_mo_model, "outcome_transform"):
if batch_mo_model.outcome_transform is not None:
raise NotImplementedError(
"Converting batched multi-output models with outcome transforms "
"is not currently supported."
Expand Down
2 changes: 1 addition & 1 deletion botorch/models/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def posterior(
# NOTE: The `outcome_transform` `untransform`s the predictions rather than the
# `posterior` (as is done in GP models). This is more general since it works
# even if the transform doesn't support `untransform_posterior`.
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
values, _ = self.outcome_transform.untransform(values)
if output_indices is not None:
values = values[..., output_indices]
Expand Down
6 changes: 2 additions & 4 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,8 @@ def __init__(
train_X=transformed_X, train_Y=train_Y, train_Yvar=train_Yvar
)
self.pyro_model: PyroModel = pyro_model
if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
self.input_transform = input_transform
self.outcome_transform = outcome_transform
self.input_transform = input_transform

def _check_if_fitted(self):
r"""Raise an exception if the model hasn't been fitted."""
Expand Down
6 changes: 2 additions & 4 deletions botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,8 @@ def __init__(
task_rank=self._rank,
)
self.pyro_model: MultitaskSaasPyroModel = pyro_model
if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
self.input_transform = input_transform
self.outcome_transform = outcome_transform
self.input_transform = input_transform

def train(self, mode: bool = True) -> None:
r"""Puts the model in `train` mode."""
Expand Down
8 changes: 3 additions & 5 deletions botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,10 @@ def __init__(
}
if train_Yvar is None:
self._subset_batch_dict["likelihood.noise_covar.raw_noise"] = -2
self.covar_module: Module = covar_module
# TODO: Allow subsetting of other covar modules
if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
self.input_transform = input_transform
self.covar_module: Module = covar_module
self.outcome_transform = outcome_transform
self.input_transform = input_transform
self.to(train_X)

@classmethod
Expand Down
16 changes: 7 additions & 9 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def posterior(
else:
mvn = self.likelihood(mvn, X)
posterior = GPyTorchPosterior(distribution=mvn)
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
posterior = self.outcome_transform.untransform_posterior(posterior)
if posterior_transform is not None:
return posterior_transform(posterior)
Expand Down Expand Up @@ -239,7 +239,7 @@ def condition_on_observations(
"""
Yvar = noise

if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
# pass the transformed data to get_fantasy_model below
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)
if not isinstance(self, BatchedMultiOutputGPyTorchModel):
Expand Down Expand Up @@ -464,7 +464,7 @@ def posterior(
mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)

posterior = GPyTorchPosterior(distribution=mvn)
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
posterior = self.outcome_transform.untransform_posterior(posterior)
if posterior_transform is not None:
return posterior_transform(posterior)
Expand Down Expand Up @@ -505,7 +505,7 @@ def condition_on_observations(
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
"""
noise = kwargs.get("noise")
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
# We need to apply transforms before shifting batch indices around.
# `noise` is assumed to already be outcome-transformed.
Y, _ = self.outcome_transform(Y)
Expand Down Expand Up @@ -585,11 +585,9 @@ def subset_output(self, idcs: list[int]) -> BatchedMultiOutputGPyTorchModel:
mod_batch_shape(new_model, mod_name, m if m > 1 else 0)

# subset outcome transform if present
try:
if new_model.outcome_transform is not None:
subset_octf = new_model.outcome_transform.subset_output(idcs=idcs)
new_model.outcome_transform = subset_octf
except AttributeError:
pass

# Subset fixed noise likelihood if present.
if isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
Expand Down Expand Up @@ -681,7 +679,7 @@ def posterior(
# Nonlinear transforms untransform to a `TransformedPosterior`,
# which can't be made into a `GPyTorchPosterior`
returns_untransformed = any(
hasattr(mod, "outcome_transform") and (not mod.outcome_transform._is_linear)
mod.outcome_transform is not None and not mod.outcome_transform._is_linear
for mod in self.models
)
# NOTE: We're not passing in the posterior transform here. We'll apply it later.
Expand Down Expand Up @@ -918,7 +916,7 @@ def posterior(
interleaved=False,
)
posterior = GPyTorchPosterior(distribution=mtmvn)
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
posterior = self.outcome_transform.untransform_posterior(posterior)
if posterior_transform is not None:
return posterior_transform(posterior)
Expand Down
10 changes: 4 additions & 6 deletions botorch/models/higher_order_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,8 @@ def __init__(
dtype=train_Y.dtype,
)

if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
self.input_transform = input_transform
self.outcome_transform = outcome_transform
self.input_transform = input_transform

def _initialize_latents(
self,
Expand Down Expand Up @@ -414,7 +412,7 @@ def condition_on_observations(
conditioned on the new observations `(X, Y)` (and possibly noise
observations passed in via kwargs).
"""
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
# we need to apply transforms before shifting batch indices around
Y, noise = self.outcome_transform(Y=Y, Yvar=noise)
# Do not check shapes when fantasizing as they are not expected to match.
Expand Down Expand Up @@ -539,7 +537,7 @@ def posterior(
output_shape=X.shape[:-1] + self.target_shape,
num_outputs=self._num_outputs,
)
if hasattr(self, "outcome_transform"):
if self.outcome_transform is not None:
posterior = self.outcome_transform.untransform_posterior(posterior)
return posterior

Expand Down
Loading
Loading