Skip to content

Commit 666f421

Browse files
committed
Initial work on FlowDistribution.
1 parent 8eee4ed commit 666f421

File tree

8 files changed

+392
-160
lines changed

8 files changed

+392
-160
lines changed

examples/normalizing_flows/dlgm_nf.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import division
77
import os
88
import time
9+
from functools import partial
910

1011
import tensorflow as tf
1112
from six.moves import range
@@ -29,14 +30,17 @@ def vae(observed, n, x_dim, z_dim, n_particles):
2930
return model
3031

3132

32-
def q_net(x, z_dim, n_particles):
33+
def q_net(x, z_dim, n_particles, n_flows):
34+
def forward(samples):
35+
return zs.repeated_flow(zs.planar_normalizing_flow, samples, n_iters=n_flows)
36+
3337
with zs.BayesianNet() as variational:
3438
lz_x = tf.layers.dense(tf.to_float(x), 500, activation=tf.nn.relu)
3539
lz_x = tf.layers.dense(lz_x, 500, activation=tf.nn.relu)
3640
z_mean = tf.layers.dense(lz_x, z_dim)
3741
z_logstd = tf.layers.dense(lz_x, z_dim)
38-
z = zs.Normal('z', z_mean, logstd=z_logstd, group_ndims=1,
39-
n_samples=n_particles)
42+
z = zs.NormalFlow('z', forward, mean=z_mean, logstd=z_logstd, group_ndims=1,
43+
n_samples=n_particles)
4044
return variational
4145

4246

@@ -67,14 +71,9 @@ def log_joint(observed):
6771
log_pz, log_px_z = model.local_log_prob(['z', 'x'])
6872
return log_pz + log_px_z
6973

70-
variational = q_net(x, z_dim, n_particles)
74+
variational = q_net(x, z_dim, n_particles, n_planar_flows)
7175
qz_samples, log_qz = variational.query('z', outputs=True,
7276
local_log_prob=True)
73-
# TODO: add tests for repeated calls of flows
74-
qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples, log_qz,
75-
n_iters=n_planar_flows)
76-
qz_samples, log_qz = zs.planar_normalizing_flow(qz_samples, log_qz,
77-
n_iters=n_planar_flows)
7877

