Skip to content

Commit

Permalink
lint: get full coverage and run ufmt
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Nov 15, 2024
1 parent 88667b7 commit 374e313
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
30 changes: 15 additions & 15 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def __post_init__(self) -> None:
# Use object.__setattr__ to bypass immutability and set a value
object.__setattr__(self, "raw_samples", None)

elif (len(batch_initial_conditions_shape) == 3 # should always be true
elif (
len(batch_initial_conditions_shape) == 3 # should always be true
and self.num_restarts is not None
and batch_initial_conditions_shape[0] not in [1, self.num_restarts]
):
Expand Down Expand Up @@ -289,7 +290,6 @@ def _combine_initial_conditions(
generated_initial_conditions: Tensor | None = None,
num_restarts: int | None = None,
) -> Tensor:

if (
provided_initial_conditions is not None
and generated_initial_conditions is not None
Expand Down Expand Up @@ -1327,18 +1327,15 @@ def _gen_starting_points_local_search(
acqvals_init.topk(k=min_points, largest=True, dim=0).indices
]

if provided_X0 is not None and generated_X0 is not None:
X0 = torch.cat([provided_X0, generated_X0], dim=0)
elif provided_X0 is not None:
X0 = provided_X0
elif generated_X0 is not None:
X0 = generated_X0
else:
raise ValueError(
"Either `batch_initial_conditions` or `raw_samples` must be set."
)

return X0
# permute to match the required behavior of _combine_initial_conditions
return _combine_initial_conditions(
provided_initial_conditions=provided_X0.permute(1, 0, 2)
if provided_X0 is not None
else None,
generated_initial_conditions=generated_X0.permute(1, 0, 2)
if generated_X0 is not None
else None,
).permute(1, 0, 2)


def optimize_acqf_discrete_local_search(
Expand Down Expand Up @@ -1399,7 +1396,10 @@ def optimize_acqf_discrete_local_search(
len(batch_initial_conditions.shape) == 3
and batch_initial_conditions.shape[-2] == 1
):
raise ValueError("batch_initial_conditions must have shape `n x 1 x d` if given.")
raise ValueError(
"batch_initial_conditions must have shape `n x 1 x d` if "
f"given (recieved {batch_initial_conditions})."
)

candidate_list = []
base_X_pending = acq_function.X_pending if q > 1 else None
Expand Down
12 changes: 11 additions & 1 deletion test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def test_optimize_acqf_batch_limit(self) -> None:

for gen_candidates, (ic_shape, expected_shape) in product(
[gen_candidates_scipy, gen_candidates_torch],
zip(initial_conditions, expected_shapes, strict=True)
zip(initial_conditions, expected_shapes, strict=True),
):
ics = torch.ones(ic_shape) if ic_shape is not None else None
with self.subTest(gen_candidates=gen_candidates, initial_conditions=ics):
Expand Down Expand Up @@ -1989,6 +1989,16 @@ def test_optimize_acqf_discrete_local_search(self):
)
)

# test ValueError for batch_initial_conditions shape
with self.assertRaisesRegex(ValueError, "must have shape `n x 1 x d`"):
candidates, acq_value = optimize_acqf_discrete_local_search(
acq_function=mock_acq_function,
q=q,
discrete_choices=discrete_choices,
X_avoid=torch.tensor([[6, 4, 9]], **tkwargs),
batch_initial_conditions=torch.tensor([[0, 2, 5]], **tkwargs),
)

# test _gen_batch_initial_conditions_local_search
with self.assertRaisesRegex(RuntimeError, "Failed to generate"):
_gen_batch_initial_conditions_local_search(
Expand Down

0 comments on commit 374e313

Please sign in to comment.