Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating deprecated torch.trtrs #54

Merged
merged 1 commit into from
Dec 4, 2023

Conversation

mattcleigh
Copy link
Contributor

@mattcleigh mattcleigh commented Nov 30, 2023

Issue

The function torch.trtrs is an alias for the now deprecated torch.triangular_solve.
https://pytorch.org/docs/stable/generated/torch.triangular_solve.html

The alias itself has been completely removed in later pytorch releases (>2.1).

normflows calls this unsupported alias which results in an attribute error when running the _LULiner layers with use_cache=True.

lower_inverse, _ = torch.trtrs(identity, lower, upper=False, unitriangular=True)

Solution

Replace the unsupported alias with the recommended function torch.linalg.solve_triangular (already used elsewhere in the project).
https://pytorch.org/docs/stable/generated/torch.linalg.solve_triangular.html

This requires swapping the order of the inputs while also removing extra (and ignored) returned copies.

@VincentStimper VincentStimper self-requested a review November 30, 2023 13:43
@VincentStimper
Copy link
Owner

Sounds good, thanks!

There is the same problem with the test that also occurred for this PR. We'll try to fix it soon so we can merge these changes.

@mattcleigh
Copy link
Contributor Author

Thanks a bunch. Let me know if you need anything more from me.

@VincentStimper VincentStimper merged commit 27d4bf8 into VincentStimper:master Dec 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants