diff --git a/normflows/flows/mixing.py b/normflows/flows/mixing.py index d35cbd5..5822836 100644 --- a/normflows/flows/mixing.py +++ b/normflows/flows/mixing.py @@ -505,9 +505,9 @@ def weight_inverse(self): """ lower, upper = self._create_lower_upper() identity = torch.eye(self.features, self.features) - lower_inverse, _ = torch.trtrs(identity, lower, upper=False, unitriangular=True) - weight_inverse, _ = torch.trtrs( - lower_inverse, upper, upper=True, unitriangular=False + lower_inverse = torch.linalg.solve_triangular(lower, identity, upper=False, unitriangular=True) + weight_inverse = torch.linalg.solve_triangular( + upper, lower_inverse, upper=True, unitriangular=False ) return weight_inverse