Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions aepsych/acquisition/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,9 @@ def __init__(
"""
super().__init__(model=model, target=target, lookahead_type=lookahead_type)
self.posterior_transform = posterior_transform
assert (
Xq is not None or query_set_size is not None
), "Must pass either query set size or a query set!"
assert Xq is not None or query_set_size is not None, (
"Must pass either query set size or a query set!"
)
if Xq is not None and query_set_size is not None:
assert Xq.shape[0] == query_set_size, (
"If passing both Xq and query_set_size,"
Expand Down Expand Up @@ -360,9 +360,9 @@ def __init__(
query_set_size (int, optional): Number of points in the query set.
Xq (torch.Tensor, optional): (m x d) global reference set.
"""
assert (
lookahead_type == "levelset"
), f"ApproxGlobalSUR only supports lookahead on level set, got {lookahead_type}!"
assert lookahead_type == "levelset", (
f"ApproxGlobalSUR only supports lookahead on level set, got {lookahead_type}!"
)
super().__init__(
lb=lb,
ub=ub,
Expand Down
2 changes: 1 addition & 1 deletion aepsych/benchmark/pathos_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def run_benchmarks_with_checkpoints(
temp_results["rep"] = temp_results["rep"] + n_reps_per_chunk * chunk
temp_results.to_csv(intermediate_fname)
print(
f"Collate done in {time.time()-collate_start} seconds, {len(bench.futures)}/{bench.num_benchmarks} left"
f"Collate done in {time.time() - collate_start} seconds, {len(bench.futures)}/{bench.num_benchmarks} left"
)

print(f"{benchmark_name} chunk {chunk} fully done!")
Expand Down
6 changes: 3 additions & 3 deletions aepsych/benchmark/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def evaluate(
# always eval f
f_hat = self.f_hat(model)
p_hat = self.p_hat(model)
assert (
self.f_true.shape == f_hat.shape
), f"self.f_true.shape=={self.f_true.shape} != f_hat.shape=={f_hat.shape}"
assert self.f_true.shape == f_hat.shape, (
f"self.f_true.shape=={self.f_true.shape} != f_hat.shape=={f_hat.shape}"
)

mae_f = torch.mean(torch.abs(self.f_true - f_hat))
mse_f = torch.mean((self.f_true - f_hat) ** 2)
Expand Down
6 changes: 3 additions & 3 deletions aepsych/factory/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ def default_mean_covar_factory(
stacklevel=2,
)

assert (config is not None) or (
dim is not None
), "Either config or dim must be provided!"
assert (config is not None) or (dim is not None), (
"Either config or dim must be provided!"
)

assert stimuli_per_trial in (1, 2), "stimuli_per_trial must be 1 or 2!"

Expand Down
12 changes: 6 additions & 6 deletions aepsych/factory/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def pairwise_mean_covar_factory(
stacklevel=2,
)

assert (
stimuli_per_trial == 1
), f"pairwise_mean_covar_factory must have stimuli_per_trial == 1, but {stimuli_per_trial} was passed instead!"
assert stimuli_per_trial == 1, (
f"pairwise_mean_covar_factory must have stimuli_per_trial == 1, but {stimuli_per_trial} was passed instead!"
)
lb = config.gettensor("common", "lb")
ub = config.gettensor("common", "ub")
assert lb.shape[0] == ub.shape[0], "bounds shape mismatch!"
Expand Down Expand Up @@ -162,9 +162,9 @@ def pairwise_mean_covar_factory(

if len(shared_dims) > 0:
active_dims = [i for i in range(config_dim) if i not in shared_dims]
assert (
len(active_dims) % 2 == 0
), "dimensionality of non-shared dims must be even!"
assert len(active_dims) % 2 == 0, (
"dimensionality of non-shared dims must be even!"
)
mean = _get_default_mean_function(config, zero_mean)
cov1 = _get_default_cov_function(
config, len(active_dims) // 2, stimuli_per_trial=1
Expand Down
2 changes: 1 addition & 1 deletion aepsych/generators/acqf_grid_search_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ def _gen(
_, idxs = torch.topk(acqf_vals, num_points)
new_candidate = grid[idxs]

logger.info(f"Gen done, time={time.time()-starttime}")
logger.info(f"Gen done, time={time.time() - starttime}")
return new_candidate
2 changes: 1 addition & 1 deletion aepsych/generators/acqf_thompson_sampler_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ def _gen(
)
new_candidate = grid[candidate_idx]

logger.info(f"Gen done, time={time.time()-starttime}")
logger.info(f"Gen done, time={time.time() - starttime}")
return new_candidate
6 changes: 3 additions & 3 deletions aepsych/models/semi_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ def __init__(
)

likelihood = likelihood or LinearBernoulliLikelihood()
assert isinstance(
likelihood, LinearBernoulliLikelihood
), "SemiP model only supports linear Bernoulli likelihoods!"
assert isinstance(likelihood, LinearBernoulliLikelihood), (
"SemiP model only supports linear Bernoulli likelihoods!"
)

super().__init__(
dim=dim,
Expand Down
14 changes: 7 additions & 7 deletions aepsych/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,9 @@ def plot_strat(
DeprecationWarning,
stacklevel=2,
)
assert (
"binary" in strat.outcome_types
), f"Plotting not supported for outcome_type {strat.outcome_types[0]}"
assert "binary" in strat.outcome_types, (
f"Plotting not supported for outcome_type {strat.outcome_types[0]}"
)

if target_level is not None and not hasattr(strat.model, "monotonic_idxs"):
warnings.warn(
Expand Down Expand Up @@ -873,7 +873,7 @@ def _plot_strat_1d(
alpha=0.3,
hatch="///",
edgecolor="gray",
label=f"{cred_level*100:.0f}% posterior mass",
label=f"{cred_level * 100:.0f}% posterior mass",
)
if target_level is not None:
from aepsych.utils import interpolate_monotonic
Expand All @@ -892,7 +892,7 @@ def _plot_strat_1d(
xerr=np.r_[thresh_med - thresh_lower, thresh_upper - thresh_med][:, None],
capsize=5,
elinewidth=1,
label=f"Est. {target_level*100:.0f}% threshold \n(with {cred_level*100:.0f}% posterior \nmass marked)",
label=f"Est. {target_level * 100:.0f}% threshold \n(with {cred_level * 100:.0f}% posterior \nmass marked)",
)

if true_testfun is not None:
Expand All @@ -911,7 +911,7 @@ def _plot_strat_1d(
true_thresh,
target_level,
"o",
label=f"True {target_level*100:.0f}% threshold",
label=f"True {target_level * 100:.0f}% threshold",
)

ax.scatter(
Expand Down Expand Up @@ -1031,7 +1031,7 @@ def _plot_strat_2d(
ax.plot(
context_grid,
thresh_75.cpu().numpy(),
label=f"Est. {target_level*100:.0f}% threshold \n(with {cred_level*100:.0f}% posterior \nmass shaded)",
label=f"Est. {target_level * 100:.0f}% threshold \n(with {cred_level * 100:.0f}% posterior \nmass shaded)",
)
ax.fill_between(
context_grid,
Expand Down
12 changes: 6 additions & 6 deletions aepsych/strategy/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def _make_next_strat(self) -> None:
return

# populate new model with final data from last model
assert (
self.x is not None and self.y is not None
), "Cannot initialize next strategy; no data has been given!"
assert self.x is not None and self.y is not None, (
"Cannot initialize next strategy; no data has been given!"
)
self.strat_list[self._strat_idx + 1].add_data(self.x, self.y)

self._suggest_count = 0
Expand Down Expand Up @@ -146,9 +146,9 @@ def get_config_options(
strat_names = config.getlist("common", "strategy_names", element_type=str)

# ensure strat_names are unique
assert len(strat_names) == len(
set(strat_names)
), f"Strategy names {strat_names} are not all unique!"
assert len(strat_names) == len(set(strat_names)), (
f"Strategy names {strat_names} are not all unique!"
)

strats = []
for name in strat_names:
Expand Down
58 changes: 30 additions & 28 deletions aepsych/strategy/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,21 +106,23 @@ def __init__(
)

if model is not None:
assert (
len(outcome_types) == model._num_outputs
), f"Strategy has {len(outcome_types)} outcomes, but model {type(model).__name__} supports {model._num_outputs}!"
assert (
stimuli_per_trial == model.stimuli_per_trial
), f"Strategy has {stimuli_per_trial} stimuli_per_trial, but model {type(model).__name__} supports {model.stimuli_per_trial}!"
assert len(outcome_types) == model._num_outputs, (
f"Strategy has {len(outcome_types)} outcomes, but model {type(model).__name__} supports {model._num_outputs}!"
)
assert stimuli_per_trial == model.stimuli_per_trial, (
f"Strategy has {stimuli_per_trial} stimuli_per_trial, but model {type(model).__name__} supports {model.stimuli_per_trial}!"
)

if isinstance(model.outcome_type, str):
assert (
len(outcome_types) == 1 and outcome_types[0] == model.outcome_type
), f"Strategy outcome types is {outcome_types} but model outcome type is {model.outcome_type}!"
), (
f"Strategy outcome types is {outcome_types} but model outcome type is {model.outcome_type}!"
)
else:
assert (
set(outcome_types) == set(model.outcome_type)
), f"Strategy outcome types is {outcome_types} but model outcome type is {model.outcome_type}!"
assert set(outcome_types) == set(model.outcome_type), (
f"Strategy outcome types is {outcome_types} but model outcome type is {model.outcome_type}!"
)

if use_gpu_modeling:
if not torch.cuda.is_available():
Expand Down Expand Up @@ -165,9 +167,9 @@ def __init__(
self.min_post_range = min_post_range
self.log_post_var = log_post_var
if self.min_post_range is not None or self.log_post_var:
assert (
model is not None
), "posterior range cannot be evaluated if model is None!"
assert model is not None, (
"posterior range cannot be evaluated if model is None!"
)
self.eval_grid = make_scaled_sobol(
lb=self.lb, ub=self.ub, size=self._n_eval_points
)
Expand Down Expand Up @@ -218,9 +220,9 @@ def normalize_inputs(
y (torch.Tensor): training outputs, normalized
n (int): number of observations
"""
assert (
x.shape == self.event_shape or x.shape[1:] == self.event_shape
), f"x shape should be {self.event_shape} or batch x {self.event_shape}, instead got {x.shape}"
assert x.shape == self.event_shape or x.shape[1:] == self.event_shape, (
f"x shape should be {self.event_shape} or batch x {self.event_shape}, instead got {x.shape}"
)

# Handle scalar y values
if y.ndim == 0:
Expand Down Expand Up @@ -284,9 +286,9 @@ def get_max(
Returns:
tuple[torch.Tensor, torch.Tensor]: Tuple containing the max and its location (argmax).
"""
assert (
self.model is not None
), "model is None! Cannot get the max without a model!"
assert self.model is not None, (
"model is None! Cannot get the max without a model!"
)
self.model.to(self.model_device)

val, arg = get_max(
Expand Down Expand Up @@ -316,9 +318,9 @@ def get_min(
Returns:
tuple[torch.Tensor, torch.Tensor]: Tuple containing the min and its location (argmin).
"""
assert (
self.model is not None
), "model is None! Cannot get the min without a model!"
assert self.model is not None, (
"model is None! Cannot get the min without a model!"
)
self.model.to(self.model_device)

val, arg = get_min(
Expand Down Expand Up @@ -350,9 +352,9 @@ def inv_query(
Returns:
tuple[torch.Tensor, torch.Tensor]: The input that corresponds to the given output value and the corresponding output.
"""
assert (
self.model is not None
), "model is None! Cannot get the inv_query without a model!"
assert self.model is not None, (
"model is None! Cannot get the inv_query without a model!"
)
self.model.to(self.model_device)

val, arg = inv_query(
Expand Down Expand Up @@ -435,9 +437,9 @@ def finished(self) -> bool:
sufficient_outcomes = True

if self.min_post_range is not None or self.log_post_var:
assert (
self.model is not None
), "model is None! Cannot predict without a model!"
assert self.model is not None, (
"model is None! Cannot predict without a model!"
)
fmean, fvar = self.model.predict(self.eval_grid, probability_space=True)
post_range = fmean.max() - fmean.min()
else:
Expand Down
2 changes: 1 addition & 1 deletion aepsych/strategy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def wrapper(self, *args, **kwargs):
logger.info("Starting fitting (warm start)...")
# warm start
self.update()
logger.info(f"Fitting done, took {time.time()-starttime}")
logger.info(f"Fitting done, took {time.time() - starttime}")
self._model_is_fresh = True
return f(self, *args, **kwargs)

Expand Down
6 changes: 3 additions & 3 deletions aepsych/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def _process_bounds(
dim = lb.shape[0]

for i, (lower, upper) in enumerate(zip(lb, ub)):
assert (
lower <= upper
), f"Lower bound {lower} is not less than or equal to upper bound {upper} on dimension {i}!"
assert lower <= upper, (
f"Lower bound {lower} is not less than or equal to upper bound {upper} on dimension {i}!"
)

return lb, ub, dim

Expand Down
18 changes: 9 additions & 9 deletions clients/python/aepsych_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ def configure(
with open(config_path, "r") as f:
config_str = f.read()
elif config_str is not None:
assert (
config_path is None
), "if config_str is passed, don't pass config_path"
assert config_path is None, (
"if config_str is passed, don't pass config_path"
)
request = {
"type": "setup",
"message": {"config_str": config_str},
Expand Down Expand Up @@ -221,14 +221,14 @@ def resume(
"""
if config_id is not None:
assert config_name is None, "if config_id is passed, don't pass config_name"
assert (
config_id in self.configs
), f"No strat with index {config_id} was created!"
assert config_id in self.configs, (
f"No strat with index {config_id} was created!"
)
elif config_name is not None:
assert config_id is None, "if config_name is passed, don't pass config_id"
assert (
config_name in self.config_names.keys()
), f"{config_name} not known, know {self.config_names.keys()}!"
assert config_name in self.config_names.keys(), (
f"{config_name} not known, know {self.config_names.keys()}!"
)
config_id = self.config_names[config_name]
request = {
"type": "resume",
Expand Down
6 changes: 3 additions & 3 deletions pubs/owenetal/code/stratplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,9 @@ def f(self, x):
plotting_axes = [ax[0, 1], ax[1, 0], ax[1, 1]]

titles = [
f"Monotonic RBF Model,\n BALV, after {sobol_trials+opt_trials} total trials",
f"Monotonic RBF Model,\n BALD, after {sobol_trials+opt_trials} total trials",
f"Monotonic RBF Model,\n LSE (ours) after {sobol_trials+opt_trials} total trials",
f"Monotonic RBF Model,\n BALV, after {sobol_trials + opt_trials} total trials",
f"Monotonic RBF Model,\n BALD, after {sobol_trials + opt_trials} total trials",
f"Monotonic RBF Model,\n LSE (ours) after {sobol_trials + opt_trials} total trials",
]

_ = [
Expand Down
8 changes: 4 additions & 4 deletions requirements-fmt.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by `pyfmt --requirements`
black==24.4.2
ruff-api==0.1.0
black==25.11.0
ruff-api==0.2.0
stdlibs==2024.1.28
ufmt==2.8.0
usort==1.0.8.post1
ufmt==2.9.0
usort==1.1.0
6 changes: 3 additions & 3 deletions tests/test_datafetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def print_par_data(data):
participant_id = pid3

[common]
parnames = [{', '.join(name for name in par_data)}]
parnames = [{", ".join(name for name in par_data)}]
stimuli_per_trial = {num_stim}
outcome_names = [{', '.join(name for name in outcome_names)}]
outcome_types = [{', '.join(type for type in outcome_types)}]
outcome_names = [{", ".join(name for name in outcome_names)}]
outcome_types = [{", ".join(type for type in outcome_types)}]
strategy_names = [my_strat]

{print_par_data(par_data)}
Expand Down
Loading
Loading