Skip to content

Commit 8a297ad

Browse files
author
giovanni-marchetti
committed
JAX implementation and refactoring
1 parent 309f547 commit 8a297ad

17 files changed

+357
-198
lines changed

README.md

+25-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,25 @@
1-
# spectral-universality
1+
# Harmonics of Learning
2+
<p align="center">
3+
<img src="weights_rot.png" alt="Rotational harmonics" width="700" />
4+
</p>
5+
6+
## Description
7+
Implementation of a complex-valued Power-Spectral Network trained via contrastive learning on an invariance objective for a finite group.
8+
As shown in the companion paper, at convergence the network learns all the irreducible unitary representations of the group. In particular, the multiplication table can be extracted from its weighs.
9+
10+
11+
We provide implementations of the model and its training in both `PyTorch` and in `JAX`.
12+
13+
14+
## Setup
15+
```
16+
python 3.8+
17+
pip install -r requirements.txt
18+
```
19+
20+
21+
## Groups
22+
The file `groups.py` provides implementations of various finite groups, including cyclic, dihedral and symmetric.
23+
24+
## Train
25+
In order to train the models in `PyTorch` and in `JAX`, run the files `train_torch.py` and `train_JAX.py` respectively. The training parameters are set up at the beginning of these files.

datasets.py

+13-43
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,31 @@
11
import torch
2-
from torchvision import transforms, datasets
3-
import numpy as np
2+
from numpy import random
43

54
from groups import *
65

76

7+
"""
8+
Contrastive learning dataset for groups. A datapoint is a pair of (noisy) complex vectors in the same orbit.
9+
"""
810
class group_dset(torch.utils.data.Dataset):
9-
def __init__(self, group, std=1.):
11+
def __init__(self, group, std=1., noise=0.):
1012
self.group = group
1113
self.std = std
14+
self.noise = noise
1215

1316
def __getitem__(self, index):
14-
x_re = self.std * torch.randn((self.group.order,))
15-
x_im = self.std * torch.randn((self.group.order,))
16-
# x_re = 2 * torch.rand((self.group.order,)) - 1.
17-
# x_im = 2 * torch.rand((self.group.order,)) - 1.
18-
x = torch.complex(x_re, x_im)
17+
x_re = self.std * random.randn(self.group.order)
18+
x_im = self.std * random.randn(self.group.order)
19+
x = x_re + 1j * x_im
1920
y = self.group.act(x)
21+
22+
perturb_re = self.noise * random.randn(self.group.order)
23+
perturb_im = self.noise * random.randn(self.group.order)
24+
x += perturb_re + 1j * perturb_im
2025

2126
return x, y
2227

2328
def __len__(self):
2429
return 1000
2530

2631

27-
28-
# class RegBiCyclic(torch.utils.data.Dataset):
29-
# def __init__(self, A, B):
30-
# self.A = A
31-
# self.B = B
32-
33-
# def __getitem__(self, index):
34-
# x_re_A = torch.randn((self.A,))
35-
# x_im_A = torch.randn((self.A,))
36-
37-
# shift = torch.randint(low=0, high=self.A, size=(1,)).item() #The index needs to start from 0 since in a product of groups the identities matter
38-
39-
# y_re_A = torch.roll(x_re_A, shift)
40-
# y_im_A = torch.roll(x_im_A, shift)
41-
42-
# x_re_B = torch.randn((self.B,))
43-
# x_im_B = torch.randn((self.B,))
44-
45-
# shift = torch.randint(low=0, high=self.B, size=(1,)).item() #The index needs to start from 0 since in a product of groups the identities matter
46-
47-
# y_re_B = torch.roll(x_re_B, shift)
48-
# y_im_B = torch.roll(x_im_B, shift)
49-
50-
51-
# x_re = torch.cat([x_re_A, x_re_B], dim=-1)
52-
# x_im =torch.cat([x_im_A, x_im_B], dim=-1)
53-
# y_re = torch.cat([y_re_A, y_re_B], dim=-1)
54-
# y_im =torch.cat([y_im_A, y_im_B], dim=-1)
55-
# x = torch.complex(x_re, x_im)
56-
# y = torch.complex(y_re, y_im)
57-
58-
# return x, y
59-
60-
# def __len__(self):
61-
# return 1000

groups.py

+44-27
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,47 @@
11
import numpy as np
2-
import torch
2+
from numpy import random
33
import math
44
import itertools as it
55

66
from utils import *
77

8+
9+
"""
10+
Abstract class representing a finite group.
11+
"""
812
class abstr_group():
913
def __init__(self, order, cayley_table, irrep_dims):
1014
self.order = order
1115
self.cayley_table = cayley_table
1216
self.irrep_dims = irrep_dims
1317

1418
def act(self, x):
15-
g = torch.randint(low=0, high=self.order, size=(1,)).item()
19+
g = random.randint(low=0, high=self.order)
1620
return x[self.cayley_table[g]]
1721

