|
| 1 | +from collections import OrderedDict |
| 2 | +from timeit import default_timer as timer |
| 3 | + |
| 4 | +import matplotlib.pyplot as plt |
| 5 | +import networkx as nx |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | +from tqdm import trange |
| 9 | + |
| 10 | +from .model.aug import random_aug |
| 11 | +from .utils import coords2adjacentmat |
| 12 | + |
| 13 | + |
| 14 | +def train_seq(graphs, args, dump_epoch_list, out_prefix, model): |
| 15 | + """The CAST MARK training function |
| 16 | +
|
| 17 | + Args: |
| 18 | + graphs (List[Tuple(str, dgl.Graph, torch.Tensor)]): List of 3-member tuples, each tuple represents one tissue sample, containing sample name, a DGL graph object, and a feature matrix in the torch.Tensor format |
| 19 | + args (model_GCNII.Args): the Args object contains training parameters |
| 20 | + dump_epoch_list (List): A list of epoch id you hope training snapshots to be dumped, for debug use, empty by default |
| 21 | + out_prefix (str): file name prefix for the snapshot files |
| 22 | + model (model_GCNII.CCA_SSG): the GNN model |
| 23 | +
|
| 24 | + Returns: |
| 25 | + Tuple(Dict, List, CCA_SSG): returns a 3-member tuple, a dictionary containing the graph embeddings for each sample, a list of every loss value, and the trained model object |
| 26 | + """ |
| 27 | + model = model.to(args.device) |
| 28 | + |
| 29 | + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr1, weight_decay=args.wd1) |
| 30 | + |
| 31 | + loss_log = [] |
| 32 | + time_now = timer() |
| 33 | + |
| 34 | + t = trange(args.epochs, desc="", leave=True) |
| 35 | + for epoch in t: |
| 36 | + |
| 37 | + with torch.no_grad(): |
| 38 | + if epoch in dump_epoch_list: |
| 39 | + model.eval() |
| 40 | + dump_embedding = OrderedDict() |
| 41 | + for name, graph, feat in graphs: |
| 42 | + # graph = graph.to(args.device) |
| 43 | + # feat = feat.to(args.device) |
| 44 | + dump_embedding[name] = model.get_embedding(graph, feat) |
| 45 | + torch.save(dump_embedding, f"{out_prefix}_embed_dict_epoch{epoch}.pt") |
| 46 | + torch.save(loss_log, f"{out_prefix}_loss_log_epoch{epoch}.pt") |
| 47 | + print(f"Successfully dumped epoch {epoch}") |
| 48 | + |
| 49 | + losses = dict() |
| 50 | + model.train() |
| 51 | + optimizer.zero_grad() |
| 52 | + # print(f'Epoch: {epoch}') |
| 53 | + |
| 54 | + for name_, graph_, feat_ in graphs: |
| 55 | + with torch.no_grad(): |
| 56 | + N = graph_.number_of_nodes() |
| 57 | + graph1, feat1 = random_aug(graph_, feat_, args.dfr, args.der) |
| 58 | + graph2, feat2 = random_aug(graph_, feat_, args.dfr, args.der) |
| 59 | + |
| 60 | + graph1 = graph1.add_self_loop() |
| 61 | + graph2 = graph2.add_self_loop() |
| 62 | + |
| 63 | + z1, z2 = model(graph1, feat1, graph2, feat2) |
| 64 | + |
| 65 | + c = torch.mm(z1.T, z2) |
| 66 | + c1 = torch.mm(z1.T, z1) |
| 67 | + c2 = torch.mm(z2.T, z2) |
| 68 | + |
| 69 | + c = c / N |
| 70 | + c1 = c1 / N |
| 71 | + c2 = c2 / N |
| 72 | + |
| 73 | + loss_inv = -torch.diagonal(c).sum() |
| 74 | + iden = torch.eye(c.size(0), device=args.device) |
| 75 | + loss_dec1 = (iden - c1).pow(2).sum() |
| 76 | + loss_dec2 = (iden - c2).pow(2).sum() |
| 77 | + loss = loss_inv + args.lambd * (loss_dec1 + loss_dec2) |
| 78 | + loss.backward() |
| 79 | + optimizer.step() |
| 80 | + |
| 81 | + # del graph1, feat1, graph2, feat2 |
| 82 | + loss_log.append(loss.item()) |
| 83 | + time_step = timer() - time_now |
| 84 | + time_now += time_step |
| 85 | + # print(f'Loss: {loss.item()} step time={time_step:.3f}s') |
| 86 | + t.set_description(f"Loss: {loss.item():.3f} step time={time_step:.3f}s") |
| 87 | + t.refresh() |
| 88 | + |
| 89 | + model.eval() |
| 90 | + with torch.no_grad(): |
| 91 | + dump_embedding = OrderedDict() |
| 92 | + for name, graph, feat in graphs: |
| 93 | + dump_embedding[name] = model.get_embedding(graph, feat) |
| 94 | + return dump_embedding, loss_log, model |
| 95 | + |
| 96 | + |
| 97 | +# graph construction tools |
| 98 | +def delaunay_dgl(sample_name, df, output_path, if_plot=True, strategy_t="convex"): |
| 99 | + coords = np.column_stack((np.array(df)[:, 0], np.array(df)[:, 1])) |
| 100 | + delaunay_graph = coords2adjacentmat(coords, output_mode="raw", strategy_t=strategy_t) |
| 101 | + if if_plot: |
| 102 | + positions = dict(zip(delaunay_graph.nodes, coords[delaunay_graph.nodes, :])) |
| 103 | + _, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 10)) |
| 104 | + nx.draw( |
| 105 | + delaunay_graph, |
| 106 | + positions, |
| 107 | + ax=ax, |
| 108 | + node_size=1, |
| 109 | + node_color="#000000", |
| 110 | + edge_color="#5A98AF", |
| 111 | + alpha=0.6, |
| 112 | + ) |
| 113 | + plt.axis("equal") |
| 114 | + plt.savefig(f"{output_path}/delaunay_{sample_name}.png") |
| 115 | + import dgl |
| 116 | + |
| 117 | + return dgl.from_networkx(delaunay_graph) |
0 commit comments