From 301bce427584679cb9edf7a27aa5efbe89195bbd Mon Sep 17 00:00:00 2001 From: Wangyou Zhang Date: Fri, 22 Oct 2021 11:25:19 +0800 Subject: [PATCH] Fix a bug in matmul; Add ComplexTensor.unbind; Update ComplexTensor.squeeze --- torch_complex/functional.py | 1 + torch_complex/tensor.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/torch_complex/functional.py b/torch_complex/functional.py index d13c305..eef46e5 100644 --- a/torch_complex/functional.py +++ b/torch_complex/functional.py @@ -218,6 +218,7 @@ def matmul( if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor): return a @ b elif not isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor): + o_real = torch.matmul(a, b.real) o_imag = torch.matmul(a, b.imag) elif isinstance(a, ComplexTensor) and not isinstance(b, ComplexTensor): return a @ b diff --git a/torch_complex/tensor.py b/torch_complex/tensor.py index d2c10ed..fbdb16e 100644 --- a/torch_complex/tensor.py +++ b/torch_complex/tensor.py @@ -619,8 +619,11 @@ def ndim(self): def sqrt(self) -> "ComplexTensor": return self ** 0.5 - def squeeze(self, dim) -> "ComplexTensor": - return ComplexTensor(self.real.squeeze(dim), self.imag.squeeze(dim)) + def squeeze(self, dim=None) -> "ComplexTensor": + if dim is None: + return ComplexTensor(self.real.squeeze(), self.imag.squeeze()) + else: + return ComplexTensor(self.real.squeeze(dim), self.imag.squeeze(dim)) def sum(self, *args, **kwargs) -> "ComplexTensor": """ @@ -668,6 +671,14 @@ def type(self, *args, **kwargs) -> str: self.real.type(*args, **kwargs), self.imag.type(*args, **kwargs) ) + def unbind(self, dim=0) -> "ComplexTensor": + return tuple( + map( + lambda x: ComplexTensor(*x), + zip(self.real.unbind(dim=dim), self.imag.unbind(dim=dim)) + ) + ) + def unfold(self, dim, size, step): return ComplexTensor( self.real.unfold(dim, size, step), self.imag.unfold(dim, size, step)