From c6616b1a9f04c0e894b9e045e66bd504db17ef5e Mon Sep 17 00:00:00 2001 From: Vincent Stimper Date: Sun, 25 Aug 2024 11:43:13 +0200 Subject: [PATCH] doc: clarified input variables of multiscale flow (#64) --- normflows/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/normflows/core.py b/normflows/core.py index 9920193..03f9566 100644 --- a/normflows/core.py +++ b/normflows/core.py @@ -463,7 +463,7 @@ def __init__(self, q0, flows, merges, transform=None, class_cond=True): Args: q0: List of base distribution - flows: List of list of flows for each level + flows: List of flows for each level merges: List of merge/split operations (forward pass must do merge) transform: Initial transformation of inputs class_cond: Flag, indicated whether model has class conditional @@ -478,11 +478,11 @@ def __init__(self, q0, flows, merges, transform=None, class_cond=True): self.class_cond = class_cond def forward_kld(self, x, y=None): - """Estimates forward KL divergence, see see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762) + """Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762) Args: x: Batch sampled from target distribution - y: Batch of targets, if applicable + y: Batch of classes to condition on, if applicable Returns: Estimate of forward KL divergence averaged over batch @@ -494,7 +494,7 @@ def forward(self, x, y=None): Args: x: Batch of data - y: Batch of targets, if applicable + y: Batch of classes to condition on, if applicable Returns: Negative log-likelihood of the batch