Skip to content

Commit

Permalink
Deduplicate shared logic in prune_inferior_points(_multi_objective) (#…
Browse files Browse the repository at this point in the history
…2629)

Summary:
Pull Request resolved: #2629

There was a lot of repetition between these two methods. This diff extracts the shared parts into a helper method.

Also cleaned some pyre complains while at it.

Reviewed By: Balandat

Differential Revision: D66006629

fbshipit-source-id: 81a1f23d80dfa095d71a79cbd552517a72fef0a8
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 15, 2024
1 parent 9c7521f commit 3c2ce15
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 97 deletions.
78 changes: 25 additions & 53 deletions botorch/acquisition/multi_objective/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,19 @@

from __future__ import annotations

import math
import warnings
from collections.abc import Callable
from math import ceil
from typing import Any

import torch
from botorch.acquisition import monte_carlo # noqa F401
from botorch.acquisition.multi_objective.objective import (
IdentityMCMultiOutputObjective,
MCMultiOutputObjective,
)
from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective
from botorch.acquisition.utils import _prune_inferior_shared_processing
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.fully_bayesian import MCMC_DIM
from botorch.models.model import Model
from botorch.sampling.get_sampler import get_sampler
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
from botorch.utils.multi_objective.box_decompositions.box_decomposition import (
BoxDecomposition,
Expand All @@ -39,9 +34,8 @@
DominatedPartitioning,
)
from botorch.utils.multi_objective.pareto import is_non_dominated
from botorch.utils.objective import compute_feasibility_indicator
from botorch.utils.sampling import draw_sobol_samples
from botorch.utils.transforms import is_ensemble
from pyre_extensions import assert_is_instance
from torch import Tensor


