Skip to content

Commit 8eb7860

Browse files
committed
doc string for variational and importance
1 parent 7d128e3 commit 8eb7860

File tree

7 files changed

+148
-97
lines changed

7 files changed

+148
-97
lines changed

tests/framework/test_base.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,10 @@ def build_meta_bn():
206206
self.assertNear(log_pc_2_out, log_pc_t_out, 1e-6)
207207

208208

209-
class TestReuse(tf.test.TestCase):
209+
class TestReuseVariables(tf.test.TestCase):
210210

211-
def test_legacy_reuse(self):
212-
@reuse("test")
211+
def test_reuse_variables(self):
212+
@reuse_variables("test")
213213
def f():
214214
w = tf.get_variable("w", shape=[])
215215
return w
@@ -233,7 +233,7 @@ def test_meta_bn(self):
233233
# the basic usage is tested in TestBayesianNet. corner cases here
234234
@meta_bayesian_net(scope='scp', reuse_variables=False)
235235
def build_mbn(var_to_return):
236-
return TestReuse._generate_bn(var_to_return)
236+
return TestReuseVariables._generate_bn(var_to_return)
237237

238238
with tf.variable_scope('you_might_want_do_this'):
239239
mbn = build_mbn('a_mean')
@@ -243,21 +243,21 @@ def build_mbn(var_to_return):
243243
self.assertNotEqual(m1.name, m2.name)
244244
with tf.variable_scope('when_you_are_perfectly_conscious'):
245245
_, m2 = build_mbn('a_mean').observe()
246-
self.assertNotEquals(m1.name, m2.name)
246+
self.assertNotEqual(m1.name, m2.name)
247247

248248
@meta_bayesian_net(scope='scp', reuse_variables=True)
249249
def build_mbn(var_to_return):
250-
return TestReuse._generate_bn(var_to_return)
250+
return TestReuseVariables._generate_bn(var_to_return)
251251

252252
meta_bn = build_mbn('a_mean')
253253
_, m1 = meta_bn.observe()
254254
_, m2 = meta_bn.observe()
255255
_, m3 = build_mbn('a_mean').observe()
256-
self.assertEquals(m1.name, m2.name)
256+
self.assertEqual(m1.name, m2.name)
257257
self.assertNotEqual(m1.name, m3.name)
258258

259259
with self.assertRaisesRegexp(ValueError, 'Cannot reuse'):
260260
@meta_bayesian_net(reuse_variables=True)
261261
def mbn(var_to_return):
262-
return TestReuse._generate_bn(var_to_return)
262+
return TestReuseVariables._generate_bn(var_to_return)
263263
mbn('a_mean')

tests/variational/test_inclusive_kl.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -41,35 +41,35 @@ def log_joint(observed):
4141
with self.assertRaisesRegexp(NotImplementedError, err_msg):
4242
sess.run(lower_bound)
4343

44-
def test_rws(self):
44+
def test_importance(self):
4545
eps_samples = tf.convert_to_tensor(self._n01_samples)
4646
mu = tf.constant(2.)
4747
sigma = tf.constant(3.)
4848
qx_samples = tf.stop_gradient(eps_samples * sigma + mu)
4949
q = Normal(mean=mu, std=sigma)
5050
log_qx = q.log_prob(qx_samples)
5151

52-
def _check_rws(x_mean, x_std, threshold):
52+
def _check_importance(x_mean, x_std, threshold):
5353
def log_joint(observed):
5454
p = Normal(mean=x_mean, std=x_std)
5555
return p.log_prob(observed['x'])
5656

5757
klpq_obj = klpq(log_joint, observed={},
5858
latent={'x': [qx_samples, log_qx]}, axis=0)
59-
cost = klpq_obj.rws()
60-
rws_grads = tf.gradients(cost, [mu, sigma])
59+
cost = klpq_obj.importance()
60+
importance_grads = tf.gradients(cost, [mu, sigma])
6161
true_cost = _kl_normal_normal(x_mean, x_std, mu, sigma)
6262
true_grads = tf.gradients(true_cost, [mu, sigma])
6363

6464
with self.session(use_gpu=True) as sess:
65-
g1 = sess.run(rws_grads)
65+
g1 = sess.run(importance_grads)
6666
g2 = sess.run(true_grads)
67-
# print('rws_grads:', g1)
67+
# print('importance_grads:', g1)
6868
# print('true_grads:', g2)
6969
self.assertAllClose(g1, g2, threshold, threshold)
7070

71-
_check_rws(0., 1., 0.01)
72-
_check_rws(2., 3., 0.02)
71+
_check_importance(0., 1., 0.01)
72+
_check_importance(2., 3., 0.02)
7373

7474
single_sample = tf.stop_gradient(tf.random_normal([]) * sigma + mu)
7575
single_log_q = q.log_prob(single_sample)
@@ -86,7 +86,7 @@ def log_joint(observed):
8686
# Cause all warnings to always be triggered.
8787
warnings.simplefilter("always")
8888
# Trigger a warning.
89-
single_sample_obj.rws()
89+
single_sample_obj.importance()
9090
self.assertTrue(issubclass(w[-1].category, UserWarning))
9191
self.assertTrue("biased and inaccurate when you're using only "
9292
"a single sample" in str(w[-1].message))

