-
Notifications
You must be signed in to change notification settings - Fork 419
FlowDistribution #88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
FlowDistribution #88
Changes from all commits
a1b7762
2d0d84a
77078ea
d5718eb
302c2c0
6b74251
5e57f1e
8eee4ed
666f421
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
__all__ = [ | ||
'Empirical', | ||
'Implicit', | ||
'FlowDistribution' | ||
] | ||
|
||
|
||
|
@@ -132,7 +133,7 @@ def _batch_shape(self): | |
|
||
def _get_batch_shape(self): | ||
if self.samples.get_shape() == tf.TensorShape(None) or \ | ||
self.explicit_value_shape == tf.TensorShape(None): | ||
self.explicit_value_shape == tf.TensorShape(None): | ||
return tf.TensorShape(None) | ||
else: | ||
d = self.explicit_value_shape.ndims | ||
|
@@ -157,3 +158,77 @@ def _prob(self, given): | |
return (2 * prob - 1) * inf_dtype | ||
else: | ||
return tf.cast(prob, tf.float32) | ||
|
||
|
||
class FlowDistribution(Distribution): | ||
""" | ||
The class of FlowDistribution distribution. | ||
The distribution describes variable which is sampled from a base | ||
distribution and then is passed through an invertible function. | ||
See :class:`~zhusuan.distributions.base.Distribution` for details. | ||
|
||
:param name: A string. The name of the `StochasticTensor`. Must be unique | ||
in the `BayesianNet` context. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The signature is incorrect |
||
:param base: An instance of `Distribution` parametrizing the base distribution. | ||
:param forward: A forward function which describes how we transform the samples | ||
from the base distribution. The signature of the function should be: | ||
transformed, log_det = forward(base_samples) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should restrict how the forward function works. Does it apply transformation on each value (of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, I think that the function signature should be on shape |
||
:param inverse: An inverse function which maps from the transformed samples to | ||
to base samples. The signature of the function should be: | ||
base_samples, log_det = inverse(transformed_samples) | ||
:param group_ndims: A 0-D `int32` Tensor representing the number of | ||
dimensions in `batch_shape` (counted from the end) that are grouped | ||
into a single event, so that their probabilities are calculated | ||
together. Default is 0, which means a single value is an event. | ||
See :class:`~zhusuan.distributions.base.Distribution` for more detailed | ||
explanation. | ||
""" | ||
|
||
def __init__(self, | ||
base, | ||
forward, | ||
inverse=None, | ||
group_ndims=0, | ||
**kwargs): | ||
self.base = base | ||
self.forward = forward | ||
self.inverse = inverse | ||
super(FlowDistribution, self).__init__( | ||
dtype=base.dtype, | ||
param_dtype=base.dtype, | ||
is_continuous=base.dtype.is_floating, | ||
group_ndims=group_ndims, | ||
is_reparameterized=False, | ||
**kwargs) | ||
|
||
def _value_shape(self): | ||
return self.base.value_shape() | ||
|
||
def _get_value_shape(self): | ||
return self.base.get_value_shape() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As stated above, the new value shape is not always the same as the original one, e.g., in the VAE case. |
||
|
||
def _batch_shape(self): | ||
return self.base.batch_shape() | ||
|
||
def _get_batch_shape(self): | ||
return self.base.get_batch_shape() | ||
|
||
def _sample(self, n_samples): | ||
return self.sample_and_log_prob(n_samples)[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to not follow our discussion here? Also the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I've decided to leave the |
||
|
||
def _log_prob(self, given): | ||
if self.inverse is None: | ||
raise ValueError("Flow distribution can only calculate log_prob through `sample_and_log_prob` " | ||
"if `inverse=None`.") | ||
else: | ||
base_given, log_det = self.inverse(given) | ||
log_prob = self.base.log_prob(base_given) | ||
return log_prob + log_det | ||
|
||
def _prob(self, given): | ||
return tf.exp(self.log_prob(given)) | ||
|
||
def sample_and_log_prob(self, n_samples=None): | ||
base_sample, log_prob = self.base.sample_and_log_prob(n_samples) | ||
transformed, log_det = self.forward(base_sample) | ||
return transformed, log_prob - log_det |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -115,9 +115,16 @@ def tensor(self): | |
"with its observed value. Error message: {}".format( | ||
self._name, e)) | ||
else: | ||
self._tensor = self.sample(self._n_samples) | ||
self._tensor, self._local_log_prob = self.sample_and_log_prob(self._n_samples) | ||
return self._tensor | ||
|
||
@property | ||
def local_log_prob(self): | ||
tensor = self.tensor | ||
if not hasattr(self, '_local_log_prob'): | ||
self._local_log_prob = self.log_prob(tensor) | ||
return self._local_log_prob | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice job. Meanwhile can you remove the |
||
|
||
def get_shape(self): | ||
return self.tensor.get_shape() | ||
|
||
|
@@ -149,6 +156,9 @@ def prob(self, given): | |
""" | ||
return self._distribution.prob(given) | ||
|
||
def sample_and_log_prob(self, n_samples): | ||
return self._distribution.sample_and_log_prob(n_samples) | ||
|
||
@staticmethod | ||
def _to_tensor(value, dtype=None, name=None, as_ref=False): | ||
if dtype and not dtype.is_compatible_with(value.dtype): | ||
|
@@ -340,14 +350,10 @@ def local_log_prob(self, name_or_names): | |
""" | ||
name_or_names = self._check_names_exist(name_or_names) | ||
if isinstance(name_or_names, tuple): | ||
ret = [] | ||
for name in name_or_names: | ||
s_tensor = self._stochastic_tensors[name] | ||
ret.append(s_tensor.log_prob(s_tensor.tensor)) | ||
return [self._stochastic_tensors[name].local_log_prob | ||
for name in name_or_names] | ||
else: | ||
s_tensor = self._stochastic_tensors[name_or_names] | ||
ret = s_tensor.log_prob(s_tensor.tensor) | ||
return ret | ||
return self._stochastic_tensors[name_or_names].local_log_prob | ||
|
||
def query(self, name_or_names, outputs=False, local_log_prob=False): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel
NormalFlow
is a bit too restricted. And it also makes the internal Normal node invisible. I'd prefer sth. likeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, this would require some extra internal machinery or require that the input to the flow is an instance of
StochastcTensor
in order to still be able to use its sampling and log-probability methods