Skip to content

Commit

Permalink
Merge pull request #10 from Emrys365/fixup_wangyou
Browse files Browse the repository at this point in the history
Fix backward issue with torch.lu_solve
  • Loading branch information
kamo-naoyuki authored Feb 25, 2022
2 parents a2db295 + ef377fa commit 2d3bfa8
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions torch_complex/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,20 @@ def matmul(
return ComplexTensor(o_real, o_imag)


def solve(b: ComplexTensor, a: ComplexTensor) -> ComplexTensor:
def solve(b: ComplexTensor, a: ComplexTensor, return_LU=False) -> ComplexTensor:
"""Solve ax = b"""
a = complex_matrix2real_matrix(a)
b = complex_vector2real_vector(b)
if LooseVersion(torch.__version__) >= LooseVersion("1.8"):
LU, pivots = torch.lu(a)
x = torch.lu_solve(b, LU, pivots)
if return_LU:
LU, pivots = torch.lu(a)
x = torch.lu_solve(b, LU, pivots)
else:
x = torch.linalg.solve(a, b)
else:
x, LU = torch.solve(b, a)
return real_vector2complex_vector(x), real_matrix2complex_matrix(LU)
if return_LU:
return real_vector2complex_vector(x), real_matrix2complex_matrix(LU)
else:
return real_vector2complex_vector(x)

0 comments on commit 2d3bfa8

Please sign in to comment.