-
Notifications
You must be signed in to change notification settings - Fork 6
/
utils.py
46 lines (38 loc) · 1.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
import torch
from scipy import sparse
import os.path
import time
import torch.nn as nn
from ase import Atoms, Atom
def initialize_model(model, device, load_save_file=False):
if load_save_file:
model.load_state_dict(torch.load(load_save_file))
else:
for param in model.parameters():
if param.dim() == 1:
continue
nn.init.constant(param, 0)
else:
nn.init.xavier_normal_(param)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
model.to(device)
return model
def one_of_k_encoding(x, allowable_set):
if x not in allowable_set:
raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
return list(map(lambda s: x == s, allowable_set))
def one_of_k_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return list(map(lambda s: x == s, allowable_set))
def atom_feature(m, atom_i, i_donor, i_acceptor):
atom = m.GetAtomWithIdx(atom_i)
return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
['C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Br', 'B', 'H']) +
one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6]) +
one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) +
[atom.GetIsAromatic()]) # (10, 7, 5, 7, 1) --> total 30