Skip to content

Commit 23efc91

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Update optimize_acqf_homotopy and refactor mixed optimizer dispatch
Summary: see title. This refactors dispatch to mixed optimizers into a shared utility for botorch and ax. This also cleans up optimize_acqf_homotopy by removing fixed_features_list and leveraging the new dispatch util. This adds support for using optimize_acqf_mixed_alternating. Reviewed By: bletham Differential Revision: D91913317
1 parent 03dba0c commit 23efc91

File tree

3 files changed

+348
-59
lines changed

3 files changed

+348
-59
lines changed

botorch/optim/optimize_homotopy.py

Lines changed: 136 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from __future__ import annotations
77

8-
from collections.abc import Callable
8+
from collections.abc import Callable, Mapping, Sequence
9+
from itertools import product
910
from typing import Any
1011

1112
import torch
@@ -14,6 +15,10 @@
1415
from botorch.optim.homotopy import Homotopy
1516
from botorch.optim.initializers import TGenInitialConditions
1617
from botorch.optim.optimize import optimize_acqf, optimize_acqf_mixed
18+
from botorch.optim.optimize_mixed import (
19+
optimize_acqf_mixed_alternating,
20+
should_use_mixed_alternating_optimizer,
21+
)
1722
from torch import Tensor
1823

1924

