-
Notifications
You must be signed in to change notification settings - Fork 1
/
matutils.py
168 lines (140 loc) · 4.75 KB
/
matutils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
""" Utility functions for testing matrix stuff.
"""
import time
import jax
import jax.numpy as jnp
from jax.ops import index, index_update
@jax.jit
def dag(A):
"""
Hermitian conjugation.
"""
return jnp.conj(A.T)
@jax.jit
def trimultmat(A, B, C):
"""
A @ B @ C for three dense matrices.
"""
return jnp.dot(jnp.dot(A, B), C)
@jax.jit
def trimultdag(A, D, C):
"""
A @ D @ C where D is a vector of diagonal entries.
"""
return jnp.dot(A, D[:, None]*C)
@jax.jit
def safe_divide_matmat(A, B):
"""
Returns C, where C=A/B at indices where B!=0, and C=0 where B=0.
"""
safe_idxs = jnp.where(B != 0)
C = jnp.zeros(A.shape)
C = index_update(C, index[safe_idxs], A[safe_idxs]/B[safe_idxs])
return C
def subblock_main_diagonal(A, bi=0):
"""
Returns the indices of the elements in A's main diagonal contained in its
[bi:, bi:] subblock.
"""
m, n = A.shape
idxs = jnp.arange(bi, min(m, n), dtype=jnp.int32)
di = (idxs, idxs)
return di
def replace_diagonal(A, D, off=0):
"""
A is an m x n matrix.
D is size nD <= min(m, n) representing the diagonal of an array.
A matrix is returned identical to A, except that its first nD elements
on the main diagonal are replaced with those of D,
and any successive (but not preceding) diagonal elements are zeroed out.
"""
m, n = A.shape
k = min(m, n)
didxs = subblock_main_diagonal(A, bi=off)
new_elements = jnp.zeros(k, dtype=D.dtype)
new_elements = index_update(new_elements, index[:D.size], D)
A = index_update(A, index[didxs], new_elements)
return A
def matshape(A):
"""
Returns A.shape if A has two dimensions and throws a ValueError
otherwise.
"""
try:
m, n = A.shape
except ValueError:
raise ValueError("A had invalid shape: ", A.shape)
return (m, n)
def frob(A, B):
"""
sqrt(sum(|A - B|**2) divided by number of elements.
"""
assert A.size == B.size
return jnp.sqrt(jnp.sum(jnp.abs(jnp.square(A - B))))/A.size
def gaussian_random_complex64(key=None, shape=()):
"""
Use jax.random to generate a Gaussian random matrix of complex128 type.
The real and imaginary parts are separately generated.
If key is unspecified, a key is generated from system time.
"""
#if key is None:
key = jax.random.PRNGKey(int(time.time()))
subkey1, subkey2 = jax.random.split(key, 2)
realpart = jax.random.normal(subkey1, shape=shape, dtype=jnp.float32)
imagpart = jax.random.normal(subkey2, shape=shape, dtype=jnp.float32)
output = realpart + 1.0j * imagpart
return output
@jax.jit
def gaussian_random_complex_arr(arglist):
"""
Use jax.random to generate a Gaussian random matrix of complex128 type.
The real and imaginary parts are separately generated.
"""
key, arr = arglist
#if key is None:
subkey1, subkey2 = jax.random.split(key, 2)
realpart = jax.random.normal(subkey1, shape=arr.shape, dtype=jnp.float32)
imagpart = jax.random.normal(subkey2, shape=arr.shape, dtype=jnp.float32)
output = realpart + 1.0j * imagpart
return output
@jax.jit
def gaussian_random_real_arr(arglist):
"""
Use jax.random to generate a Gaussian random matrix of complex128 type.
The real and imaginary parts are separately generated.
"""
key, arr = arglist
realpart = jax.random.normal(key, shape=arr.shape, dtype=jnp.float32)
output = realpart
return output
@jax.jit
def gaussian_random_fill(work_array):
"""
Fill work_array with random values in place.
"""
key = jax.random.PRNGKey(int(time.time()))
subkey1, subkey2 = jax.random.split(key, 2)
output = jax.lax.cond(jnp.iscomplexobj(work_array),
(subkey1, work_array),
lambda x: gaussian_random_complex_arr(x).astype(
work_array.dtype),
(subkey2, work_array),
lambda x: gaussian_random_real_arr(x).astype(
work_array.dtype))
work_array = jax.ops.index_update(work_array, index[:], output)
return work_array
def gaussian_random(key=None, shape=(), dtype=jnp.float32):
"""
Generates a random matrix of the given shape and dtype, which unlike
in pure jax may be complex. If 'key' is unspecified, a key is generated
from system time.
"""
if dtype == jnp.complex64:
output = gaussian_random_complex64(key=key, shape=shape)
elif dtype == jnp.complex128:
raise NotImplementedError("double precision complex isn't supported")
else:
if key is None:
key = jax.random.PRNGKey(int(time.time()))
output = jax.random.normal(key, shape=shape, dtype=dtype)
return output