Skip to content

Commit

Permalink
Bugfix LUtransform.log_abs_det_jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
fariedabuzaid committed Sep 28, 2023
1 parent ef3ee95 commit 5f39f52
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/veriflow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,14 @@ def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> float:
Returns:
float: log absolute determinant of the Jacobian of the transform $(LU)x + \mathrm{bias}$
"""
# log |Det(LU)| = sum(log(|diag(U)|))
# (as L is lower triangular with all 1s on the diag, i.e. log|Det(L)| = 0, and U is upper triangular)
# However, since onnx export of diag() is currently not supported, we have use
# a reformulation. Note dU keeps the quadratic structure but replace all values
# outside the diagonal with 1. Then sum(log(|diag(U)|)) = sum(log(|U|))
U = self.U
dU = U - U.triu(1)
return dU.abs().prod().log()
dU = U - U.triu(1) + (torch.ones_like(U) - torch.eye(self.dim))
return dU.abs().log().sum()

def sign(self) -> int:
""" Computes the sign of the determinant of the Jacobian of the transform $(LU)x + \mathrm{bias}$.
Expand Down

0 comments on commit 5f39f52

Please sign in to comment.