Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
adjustText
anndata>=0.8.0
colorcet>=2.0.1
cvxopt>=1.2.3
# cvxopt>=1.2.3
csbdeep>=0.6.3
descartes
dynamo-release>=1.4.1
Expand All @@ -21,9 +21,9 @@ networkx>=2.6.3
numba>=0.46.0
numpy>=1.18.1
opencv-python>=4.5.4.60
pandana
# pandana
pandas>=0.25.1
paste-bio>=1.4.0
# paste-bio>=1.4.0
plotly>=5.1.0
POT>=0.8.1
psutil>=5.6.3
Expand All @@ -39,7 +39,7 @@ seaborn>=0.9.0
setuptools>=58.0.4
Shapely>=1.8.0
statsmodels>=0.9.0
tensorflow
# tensorflow
tqdm>=4.62.3
torch
trame>=2.2.5
Expand Down
117 changes: 117 additions & 0 deletions spateo/external/CAST/CAST_Mark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from collections import OrderedDict
from timeit import default_timer as timer

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch
from tqdm import trange

from .model.aug import random_aug
from .utils import coords2adjacentmat


def train_seq(graphs, args, dump_epoch_list, out_prefix, model):
"""The CAST MARK training function

Args:
graphs (List[Tuple(str, dgl.Graph, torch.Tensor)]): List of 3-member tuples, each tuple represents one tissue sample, containing sample name, a DGL graph object, and a feature matrix in the torch.Tensor format
args (model_GCNII.Args): the Args object contains training parameters
dump_epoch_list (List): A list of epoch id you hope training snapshots to be dumped, for debug use, empty by default
out_prefix (str): file name prefix for the snapshot files
model (model_GCNII.CCA_SSG): the GNN model

Returns:
Tuple(Dict, List, CCA_SSG): returns a 3-member tuple, a dictionary containing the graph embeddings for each sample, a list of every loss value, and the trained model object
"""
model = model.to(args.device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr1, weight_decay=args.wd1)

loss_log = []
time_now = timer()

t = trange(args.epochs, desc="", leave=True)
for epoch in t:

with torch.no_grad():
if epoch in dump_epoch_list:
model.eval()
dump_embedding = OrderedDict()
for name, graph, feat in graphs:
# graph = graph.to(args.device)
# feat = feat.to(args.device)
dump_embedding[name] = model.get_embedding(graph, feat)
torch.save(dump_embedding, f"{out_prefix}_embed_dict_epoch{epoch}.pt")
torch.save(loss_log, f"{out_prefix}_loss_log_epoch{epoch}.pt")
print(f"Successfully dumped epoch {epoch}")

losses = dict()
model.train()
optimizer.zero_grad()
# print(f'Epoch: {epoch}')

for name_, graph_, feat_ in graphs:
with torch.no_grad():
N = graph_.number_of_nodes()
graph1, feat1 = random_aug(graph_, feat_, args.dfr, args.der)
graph2, feat2 = random_aug(graph_, feat_, args.dfr, args.der)

graph1 = graph1.add_self_loop()
graph2 = graph2.add_self_loop()

z1, z2 = model(graph1, feat1, graph2, feat2)

c = torch.mm(z1.T, z2)
c1 = torch.mm(z1.T, z1)
c2 = torch.mm(z2.T, z2)

c = c / N
c1 = c1 / N
c2 = c2 / N

loss_inv = -torch.diagonal(c).sum()
iden = torch.eye(c.size(0), device=args.device)
loss_dec1 = (iden - c1).pow(2).sum()
loss_dec2 = (iden - c2).pow(2).sum()
loss = loss_inv + args.lambd * (loss_dec1 + loss_dec2)
loss.backward()
optimizer.step()

# del graph1, feat1, graph2, feat2
loss_log.append(loss.item())
time_step = timer() - time_now
time_now += time_step
# print(f'Loss: {loss.item()} step time={time_step:.3f}s')
t.set_description(f"Loss: {loss.item():.3f} step time={time_step:.3f}s")
t.refresh()

model.eval()
with torch.no_grad():
dump_embedding = OrderedDict()
for name, graph, feat in graphs:
dump_embedding[name] = model.get_embedding(graph, feat)
return dump_embedding, loss_log, model


# graph construction tools
def delaunay_dgl(sample_name, df, output_path, if_plot=True, strategy_t="convex"):
coords = np.column_stack((np.array(df)[:, 0], np.array(df)[:, 1]))
delaunay_graph = coords2adjacentmat(coords, output_mode="raw", strategy_t=strategy_t)
if if_plot:
positions = dict(zip(delaunay_graph.nodes, coords[delaunay_graph.nodes, :]))
_, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 10))
nx.draw(
delaunay_graph,
positions,
ax=ax,
node_size=1,
node_color="#000000",
edge_color="#5A98AF",
alpha=0.6,
)
plt.axis("equal")
plt.savefig(f"{output_path}/delaunay_{sample_name}.png")
import dgl

return dgl.from_networkx(delaunay_graph)
Loading
Loading