Skip to content

Commit

Permalink
fix: setting of p_target for evaluation (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
lollcat authored Apr 21, 2023
1 parent 57c9b7b commit eb26085
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 11 deletions.
11 changes: 8 additions & 3 deletions fab/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@ def __init__(self,
"flow_alpha_2_div_unbiased", "flow_alpha_2_div_nis",
"target_forward_kl"]
if loss_type in EXPERIMENTAL_LOSSES:
warnings.warn("Running using experiment loss not used within the main FAB paper.")
raise Exception("Running using experiment loss not used within the main FAB paper.")
if loss_type in ALPHA_DIV_TARGET_LOSSES:
assert alpha is not None, "Alpha must be specified if using the alpha div loss."
self.p_target = loss_type not in ALPHA_DIV_TARGET_LOSSES
self.alpha = alpha
self.loss_type = loss_type
self.flow = flow
Expand All @@ -69,7 +68,7 @@ def __init__(self,
transition_operator=self.transition_operator,
n_intermediate_distributions=self.n_intermediate_distributions,
distribution_spacing_type=self.ais_distribution_spacing,
p_target=self.p_target,
p_target=False,
alpha=self.alpha
)

Expand Down Expand Up @@ -192,8 +191,11 @@ def get_iter_info(self) -> Dict[str, Any]:
def get_eval_info(self,
outer_batch_size: int,
inner_batch_size: int,
set_p_target: bool = True
) -> Dict[str, Any]:
if hasattr(self, "annealed_importance_sampler"):
if set_p_target:
self.set_ais_target(min_is_target=False) # Evaluate with target=p.
base_samples, base_log_w, ais_samples, ais_log_w = \
self.annealed_importance_sampler.generate_eval_data(outer_batch_size,
inner_batch_size)
Expand All @@ -206,6 +208,9 @@ def get_eval_info(self,
info.update(flow_info)
info.update(ais_info)

# Back to target = p^\alpha & q^(1-\alpha).
self.set_ais_target(min_is_target=True)

else:
raise NotImplementedError
# TODO
Expand Down
2 changes: 1 addition & 1 deletion fab/sampling_methods/transition_operators/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self,
base_log_prob: LogProbFunc,
target_log_prob: LogProbFunc,
alpha: float = None,
p_target: bool = True,
p_target: bool = False,
epsilon: float = 1.0,
n_outer: int = 1,
L: int = 5,
Expand Down
2 changes: 1 addition & 1 deletion fab/sampling_methods/transition_operators/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self,
target_log_prob: LogProbFunc,
n_updates,
alpha: float = None,
p_target: bool = None,
p_target: bool = False,
max_step_size=1.0,
min_step_size=0.1,
adjust_step_size=True,
Expand Down
18 changes: 12 additions & 6 deletions fab/train_with_prioritised_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,21 @@ def make_and_save_plots(self, i, save):
plt.close(figure)

def perform_eval(self, i, eval_batch_size, batch_size):
# set ais distribution to target for evaluation of ess
# Set ais distribution to target for evaluation of ess, freeze transition operator params.
self.model.annealed_importance_sampler.transition_operator.set_eval_mode(True)
self.model.annealed_importance_sampler.p_target = True
eval_info_true_target = self.model.get_eval_info(outer_batch_size=eval_batch_size,
inner_batch_size=batch_size)
# set ais distribution back to p^\alpha q^{1-\alpha}.
self.model.annealed_importance_sampler.p_target = False
inner_batch_size=batch_size,
set_p_target=True)
# Double check the ais distribution has been set back to p^\alpha q^{1-\alpha}.
assert self.model.annealed_importance_sampler.p_target is False
assert self.model.annealed_importance_sampler.transition_operator.p_target is False
# Evaluation with the AIS ESS with target set as p^\alpha q^{1-\alpha}.
eval_info_practical_target = self.model.get_eval_info(outer_batch_size=eval_batch_size,
inner_batch_size=batch_size)
inner_batch_size=batch_size,
set_p_target=False)
self.model.annealed_importance_sampler.transition_operator.set_eval_mode(False)


eval_info = {}
eval_info.update({key + "_p_target": val for key, val in eval_info_true_target.items()})
eval_info.update(
Expand All @@ -96,6 +101,7 @@ def perform_eval(self, i, eval_batch_size, batch_size):
self.logger.write(eval_info)



def run(self,
n_iterations: int,
batch_size: int,
Expand Down

0 comments on commit eb26085

Please sign in to comment.