Expand Down Expand Up @@ -115,40 +109,14 @@ def prune_inferior_points_multi_objective(
with `N_nz` the number of points in `X` that have non-zero (empirical,
under `num_samples` samples) probability of being pareto optimal.
"""
if marginalize_dim is None and is_ensemble(model):
# TODO: Properly deal with marginalizing fully Bayesian models
marginalize_dim = MCMC_DIM

if X.ndim > 2:
# TODO: support batched inputs (req. dealing with ragged tensors)
raise UnsupportedError(
"Batched inputs `X` are currently unsupported by "
"prune_inferior_points_multi_objective"
)
if X.size(-2) == 0:
raise ValueError("X must have at least one point.")
if max_frac <= 0 or max_frac > 1.0:
raise ValueError(f"max_frac must take values in (0, 1], is {max_frac}")
max_points = math.ceil(max_frac * X.size(-2))
with torch.no_grad():
posterior = model.posterior(X=X)
sampler = get_sampler(posterior, sample_shape=torch.Size([num_samples]))
samples = sampler(posterior)
if objective is None:
objective = IdentityMCMultiOutputObjective()
obj_vals = objective(samples, X=X)
if obj_vals.ndim > 3:
if obj_vals.ndim == 4 and marginalize_dim is not None:
obj_vals = obj_vals.mean(dim=marginalize_dim)
else:
# TODO: support batched inputs (req. dealing with ragged tensors)
raise UnsupportedError(
"Models with multiple batch dims are currently unsupported by"
" prune_inferior_points_multi_objective."
)
infeas = ~compute_feasibility_indicator(
max_points, obj_vals, infeas = _prune_inferior_shared_processing(
model=model,
X=X,
is_moo=True,
objective=objective,
constraints=constraints,
samples=samples,
num_samples=num_samples,
max_frac=max_frac,
marginalize_dim=marginalize_dim,
)
if infeas.any():
Expand All @@ -168,9 +136,9 @@ def prune_inferior_points_multi_objective(

def compute_sample_box_decomposition(
pareto_fronts: Tensor,
partitioning: BoxDecomposition = DominatedPartitioning,
partitioning: type[BoxDecomposition] = DominatedPartitioning,
maximize: bool = True,
num_constraints: int | None = 0,
num_constraints: int = 0,
) -> Tensor:
r"""Computes the box decomposition associated with some sampled optimal
objectives. This also supports the single-objective and constrained optimization
Expand All @@ -195,7 +163,10 @@ def compute_sample_box_decomposition(
the hyper-rectangles. The number `J` is the smallest number of boxes needed
to partition all the Pareto samples.
"""
tkwargs = {"dtype": pareto_fronts.dtype, "device": pareto_fronts.device}
tkwargs: dict[str, Any] = {
"dtype": pareto_fronts.dtype,
"device": pareto_fronts.device,
}
# We will later compute `norm.log_prob(NEG_INF)`, this is `-inf` if `NEG_INF` is
# too small.
NEG_INF = -1e10
Expand All @@ -214,16 +185,18 @@ def compute_sample_box_decomposition(

if M == 1:
# Only consider a Pareto front with one element.
extreme_values = weight * torch.max(weight * pareto_fronts, dim=-2).values
extreme_values = assert_is_instance(
weight * torch.max(weight * pareto_fronts, dim=-2).values, Tensor
)
ref_point = weight * ref_point.expand(extreme_values.shape)

if maximize:
hypercell_bounds = torch.stack(
[ref_point, extreme_values], axis=-2
[ref_point, extreme_values], dim=-2
).unsqueeze(-1)
else:
hypercell_bounds = torch.stack(
[extreme_values, ref_point], axis=-2
[extreme_values, ref_point], dim=-2
).unsqueeze(-1)
else:
bd_list = []
Expand All @@ -244,17 +217,15 @@ def compute_sample_box_decomposition(
# Add an extra box for the inequality constraint.
if K > 0:
# `num_pareto_samples x 2 x (J - 1) x K`
feasible_boxes = torch.zeros(
hypercell_bounds.shape[:-1] + torch.Size([K]), **tkwargs
)
feasible_boxes = torch.zeros(hypercell_bounds.shape[:-1] + (K,), **tkwargs)

feasible_boxes[..., 0, :, :] = NEG_INF
# `num_pareto_samples x 2 x (J - 1) x (M + K)`
hypercell_bounds = torch.cat([hypercell_bounds, feasible_boxes], dim=-1)

# `num_pareto_samples x 2 x 1 x (M + K)`
infeasible_box = torch.zeros(
hypercell_bounds.shape[:-2] + torch.Size([1, M + K]), **tkwargs
hypercell_bounds.shape[:-2] + (1, M + K), **tkwargs
)
infeasible_box[..., 1, :, M:] = -NEG_INF
infeasible_box[..., 0, :, 0:M] = NEG_INF
Expand Down Expand Up @@ -292,11 +263,12 @@ def random_search_optimizer(
- A `num_points x M`-dim Tensor containing the collection of optimal
objectives.
"""
tkwargs = {"dtype": bounds.dtype, "device": bounds.device}
tkwargs: dict[str, Any] = {"dtype": bounds.dtype, "device": bounds.device}
weight = 1.0 if maximize else -1.0
optimal_inputs = torch.tensor([], **tkwargs)
optimal_outputs = torch.tensor([], **tkwargs)
num_tries = 0
num_found = 0
ratio = 2
while ratio > 1 and num_tries < max_tries:
X = draw_sobol_samples(bounds=bounds, n=pop_size, q=1).squeeze(-2)
Expand Down
122 changes: 80 additions & 42 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import torch
from botorch.acquisition.objective import (
IdentityMCObjective,
MCAcquisitionObjective,
PosteriorTransform,
ScalarizedPosteriorTransform,
Expand All @@ -34,6 +33,7 @@
from botorch.utils.sampling import optimize_posterior_samples
from botorch.utils.transforms import is_ensemble, normalize_indices
from gpytorch.models import GP
from pyre_extensions import none_throws
from torch import Tensor


Expand Down Expand Up @@ -244,6 +244,76 @@ def objective(Y: Tensor, X: Tensor | None = None):
return -(lb.clamp_max(0.0))


def _prune_inferior_shared_processing(
model: Model,
X: Tensor,
is_moo: bool,
objective: MCAcquisitionObjective | None = None,
posterior_transform: PosteriorTransform | None = None,
constraints: list[Callable[[Tensor], Tensor]] | None = None,
num_samples: int = 2048,
max_frac: float = 1.0,
sampler: MCSampler | None = None,
marginalize_dim: int | None = None,
) -> tuple[int, Tensor, Tensor]:
r"""Shared data processing for `prune_inferior_points` and
`prune_inferior_points_multi_objective`.
Returns:
- max_points: The maximum number of points to keep.
- obj_vals: The objective values of the points in `X`.
- infeas: A boolean tensor indicating feasibility of `X`.
"""
func_name = (
"prune_inferior_points_multi_objective" if is_moo else "prune_inferior_points"
)
if marginalize_dim is None and is_ensemble(model):
marginalize_dim = MCMC_DIM

if X.ndim > 2:
raise UnsupportedError(
f"Batched inputs `X` are currently unsupported by `{func_name}`"
)
if X.size(-2) == 0:
raise ValueError("X must have at least one point.")
if max_frac <= 0 or max_frac > 1.0:
raise ValueError(f"max_frac must take values in (0, 1], is {max_frac}")
max_points = math.ceil(max_frac * X.size(-2))
with torch.no_grad():
posterior = model.posterior(X=X, posterior_transform=posterior_transform)
if sampler is None:
sampler = get_sampler(
posterior=posterior, sample_shape=torch.Size([num_samples])
)
samples = sampler(posterior)
if objective is not None:
obj_vals = objective(samples=samples, X=X)
elif is_moo:
obj_vals = samples
else:
obj_vals = samples.squeeze(-1)
if obj_vals.ndim > (2 + is_moo):
if obj_vals.ndim == (3 + is_moo) and marginalize_dim is not None:
if marginalize_dim < 0:
# Update `marginalize_dim` to be positive while accounting for
# removal of output dimension in SOO.
marginalize_dim = (not is_moo) + none_throws(
normalize_indices([marginalize_dim], d=obj_vals.ndim)
)[0]
obj_vals = obj_vals.mean(dim=marginalize_dim)
else:
raise UnsupportedError(
"Models with multiple batch dims are currently unsupported by "
f"`{func_name}`."
)
infeas = ~compute_feasibility_indicator(
constraints=constraints,
samples=samples,
marginalize_dim=marginalize_dim,
)
return max_points, obj_vals, infeas


def prune_inferior_points(
model: Model,
X: Tensor,
Expand Down Expand Up @@ -292,48 +362,16 @@ def prune_inferior_points(
with `N_nz` the number of points in `X` that have non-zero (empirical,
under `num_samples` samples) probability of being the best point.
"""
if marginalize_dim is None and is_ensemble(model):
# TODO: Properly deal with marginalizing fully Bayesian models
marginalize_dim = MCMC_DIM

if X.ndim > 2:
# TODO: support batched inputs (req. dealing with ragged tensors)
raise UnsupportedError(
"Batched inputs `X` are currently unsupported by prune_inferior_points"
)
if X.size(-2) == 0:
raise ValueError("X must have at least one point.")
if max_frac <= 0 or max_frac > 1.0:
raise ValueError(f"max_frac must take values in (0, 1], is {max_frac}")
max_points = math.ceil(max_frac * X.size(-2))
with torch.no_grad():
posterior = model.posterior(X=X, posterior_transform=posterior_transform)
if sampler is None:
sampler = get_sampler(
posterior=posterior, sample_shape=torch.Size([num_samples])
)
samples = sampler(posterior)
if objective is None:
objective = IdentityMCObjective()
obj_vals = objective(samples, X=X)
if obj_vals.ndim > 2:
if obj_vals.ndim == 3 and marginalize_dim is not None:
if marginalize_dim < 0:
# we do this again in compute_feasibility_indicator, but that will
# have no effect since marginalize_dim will be non-negative
marginalize_dim = (
1 + normalize_indices([marginalize_dim], d=obj_vals.ndim)[0]
)
obj_vals = obj_vals.mean(dim=marginalize_dim)
else:
# TODO: support batched inputs (req. dealing with ragged tensors)
raise UnsupportedError(
"Models with multiple batch dims are currently unsupported by"
" prune_inferior_points."
)
infeas = ~compute_feasibility_indicator(
max_points, obj_vals, infeas = _prune_inferior_shared_processing(
model=model,
X=X,
is_moo=False,
objective=objective,
posterior_transform=posterior_transform,
constraints=constraints,
samples=samples,
num_samples=num_samples,
max_frac=max_frac,
sampler=sampler,
marginalize_dim=marginalize_dim,
)
if infeas.any():
Expand Down
2 changes: 0 additions & 2 deletions test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from unittest.mock import patch

import torch

from botorch.acquisition.objective import (
ExpectationPosteriorTransform,
GenericMCObjective,
Expand All @@ -34,7 +33,6 @@
UnsupportedError,
)
from botorch.models import SingleTaskGP

from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
from gpytorch.distributions import MultivariateNormal

Expand Down

0 comments on commit 3c2ce15

Please sign in to comment.