diff --git a/normflows/core.py b/normflows/core.py index 9920193..03f9566 100644 --- a/normflows/core.py +++ b/normflows/core.py @@ -463,7 +463,7 @@ def __init__(self, q0, flows, merges, transform=None, class_cond=True): Args: q0: List of base distribution - flows: List of list of flows for each level + flows: List of flows for each level merges: List of merge/split operations (forward pass must do merge) transform: Initial transformation of inputs class_cond: Flag, indicated whether model has class conditional @@ -478,11 +478,11 @@ def __init__(self, q0, flows, merges, transform=None, class_cond=True): self.class_cond = class_cond def forward_kld(self, x, y=None): - """Estimates forward KL divergence, see see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762) + """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762) Args: x: Batch sampled from target distribution - y: Batch of targets, if applicable + y: Batch of classes to condition on, if applicable Returns: Estimate of forward KL divergence averaged over batch @@ -494,7 +494,7 @@ def forward(self, x, y=None): Args: x: Batch of data - y: Batch of targets, if applicable + y: Batch of classes to condition on, if applicable Returns: Negative log-likelihood of the batch