@@ -64,7 +69,9 @@ def optimize_acqf_homotopy(
6469
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
6570
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
6671
nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None,
67-
fixed_features_list: list[dict[int, float]] | None = None,
72+
fixed_features: dict[int, float] | None = None,
73+
discrete_dims: Mapping[int, Sequence[float]] | None = None,
74+
cat_dims: Mapping[int, Sequence[float]] | None = None,
6875
post_processing_func: Callable[[Tensor], Tensor] | None = None,
6976
batch_initial_conditions: Tensor | None = None,
7077
gen_candidates: TGenCandidates | None = None,
@@ -124,10 +131,19 @@ def optimize_acqf_homotopy(
124131
Using non-linear inequality constraints also requires that ``batch_limit``
125132
is set to 1, which will be done automatically if not specified in
126133
``options``.
127-
fixed_features_list: A list of maps ``{feature_index: value}``. The i-th
128-
item represents the fixed_feature for the i-th optimization. If
129-
``fixed_features_list`` is provided, ``optimize_acqf_mixed`` is invoked.
130-
All indices (``feature_index``) should be non-negative.
134+
fixed_features: A map ``{feature_index: value}`` for features that should
135+
be fixed to a particular value during generation. Used with
136+
``optimize_acqf`` or ``optimize_acqf_mixed_alternating``.
137+
discrete_dims: A dictionary mapping indices of discrete and binary
138+
dimensions to a list of allowed values for that dimension. If provided
139+
along with ``cat_dims``, the optimizer is chosen based on the number
140+
of discrete combinations: ``optimize_acqf_mixed_alternating`` is used
141+
if there are more than 10 combinations, otherwise ``optimize_acqf_mixed``
142+
is used.
143+
cat_dims: A dictionary mapping indices of categorical dimensions
144+
to a list of allowed values for that dimension. If provided
145+
along with ``discrete_dims``, the optimizer is chosen based on the
146+
number of discrete combinations (see ``discrete_dims``).
131147
post_processing_func: A function that post-processes an optimization
132148
result appropriately (i.e., according to ``round-trip``
133149
transformations).
@@ -151,27 +167,69 @@ def optimize_acqf_homotopy(
151167
ic_gen_kwargs: Additional keyword arguments passed to function specified by
152168
``ic_generator``
153169
"""
154-
shared_optimize_acqf_kwargs = {
155-
"num_restarts": num_restarts,
156-
"inequality_constraints": inequality_constraints,
157-
"equality_constraints": equality_constraints,
158-
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
159-
"return_best_only": False, # False to make n_restarts persist through homotopy.
160-
"gen_candidates": gen_candidates,
161-
"ic_generator": ic_generator,
162-
"timeout_sec": timeout_sec,
163-
"retry_on_optimization_warning": retry_on_optimization_warning,
164-
**ic_gen_kwargs,
165-
}
166-
167-
if fixed_features_list and len(fixed_features_list) > 1:
170+
# Determine which optimization function to use based on the arguments.
171+
# This logic uses the shared `should_use_mixed_alternating_optimizer` utility
172+
# which is also used by `determine_optimizer` in Ax.
173+
use_mixed_alternating = should_use_mixed_alternating_optimizer(
174+
discrete_dims=discrete_dims,
175+
cat_dims=cat_dims,
176+
)
177+
178+
if use_mixed_alternating:
179+
# Use optimize_acqf_mixed_alternating for mixed discrete/continuous problems
180+
# with many discrete combinations.
181+
optimization_fn = optimize_acqf_mixed_alternating
182+
shared_optimize_acqf_kwargs = {
183+
"num_restarts": num_restarts,
184+
"inequality_constraints": inequality_constraints,
185+
"equality_constraints": equality_constraints,
186+
"discrete_dims": discrete_dims,
187+
"cat_dims": cat_dims,
188+
"fixed_features": fixed_features,
189+
}
190+
fixed_features_kwargs = {}
191+
elif discrete_dims is not None or cat_dims is not None:
192+
# Use optimize_acqf_mixed for mixed problems with few discrete combinations.
193+
# Build fixed_features_list from discrete_dims and cat_dims.
194+
195+
all_discrete_dims = {**(discrete_dims or {}), **(cat_dims or {})}
196+
dim_indices = sorted(all_discrete_dims.keys())
197+
value_lists = [all_discrete_dims[idx] for idx in dim_indices]
198+
fixed_features_list = [
199+
dict(zip(dim_indices, combo)) for combo in product(*value_lists)
200+
]
168201
optimization_fn = optimize_acqf_mixed
202+
shared_optimize_acqf_kwargs = {
203+
"num_restarts": num_restarts,
204+
"inequality_constraints": inequality_constraints,
205+
"equality_constraints": equality_constraints,
206+
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
207+
"return_best_only": False, # to make n_restarts persist through homotopy.
208+
"gen_candidates": gen_candidates,
209+
"ic_generator": ic_generator,
210+
"timeout_sec": timeout_sec,
211+
"retry_on_optimization_warning": retry_on_optimization_warning,
212+
**ic_gen_kwargs,
213+
}
169214
fixed_features_kwargs = {"fixed_features_list": fixed_features_list}
170215
else:
171216
optimization_fn = optimize_acqf
172-
fixed_features_kwargs = {
173-
"fixed_features": fixed_features_list[0] if fixed_features_list else None
217+
shared_optimize_acqf_kwargs = {
218+
"num_restarts": num_restarts,
219+
"inequality_constraints": inequality_constraints,
220+
"equality_constraints": equality_constraints,
221+
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
222+
"return_best_only": False, # to make n_restarts persist through homotopy.
223+
"gen_candidates": gen_candidates,
224+
"ic_generator": ic_generator,
225+
"timeout_sec": timeout_sec,
226+
"retry_on_optimization_warning": retry_on_optimization_warning,
227+
**ic_gen_kwargs,
174228
}
229+
if fixed_features is not None:
230+
fixed_features_kwargs = {"fixed_features": fixed_features}
231+
else:
232+
fixed_features_kwargs = {}
175233

176234
candidate_list, acq_value_list = [], []
177235
if q > 1:
@@ -183,16 +241,34 @@ def optimize_acqf_homotopy(
183241
homotopy.restart()
184242

185243
while not homotopy.should_stop:
186-
candidates, acq_values = optimization_fn(
187-
acq_function=acq_function,
188-
bounds=bounds,
189-
q=1,
190-
options=options,
191-
batch_initial_conditions=candidates,
192-
raw_samples=q_raw_samples,
193-
**fixed_features_kwargs,
194-
**shared_optimize_acqf_kwargs,
195-
)
244+
if use_mixed_alternating:
245+
# optimize_acqf_mixed_alternating handles its own initialization
246+
# and always returns best only, so we get shape (1, d).
247+
candidates, acq_values = optimization_fn(
248+
acq_function=acq_function,
249+
bounds=bounds,
250+
q=1,
251+
options=options,
252+
raw_samples=q_raw_samples or raw_samples,
253+
post_processing_func=post_processing_func,
254+
**fixed_features_kwargs,
255+
**shared_optimize_acqf_kwargs,
256+
)
257+
# Reshape to (1, 1, d) to match optimize_acqf output shape and
258+
# ensure acq_values is 1-d for prune_candidates.
259+
candidates = candidates.unsqueeze(0)
260+
acq_values = acq_values.view(-1)
261+
else:
262+
candidates, acq_values = optimization_fn(
263+
acq_function=acq_function,
264+
bounds=bounds,
265+
q=1,
266+
options=options,
267+
batch_initial_conditions=candidates,
268+
raw_samples=q_raw_samples,
269+
**fixed_features_kwargs,
270+
**shared_optimize_acqf_kwargs,
271+
)
196272

197273
homotopy.step()
198274

@@ -208,19 +284,36 @@ def optimize_acqf_homotopy(
208284
).unsqueeze(1)
209285

210286
# Optimize one more time with the final options
211-
candidates, acq_values = optimization_fn(
212-
acq_function=acq_function,
213-
bounds=bounds,
214-
q=1,
215-
options=final_options,
216-
raw_samples=q_raw_samples,
217-
batch_initial_conditions=candidates,
218-
**fixed_features_kwargs,
219-
**shared_optimize_acqf_kwargs,
220-
)
287+
if use_mixed_alternating:
288+
candidates, acq_values = optimization_fn(
289+
acq_function=acq_function,
290+
bounds=bounds,
291+
q=1,
292+
options=final_options,
293+
raw_samples=raw_samples,
294+
post_processing_func=post_processing_func,
295+
**fixed_features_kwargs,
296+
**shared_optimize_acqf_kwargs,
297+
)
298+
# Reshape to (1, 1, d) to match optimize_acqf output shape.
299+
candidates = candidates.unsqueeze(0)
300+
acq_values = acq_values.view(-1)
301+
else:
302+
candidates, acq_values = optimization_fn(
303+
acq_function=acq_function,
304+
bounds=bounds,
305+
q=1,
306+
options=final_options,
307+
raw_samples=q_raw_samples,
308+
batch_initial_conditions=candidates,
309+
**fixed_features_kwargs,
310+
**shared_optimize_acqf_kwargs,
311+
)
221312

222313
# Post-process the candidates and grab the best candidate
223-
if post_processing_func is not None:
314+
# Note: post_processing_func is already applied within
315+
# optimize_acqf_mixed_alternating, so we skip it here in that case.
316+
if post_processing_func is not None and not use_mixed_alternating:
224317
candidates = post_processing_func(candidates)
225318
acq_values = acq_function(candidates)
226319

botorch/optim/optimize_mixed.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@
5858
# Convergence is defined as the improvements of one discrete, followed by a scalar
5959
# optimization yield less than ``CONVERGENCE_TOL`` improvements.
6060

61+
# Threshold for choosing between optimize_acqf_mixed and
62+
# optimize_acqf_mixed_alternating in mixed (not fully discrete) search spaces.
63+
ALTERNATING_OPTIMIZER_THRESHOLD = 10
64+
65+
# For fully discrete search spaces.
66+
MAX_CHOICES_ENUMERATE = 10_000
67+
MAX_CARDINALITY_FOR_LOCAL_SEARCH = 100
68+
6169
SUPPORTED_OPTIONS = {
6270
"initialization_strategy",
6371
"tol",
@@ -74,6 +82,41 @@
7482
SUPPORTED_INITIALIZATION = {"continuous_relaxation", "equally_spaced", "random"}
7583

7684

85+
def should_use_mixed_alternating_optimizer(
86+
discrete_dims: Mapping[int, Sequence[float]] | None = None,
87+
cat_dims: Mapping[int, Sequence[float]] | None = None,
88+
) -> bool:
89+
r"""Determine whether to use ``optimize_acqf_mixed_alternating`` for a mixed
90+
(not fully discrete) search space based on the number of discrete combinations.
91+
92+
For mixed search spaces, if there are more than ``ALTERNATING_OPTIMIZER_THRESHOLD``
93+
combinations of discrete choices, we use ``optimize_acqf_mixed_alternating``,
94+
which alternates between continuous and discrete optimization steps. Otherwise,
95+
we use ``optimize_acqf_mixed``, which enumerates all discrete combinations and
96+
optimizes the continuous features with discrete features being fixed.
97+
98+
Args:
99+
discrete_dims: A dictionary mapping indices of discrete (ordinal)
100+
dimensions to their respective sets of values provided as a sequence.
101+
cat_dims: A dictionary mapping indices of categorical dimensions
102+
to their respective sets of values provided as a sequence.
103+
104+
Returns:
105+
``True`` if ``optimize_acqf_mixed_alternating`` should be used, ``False``
106+
if ``optimize_acqf_mixed`` should be used instead.
107+
"""
108+
if discrete_dims is None and cat_dims is None:
109+
return False
110+
111+
n_combos = 1
112+
for values in (discrete_dims or {}).values():
113+
n_combos *= len(values)
114+
for values in (cat_dims or {}).values():
115+
n_combos *= len(values)
116+
117+
return n_combos > ALTERNATING_OPTIMIZER_THRESHOLD
118+
119+
77120
def _setup_continuous_relaxation(
78121
discrete_dims: dict[int, list[float]],
79122
max_discrete_values: int,

0 commit comments

Comments
 (0)