7978
lower_bound = zs.variational.elbo(log_joint,
8079
observed={'x': x},

tests/model/test_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ def test_init(self):
2121
probs = Mock()
2222
sample_func = Mock(return_value=samples)
2323
log_prob_func = Mock(return_value=log_probs)
24+
sample_and_log_prob_func = Mock(return_value=(samples, log_probs))
2425
prob_func = Mock(return_value=probs)
2526
distribution = Mock(sample=sample_func,
2627
log_prob=log_prob_func,
28+
sample_and_log_prob=sample_and_log_prob_func,
2729
prob=prob_func,
2830
dtype=tf.int32)
2931
with BayesianNet() as model:
@@ -86,9 +88,11 @@ def test_session_run(self):
8688
probs = Mock()
8789
sample_func = Mock(return_value=samples)
8890
log_prob_func = Mock(return_value=log_probs)
91+
sample_and_log_prob_func = Mock(return_value=(samples, log_probs))
8992
prob_func = Mock(return_value=probs)
9093
distribution = Mock(sample=sample_func,
9194
log_prob=log_prob_func,
95+
sample_and_log_prob=sample_and_log_prob_func,
9296
prob=prob_func,
9397
dtype=tf.int32)
9498

tests/test_transform.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,25 @@ def test_planar_normalizing_flow(self):
2020
z.append(np.array([[vz[i]]]))
2121
z[i] = tf.constant(z[i], dtype=tf.float32)
2222
z_0 = tf.concat(z, axis=1)
23-
z_1, n_log_det_ja = planar_normalizing_flow(
24-
z_0, [0.0], n_iters=10)
25-
26-
n_log_det_ja = tf.reshape(n_log_det_ja, [])
23+
z_1, n_log_det_ja = repeated_flow(planar_normalizing_flow, z_0, n_iters=10)
2724

2825
grad = []
2926
for i in range(len(vz)):
3027
z_1i = z_1[0, i]
3128
grad.append(tf.gradients(z_1i, z_0)[0])
32-
jocabian = tf.concat(grad, axis=0)
33-
log_det_jacobian = tf.log(tf.matrix_determinant(jocabian))
29+
jacobian = tf.concat(grad, axis=0)
30+
log_det_jacobian = tf.log(tf.matrix_determinant(jacobian))
3431

3532
sess.run(tf.global_variables_initializer())
36-
test_value, true_value = sess.run([-log_det_jacobian,
37-
n_log_det_ja])
33+
test_value, true_value = sess.run([log_det_jacobian,
34+
tf.squeeze(n_log_det_ja)])
3835
self.assertAllClose(test_value, true_value)
3936

4037
def test_flow_shape(self):
4138
z = tf.random_normal(shape=(2, 10, 6), mean=0, stddev=0.05)
4239
log_pz = tf.random_normal(shape=(2, 10), mean=0, stddev=0.05)
43-
t_z, t_log_pz = planar_normalizing_flow(z, log_pz, n_iters=10)
40+
t_z, log_det = repeated_flow(planar_normalizing_flow, z, n_iters=10)
41+
t_log_pz = log_pz - log_det
4442
with self.test_session(use_gpu=True) as sess:
4543
sess.run(tf.global_variables_initializer())
4644
o_z, o_log_pz = sess.run([t_z, t_log_pz])

zhusuan/distributions/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,8 @@ def _prob(self, given):
333333
Private method for subclasses to rewrite the :meth:`prob` method.
334334
"""
335335
raise NotImplementedError()
336+
337+
def sample_and_log_prob(self, n_samples=None):
338+
samples = self.sample(n_samples=n_samples)
339+
log_p = self.log_prob(samples)
340+
return samples, log_p

zhusuan/distributions/special.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
__all__ = [
1414
'Empirical',
1515
'Implicit',
16+
'FlowDistribution'
1617
]
1718

1819

@@ -132,7 +133,7 @@ def _batch_shape(self):
132133

133134
def _get_batch_shape(self):
134135
if self.samples.get_shape() == tf.TensorShape(None) or \
135-
self.explicit_value_shape == tf.TensorShape(None):
136+
self.explicit_value_shape == tf.TensorShape(None):
136137
return tf.TensorShape(None)
137138
else:
138139
d = self.explicit_value_shape.ndims
@@ -157,3 +158,77 @@ def _prob(self, given):
157158
return (2 * prob - 1) * inf_dtype
158159
else:
159160
return tf.cast(prob, tf.float32)
161+
162+
163+
class FlowDistribution(Distribution):
164+
"""
165+
The class of FlowDistribution distribution.
166+
The distribution describes variable which is sampled from a base
167+
distribution and then is passed through an invertible function.
168+
See :class:`~zhusuan.distributions.base.Distribution` for details.
169+
170+
:param name: A string. The name of the `StochasticTensor`. Must be unique
171+
in the `BayesianNet` context.
172+
:param base: An instance of `Distribution` parametrizing the base distribution.
173+
:param forward: A forward function which describes how we transform the samples
174+
from the base distribution. The signature of the function should be:
175+
transformed, log_det = forward(base_samples)
176+
:param inverse: An inverse function which maps from the transformed samples to
177+
to base samples. The signature of the function should be:
178+
base_samples, log_det = inverse(transformed_samples)
179+
:param group_ndims: A 0-D `int32` Tensor representing the number of
180+
dimensions in `batch_shape` (counted from the end) that are grouped
181+
into a single event, so that their probabilities are calculated
182+
together. Default is 0, which means a single value is an event.
183+
See :class:`~zhusuan.distributions.base.Distribution` for more detailed
184+
explanation.
185+
"""
186+
187+
def __init__(self,
188+
base,
189+
forward,
190+
inverse=None,
191+
group_ndims=0,
192+
**kwargs):
193+
self.base = base
194+
self.forward = forward
195+
self.inverse = inverse
196+
super(FlowDistribution, self).__init__(
197+
dtype=base.dtype,
198+
param_dtype=base.dtype,
199+
is_continuous=base.dtype.is_floating,
200+
group_ndims=group_ndims,
201+
is_reparameterized=False,
202+
**kwargs)
203+
204+
def _value_shape(self):
205+
return self.base.value_shape()
206+
207+
def _get_value_shape(self):
208+
return self.base.get_value_shape()
209+
210+
def _batch_shape(self):
211+
return self.base.batch_shape()
212+
213+
def _get_batch_shape(self):
214+
return self.base.get_batch_shape()
215+
216+
def _sample(self, n_samples):
217+
return self.sample_and_log_prob(n_samples)[0]
218+
219+
def _log_prob(self, given):
220+
if self.inverse is None:
221+
raise ValueError("Flow distribution can only calculate log_prob through `sample_and_log_prob` "
222+
"if `inverse=None`.")
223+
else:
224+
base_given, log_det = self.inverse(given)
225+
log_prob = self.base.log_prob(base_given)
226+
return log_prob + log_det
227+
228+
def _prob(self, given):
229+
return tf.exp(self.log_prob(given))
230+
231+
def sample_and_log_prob(self, n_samples=None):
232+
base_sample, log_prob = self.base.sample_and_log_prob(n_samples)
233+
transformed, log_det = self.forward(base_sample)
234+
return transformed, log_prob - log_det

zhusuan/model/base.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,16 @@ def tensor(self):
115115
"with its observed value. Error message: {}".format(
116116
self._name, e))
117117
else:
118-
self._tensor = self.sample(self._n_samples)
118+
self._tensor, self._local_log_prob = self.sample_and_log_prob(self._n_samples)
119119
return self._tensor
120120

121+
@property
122+
def local_log_prob(self):
123+
tensor = self.tensor
124+
if not hasattr(self, '_local_log_prob'):
125+
self._local_log_prob = self.log_prob(tensor)
126+
return self._local_log_prob
127+
121128
def get_shape(self):
122129
return self.tensor.get_shape()
123130

@@ -149,6 +156,9 @@ def prob(self, given):
149156
"""
150157
return self._distribution.prob(given)
151158

159+
def sample_and_log_prob(self, n_samples):
160+
return self._distribution.sample_and_log_prob(n_samples)
161+
152162
@staticmethod
153163
def _to_tensor(value, dtype=None, name=None, as_ref=False):
154164
if dtype and not dtype.is_compatible_with(value.dtype):
@@ -340,14 +350,10 @@ def local_log_prob(self, name_or_names):
340350
"""
341351
name_or_names = self._check_names_exist(name_or_names)
342352
if isinstance(name_or_names, tuple):
343-
ret = []
344-
for name in name_or_names:
345-
s_tensor = self._stochastic_tensors[name]
346-
ret.append(s_tensor.log_prob(s_tensor.tensor))
353+
return [self._stochastic_tensors[name].local_log_prob
354+
for name in name_or_names]
347355
else:
348-
s_tensor = self._stochastic_tensors[name_or_names]
349-
ret = s_tensor.log_prob(s_tensor.tensor)
350-
return ret
356+
return self._stochastic_tensors[name_or_names].local_log_prob
351357

352358
def query(self, name_or_names, outputs=False, local_log_prob=False):
353359
"""

zhusuan/model/stochastic.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
'GumbelSoftmax',
3838
'Empirical',
3939
'Implicit',
40+
'NormalFlow'
4041
]
4142

