-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathutils_math.py
110 lines (103 loc) · 5.98 KB
/
utils_math.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
@author: Hamed Hojatian
Complex aljebra for pytorch in case that pytorch
version doesn't support complex tensor (V 1.8 and
later support complex tensor and aljebra)
"""
import torch
from numpy import genfromtxt
import numpy as np
def Th_comp_matmul(Ar, Ai, Br, Bi): # Complex matmul pytorch function ########
if Ar.ndim == 3 and Br.ndim == 3:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=2), torch.cat((Ai, Ar), dim=2)), dim=1)
b_th = torch.cat((torch.cat((Br, -Bi), dim=2), torch.cat((Bi, Br), dim=2)), dim=1)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, 0:int(c_th.shape[1] / 2), 0:int(c_th.shape[2] / 2)]
c_th_i = c_th[:, int(c_th.shape[1] / 2):, 0:int(c_th.shape[2] / 2)]
elif Ar.ndim == 2 and Br.ndim == 2:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=1), torch.cat((Ai, Ar), dim=1)), dim=0)
b_th = torch.cat((torch.cat((Br, -Bi), dim=1), torch.cat((Bi, Br), dim=1)), dim=0)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[0:int(c_th.shape[0] / 2), 0:int(c_th.shape[1] / 2)]
c_th_i = c_th[int(c_th.shape[0] / 2):, 0:int(c_th.shape[1] / 2)]
elif Ar.ndim == 4 and Br.ndim == 4:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=3), torch.cat((Ai, Ar), dim=3)), dim=2)
b_th = torch.cat((torch.cat((Br, -Bi), dim=3), torch.cat((Bi, Br), dim=3)), dim=2)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, :, 0:int(c_th.shape[2] / 2), 0:int(c_th.shape[3] / 2)]
c_th_i = c_th[:, :, int(c_th.shape[2] / 2):, 0:int(c_th.shape[3] / 2)]
elif Ar.ndim == 5 and Br.ndim == 5:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=4), torch.cat((Ai, Ar), dim=4)), dim=3)
b_th = torch.cat((torch.cat((Br, -Bi), dim=4), torch.cat((Bi, Br), dim=4)), dim=3)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, :, :, 0:int(c_th.shape[3] / 2), 0:int(c_th.shape[4] / 2)]
c_th_i = c_th[:, :, :, int(c_th.shape[3] / 2):, 0:int(c_th.shape[4] / 2)]
elif Ar.ndim * Br.ndim == 12:
if Ar.ndim == 4:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=3), torch.cat((Ai, Ar), dim=3)), dim=2)
b_th = torch.cat((torch.cat((Br, -Bi), dim=2), torch.cat((Bi, Br), dim=2)), dim=1)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, :, 0:int(c_th.shape[2] / 2), 0:int(c_th.shape[3] / 2)]
c_th_i = c_th[:, :, int(c_th.shape[2] / 2):, 0:int(c_th.shape[3] / 2)]
elif Br.ndim == 4:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=2), torch.cat((Ai, Ar), dim=2)), dim=1)
b_th = torch.cat((torch.cat((Br, -Bi), dim=3), torch.cat((Bi, Br), dim=3)), dim=2)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, :, 0:int(c_th.shape[2] / 2), 0:int(c_th.shape[3] / 2)]
c_th_i = c_th[:, :, int(c_th.shape[2] / 2):, 0:int(c_th.shape[3] / 2)]
elif Ar.ndim * Br.ndim == 20:
if Ar.ndim == 5:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=4), torch.cat((Ai, Ar), dim=4)), dim=3)
b_th = torch.cat((torch.cat((Br, -Bi), dim=3), torch.cat((Bi, Br), dim=3)), dim=2)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, :, :, 0:int(c_th.shape[3] / 2), 0:int(c_th.shape[4] / 2)]
c_th_i = c_th[:, :, :, int(c_th.shape[3] / 2):, 0:int(c_th.shape[4] / 2)]
elif Br.ndim == 5:
a_th = torch.cat((torch.cat((Ar, -Ai), dim=3), torch.cat((Ai, Ar), dim=3)), dim=2)
b_th = torch.cat((torch.cat((Br, -Bi), dim=4), torch.cat((Bi, Br), dim=4)), dim=3)
c_th = torch.matmul(a_th, b_th)
c_th_r = c_th[:, :, :, 0:int(c_th.shape[3] / 2), 0:int(c_th.shape[4] / 2)]
c_th_i = c_th[:, :, :, int(c_th.shape[3] / 2):, 0:int(c_th.shape[4] / 2)]
else:
raise Exception('the dimension is not defined for Th_comp_matmul.')
return c_th_r, c_th_i
def Th_inv(Ar, Ai): # Complex inverse pytorch function ########
Ar_inv = torch.inverse(Ar + torch.matmul(torch.matmul(Ai, torch.inverse(Ar)), Ai))
Ai_inv = - torch.matmul(torch.matmul(torch.inverse(Ar), Ai), Ar_inv)
return Ar_inv, Ai_inv
def Th_pinv(Ar, Ai): # Complex inverse pytorch function ########
if Ar.ndim == 2:
if Ar.shape[0] < Ar.shape[1]:
Tempr, Tempi = Th_comp_matmul(Ar, Ai, Ar.T, -Ai.T)
Ar_inv, Ai_inv = Th_inv(Tempr, Tempi)
return Th_comp_matmul(Ar.T, -Ai.T, Ar_inv, Ai_inv)
elif Ar.shape[0] > Ar.shape[1]:
Tempr, Tempi = Th_comp_matmul(Ar.T, -Ai.T, Ar, Ai)
Ar_inv, Ai_inv = Th_inv(Tempr, Tempi)
return Th_comp_matmul(Ar_inv, Ai_inv, Ar.T, -Ai.T)
elif Ar.shape[0] == Ar.shape[1]:
return Th_inv(Ar, Ai)
elif Ar.ndim == 3:
if Ar.shape[1] < Ar.shape[2]:
Tempr, Tempi = Th_comp_matmul(Ar, Ai, Ar.permute(0, 2, 1), -Ai.permute(0, 2, 1))
Ar_inv, Ai_inv = Th_inv(Tempr, Tempi)
return Th_comp_matmul(Ar.permute(0, 2, 1), -Ai.permute(0, 2, 1), Ar_inv, Ai_inv)
elif Ar.shape[1] > Ar.shape[2]:
Tempr, Tempi = Th_comp_matmul(Ar.permute(0, 2, 1), -Ai.permute(0, 2, 1), Ar, Ai)
Ar_inv, Ai_inv = Th_inv(Tempr, Tempi)
return Th_comp_matmul(Ar_inv, Ai_inv, Ar.permute(0, 2, 1), -Ai.permute(0, 2, 1))
elif Ar.shape[1] == Ar.shape[2]:
return Th_inv(Ar, Ai)
elif Ar.ndim == 4:
if Ar.shape[2] < Ar.shape[3]:
Tempr, Tempi = Th_comp_matmul(Ar, Ai, Ar.permute(0, 1, 3, 2), -Ai.permute(0, 1, 3, 2))
Ar_inv, Ai_inv = Th_inv(Tempr, Tempi)
return Th_comp_matmul(Ar.permute(0, 1, 3, 2), -Ai.permute(0, 1, 3, 2), Ar_inv, Ai_inv)
elif Ar.shape[2] > Ar.shape[3]:
Tempr, Tempi = Th_comp_matmul(Ar.permute(0, 1, 3, 2), -Ai.permute(0, 1, 3, 2), Ar, Ai)
Ar_inv, Ai_inv = Th_inv(Tempr, Tempi)
return Th_comp_matmul(Ar_inv, Ai_inv, Ar.permute(0, 1, 3, 2), -Ai.permute(0, 1, 3, 2))
elif Ar.shape[2] == Ar.shape[3]:
return Th_inv(Ar, Ai)
else:
raise Exception('5-D is not defined for Th_pinv.')