Skip to content

Commit 27d4bf8

Browse files
authored
Updating depricated torch.trtrs (#54)
1 parent 8272cdc commit 27d4bf8

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

normflows/flows/mixing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -505,9 +505,9 @@ def weight_inverse(self):
505505
"""
506506
lower, upper = self._create_lower_upper()
507507
identity = torch.eye(self.features, self.features)
508-
lower_inverse, _ = torch.trtrs(identity, lower, upper=False, unitriangular=True)
509-
weight_inverse, _ = torch.trtrs(
510-
lower_inverse, upper, upper=True, unitriangular=False
508+
lower_inverse = torch.linalg.solve_triangular(lower, identity, upper=False, unitriangular=True)
509+
weight_inverse = torch.linalg.solve_triangular(
510+
upper, lower_inverse, upper=True, unitriangular=False
511511
)
512512
return weight_inverse
513513

0 commit comments

Comments
 (0)