Skip to content

Commit

Permalink
Fix a bug in matmul; Add ComplexTensor.unbind; Update ComplexTensor.s…
Browse files Browse the repository at this point in the history
…queeze
  • Loading branch information
Emrys365 committed Oct 22, 2021
1 parent 83dc599 commit 301bce4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
1 change: 1 addition & 0 deletions torch_complex/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def matmul(
if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor):
return a @ b
elif not isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor):
o_real = torch.matmul(a, b.real)
o_imag = torch.matmul(a, b.imag)
elif isinstance(a, ComplexTensor) and not isinstance(b, ComplexTensor):
return a @ b
Expand Down
15 changes: 13 additions & 2 deletions torch_complex/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,8 +619,11 @@ def ndim(self):
def sqrt(self) -> "ComplexTensor":
return self ** 0.5

def squeeze(self, dim) -> "ComplexTensor":
return ComplexTensor(self.real.squeeze(dim), self.imag.squeeze(dim))
def squeeze(self, dim=None) -> "ComplexTensor":
if dim is None:
return ComplexTensor(self.real.squeeze(), self.imag.squeeze())
else:
return ComplexTensor(self.real.squeeze(dim), self.imag.squeeze(dim))

def sum(self, *args, **kwargs) -> "ComplexTensor":
"""
Expand Down Expand Up @@ -668,6 +671,14 @@ def type(self, *args, **kwargs) -> str:
self.real.type(*args, **kwargs), self.imag.type(*args, **kwargs)
)

def unbind(self, dim=0) -> "ComplexTensor":
return tuple(
map(
lambda x: ComplexTensor(*x),
zip(self.real.unbind(dim=dim), self.imag.unbind(dim=dim))
)
)

def unfold(self, dim, size, step):
return ComplexTensor(
self.real.unfold(dim, size, step), self.imag.unfold(dim, size, step)
Expand Down

0 comments on commit 301bce4

Please sign in to comment.