zhusuan/evaluation.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import tensorflow as tf
1010
import numpy as np
1111

12-
from zhusuan.utils import log_mean_exp, merge_dicts
12+
from zhusuan.utils import merge_dicts
1313
from zhusuan.variational import ImportanceWeightedObjective
1414

1515

@@ -20,24 +20,29 @@
2020

2121

2222
def is_loglikelihood(meta_bn, observed, latent=None, axis=None,
23-
proposal=None, allow_default=False):
23+
proposal=None):
2424
"""
2525
Marginal log likelihood (:math:`\log p(x)`) estimates using self-normalized
2626
importance sampling.
2727
28-
:param log_joint: A function that accepts a dictionary argument of
28+
:param meta_bn: A :class:`~zhusuan.framework.meta_bn.MetaBayesianNet`
29+
instance or a log joint probability function.
30+
For the latter, it must accepts a dictionary argument of
2931
``(string, Tensor)`` pairs, which are mappings from all
30-
`StochasticTensor` names in the model to their observed values. The
32+
node names in the model to their observed values. The
3133
function should return a Tensor, representing the log joint likelihood
3234
of the model.
3335
:param observed: A dictionary of ``(string, Tensor)`` pairs. Mapping from
34-
names of observed `StochasticTensor` s to their values
36+
names of observed stochastic nodes to their values.
3537
:param latent: A dictionary of ``(string, (Tensor, Tensor))`` pairs.
36-
Mapping from names of latent `StochasticTensor` s to their samples and
37-
log probabilities.
38+
Mapping from names of latent stochastic nodes to their samples and
39+
log probabilities. `latent` and `proposal` are mutually exclusive.
3840
:param axis: The sample dimension(s) to reduce when computing the
39-
outer expectation in the importance sampling estimation. If None, no
40-
dimension is reduced.
41+
outer expectation in the objective. If ``None``, no dimension is
42+
reduced.
43+
:param proposal: A :class:`~zhusuan.framework.bn.BayesianNet` instance
44+
that defines the proposal distributions of latent nodes.
45+
`proposal` and `latent` are mutually exclusive.
4146
4247
:return: A Tensor. The estimated log likelihood of observed data.
4348
"""
@@ -46,8 +51,7 @@ def is_loglikelihood(meta_bn, observed, latent=None, axis=None,
4651
observed,
4752
latent=latent,
4853
axis=axis,
49-
variational=proposal,
50-
allow_default=allow_default).tensor
54+
variational=proposal).tensor
5155

5256

5357
class AIS:

zhusuan/framework/utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import print_function
66
from __future__ import division
77
from collections import deque, OrderedDict
8+
import warnings
89

910
import tensorflow as tf
1011

@@ -109,5 +110,8 @@ def reuse(scope):
109110
"""
110111
(Deprecated) Alias of :func:`reuse_variables`.
111112
"""
112-
# TODO: raise warning
113+
warnings.warn(
114+
"The `reuse()` function has been renamed to `reuse_variables()`, "
115+
"`reuse()` will be removed in the coming version (0.4.1)",
116+
DeprecationWarning)
113117
return reuse_variables(scope)

zhusuan/variational/exclusive_kl.py

