Skip to content

Commit

Permalink
switch back to LU decomposition of inverse transform
Browse files Browse the repository at this point in the history
  • Loading branch information
fariedabuzaid committed Sep 19, 2023
1 parent 39929f4 commit bf18e9f
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions veriflow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> float:
return self.scale.abs().log().sum()

def sign(self) -> int:
return self.prod().sign()
return 1 if (self.scale < 0).int().sum() % 2 == 0 else -1

def is_feasible(self) -> bool:
"""Checks if the layer is feasible, i.e. if the diagonal elements of $\mathbf{U}$ are all positive"""
Expand Down Expand Up @@ -198,7 +198,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
:type x: torch.Tensor
:return: transformed tensor $(LU)x + \mathrm{bias}$
"""
return F.linear(y, self.weight, self.bias)
return F.linear(x, self.weight, self.bias)


def backward(self, y: torch.Tensor) -> torch.Tensor:
Expand All @@ -207,8 +207,7 @@ def backward(self, y: torch.Tensor) -> torch.Tensor:
:param y: input tensor
:type y: torch.Tensor
:return: transformed tensor $(LU)^{-1}(y - \mathrm{bias})$"""
M_inv = LA.inv(self.weight)
return torch.functional.F.linear(x - self.bias, M_inv)
return torch.functional.F.linear(y - self.bias, self.inv_weight)

@property
def L(self):
Expand All @@ -220,10 +219,16 @@ def U(self):
"""The upper triangular matrix $\mathbf{U}$ of the layers LU decomposition"""
return self.U_raw.triu()

@property
def inv_weight(self):
"""Inverse weight matrix of the affine transform"""
return LA.matmul(self.L, self.U)

@property
def weight(self):
"""Weight matrix of the affine transform"""
return LA.matmul(self.L, self.U)
return LA.inv(LA.matmul(self.L, self.U))


def _call(self, x: torch.Tensor) -> torch.Tensor:
return self.forward(x)
Expand Down

0 comments on commit bf18e9f

Please sign in to comment.