Skip to content

Commit 1697317

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Structuring arguments in gen_candidates_torch (#3019)
Summary: X-link: #3019 Structuring the optimizer and stopping_criterion arguments in gen_candidates_torch. Fixes #2994. Reviewed By: sdaulton Differential Revision: D82839737
1 parent af48840 commit 1697317

File tree

2 files changed

+73
-4
lines changed

2 files changed

+73
-4
lines changed

botorch/generation/gen.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,9 @@ def gen_candidates_torch(
532532
optimizer (Optimizer): The pytorch optimizer to use to perform
533533
candidate search.
534534
options: Options used to control the optimization. Includes
535-
maxiter: Maximum number of iterations
535+
optimizer_options: Dict of additional options to pass to the optimizer
536+
(e.g. lr, weight_decay)
537+
stopping_criterion_options: Dict of options for the stopping criterion.
536538
callback: A callback function accepting the current iteration, loss,
537539
and gradients as arguments. This function is executed after computing
538540
the loss and gradients, but before calling the optimizer.
@@ -571,7 +573,6 @@ def gen_candidates_torch(
571573
# the 1st order optimizers implemented in this method.
572574
# Here, it does not matter whether one combines multiple optimizations into
573575
# one or not.
574-
options.pop("max_optimization_problem_aggregation_size", None)
575576
_clamp = partial(columnwise_clamp, lower=lower_bounds, upper=upper_bounds)
576577
clamped_candidates = _clamp(initial_conditions)
577578
if fixed_features:
@@ -580,11 +581,30 @@ def gen_candidates_torch(
580581
[i for i in range(clamped_candidates.shape[-1]) if i not in fixed_features],
581582
]
582583
clamped_candidates = clamped_candidates.requires_grad_(True)
583-
_optimizer = optimizer(params=[clamped_candidates], lr=options.get("lr", 0.025))
584+
585+
# Extract optimizer-specific options from the options dict
586+
optimizer_options = options.get("optimizer_options", {}).copy()
587+
stopping_criterion_options = options.get("stopping_criterion_options", {}).copy()
588+
589+
# Backward compatibility: if old 'maxiter' parameter is passed, move it to
590+
# stopping_criterion_options with a deprecation warning
591+
if "maxiter" in options:
592+
warnings.warn(
593+
"Passing 'maxiter' directly in options is deprecated. "
594+
"Please use options['stopping_criterion_options']['maxiter'] instead.",
595+
DeprecationWarning,
596+
stacklevel=2,
597+
)
598+
# For backward compatibility, pass to stopping_criterion_options
599+
if "maxiter" not in stopping_criterion_options:
600+
stopping_criterion_options["maxiter"] = options["maxiter"]
601+
602+
optimizer_options.setdefault("lr", 0.025)
603+
_optimizer = optimizer(params=[clamped_candidates], **optimizer_options)
584604

585605
i = 0
586606
stop = False
587-
stopping_criterion = ExpMAStoppingCriterion(**options)
607+
stopping_criterion = ExpMAStoppingCriterion(**stopping_criterion_options)
588608
while not stop:
589609
i += 1
590610
with torch.no_grad():

test/generation/test_gen.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,55 @@ def test_gen_candidates_torch_timeout_behavior(self):
324324
self.assertFalse(any(issubclass(w.category, OptimizationWarning) for w in ws))
325325
self.assertTrue("Optimization timed out" in logs.output[-1])
326326

327+
def test_gen_candidates_torch_optimizer_with_optimizer_args(self):
328+
"""Test that optimizer is created with correct args."""
329+
self._setUp(double=False)
330+
qEI = qExpectedImprovement(self.model, best_f=self.f_best)
331+
332+
# Test new structured API
333+
with self.subTest(api="structured"):
334+
# Create a mock optimizer class
335+
mock_optimizer_class = mock.MagicMock()
336+
mock_optimizer_instance = mock.MagicMock()
337+
mock_optimizer_class.return_value = mock_optimizer_instance
338+
339+
gen_candidates_torch(
340+
initial_conditions=self.initial_conditions,
341+
acquisition_function=qEI,
342+
lower_bounds=0,
343+
upper_bounds=1,
344+
optimizer=mock_optimizer_class,
345+
options={
346+
"optimizer_options": {"lr": 0.02, "weight_decay": 1e-5},
347+
"stopping_criterion_options": {"maxiter": 1},
348+
},
349+
)
350+
351+
# Verify that the optimizer was called with the correct arguments
352+
mock_optimizer_class.assert_called_once()
353+
call_args = mock_optimizer_class.call_args
354+
self.assertIn("params", call_args.kwargs)
355+
self.assertEqual(call_args.kwargs["lr"], 0.02)
356+
self.assertEqual(call_args.kwargs["weight_decay"], 1e-5)
357+
358+
# Test backward compatibility with old maxiter parameter
359+
with self.subTest(api="backward_compat"):
360+
with warnings.catch_warnings(record=True) as ws:
361+
warnings.simplefilter("always", category=DeprecationWarning)
362+
gen_candidates_torch(
363+
initial_conditions=self.initial_conditions,
364+
acquisition_function=qEI,
365+
lower_bounds=0,
366+
upper_bounds=1,
367+
options={"maxiter": 1},
368+
)
369+
# Verify deprecation warning was raised
370+
deprecation_warnings = [
371+
w for w in ws if issubclass(w.category, DeprecationWarning)
372+
]
373+
self.assertTrue(len(deprecation_warnings) > 0)
374+
self.assertIn("maxiter", str(deprecation_warnings[0].message))
375+
327376
def test_gen_candidates_scipy_warns_opt_no_res(self):
328377
ckwargs = {"dtype": torch.float, "device": self.device}
329378

0 commit comments

Comments
 (0)