+35-24
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,14 @@ class EvidenceLowerBoundObjective(VariationalObjective):
2424
calling :func:`elbo`::
2525
2626
# lower_bound is an EvidenceLowerBoundObjective instance
27-
lower_bound = zs.variational.elbo(log_joint, observed, latent)
27+
lower_bound = zs.variational.elbo(
28+
meta_bn, observed, variational=variational, axis=0)
29+
30+
Here ``meta_bn`` is a :class:`~zhusuan.framework.meta_bn.MetaBayesianNet`
31+
instance representing the model to be inferred. ``variational`` is
32+
a :class:`~zhusuan.framework.bn.BayesianNet` instance that defines the
33+
variational family. ``axis`` is the index of the sample dimension used
34+
to estimate the expectation when computing the objective.
2835
2936
Instances of :class:`EvidenceLowerBoundObjective` are Tensor-like. They
3037
can be automatically or manually cast into Tensors when fed into Tensorflow
@@ -60,8 +67,7 @@ class EvidenceLowerBoundObjective(VariationalObjective):
6067
6168
# optimize the surrogate cost wrt. variational parameters
6269
optimizer = tf.train.AdamOptimizer(learning_rate)
63-
infer_op = optimizer.minimize(cost,
64-
var_list=variational_parameters)
70+
infer_op = optimizer.minimize(cost, var_list=variational_parameters)
6571
with tf.Session() as sess:
6672
for _ in range(n_iters):
6773
_, lb = sess.run([infer_op, lower_bound], feed_dict=...)
@@ -81,11 +87,9 @@ class EvidenceLowerBoundObjective(VariationalObjective):
8187
optimize the class instance::
8288
8389
# optimize wrt. model parameters
84-
learn_op = optimizer.minimize(-lower_bound,
85-
var_list=model_parameters)
90+
learn_op = optimizer.minimize(-lower_bound, var_list=model_parameters)
8691
# or
87-
# learn_op = optimizer.minimize(cost,
88-
# var_list=model_parameters)
92+
# learn_op = optimizer.minimize(cost, var_list=model_parameters)
8993
# both ways are correct
9094
9195
Or we can do inference and learning jointly by optimize over both
@@ -95,30 +99,34 @@ class EvidenceLowerBoundObjective(VariationalObjective):
9599
infer_and_learn_op = optimizer.minimize(
96100
cost, var_list=model_and_variational_parameters)
97101
98-
:param log_joint: A function that accepts a dictionary argument of
102+
:param meta_bn: A :class:`~zhusuan.framework.meta_bn.MetaBayesianNet`
103+
instance or a log joint probability function.
104+
For the latter, it must accepts a dictionary argument of
99105
``(string, Tensor)`` pairs, which are mappings from all
100-
`StochasticTensor` names in the model to their observed values. The
106+
node names in the model to their observed values. The
101107
function should return a Tensor, representing the log joint likelihood
102108
of the model.
103109
:param observed: A dictionary of ``(string, Tensor)`` pairs. Mapping from
104-
names of observed `StochasticTensor` s to their values.
110+
names of observed stochastic nodes to their values.
105111
:param latent: A dictionary of ``(string, (Tensor, Tensor))`` pairs.
106-
Mapping from names of latent `StochasticTensor` s to their samples and
107-
log probabilities.
112+
Mapping from names of latent stochastic nodes to their samples and
113+
log probabilities. `latent` and `variational` are mutually exclusive.
108114
:param axis: The sample dimension(s) to reduce when computing the
109115
outer expectation in the objective. If ``None``, no dimension is
110116
reduced.
117+
:param variational: A :class:`~zhusuan.framework.bn.BayesianNet` instance
118+
that defines the variational family.
119+
`variational` and `latent` are mutually exclusive.
111120
"""
112121

113122
def __init__(self, meta_bn, observed, latent=None, axis=None,
114-
variational=None, allow_default=False):
123+
variational=None):
115124
self._axis = axis
116125
super(EvidenceLowerBoundObjective, self).__init__(
117126
meta_bn,
118127
observed,
119128
latent=latent,
120-
variational=variational,
121-
allow_default=allow_default)
129+
variational=variational)
122130

123131
def _objective(self):
124132
lower_bound = self._log_joint_term()
@@ -223,27 +231,31 @@ def reinforce(self,
223231
return cost
224232

225233

226-
def elbo(meta_bn, observed, latent=None, axis=None, variational=None,
227-
allow_default=False):
234+
def elbo(meta_bn, observed, latent=None, axis=None, variational=None):
228235
"""
229236
The evidence lower bound (ELBO) objective for variational inference. The
230237
returned value is a :class:`EvidenceLowerBoundObjective` instance.
231238
232239
See :class:`EvidenceLowerBoundObjective` for examples of usage.
233240
234-
:param log_joint: A function that accepts a dictionary argument of
241+
:param meta_bn: A :class:`~zhusuan.framework.meta_bn.MetaBayesianNet`
242+
instance or a log joint probability function.
243+
For the latter, it must accepts a dictionary argument of
235244
``(string, Tensor)`` pairs, which are mappings from all
236-
`StochasticTensor` names in the model to their observed values. The
245+
node names in the model to their observed values. The
237246
function should return a Tensor, representing the log joint likelihood
238247
of the model.
239248
:param observed: A dictionary of ``(string, Tensor)`` pairs. Mapping from
240-
names of observed `StochasticTensor` s to their values.
249+
names of observed stochastic nodes to their values.
241250
:param latent: A dictionary of ``(string, (Tensor, Tensor))`` pairs.
242-
Mapping from names of latent `StochasticTensor` s to their samples and
243-
log probabilities.
251+
Mapping from names of latent stochastic nodes to their samples and
252+
log probabilities. `latent` and `variational` are mutually exclusive.
244253
:param axis: The sample dimension(s) to reduce when computing the
245254
outer expectation in the objective. If ``None``, no dimension is
246255
reduced.
256+
:param variational: A :class:`~zhusuan.framework.bn.BayesianNet` instance
257+
that defines the variational family.
258+
`variational` and `latent` are mutually exclusive.
247259
248260
:return: An :class:`EvidenceLowerBoundObjective` instance.
249261
"""
@@ -252,5 +264,4 @@ def elbo(meta_bn, observed, latent=None, axis=None, variational=None,
252264
observed,
253265
latent=latent,
254266
axis=axis,
255-
variational=variational,
256-
allow_default=allow_default)
267+
variational=variational)

0 commit comments

Comments
 (0)