Skip to content

Commit 26b86e0

Browse files
esantorellafacebook-github-bot
authored andcommitted
Make optimizers raise an error when provided negative fixed features
Summary: Context: See #2602 This PR: * Adds a check for negative fixed_features keys to input validation for optimizers. This applies to all of the optimizers that take fixed_features. * Updates docstrings Differential Revision: D65272024
1 parent 66660e3 commit 26b86e0

File tree

2 files changed

+55
-6
lines changed

2 files changed

+55
-6
lines changed

botorch/optim/optimize.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ def __post_init__(self) -> None:
125125
"Must specify `raw_samples` when "
126126
"`batch_initial_conditions` is None`."
127127
)
128+
if self.fixed_features is not None and any(
129+
(k < 0 for k in self.fixed_features)
130+
):
131+
raise ValueError("All indices (keys) in `fixed_features` must be >= 0.")
128132

129133
def get_ic_generator(self) -> TGenInitialConditions:
130134
if self.ic_generator is not None:
@@ -467,7 +471,8 @@ def optimize_acqf(
467471
is set to 1, which will be done automatically if not specified in
468472
`options`.
469473
fixed_features: A map `{feature_index: value}` for features that
470-
should be fixed to a particular value during generation.
474+
should be fixed to a particular value during generation. All indices
475+
should be non-negative.
471476
post_processing_func: A function that post-processes an optimization
472477
result appropriately (i.e., according to `round-trip`
473478
transformations).
@@ -610,7 +615,8 @@ def optimize_acqf_cyclic(
610615
with each tuple encoding an inequality constraint of the form
611616
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`
612617
fixed_features: A map `{feature_index: value}` for features that
613-
should be fixed to a particular value during generation.
618+
should be fixed to a particular value during generation. All indices
619+
should be non-negative.
614620
post_processing_func: A function that post-processes an optimization
615621
result appropriately (i.e., according to `round-trip`
616622
transformations).
@@ -758,11 +764,13 @@ def optimize_acqf_list(
758764
Using non-linear inequality constraints also requires that `batch_limit`
759765
is set to 1, which will be done automatically if not specified in
760766
`options`.
761-
fixed_features: A map `{feature_index: value}` for features that
762-
should be fixed to a particular value during generation.
767+
fixed_features: A map `{feature_index: value}` for features that should
768+
be fixed to a particular value during generation. All indices
769+
(`feature_index`) should be non-negative.
763770
fixed_features_list: A list of maps `{feature_index: value}`. The i-th
764771
item represents the fixed_feature for the i-th optimization. If
765772
`fixed_features_list` is provided, `optimize_acqf_mixed` is invoked.
773+
All indices (`feature_index`) should be non-negative.
766774
post_processing_func: A function that post-processes an optimization
767775
result appropriately (i.e., according to `round-trip`
768776
transformations).
@@ -872,7 +880,8 @@ def optimize_acqf_mixed(
872880
raw_samples: Number of samples for initialization. This is required
873881
if `batch_initial_conditions` is not specified.
874882
fixed_features_list: A list of maps `{feature_index: value}`. The i-th
875-
item represents the fixed_feature for the i-th optimization.
883+
item represents the fixed_feature for the i-th optimization. All
884+
indices (`feature_index`) should be non-negative.
876885
options: Options for candidate generation.
877886
inequality constraints: A list of tuples (indices, coefficients, rhs),
878887
with each tuple encoding an inequality constraint of the form

test/optim/test_optimize.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import itertools
88
import warnings
9+
from functools import partial
910
from itertools import product
1011
from typing import Any
1112
from unittest import mock
@@ -1119,6 +1120,45 @@ def __call__(self, x, f):
11191120
self.assertEqual(f_np_wrapper.call_count, 2)
11201121

11211122

1123+
class TestAllOptimizers(BotorchTestCase):
1124+
def test_raises_with_negative_fixed_features(self) -> None:
1125+
cases = {
1126+
"optimize_acqf": partial(
1127+
optimize_acqf,
1128+
acq_function=MockAcquisitionFunction(),
1129+
fixed_features={-1: 0.0},
1130+
q=1,
1131+
),
1132+
"optimize_acqf_cyclic": partial(
1133+
optimize_acqf_cyclic,
1134+
acq_function=MockAcquisitionFunction(),
1135+
fixed_features={-1: 0.0},
1136+
q=1,
1137+
),
1138+
"optimize_acqf_mixed": partial(
1139+
optimize_acqf_mixed,
1140+
acq_function=MockAcquisitionFunction(),
1141+
fixed_features_list=[{-1: 0.0}],
1142+
q=1,
1143+
),
1144+
"optimize_acqf_list": partial(
1145+
optimize_acqf_list,
1146+
acq_function_list=[MockAcquisitionFunction()],
1147+
fixed_features={-1: 0.0},
1148+
),
1149+
}
1150+
1151+
for name, func in cases.items():
1152+
with self.subTest(name), self.assertRaisesRegex(
1153+
ValueError, "must be >= 0."
1154+
):
1155+
func(
1156+
bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]], device=self.device),
1157+
num_restarts=4,
1158+
raw_samples=16,
1159+
)
1160+
1161+
11221162
class TestOptimizeAcqfCyclic(BotorchTestCase):
11231163
@mock.patch("botorch.optim.optimize._optimize_acqf") # noqa: C901
11241164
# TODO: make sure this runs without mock
@@ -1171,7 +1211,7 @@ def test_optimize_acqf_cyclic(self, mock_optimize_acqf):
11711211
"set_X_pending",
11721212
wraps=mock_acq_function.set_X_pending,
11731213
) as mock_set_X_pending:
1174-
candidates, acq_value = optimize_acqf_cyclic(
1214+
candidates, _ = optimize_acqf_cyclic(
11751215
acq_function=mock_acq_function,
11761216
bounds=bounds,
11771217
q=q,

0 commit comments

Comments
 (0)