Skip to content

Commit 4cc5ed5

Browse files
esantorellafacebook-github-bot
authored andcommitted
Stop some warnings in unit tests (#1992)
Summary: ## Motivation Warning output from unit tests sometimes indicates a serious problem and sometimes merely clutters the output so we can't notice the serious problems. These warnings are the latter: * InputDataWarning: Input data is not contained to the unit cube. Please consider min-max scaling the input data (occurred 194 times, now 0) * BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly. (occurred 40 times, now 0) * The first positional argument of samplers, `num_samples`, has been deprecated and replaced with `sample_shape`, which expects a `torch.Size` object.' (occurred 35 times, now 0) ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #1992 Test Plan: Units ## Related PRs #1792, #1539 Reviewed By: saitcakmak Differential Revision: D48530764 Pulled By: esantorella fbshipit-source-id: c5b1898ce8156a6f02550acb09bd5eba0c157c5e
1 parent c43b074 commit 4cc5ed5

20 files changed

+99
-64
lines changed

botorch/utils/testing.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,25 +43,31 @@ class BotorchTestCase(TestCase):
4343

4444
device = torch.device("cpu")
4545

46-
def setUp(self):
46+
def setUp(self, suppress_input_warnings: bool = True) -> None:
4747
warnings.resetwarnings()
4848
settings.debug._set_state(False)
4949
warnings.simplefilter("always", append=True)
50-
warnings.filterwarnings(
51-
"ignore",
52-
message="The model inputs are of type",
53-
category=UserWarning,
54-
)
55-
warnings.filterwarnings(
56-
"ignore",
57-
message="Non-strict enforcement of botorch tensor conventions.",
58-
category=BotorchTensorDimensionWarning,
59-
)
60-
warnings.filterwarnings(
61-
"ignore",
62-
message="Input data is not standardized.",
63-
category=InputDataWarning,
64-
)
50+
if suppress_input_warnings:
51+
warnings.filterwarnings(
52+
"ignore",
53+
message="The model inputs are of type",
54+
category=UserWarning,
55+
)
56+
warnings.filterwarnings(
57+
"ignore",
58+
message="Non-strict enforcement of botorch tensor conventions.",
59+
category=BotorchTensorDimensionWarning,
60+
)
61+
warnings.filterwarnings(
62+
"ignore",
63+
message="Input data is not standardized.",
64+
category=InputDataWarning,
65+
)
66+
warnings.filterwarnings(
67+
"ignore",
68+
message="Input data is not contained to the unit cube.",
69+
category=InputDataWarning,
70+
)
6571

6672
def assertAllClose(
6773
self,

test/acquisition/test_input_constructors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class DummyAcquisitionFunction(AcquisitionFunction):
9797

9898
class InputConstructorBaseTestCase:
9999
def setUp(self) -> None:
100+
super().setUp()
100101
self.mock_model = MockModel(
101102
posterior=MockPosterior(mean=None, variance=None, base_shape=(1,))
102103
)

test/acquisition/test_monte_carlo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def test_cache_root(self):
556556
"prune_baseline": False,
557557
"cache_root": True,
558558
"posterior_transform": ScalarizedPosteriorTransform(weights=torch.ones(m)),
559-
"sampler": SobolQMCNormalSampler(5),
559+
"sampler": SobolQMCNormalSampler(sample_shape=torch.Size([5])),
560560
}
561561
acqf = qNoisyExpectedImprovement(**nei_args)
562562
X = torch.randn_like(X_baseline)

test/acquisition/test_preference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919

2020
class TestPreferenceAcquisitionFunctions(BotorchTestCase):
21-
def setUp(self):
21+
def setUp(self) -> None:
22+
super().setUp()
2223
self.twargs = {"dtype": torch.double}
2324
self.X_dim = 3
2425
self.Y_dim = 2

test/acquisition/test_prior_guided.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def get_weighted_val(ei_val, prob, exponent, use_log):
3939

4040
class TestPriorGuidedAcquisitionFunction(BotorchTestCase):
4141
def setUp(self):
42+
super().setUp()
4243
self.prior = DummyPrior()
4344
self.train_X = torch.rand(5, 3, dtype=torch.double, device=self.device)
4445
self.train_Y = self.train_X.norm(dim=-1, keepdim=True)

test/models/test_gpytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def test_fantasize_flag(self):
296296
self.assertFalse(model.last_fantasize_flag)
297297
model.posterior(test_X)
298298
self.assertFalse(model.last_fantasize_flag)
299-
model.fantasize(test_X, SobolQMCNormalSampler(2))
299+
model.fantasize(test_X, SobolQMCNormalSampler(sample_shape=torch.Size([2])))
300300
self.assertTrue(model.last_fantasize_flag)
301301
model.last_fantasize_flag = False
302302
with fantasize():

test/models/test_model_list_gp_regression.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,9 @@ def test_fantasize(self):
382382
m1 = SingleTaskGP(torch.rand(5, 2), torch.rand(5, 1)).eval()
383383
m2 = SingleTaskGP(torch.rand(5, 2), torch.rand(5, 1)).eval()
384384
modellist = ModelListGP(m1, m2)
385-
fm = modellist.fantasize(torch.rand(3, 2), sampler=IIDNormalSampler(2))
385+
fm = modellist.fantasize(
386+
torch.rand(3, 2), sampler=IIDNormalSampler(sample_shape=torch.Size([2]))
387+
)
386388
self.assertIsInstance(fm, ModelListGP)
387389
for i in range(2):
388390
fm_i = fm.models[i]
@@ -391,8 +393,8 @@ def test_fantasize(self):
391393
self.assertEqual(fm_i.train_targets.shape, torch.Size([2, 8]))
392394

393395
# test decoupled
394-
sampler1 = IIDNormalSampler(2)
395-
sampler2 = IIDNormalSampler(2)
396+
sampler1 = IIDNormalSampler(sample_shape=torch.Size([2]))
397+
sampler2 = IIDNormalSampler(sample_shape=torch.Size([2]))
396398
eval_mask = torch.tensor(
397399
[[1, 0], [0, 1], [1, 0]],
398400
dtype=torch.bool,
@@ -457,7 +459,7 @@ def _get_fant_mean(
457459
return fant.posterior(target_x).mean.mean(dim=(-2, -3))
458460

459461
# ~0
460-
sampler = IIDNormalSampler(10, seed=0)
462+
sampler = IIDNormalSampler(sample_shape=torch.Size([10]), seed=0)
461463
fant_mean_with_manual_transform = _get_fant_mean(
462464
model_manually_transformed, sampler=sampler
463465
)
@@ -490,8 +492,8 @@ def _get_fant_mean(
490492
)
491493
# test decoupled
492494
sampler = ListSampler(
493-
IIDNormalSampler(10, seed=0),
494-
IIDNormalSampler(10, seed=0),
495+
IIDNormalSampler(sample_shape=torch.Size([10]), seed=0),
496+
IIDNormalSampler(sample_shape=torch.Size([10]), seed=0),
495497
)
496498
fant_mean_with_manual_transform = _get_fant_mean(
497499
model_manually_transformed,
@@ -539,7 +541,7 @@ def test_fantasize_with_outcome_transform_fixed_noise(self) -> None:
539541
100 at x=0. If transforms are not properly applied, we'll get answers
540542
on the order of ~1. Answers between 99 and 101 are acceptable.
541543
"""
542-
n_fants = 20
544+
n_fants = torch.Size([20])
543545
y_at_low_x = 100.0
544546
y_at_high_x = -40.0
545547

test/models/utils/test_assorted.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ def test_add_output_dim(self):
7272

7373

7474
class TestInputDataChecks(BotorchTestCase):
75+
def setUp(self) -> None:
76+
# The super class usually disables input data warnings in unit tests.
77+
# Don't do that here.
78+
super().setUp(suppress_input_warnings=False)
79+
7580
def test_check_no_nans(self):
7681
check_no_nans(torch.tensor([1.0, 2.0]))
7782
with self.assertRaises(InputDataError):
@@ -87,12 +92,10 @@ def test_check_min_max_scaling(self):
8792
any(issubclass(w.category, InputDataWarning) for w in ws)
8893
)
8994
check_min_max_scaling(X=X, raise_on_fail=True)
90-
with warnings.catch_warnings(record=True) as ws:
95+
with self.assertWarnsRegex(
96+
expected_warning=InputDataWarning, expected_regex="not scaled"
97+
):
9198
check_min_max_scaling(X=X, strict=True)
92-
self.assertTrue(
93-
any(issubclass(w.category, InputDataWarning) for w in ws)
94-
)
95-
self.assertTrue(any("not scaled" in str(w.message) for w in ws))
9699
with self.assertRaises(InputDataError):
97100
check_min_max_scaling(X=X, strict=True, raise_on_fail=True)
98101
# check proper input

test/optim/test_fit.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525

2626

2727
class TestFitGPyTorchMLLScipy(BotorchTestCase):
28-
def setUp(self):
28+
def setUp(self) -> None:
29+
super().setUp()
2930
self.mlls = {}
3031
with torch.random.fork_rng():
3132
torch.manual_seed(0)
@@ -172,7 +173,8 @@ def _assert_np_array_is_float64_type(array) -> bool:
172173

173174

174175
class TestFitGPyTorchMLLTorch(BotorchTestCase):
175-
def setUp(self):
176+
def setUp(self) -> None:
177+
super().setUp()
176178
self.mlls = {}
177179
with torch.random.fork_rng():
178180
torch.manual_seed(0)
@@ -236,7 +238,8 @@ def _test_fit_gpytorch_mll_torch(self, mll):
236238

237239

238240
class TestFitGPyTorchScipy(BotorchTestCase):
239-
def setUp(self):
241+
def setUp(self) -> None:
242+
super().setUp()
240243
self.mlls = {}
241244
with torch.random.fork_rng():
242245
torch.manual_seed(0)
@@ -372,6 +375,7 @@ def _test_fit_gpytorch_scipy(self, mll):
372375

373376
class TestFitGPyTorchTorch(BotorchTestCase):
374377
def setUp(self):
378+
super().setUp()
375379
self.mlls = {}
376380
with torch.random.fork_rng():
377381
torch.manual_seed(0)

test/optim/test_initializers.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,10 @@ def test_gen_batch_initial_conditions(self):
188188
MockAcquisitionFunction,
189189
"__call__",
190190
wraps=mock_acqf.__call__,
191-
) as mock_acqf_call:
191+
) as mock_acqf_call, warnings.catch_warnings():
192+
warnings.simplefilter(
193+
"ignore", category=BadInitialCandidatesWarning
194+
)
192195
batch_initial_conditions = gen_batch_initial_conditions(
193196
acq_function=mock_acqf,
194197
bounds=bounds,
@@ -248,6 +251,9 @@ def test_gen_batch_initial_conditions_highdim(self):
248251
[True, False], [None, 1234], [None, ffs_map], [True, False]
249252
):
250253
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
254+
warnings.simplefilter(
255+
"ignore", category=BadInitialCandidatesWarning
256+
)
251257
batch_initial_conditions = gen_batch_initial_conditions(
252258
acq_function=MockAcquisitionFunction(),
253259
bounds=bounds,
@@ -279,19 +285,17 @@ def test_gen_batch_initial_conditions_highdim(self):
279285
torch.all(batch_initial_conditions[..., idx] == val)
280286
)
281287

282-
def test_gen_batch_initial_conditions_warning(self):
288+
def test_gen_batch_initial_conditions_warning(self) -> None:
283289
for dtype in (torch.float, torch.double):
284290
bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype)
285291
samples = torch.zeros(10, 1, 2, device=self.device, dtype=dtype)
286-
with ExitStack() as es:
287-
ws = es.enter_context(warnings.catch_warnings(record=True))
288-
es.enter_context(settings.debug(True))
289-
es.enter_context(
290-
mock.patch(
291-
"botorch.optim.initializers.draw_sobol_samples",
292-
return_value=samples,
293-
)
294-
)
292+
with self.assertWarnsRegex(
293+
expected_warning=BadInitialCandidatesWarning,
294+
expected_regex="Unable to find non-zero acquisition",
295+
), mock.patch(
296+
"botorch.optim.initializers.draw_sobol_samples",
297+
return_value=samples,
298+
):
295299
batch_initial_conditions = gen_batch_initial_conditions(
296300
acq_function=MockAcquisitionFunction(),
297301
bounds=bounds,
@@ -300,16 +304,12 @@ def test_gen_batch_initial_conditions_warning(self):
300304
raw_samples=10,
301305
options={"seed": 1234},
302306
)
303-
self.assertEqual(len(ws), 1)
304-
self.assertTrue(
305-
any(issubclass(w.category, BadInitialCandidatesWarning) for w in ws)
306-
)
307-
self.assertTrue(
308-
torch.equal(
309-
batch_initial_conditions,
310-
torch.zeros(2, 1, 2, device=self.device, dtype=dtype),
311-
)
307+
self.assertTrue(
308+
torch.equal(
309+
batch_initial_conditions,
310+
torch.zeros(2, 1, 2, device=self.device, dtype=dtype),
312311
)
312+
)
313313

314314
def test_gen_batch_initial_conditions_transform_intra_point_constraint(self):
315315
for dtype in (torch.float, torch.double):
@@ -549,7 +549,10 @@ def test_gen_batch_initial_conditions_constraints(self):
549549
MockAcquisitionFunction,
550550
"__call__",
551551
wraps=mock_acqf.__call__,
552-
) as mock_acqf_call:
552+
) as mock_acqf_call, warnings.catch_warnings():
553+
warnings.simplefilter(
554+
"ignore", category=BadInitialCandidatesWarning
555+
)
553556
batch_initial_conditions = gen_batch_initial_conditions(
554557
acq_function=mock_acqf,
555558
bounds=bounds,
@@ -723,7 +726,10 @@ def generator(n: int, q: int, seed: int):
723726
MockAcquisitionFunction,
724727
"__call__",
725728
wraps=mock_acqf.__call__,
726-
):
729+
), warnings.catch_warnings():
730+
warnings.simplefilter(
731+
"ignore", category=BadInitialCandidatesWarning
732+
)
727733
batch_initial_conditions = gen_batch_initial_conditions(
728734
acq_function=mock_acqf,
729735
bounds=bounds,

0 commit comments

Comments
 (0)