1313__all__ = [
1414 'Empirical' ,
1515 'Implicit' ,
16+ 'FlowDistribution'
1617]
1718
1819
@@ -132,7 +133,7 @@ def _batch_shape(self):
132133
133134 def _get_batch_shape (self ):
134135 if self .samples .get_shape () == tf .TensorShape (None ) or \
135- self .explicit_value_shape == tf .TensorShape (None ):
136+ self .explicit_value_shape == tf .TensorShape (None ):
136137 return tf .TensorShape (None )
137138 else :
138139 d = self .explicit_value_shape .ndims
@@ -157,3 +158,77 @@ def _prob(self, given):
157158 return (2 * prob - 1 ) * inf_dtype
158159 else :
159160 return tf .cast (prob , tf .float32 )
161+
162+
163+ class FlowDistribution (Distribution ):
164+ """
165+ The class of FlowDistribution distribution.
166+ The distribution describes variable which is sampled from a base
167+ distribution and then is passed through an invertible function.
168+ See :class:`~zhusuan.distributions.base.Distribution` for details.
169+
170+ :param name: A string. The name of the `StochasticTensor`. Must be unique
171+ in the `BayesianNet` context.
172+ :param base: An instance of `Distribution` parametrizing the base distribution.
173+ :param forward: A forward function which describes how we transform the samples
174+ from the base distribution. The signature of the function should be:
175+ transformed, log_det = forward(base_samples)
176+ :param inverse: An inverse function which maps from the transformed samples to
177+ to base samples. The signature of the function should be:
178+ base_samples, log_det = inverse(transformed_samples)
179+ :param group_ndims: A 0-D `int32` Tensor representing the number of
180+ dimensions in `batch_shape` (counted from the end) that are grouped
181+ into a single event, so that their probabilities are calculated
182+ together. Default is 0, which means a single value is an event.
183+ See :class:`~zhusuan.distributions.base.Distribution` for more detailed
184+ explanation.
185+ """
186+
187+ def __init__ (self ,
188+ base ,
189+ forward ,
190+ inverse = None ,
191+ group_ndims = 0 ,
192+ ** kwargs ):
193+ self .base = base
194+ self .forward = forward
195+ self .inverse = inverse
196+ super (FlowDistribution , self ).__init__ (
197+ dtype = base .dtype ,
198+ param_dtype = base .dtype ,
199+ is_continuous = base .dtype .is_floating ,
200+ group_ndims = group_ndims ,
201+ is_reparameterized = False ,
202+ ** kwargs )
203+
204+ def _value_shape (self ):
205+ return self .base .value_shape ()
206+
207+ def _get_value_shape (self ):
208+ return self .base .get_value_shape ()
209+
210+ def _batch_shape (self ):
211+ return self .base .batch_shape ()
212+
213+ def _get_batch_shape (self ):
214+ return self .base .get_batch_shape ()
215+
216+ def _sample (self , n_samples ):
217+ return self .sample_and_log_prob (n_samples )[0 ]
218+
219+ def _log_prob (self , given ):
220+ if self .inverse is None :
221+ raise ValueError ("Flow distribution can only calculate log_prob through `sample_and_log_prob` "
222+ "if `inverse=None`." )
223+ else :
224+ base_given , log_det = self .inverse (given )
225+ log_prob = self .base .log_prob (base_given )
226+ return log_prob + log_det
227+
228+ def _prob (self , given ):
229+ return tf .exp (self .log_prob (given ))
230+
231+ def sample_and_log_prob (self , n_samples = None ):
232+ base_sample , log_prob = self .base .sample_and_log_prob (n_samples )
233+ transformed , log_det = self .forward (base_sample )
234+ return transformed , log_prob - log_det
0 commit comments