4243

@@ -983,14 +984,14 @@ def __init__(self,
983984
is_continuous=None,
984985
n_samples=None,
985986
**kwargs):
986-
norm = distributions.Empirical(
987+
empirical = distributions.Empirical(
987988
dtype, batch_shape,
988989
value_shape=value_shape,
989990
group_ndims=group_ndims,
990991
is_continous=is_continuous,
991992
**kwargs
992993
)
993-
super(Empirical, self).__init__(name, norm, n_samples)
994+
super(Empirical, self).__init__(name, empirical, n_samples)
994995

995996

996997
class Implicit(StochasticTensor):
@@ -1021,10 +1022,70 @@ def __init__(self,
10211022
group_ndims=0,
10221023
n_samples=None,
10231024
**kwargs):
1024-
norm = distributions.Implicit(
1025+
implicit = distributions.Implicit(
10251026
samples,
10261027
value_shape=value_shape,
10271028
group_ndims=group_ndims,
10281029
**kwargs
10291030
)
1030-
super(Implicit, self).__init__(name, norm, n_samples)
1031+
super(Implicit, self).__init__(name, implicit, n_samples)
1032+
1033+
1034+
class NormalFlow(StochasticTensor):
1035+
"""
1036+
The class of univariate Normal `StochasticTensor` with a invertible flow
1037+
transformation.
1038+
See :class:`~zhusuan.model.stochastic.Normal` and
1039+
:class:`~zhusuan.distributions.special.FlowDistribution` for details.
1040+
1041+
:param name: A string. The name of the `StochasticTensor`. Must be unique
1042+
in the `BayesianNet` context.
1043+
:param forward: A forward function which describes how we transform the samples
1044+
from the base distribution. The signature of the function should be:
1045+
transformed, log_det = forward(base_samples)
1046+
:param inverse: An inverse function which maps from the transformed samples to
1047+
to base samples. The signature of the function should be:
1048+
base_samples, log_det = inverse(transformed_samples)
1049+
:param mean: A `float` Tensor. The mean of the Normal distribution.
1050+
Should be broadcastable to match `logstd`.
1051+
:param logstd: A `float` Tensor. The log standard deviation of the Normal
1052+
distribution. Should be broadcastable to match `mean`.
1053+
:param std: A `float` Tensor. The standard deviation of the Normal
1054+
distribution. Should be positive and broadcastable to match `mean`.
1055+
:param n_samples: A 0-D `int32` Tensor or None. Number of samples
1056+
generated by this `StochasticTensor`.
1057+
:param group_ndims: A 0-D `int32` Tensor representing the number of
1058+
dimensions in `batch_shape` (counted from the end) that are grouped
1059+
into a single event, so that their probabilities are calculated
1060+
together. Default is 0, which means a single value is an event.
1061+
See :class:`~zhusuan.distributions.base.Distribution` for more detailed
1062+
explanation.
1063+
:param is_reparameterized: A Bool. If True, gradients on samples from this
1064+
`StochasticTensor` are allowed to propagate into inputs, using the
1065+
reparametrization trick from (Kingma, 2013).
1066+
:param check_numerics: Bool. Whether to check numeric issues.
1067+
"""
1068+
1069+
def __init__(self,
1070+
name,
1071+
forward,
1072+
inverse=None,
1073+
mean=0.,
1074+
logstd=None,
1075+
std=None,
1076+
n_samples=None,
1077+
group_ndims=0,
1078+
is_reparameterized=True,
1079+
check_numerics=False,
1080+
**kwargs):
1081+
normal = distributions.Normal(
1082+
mean,
1083+
logstd=logstd,
1084+
std=std,
1085+
group_ndims=group_ndims,
1086+
is_reparameterized=is_reparameterized,
1087+
check_numerics=check_numerics,
1088+
**kwargs
1089+
)
1090+
flow = distributions.FlowDistribution(normal, forward, inverse, group_ndims=group_ndims)
1091+
super(NormalFlow, self).__init__(name, flow, n_samples)

0 commit comments

Comments
 (0)