55
66from __future__ import annotations
77
8- from collections .abc import Callable
8+ from collections .abc import Callable , Mapping , Sequence
9+ from itertools import product
910from typing import Any
1011
1112import torch
1415from botorch .optim .homotopy import Homotopy
1516from botorch .optim .initializers import TGenInitialConditions
1617from botorch .optim .optimize import optimize_acqf , optimize_acqf_mixed
18+ from botorch .optim .optimize_mixed import (
19+ optimize_acqf_mixed_alternating ,
20+ should_use_mixed_alternating_optimizer ,
21+ )
1722from torch import Tensor
1823
1924
@@ -64,7 +69,9 @@ def optimize_acqf_homotopy(
6469 inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
6570 equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
6671 nonlinear_inequality_constraints : list [tuple [Callable , bool ]] | None = None ,
67- fixed_features_list : list [dict [int , float ]] | None = None ,
72+ fixed_features : dict [int , float ] | None = None ,
73+ discrete_dims : Mapping [int , Sequence [float ]] | None = None ,
74+ cat_dims : Mapping [int , Sequence [float ]] | None = None ,
6875 post_processing_func : Callable [[Tensor ], Tensor ] | None = None ,
6976 batch_initial_conditions : Tensor | None = None ,
7077 gen_candidates : TGenCandidates | None = None ,
@@ -124,10 +131,19 @@ def optimize_acqf_homotopy(
124131 Using non-linear inequality constraints also requires that ``batch_limit``
125132 is set to 1, which will be done automatically if not specified in
126133 ``options``.
127- fixed_features_list: A list of maps ``{feature_index: value}``. The i-th
128- item represents the fixed_feature for the i-th optimization. If
129- ``fixed_features_list`` is provided, ``optimize_acqf_mixed`` is invoked.
130- All indices (``feature_index``) should be non-negative.
134+ fixed_features: A map ``{feature_index: value}`` for features that should
135+ be fixed to a particular value during generation. Used with
136+ ``optimize_acqf`` or ``optimize_acqf_mixed_alternating``.
137+ discrete_dims: A dictionary mapping indices of discrete and binary
138+ dimensions to a list of allowed values for that dimension. If provided
139+ along with ``cat_dims``, the optimizer is chosen based on the number
140+ of discrete combinations: ``optimize_acqf_mixed_alternating`` is used
141+ if there are more than 10 combinations, otherwise ``optimize_acqf_mixed``
142+ is used.
143+ cat_dims: A dictionary mapping indices of categorical dimensions
144+ to a list of allowed values for that dimension. If provided
145+ along with ``discrete_dims``, the optimizer is chosen based on the
146+ number of discrete combinations (see ``discrete_dims``).
131147 post_processing_func: A function that post-processes an optimization
132148 result appropriately (i.e., according to ``round-trip``
133149 transformations).
@@ -151,27 +167,69 @@ def optimize_acqf_homotopy(
151167 ic_gen_kwargs: Additional keyword arguments passed to function specified by
152168 ``ic_generator``
153169 """
154- shared_optimize_acqf_kwargs = {
155- "num_restarts" : num_restarts ,
156- "inequality_constraints" : inequality_constraints ,
157- "equality_constraints" : equality_constraints ,
158- "nonlinear_inequality_constraints" : nonlinear_inequality_constraints ,
159- "return_best_only" : False , # False to make n_restarts persist through homotopy.
160- "gen_candidates" : gen_candidates ,
161- "ic_generator" : ic_generator ,
162- "timeout_sec" : timeout_sec ,
163- "retry_on_optimization_warning" : retry_on_optimization_warning ,
164- ** ic_gen_kwargs ,
165- }
166-
167- if fixed_features_list and len (fixed_features_list ) > 1 :
170+ # Determine which optimization function to use based on the arguments.
171+ # This logic uses the shared `should_use_mixed_alternating_optimizer` utility
172+ # which is also used by `determine_optimizer` in Ax.
173+ use_mixed_alternating = should_use_mixed_alternating_optimizer (
174+ discrete_dims = discrete_dims ,
175+ cat_dims = cat_dims ,
176+ )
177+
178+ if use_mixed_alternating :
179+ # Use optimize_acqf_mixed_alternating for mixed discrete/continuous problems
180+ # with many discrete combinations.
181+ optimization_fn = optimize_acqf_mixed_alternating
182+ shared_optimize_acqf_kwargs = {
183+ "num_restarts" : num_restarts ,
184+ "inequality_constraints" : inequality_constraints ,
185+ "equality_constraints" : equality_constraints ,
186+ "discrete_dims" : discrete_dims ,
187+ "cat_dims" : cat_dims ,
188+ "fixed_features" : fixed_features ,
189+ }
190+ fixed_features_kwargs = {}
191+ elif discrete_dims is not None or cat_dims is not None :
192+ # Use optimize_acqf_mixed for mixed problems with few discrete combinations.
193+ # Build fixed_features_list from discrete_dims and cat_dims.
194+
195+ all_discrete_dims = {** (discrete_dims or {}), ** (cat_dims or {})}
196+ dim_indices = sorted (all_discrete_dims .keys ())
197+ value_lists = [all_discrete_dims [idx ] for idx in dim_indices ]
198+ fixed_features_list = [
199+ dict (zip (dim_indices , combo )) for combo in product (* value_lists )
200+ ]
168201 optimization_fn = optimize_acqf_mixed
202+ shared_optimize_acqf_kwargs = {
203+ "num_restarts" : num_restarts ,
204+ "inequality_constraints" : inequality_constraints ,
205+ "equality_constraints" : equality_constraints ,
206+ "nonlinear_inequality_constraints" : nonlinear_inequality_constraints ,
207+ "return_best_only" : False , # to make n_restarts persist through homotopy.
208+ "gen_candidates" : gen_candidates ,
209+ "ic_generator" : ic_generator ,
210+ "timeout_sec" : timeout_sec ,
211+ "retry_on_optimization_warning" : retry_on_optimization_warning ,
212+ ** ic_gen_kwargs ,
213+ }
169214 fixed_features_kwargs = {"fixed_features_list" : fixed_features_list }
170215 else :
171216 optimization_fn = optimize_acqf
172- fixed_features_kwargs = {
173- "fixed_features" : fixed_features_list [0 ] if fixed_features_list else None
217+ shared_optimize_acqf_kwargs = {
218+ "num_restarts" : num_restarts ,
219+ "inequality_constraints" : inequality_constraints ,
220+ "equality_constraints" : equality_constraints ,
221+ "nonlinear_inequality_constraints" : nonlinear_inequality_constraints ,
222+ "return_best_only" : False , # to make n_restarts persist through homotopy.
223+ "gen_candidates" : gen_candidates ,
224+ "ic_generator" : ic_generator ,
225+ "timeout_sec" : timeout_sec ,
226+ "retry_on_optimization_warning" : retry_on_optimization_warning ,
227+ ** ic_gen_kwargs ,
174228 }
229+ if fixed_features is not None :
230+ fixed_features_kwargs = {"fixed_features" : fixed_features }
231+ else :
232+ fixed_features_kwargs = {}
175233
176234 candidate_list , acq_value_list = [], []
177235 if q > 1 :
@@ -183,16 +241,34 @@ def optimize_acqf_homotopy(
183241 homotopy .restart ()
184242
185243 while not homotopy .should_stop :
186- candidates , acq_values = optimization_fn (
187- acq_function = acq_function ,
188- bounds = bounds ,
189- q = 1 ,
190- options = options ,
191- batch_initial_conditions = candidates ,
192- raw_samples = q_raw_samples ,
193- ** fixed_features_kwargs ,
194- ** shared_optimize_acqf_kwargs ,
195- )
244+ if use_mixed_alternating :
245+ # optimize_acqf_mixed_alternating handles its own initialization
246+ # and always returns best only, so we get shape (1, d).
247+ candidates , acq_values = optimization_fn (
248+ acq_function = acq_function ,
249+ bounds = bounds ,
250+ q = 1 ,
251+ options = options ,
252+ raw_samples = q_raw_samples or raw_samples ,
253+ post_processing_func = post_processing_func ,
254+ ** fixed_features_kwargs ,
255+ ** shared_optimize_acqf_kwargs ,
256+ )
257+ # Reshape to (1, 1, d) to match optimize_acqf output shape and
258+ # ensure acq_values is 1-d for prune_candidates.
259+ candidates = candidates .unsqueeze (0 )
260+ acq_values = acq_values .view (- 1 )
261+ else :
262+ candidates , acq_values = optimization_fn (
263+ acq_function = acq_function ,
264+ bounds = bounds ,
265+ q = 1 ,
266+ options = options ,
267+ batch_initial_conditions = candidates ,
268+ raw_samples = q_raw_samples ,
269+ ** fixed_features_kwargs ,
270+ ** shared_optimize_acqf_kwargs ,
271+ )
196272
197273 homotopy .step ()
198274
@@ -208,19 +284,36 @@ def optimize_acqf_homotopy(
208284 ).unsqueeze (1 )
209285
210286 # Optimize one more time with the final options
211- candidates , acq_values = optimization_fn (
212- acq_function = acq_function ,
213- bounds = bounds ,
214- q = 1 ,
215- options = final_options ,
216- raw_samples = q_raw_samples ,
217- batch_initial_conditions = candidates ,
218- ** fixed_features_kwargs ,
219- ** shared_optimize_acqf_kwargs ,
220- )
287+ if use_mixed_alternating :
288+ candidates , acq_values = optimization_fn (
289+ acq_function = acq_function ,
290+ bounds = bounds ,
291+ q = 1 ,
292+ options = final_options ,
293+ raw_samples = raw_samples ,
294+ post_processing_func = post_processing_func ,
295+ ** fixed_features_kwargs ,
296+ ** shared_optimize_acqf_kwargs ,
297+ )
298+ # Reshape to (1, 1, d) to match optimize_acqf output shape.
299+ candidates = candidates .unsqueeze (0 )
300+ acq_values = acq_values .view (- 1 )
301+ else :
302+ candidates , acq_values = optimization_fn (
303+ acq_function = acq_function ,
304+ bounds = bounds ,
305+ q = 1 ,
306+ options = final_options ,
307+ raw_samples = q_raw_samples ,
308+ batch_initial_conditions = candidates ,
309+ ** fixed_features_kwargs ,
310+ ** shared_optimize_acqf_kwargs ,
311+ )
221312
222313 # Post-process the candidates and grab the best candidate
223- if post_processing_func is not None :
314+ # Note: post_processing_func is already applied within
315+ # optimize_acqf_mixed_alternating, so we skip it here in that case.
316+ if post_processing_func is not None and not use_mixed_alternating :
224317 candidates = post_processing_func (candidates )
225318 acq_values = acq_function (candidates )
226319
0 commit comments