1822
def check_dims(self):
19-
irrep_dims = torch.tensor(self.irrep_dims)
20-
assert (irrep_dims**2).sum().item() == self.order
23+
irrep_dims = np.array(self.irrep_dims)
24+
assert (irrep_dims**2).sum() == self.order
2125

2226

23-
27+
"""
28+
Cyclic groups
29+
"""
2430
class cyclic(abstr_group):
2531
def __init__(self, N):
2632
self.order = N
2733
self.irrep_dims = [1]*N
2834

29-
self.cayley_table = torch.zeros(N, N)
35+
self.cayley_table = np.zeros((N, N))
3036
for i in range(N):
31-
self.cayley_table[i] = torch.roll(torch.arange(0, N), -i)
32-
self.cayley_table = self.cayley_table.long()
37+
self.cayley_table[i] = np.roll(np.arange(0, N), -i)
38+
self.cayley_table = self.cayley_table.astype(int)
3339

3440

3541

36-
42+
"""
43+
Dihedral groups
44+
"""
3745
class dihedral(abstr_group):
3846
def __init__(self, N):
3947
self.order = 2*N
@@ -44,19 +52,19 @@ def __init__(self, N):
4452
self.irrep_dims = [1]*2 + [2]*int((N - 1) / 2)
4553

4654

47-
reflection = torch.Tensor([0] + [N-i for i in range(1, N)]).long()
48-
self.group_elems = torch.zeros(2*N, N)
55+
reflection = np.array([0] + [N-i for i in range(1, N)]).astype(int)
56+
self.group_elems = np.zeros((2*N, N))
4957
for i in range(N):
50-
cycle = torch.roll(torch.arange(0, N), i)
58+
cycle = np.roll(np.arange(0, N), i)
5159
self.group_elems[i] = cycle
5260
self.group_elems[N+i] = cycle[reflection]
53-
self.group_elems = self.group_elems.long()
61+
self.group_elems = self.group_elems.astype(int)
5462

55-
self.cayley_table = torch.zeros(2*N, 2*N)
63+
self.cayley_table = np.zeros((2*N, 2*N))
5664
for i in range(2*N):
5765
for j in range(2*N):
5866
comp = self.group_elems[i][self.group_elems[j]]
59-
self.cayley_table[i, j] = torch.argmin( ((comp.unsqueeze(0) - self.group_elems)**2).sum(-1) )
67+
self.cayley_table[i, j] = np.argmin( ((np.expand_dims(comp, 0) - self.group_elems)**2).sum(-1) )
6068

6169
if N == 2:
6270
C = [
@@ -65,48 +73,54 @@ def __init__(self, N):
6573
[2, 3, 0, 1],
6674
[3, 2, 1, 0]
6775
]
68-
self.cayley_table = torch.Tensor(C)
69-
70-
self.cayley_table = self.cayley_table.long()
76+
self.cayley_table = np.array(C)
7177

78+
self.cayley_table = self.cayley_table.astype(int)
7279

7380

81+
"""
82+
Symmetric groups
83+
"""
7484
class symmetric(abstr_group):
7585
def __init__(self, N):
7686
self.order = math.factorial(N)
7787

7888
self.irrep_dims = [hook_length(P, N) for P in list(gen_partitions(N))]
7989

80-
self.group_elems = torch.zeros(self.order, N)
90+
self.group_elems = np.zeros((self.order, N))
8191
for i, perm in enumerate(it.permutations(range(N))):
82-
self.group_elems[i] = torch.Tensor(list(perm))
83-
self.group_elems = self.group_elems.long()
92+
self.group_elems[i] = np.array(list(perm))
93+
self.group_elems = self.group_elems.astype(int)
8494

85-
self.cayley_table = torch.zeros(self.order, self.order)
95+
self.cayley_table = np.zeros((self.order, self.order))
8696
for i in range(self.order):
8797
for j in range(self.order):
8898
comp = self.group_elems[i][self.group_elems[j]]
89-
self.cayley_table[i, j] = torch.argmin( ((comp.unsqueeze(0) - self.group_elems)**2).sum(-1) )
99+
self.cayley_table[i, j] = np.argmin( ((np.expand_dims(comp, 0) - self.group_elems)**2).sum(-1) )
100+
101+
self.cayley_table = self.cayley_table.astype(int)
90102

91-
self.cayley_table = self.cayley_table.long()
92103

93104

105+
"""
106+
Direct product of groups
107+
"""
94108
def direct_product(group_1, group_2):
95109
order_1 = group_1.order
96110
order_2 = group_2.order
97111
order_res = order_1 * order_2
98112

