Skip to content

Commit eb26085

Browse files
authored
fix: setting of p_target for evaluation (#72)
1 parent 57c9b7b commit eb26085

File tree

4 files changed

+22
-11
lines changed

4 files changed

+22
-11
lines changed

fab/core.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,9 @@ def __init__(self,
4848
"flow_alpha_2_div_unbiased", "flow_alpha_2_div_nis",
4949
"target_forward_kl"]
5050
if loss_type in EXPERIMENTAL_LOSSES:
51-
warnings.warn("Running using experiment loss not used within the main FAB paper.")
51+
raise Exception("Running using experiment loss not used within the main FAB paper.")
5252
if loss_type in ALPHA_DIV_TARGET_LOSSES:
5353
assert alpha is not None, "Alpha must be specified if using the alpha div loss."
54-
self.p_target = loss_type not in ALPHA_DIV_TARGET_LOSSES
5554
self.alpha = alpha
5655
self.loss_type = loss_type
5756
self.flow = flow
@@ -69,7 +68,7 @@ def __init__(self,
6968
transition_operator=self.transition_operator,
7069
n_intermediate_distributions=self.n_intermediate_distributions,
7170
distribution_spacing_type=self.ais_distribution_spacing,
72-
p_target=self.p_target,
71+
p_target=False,
7372
alpha=self.alpha
7473
)
7574

@@ -192,8 +191,11 @@ def get_iter_info(self) -> Dict[str, Any]:
192191
def get_eval_info(self,
193192
outer_batch_size: int,
194193
inner_batch_size: int,
194+
set_p_target: bool = True
195195
) -> Dict[str, Any]:
196196
if hasattr(self, "annealed_importance_sampler"):
197+
if set_p_target:
198+
self.set_ais_target(min_is_target=False) # Evaluate with target=p.
197199
base_samples, base_log_w, ais_samples, ais_log_w = \
198200
self.annealed_importance_sampler.generate_eval_data(outer_batch_size,
199201
inner_batch_size)
@@ -206,6 +208,9 @@ def get_eval_info(self,
206208
info.update(flow_info)
207209
info.update(ais_info)
208210

211+
# Back to target = p^\alpha & q^(1-\alpha).
212+
self.set_ais_target(min_is_target=True)
213+
209214
else:
210215
raise NotImplementedError
211216
# TODO

fab/sampling_methods/transition_operators/hmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(self,
1212
base_log_prob: LogProbFunc,
1313
target_log_prob: LogProbFunc,
1414
alpha: float = None,
15-
p_target: bool = True,
15+
p_target: bool = False,
1616
epsilon: float = 1.0,
1717
n_outer: int = 1,
1818
L: int = 5,

fab/sampling_methods/transition_operators/metropolis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self,
1414
target_log_prob: LogProbFunc,
1515
n_updates,
1616
alpha: float = None,
17-
p_target: bool = None,
17+
p_target: bool = False,
1818
max_step_size=1.0,
1919
min_step_size=0.1,
2020
adjust_step_size=True,

fab/train_with_prioritised_buffer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,21 @@ def make_and_save_plots(self, i, save):
7777
plt.close(figure)
7878

7979
def perform_eval(self, i, eval_batch_size, batch_size):
80-
# set ais distribution to target for evaluation of ess
80+
# Set ais distribution to target for evaluation of ess, freeze transition operator params.
8181
self.model.annealed_importance_sampler.transition_operator.set_eval_mode(True)
82-
self.model.annealed_importance_sampler.p_target = True
8382
eval_info_true_target = self.model.get_eval_info(outer_batch_size=eval_batch_size,
84-
inner_batch_size=batch_size)
85-
# set ais distribution back to p^\alpha q^{1-\alpha}.
86-
self.model.annealed_importance_sampler.p_target = False
83+
inner_batch_size=batch_size,
84+
set_p_target=True)
85+
# Double check the ais distribution has been set back to p^\alpha q^{1-\alpha}.
86+
assert self.model.annealed_importance_sampler.p_target is False
87+
assert self.model.annealed_importance_sampler.transition_operator.p_target is False
88+
# Evaluation with the AIS ESS with target set as p^\alpha q^{1-\alpha}.
8789
eval_info_practical_target = self.model.get_eval_info(outer_batch_size=eval_batch_size,
88-
inner_batch_size=batch_size)
90+
inner_batch_size=batch_size,
91+
set_p_target=False)
8992
self.model.annealed_importance_sampler.transition_operator.set_eval_mode(False)
93+
94+
9095
eval_info = {}
9196
eval_info.update({key + "_p_target": val for key, val in eval_info_true_target.items()})
9297
eval_info.update(
@@ -96,6 +101,7 @@ def perform_eval(self, i, eval_batch_size, batch_size):
96101
self.logger.write(eval_info)
97102

98103

104+
99105
def run(self,
100106
n_iterations: int,
101107
batch_size: int,

0 commit comments

Comments
 (0)