Skip to content

Commit f108005

Browse files
committed
add some missing scripts
1 parent 8e4089b commit f108005

File tree

3 files changed

+133
-2
lines changed

3 files changed

+133
-2
lines changed

Diff for: calccenters.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env python3
2+
3+
'''Glob through files in current directory looking for */*_ligand.sdf and */*.gninatypes (assuming PDBbind layout).
4+
Calculate the distance between centers. If types files are passed, create versions with this information,
5+
optionally filtering.
6+
'''
7+
8+
import sys,glob,argparse,os
9+
import numpy as np
10+
import pybel
11+
import struct
12+
import openbabel
13+
14+
openbabel.obErrorLog.StopLogging()
15+
16+
parser = argparse.ArgumentParser()
17+
18+
parser.add_argument('typefiles',metavar='file',type=str, nargs='+',help='Types files to process')
19+
parser.add_argument('--filter',type=float,default=100.0,help='Filter out examples greater the specified value')
20+
parser.add_argument('--suffix',type=str,default='_wc',help='Suffix for new types files')
21+
args = parser.parse_args()
22+
23+
centerinfo = dict()
24+
#first process all gninatypes files in current directory tree
25+
for ligfile in glob.glob('*/*_ligand.sdf'):
26+
mol = next(pybel.readfile('sdf',ligfile))
27+
#calc center
28+
center = np.mean([a.coords for a in mol.atoms],axis=0)
29+
dir = ligfile.split('/')[0]
30+
for gtypes in glob.glob('%s/*.gninatypes'%dir):
31+
buf = open(gtypes,'rb').read()
32+
n = len(buf)/4
33+
vals = np.array(struct.unpack('f'*n,buf)).reshape(n/4,4)
34+
lcenter = np.mean(vals,axis=0)[0:3]
35+
dist = np.linalg.norm(center-lcenter)
36+
centerinfo[gtypes] = dist
37+
38+
for tfile in args.typefiles:
39+
fname,ext = os.path.splitext(tfile)
40+
outname = fname+args.suffix+ext
41+
out = open(outname,'w')
42+
for line in open(tfile):
43+
lfile = line.split('#')[0].split()[-1]
44+
if lfile not in centerinfo:
45+
print("Missing",lfile,tfile)
46+
sys.exit(0)
47+
else:
48+
d = centerinfo[lfile]
49+
if d < args.filter:
50+
out.write(line.rstrip()+" %f\n"%d)

Diff for: cgo_arrow.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
'''
2+
http://pymolwiki.org/index.php/cgo_arrow
3+
4+
(c) 2013 Thomas Holder, Schrodinger Inc.
5+
6+
License: BSD-2-Clause
7+
'''
8+
9+
from pymol import cmd, cgo, CmdException
10+
11+
12+
def cgo_arrow(atom1='pk1', atom2='pk2', radius=0.5, gap=0.0, hlength=-1, hradius=-1,
13+
color='blue red', name=''):
14+
'''
15+
DESCRIPTION
16+
17+
Create a CGO arrow between two picked atoms.
18+
19+
ARGUMENTS
20+
21+
atom1 = string: single atom selection or list of 3 floats {default: pk1}
22+
23+
atom2 = string: single atom selection or list of 3 floats {default: pk2}
24+
25+
radius = float: arrow radius {default: 0.5}
26+
27+
gap = float: gap between arrow tips and the two atoms {default: 0.0}
28+
29+
hlength = float: length of head
30+
31+
hradius = float: radius of head
32+
33+
color = string: one or two color names {default: blue red}
34+
35+
name = string: name of CGO object
36+
'''
37+
from chempy import cpv
38+
39+
radius, gap = float(radius), float(gap)
40+
hlength, hradius = float(hlength), float(hradius)
41+
42+
try:
43+
color1, color2 = color.split()
44+
except:
45+
color1 = color2 = color
46+
color1 = list(cmd.get_color_tuple(color1))
47+
color2 = list(cmd.get_color_tuple(color2))
48+
49+
def get_coord(v):
50+
if not isinstance(v, str):
51+
return v
52+
if v.startswith('['):
53+
return cmd.safe_list_eval(v)
54+
return cmd.get_atom_coords(v)
55+
56+
xyz1 = get_coord(atom1)
57+
xyz2 = get_coord(atom2)
58+
normal = cpv.normalize(cpv.sub(xyz1, xyz2))
59+
60+
if hlength < 0:
61+
hlength = radius * 3.0
62+
if hradius < 0:
63+
hradius = hlength * 0.6
64+
65+
if gap:
66+
diff = cpv.scale(normal, gap)
67+
xyz1 = cpv.sub(xyz1, diff)
68+
xyz2 = cpv.add(xyz2, diff)
69+
70+
xyz3 = cpv.add(cpv.scale(normal, hlength), xyz2)
71+
72+
obj = [cgo.CYLINDER] + xyz1 + xyz3 + [radius] + color1 + color2 + \
73+
[cgo.CONE] + xyz3 + xyz2 + [hradius, 0.0] + color2 + color2 + \
74+
[1.0, 0.0]
75+
76+
if not name:
77+
name = cmd.get_unused_name('arrow')
78+
79+
cmd.load_cgo(obj, name)
80+
81+
cmd.extend('cgo_arrow', cgo_arrow)

Diff for: train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ def update_from_result(name, test, result):
588588

589589
#check if we improved on the test set, if so write a snapshot
590590
bests['test_rmsd'], _ , to_snap = check_improvement(test_rmsd, args.update_ratio, bests['test_rmsd'], best_train_interval, i, False)
591-
if args.keep_best and to_snap:
591+
if args.keep_best and to_snap and not keepsnap: #don't write if already written
592592
keepsnap = True
593593
print("Writing snapshot because rmsd is better")
594594
solver.snapshot() #a bit too much - gigabytes of data
@@ -605,7 +605,7 @@ def update_from_result(name, test, result):
605605

606606
#checking if test rmsd_rmse has improved
607607
bests['test_rmsd_rmse'], _ , to_snap = check_improvement(test_rmsd_rmse, args.update_ratio, bests['test_rmsd_rmse'], best_train_interval, i, False)
608-
if args.keep_best and to_snap:
608+
if args.keep_best and to_snap and not keepsnap:
609609
keepsnap = True
610610
print("Writing snapshot because rmsd_rmse is better")
611611
solver.snapshot() #a bit too much - gigabytes of data

0 commit comments

Comments
 (0)