Skip to content

Commit 3c01986

Browse files
authored
Merge pull request #250 from Starlitnightly/main
Optimized `SCC` Implementation and Removed `TensorFlow` Dependencies
2 parents 1b4f67d + 882d8b4 commit 3c01986

File tree

25 files changed

+4544
-60
lines changed

25 files changed

+4544
-60
lines changed

requirements.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
adjustText
22
anndata>=0.8.0
33
colorcet>=2.0.1
4-
cvxopt>=1.2.3
4+
# cvxopt>=1.2.3
55
csbdeep>=0.6.3
66
descartes
77
dynamo-release>=1.4.1
@@ -21,9 +21,9 @@ networkx>=2.6.3
2121
numba>=0.46.0
2222
numpy>=1.18.1
2323
opencv-python>=4.5.4.60
24-
pandana
24+
# pandana
2525
pandas>=0.25.1
26-
paste-bio>=1.4.0
26+
# paste-bio>=1.4.0
2727
plotly>=5.1.0
2828
POT>=0.8.1
2929
psutil>=5.6.3
@@ -39,7 +39,7 @@ seaborn>=0.9.0
3939
setuptools>=58.0.4
4040
Shapely>=1.8.0
4141
statsmodels>=0.9.0
42-
tensorflow
42+
# tensorflow
4343
tqdm>=4.62.3
4444
torch
4545
trame>=2.2.5

spateo/external/CAST/CAST_Mark.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from collections import OrderedDict
2+
from timeit import default_timer as timer
3+
4+
import matplotlib.pyplot as plt
5+
import networkx as nx
6+
import numpy as np
7+
import torch
8+
from tqdm import trange
9+
10+
from .model.aug import random_aug
11+
from .utils import coords2adjacentmat
12+
13+
14+
def train_seq(graphs, args, dump_epoch_list, out_prefix, model):
15+
"""The CAST MARK training function
16+
17+
Args:
18+
graphs (List[Tuple(str, dgl.Graph, torch.Tensor)]): List of 3-member tuples, each tuple represents one tissue sample, containing sample name, a DGL graph object, and a feature matrix in the torch.Tensor format
19+
args (model_GCNII.Args): the Args object contains training parameters
20+
dump_epoch_list (List): A list of epoch id you hope training snapshots to be dumped, for debug use, empty by default
21+
out_prefix (str): file name prefix for the snapshot files
22+
model (model_GCNII.CCA_SSG): the GNN model
23+
24+
Returns:
25+
Tuple(Dict, List, CCA_SSG): returns a 3-member tuple, a dictionary containing the graph embeddings for each sample, a list of every loss value, and the trained model object
26+
"""
27+
model = model.to(args.device)
28+
29+
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr1, weight_decay=args.wd1)
30+
31+
loss_log = []
32+
time_now = timer()
33+
34+
t = trange(args.epochs, desc="", leave=True)
35+
for epoch in t:
36+
37+
with torch.no_grad():
38+
if epoch in dump_epoch_list:
39+
model.eval()
40+
dump_embedding = OrderedDict()
41+
for name, graph, feat in graphs:
42+
# graph = graph.to(args.device)
43+
# feat = feat.to(args.device)
44+
dump_embedding[name] = model.get_embedding(graph, feat)
45+
torch.save(dump_embedding, f"{out_prefix}_embed_dict_epoch{epoch}.pt")
46+
torch.save(loss_log, f"{out_prefix}_loss_log_epoch{epoch}.pt")
47+
print(f"Successfully dumped epoch {epoch}")
48+
49+
losses = dict()
50+
model.train()
51+
optimizer.zero_grad()
52+
# print(f'Epoch: {epoch}')
53+
54+
for name_, graph_, feat_ in graphs:
55+
with torch.no_grad():
56+
N = graph_.number_of_nodes()
57+
graph1, feat1 = random_aug(graph_, feat_, args.dfr, args.der)
58+
graph2, feat2 = random_aug(graph_, feat_, args.dfr, args.der)
59+
60+
graph1 = graph1.add_self_loop()
61+
graph2 = graph2.add_self_loop()
62+
63+
z1, z2 = model(graph1, feat1, graph2, feat2)
64+
65+
c = torch.mm(z1.T, z2)
66+
c1 = torch.mm(z1.T, z1)
67+
c2 = torch.mm(z2.T, z2)
68+
69+
c = c / N
70+
c1 = c1 / N
71+
c2 = c2 / N
72+
73+
loss_inv = -torch.diagonal(c).sum()
74+
iden = torch.eye(c.size(0), device=args.device)
75+
loss_dec1 = (iden - c1).pow(2).sum()
76+
loss_dec2 = (iden - c2).pow(2).sum()
77+
loss = loss_inv + args.lambd * (loss_dec1 + loss_dec2)
78+
loss.backward()
79+
optimizer.step()
80+
81+
# del graph1, feat1, graph2, feat2
82+
loss_log.append(loss.item())
83+
time_step = timer() - time_now
84+
time_now += time_step
85+
# print(f'Loss: {loss.item()} step time={time_step:.3f}s')
86+
t.set_description(f"Loss: {loss.item():.3f} step time={time_step:.3f}s")
87+
t.refresh()
88+
89+
model.eval()
90+
with torch.no_grad():
91+
dump_embedding = OrderedDict()
92+
for name, graph, feat in graphs:
93+
dump_embedding[name] = model.get_embedding(graph, feat)
94+
return dump_embedding, loss_log, model
95+
96+
97+
# graph construction tools
98+
def delaunay_dgl(sample_name, df, output_path, if_plot=True, strategy_t="convex"):
99+
coords = np.column_stack((np.array(df)[:, 0], np.array(df)[:, 1]))
100+
delaunay_graph = coords2adjacentmat(coords, output_mode="raw", strategy_t=strategy_t)
101+
if if_plot:
102+
positions = dict(zip(delaunay_graph.nodes, coords[delaunay_graph.nodes, :]))
103+
_, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 10))
104+
nx.draw(
105+
delaunay_graph,
106+
positions,
107+
ax=ax,
108+
node_size=1,
109+
node_color="#000000",
110+
edge_color="#5A98AF",
111+
alpha=0.6,
112+
)
113+
plt.axis("equal")
114+
plt.savefig(f"{output_path}/delaunay_{sample_name}.png")
115+
import dgl
116+
117+
return dgl.from_networkx(delaunay_graph)

0 commit comments

Comments
 (0)