From bf18e9f7e39996f15463059a10c0bd400d8cc4b2 Mon Sep 17 00:00:00 2001 From: Faried Abu Zaid Date: Tue, 19 Sep 2023 23:06:45 +0200 Subject: [PATCH] switch back to LU decomposition of inverse transform --- veriflow/transforms.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/veriflow/transforms.py b/veriflow/transforms.py index 40a6a49..f438efa 100644 --- a/veriflow/transforms.py +++ b/veriflow/transforms.py @@ -60,7 +60,7 @@ def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> float: return self.scale.abs().log().sum() def sign(self) -> int: - return self.prod().sign() + return 1 if (self.scale < 0).int().sum() % 2 == 0 else -1 def is_feasible(self) -> bool: """Checks if the layer is feasible, i.e. if the diagonal elements of $\mathbf{U}$ are all positive""" @@ -198,7 +198,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: :type x: torch.Tensor :return: transformed tensor $(LU)x + \mathrm{bias}$ """ - return F.linear(y, self.weight, self.bias) + return F.linear(x, self.weight, self.bias) def backward(self, y: torch.Tensor) -> torch.Tensor: @@ -207,8 +207,7 @@ def backward(self, y: torch.Tensor) -> torch.Tensor: :param y: input tensor :type y: torch.Tensor :return: transformed tensor $(LU)^{-1}(y - \mathrm{bias})$""" - M_inv = LA.inv(self.weight) - return torch.functional.F.linear(x - self.bias, M_inv) + return torch.functional.F.linear(y - self.bias, self.inv_weight) @property def L(self): @@ -220,10 +219,16 @@ def U(self): """The upper triangular matrix $\mathbf{U}$ of the layers LU decomposition""" return self.U_raw.triu() + @property + def inv_weight(self): + """Inverse weight matrix of the affine transform""" + return LA.matmul(self.L, self.U) + @property def weight(self): """Weight matrix of the affine transform""" - return LA.matmul(self.L, self.U) + return LA.inv(LA.matmul(self.L, self.U)) + def _call(self, x: torch.Tensor) -> torch.Tensor: return self.forward(x)