forked from lightaime/deep_gcns_torch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
128 lines (90 loc) · 3.53 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import __init__
from ogb.nodeproppred import Evaluator
import torch
import torch.nn.functional as F
from torch_geometric.utils import to_undirected, add_self_loops
from args import ArgsInit
from ogb.nodeproppred import PygNodePropPredDataset
from model import DeeperGCN
from utils.ckpt_util import save_ckpt
import logging
import time
@torch.no_grad()
def test(model, x, edge_index, y_true, split_idx, evaluator):
model.eval()
out = model(x, edge_index)
y_pred = out.argmax(dim=-1, keepdim=True)
train_acc = evaluator.eval({
'y_true': y_true[split_idx['train']],
'y_pred': y_pred[split_idx['train']],
})['acc']
valid_acc = evaluator.eval({
'y_true': y_true[split_idx['valid']],
'y_pred': y_pred[split_idx['valid']],
})['acc']
test_acc = evaluator.eval({
'y_true': y_true[split_idx['test']],
'y_pred': y_pred[split_idx['test']],
})['acc']
return train_acc, valid_acc, test_acc
def train(model, x, edge_index, y_true, train_idx, optimizer):
model.train()
optimizer.zero_grad()
pred = model(x, edge_index)[train_idx]
loss = F.nll_loss(pred, y_true.squeeze(1)[train_idx])
loss.backward()
optimizer.step()
return loss.item()
def main():
args = ArgsInit().save_exp()
logging.getLogger().setLevel(logging.INFO)
if args.use_gpu:
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
else:
device = torch.device('cpu')
dataset = PygNodePropPredDataset(name=args.dataset)
data = dataset[0]
split_idx = dataset.get_idx_split()
evaluator = Evaluator(args.dataset)
x = data.x.to(device)
y_true = data.y.to(device)
train_idx = split_idx['train'].to(device)
edge_index = data.edge_index.to(device)
edge_index = to_undirected(edge_index, data.num_nodes)
if args.self_loop:
edge_index = add_self_loops(edge_index, num_nodes=data.num_nodes)[0]
sub_dir = 'SL_{}'.format(args.self_loop)
args.in_channels = data.x.size(-1)
args.num_tasks = dataset.num_classes
logging.info('%s' % args)
model = DeeperGCN(args).to(device)
logging.info(model)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
results = {'highest_valid': 0,
'final_train': 0,
'final_test': 0,
'highest_train': 0}
start_time = time.time()
for epoch in range(1, args.epochs + 1):
epoch_loss = train(model, x, edge_index, y_true, train_idx, optimizer)
logging.info('Epoch {}, training loss {:.4f}'.format(epoch, epoch_loss))
model.print_params(epoch=epoch)
result = test(model, x, edge_index, y_true, split_idx, evaluator)
logging.info("%s" % results)
train_accuracy, valid_accuracy, test_accuracy = result
if train_accuracy > results['highest_train']:
results['highest_train'] = train_accuracy
if valid_accuracy > results['highest_valid']:
results['highest_valid'] = valid_accuracy
results['final_train'] = train_accuracy
results['final_test'] = test_accuracy
save_ckpt(model, optimizer,
round(epoch_loss, 4), epoch,
args.model_save_path,
sub_dir, name_post='valid_best')
logging.info("%s" % results)
end_time = time.time()
total_time = end_time - start_time
logging.info('Total time: {}'.format(time.strftime('%H:%M:%S', time.gmtime(total_time))))
if __name__ == "__main__":
main()