99113
cayley_1 = group_1.cayley_table
100114
cayley_2 = group_2.cayley_table
101-
cayley_res = torch.zeros(order_res, order_res)
115+
cayley_res = np.zeros((order_res, order_res))
102116
for i_1 in range(order_1):
103117
for i_2 in range(order_2):
104118
for j_1 in range(order_1):
105119
for j_2 in range(order_2):
106120
g_1 = cayley_1[i_1, j_1]
107121
g_2 = cayley_2[i_2, j_2]
108122
cayley_res[i_1*order_2 + i_2, j_1*order_2 + j_2] = g_1*order_2 + g_2
109-
cayley_res = cayley_res.long()
123+
cayley_res = cayley_res.astype(int)
110124

111125
irrep_dims_1 = group_1.irrep_dims
112126
irrep_dim_2 = group_2.irrep_dims
@@ -115,4 +129,7 @@ def direct_product(group_1, group_2):
115129
for d_2 in irrep_dim_2:
116130
irrep_dims_res.append(d_1 * d_2)
117131

118-
return abstr_group(order_res, cayley_res, irrep_dims_res)
132+
return abstr_group(order_res, cayley_res, irrep_dims_res)
133+
134+
135+

models_JAX.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import jax.numpy as jnp
2+
import jax
3+
from jax.lax import complex
4+
5+
from utils import *
6+
7+
8+
9+
# initializer = jax.nn.initializers.glorot_uniform(in_axis=-3, out_axis=-2)
10+
11+
initializer = jax.nn.initializers.uniform(scale=1.)
12+
13+
def init_weights(group_order, irrep_dims):
14+
keys = jax.random.split(jax.random.PRNGKey(42), len(irrep_dims))
15+
return [(2. / d_i) * initializer(k, (group_order - 1, d_i, d_i, 2), jnp.float32) - (1. / d_i)
16+
for k, d_i in zip(keys, irrep_dims)
17+
]
18+
19+
20+
def pad_eye(W_i):
21+
d_i = W_i.shape[-1]
22+
eyecm = complex(jnp.eye(d_i), jnp.zeros((d_i, d_i)))
23+
return jnp.concatenate([jnp.expand_dims(eyecm, 0), W_i], 0)
24+
25+
26+
def total_weight(W, irrep_dims, group_order):
27+
W_list = []
28+
for W_i, d_i in zip(W, irrep_dims):
29+
Wcm = complex(W_i[..., 0], W_i[..., 1])
30+
W_cm_ext = jnp.reshape(pad_eye(Wcm), (group_order, d_i * d_i))
31+
W_list.append(W_cm_ext)
32+
return jnp.concatenate(W_list, -1)
33+
34+
35+
36+
def forward(W, x):
37+
res = []
38+
for W_i in W:
39+
Wcm = complex(W_i[..., 0], W_i[..., 1])
40+
W_cm_ext = pad_eye(Wcm)
41+
42+
W_i_x = (jnp.expand_dims(W_cm_ext, 0) * jnp.expand_dims(jnp.expand_dims(x, -1), -1)).sum(1)
43+
W_i_x_T = jnp.conjugate(jnp.transpose(W_i_x, axes=(0, -1, -2)))
44+
45+
res.append(W_i_x @ W_i_x_T)
46+
return res
47+
48+
49+
50+
def loss(W, x, y):
51+
res_x = forward(W, x)
52+
res_y = forward(W, y)
53+
54+
res_loss = jnp.zeros(x.shape[0])
55+
for (res_x_i, res_y_i) in zip(res_x, res_y):
56+
res_loss += (jnp.abs((res_x_i - res_y_i))**2).mean(-1).mean(-1)
57+
58+
return res_loss / len(res_x)
59+
60+
61+
def reg(W, irrep_dims, group_order):
62+
d_tot = jnp.array(irrep_dims).sum()
63+
eyecm = (d_tot) * complex(jnp.eye(group_order), jnp.zeros((group_order, group_order)))
64+
65+
W_tot = total_weight(W, irrep_dims, group_order)
66+
W_tot_T = jnp.conjugate(jnp.transpose(W_tot, axes=(-1, -2)))
67+
return (jnp.abs((eyecm - W_tot @ W_tot_T ))**2).mean()
68+
69+
70+
71+
"""
72+
Function recovering the Cayley table from the weights of the model
73+
"""
74+
def get_table(W, group_order):
75+
76+
res = jnp.zeros((group_order, group_order))
77+
for g in range(group_order):
78+
for h in range(group_order):
79+
80+
diffs = jnp.zeros(group_order)
81+
for W_i in W:
82+
Wcm = complex(W_i[..., 0], W_i[..., 1])
83+
W_cm_ext = jnp.conjugate(jnp.transpose(pad_eye(Wcm), axes=(0, -1, -2)))
84+
W_gh = W_cm_ext[g] @ W_cm_ext[h]
85+
diffs += (jnp.abs(jnp.expand_dims(W_gh, 0) - W_cm_ext)**2).mean(-1).mean(-1)
86+
87+
res = res.at[g, h].set(jnp.argmin(diffs))
88+
89+
return res
90+
91+
92+
93+

0 commit comments

Comments
 (0)