Skip to content

Commit 9fa5c8d

Browse files
author
Paul Francoeur
committed
Adding 3 scripts for pipeline to generate counterexamples
1 parent 00b233a commit 9fa5c8d

4 files changed

+14878
-0
lines changed

Diff for: counterexample_generation_jobs.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#!/usr/bin/env python3
2+
3+
'''
4+
This is a script which will generate a file of commands for gnina to use cnn_minimze to generate iterative training poses.
5+
6+
ASSUMPTIONS
7+
i) assumes all receptors are PDB files IE end in .pdb
8+
ii) Assumes all docked poses or outputs from gnina will be SDF files.
9+
iii) The crystal ligand filenames are formatted PDBid_LignameLIGSUFFIX
10+
iv) assumes file format is ROOT/POCKET/FILES
11+
v) Will generate a line for every identified crystal ligand with every identified receptor in POCKET -- i.e. crossdocking.
12+
vi) Assumes ligands will have the name of their corresponding crystal ligand file present in their filename. (This is especially important is using docked poses.)
13+
vii) Will generate REC_LIG_lig_it#_docked.sdf files as output. (If using docked poses as well, they will have their name will have extra _it#_ parts in it, the current it# will be the leftmost one)
14+
'''
15+
16+
17+
import os, argparse, glob, re
18+
19+
def get_receptors(root,rec_id):
20+
all_pdbs=glob.glob(root+'*.pdb')
21+
identifier=re.compile(rec_id)
22+
recs=[x for x in all_pdbs if re.match(identifier,x.split('/')[-1])]
23+
return recs
24+
25+
def get_ligands(root,lig_suffix):
26+
all_ligs=glob.glob(root+'*'+lig_suffix)
27+
return all_ligs
28+
29+
def generate_line(receptor,ligand,outname,crystal_ligand,seed,num_modes,builtin_cnn,supplied_cnn=None,supplied_weights=None):
30+
if bool(supplied_cnn) and bool(supplied_weights):
31+
return(f'gnina -r {receptor} -l {ligand} -o {outname} --autobox_ligand {crystal_ligand} --seed {seed} --gpu --minimize --cnn_scoring refinement --num_modes {num_modes} --cnn_model {supplied_cnn} --cnn_weights {supplied_weights}\n')
32+
else:
33+
return(f'gnina -r {receptor} -l {ligand} -o {outname} --autobox_ligand {crystal_ligand} --seed {seed} --gpu --minimize --cnn_scoring refinement --num_modes {num_modes} --cnn {builtin_cnn}\n')
34+
35+
#grabbing the arguments
36+
parser=argparse.ArgumentParser(description='Create cnn_minimize jobs for a dataset. Assumes dataset file structure is <ROOT>/<Identifier>/<FILES>')
37+
parser.add_argument('-o','--outfile',type=str,required=True,help='Name for gnina job commands output file.')
38+
parser.add_argument('-r','--root',default='./',help='ROOT for data directory structure. Defaults to current working directory.')
39+
parser.add_argument('-ri','--rec_id',default='...._._rec.pdb',help='Regular expression to identify the receptor PDB. Defaults to ...._._rec.pdb')
40+
parser.add_argument('-cs','--crystal_suffix',default='_lig.pdb',help='Expresssion to glob the crystal ligand PDB. Defaults to _lig.pdb. Assumes filename is PDBid_LignameLIGSUFFIX')
41+
parser.add_argument('-ds','--docked_suffix',default='_tt_docked.sdf',help='Expression to glob docked poses. These contain the poses that need to be minimized. Default is "_tt_docked.sdf"')
42+
parser.add_argument('-i','--iteration',type=int,required=True,help='Sets what iteration number we are doing. Adds _it#_docked.sdf to the output file for the gnina job line.')
43+
parser.add_argument('--num_modes',type=int,default=20,help='Sets the --num_modes argument for the gnina command. Defaults to 20.')
44+
parser.add_argument('--cnn',type=str, default='dense',help='Sets the --cnn command for the gnina command. Defaults to dense. Must be dense, general_default2018, or crossdock_default2018.')
45+
parser.add_argument('--cnn_model',type=str,default=None,help='Override --cnn with a user provided caffe model file. If used, requires the user to pass in a weights file as well.')
46+
parser.add_argument('--cnn_weights',type=str,default=None,help='The weights file to use with the supplied caffemodel file.')
47+
parser.add_argument('--seed',default=42,type=int,help='Seed for the gnina commands. Defaults to 42')
48+
parser.add_argument('--dirs',type=str,default=None,help='Supplied directories to do a subset of the dataset. Default behavior is to do every directory.')
49+
args=parser.parse_args()
50+
51+
#double checking that the arguments are compatible
52+
if args.cnn_model:
53+
assert bool(args.cnn_weights),"Didn't set cnn_weights to go with cnn_model"
54+
else:
55+
assert args.cnn in set(['dense','general_default2018','crossdock_default2018']),"Must have built-in cnn be dense, general_default2018, or crossdock_default2018"
56+
assert args.num_modes>1,"Need to set num_modes to a positive integer."
57+
assert args.seed>0,"Need a positive seed."
58+
assert args.iteration>0,"Need an iteration number >=1."
59+
60+
61+
#now we begin.
62+
#Step 1 -- assemble all of the directories that we will be using.
63+
dataroot=sys.path.join(args.root,'')
64+
todo=glob.glob(dataroot+'*/')
65+
66+
if args.dirs:
67+
subdirs=open(args.dirs).readlines()
68+
subdirs=[x.rstrip() for x in subdirs]
69+
subdirs=set(subdirs)
70+
todo=[x for x in todo if x.split('/')[-2] in subdirs]
71+
72+
#Step 2 -- main loop of the script
73+
#set the iteration plugin variable
74+
itname='_it'+str(args.iteration)
75+
76+
# We loop over the pockets
77+
#TODO -- change to only do the docked poses
78+
with open(args.outfile,'w') as outfile:
79+
for pocket_root in todo:
80+
#grab the receptors
81+
recs=get_receptors(pocket_root,args.rec_id)
82+
83+
#grab all of the crystal ligands
84+
cr_ligs=get_ligands(pocket_root,args.crystal_suffix)
85+
86+
#Grab all of the docked poses
87+
ligs=get_ligands(pocket_root,args.docked_suffix)
88+
for r in recs:
89+
for cl in cr_ligs:
90+
#determine which ligands will work -- IE which ligands have the crystal ligand indentifier in their name, and which ligands have the receptor in their name.
91+
lig_todo=[l for l in ligs if cl.split('/')[-1].split(args.crystal_suffix)[0] in l]
92+
lig_todo=[l for l in lig_todo if r.split('/')[-1].split('.pdb')[0] in l]
93+
for ligname in lig_todo:
94+
#generate the output filename
95+
#if args.docked_suffix and args.docked_suffix in ligname:
96+
outname=ligname.replace(args.docked_suffix,itname+args.docked_suffix)
97+
#else:
98+
# rec_part=r.split('.pdb')[0]+'_'
99+
# lig_part=ligname.split('/')[-1].split(args.crystal_suffix)[0]
100+
# outname=rec_part+lig_part+'_lig_'+itname+'docked.sdf'
101+
102+
outfile.write(generate_line(receptor=r,ligand=ligname,outname=outname,crystal_ligand=cl,seed=args.seed,num_modes=args.num_modes,builtin_cnn=args.cnn,supplied_cnn=args.cnn_model,supplied_weights=args.cnn_weights))
103+

Diff for: generate_counterexample_typeslines.py

+252
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
#!/usr/bin/env python3
2+
3+
'''
4+
This script will generate the lines for a new types file with the iterative poses generated from counterexample_generation_jobs.py
5+
6+
!!WARNING!!
7+
Part of this process is to determine which newly generated poses are NOT REDUNDANT with the previously generated ones.
8+
This requires an O(n^2) calculation to calculate the RMSD between every pose...
9+
Ergo, this calculation depending on the number of poses in a given pocket could take a very long time.
10+
This script also works on all ligands present in the pocket, so there is the potential for multiple O(n^2) calculations to take place.
11+
12+
We have done our best to avoid needless calculations, but this is why we generate the lines for each pocket independently
13+
14+
ASSUMPTIONS:
15+
i) Poses with <2 RMSD to the crystal pose will be labeled as positive poses
16+
ii) you have obrms installed, and can run it from your commandline
17+
iii) the jobfile provided as input contains the full PATH to the files specified.
18+
iv) the gninatypes files (generated by gninatyper) for the poses in args.input have ALREADY BEEN generated.
19+
v) The crystal ligand files are formatted PDBid_LignameLIGSUFFIX
20+
vi) The OLD sdf file with the unique poses is named LignameOLDUNIQUESUFFIX
21+
22+
INPUT:
23+
i) The path to the pocket you are working on
24+
ii) the threshold RMSD to determine if they are the same pose
25+
iii) the name for the txt file that contains the lines to write (will be written in the POCKET DIRECTORY)
26+
iv) the suffix of the NEW sdf file that contains all of the unique poses
27+
v) the commands file generated from counterexample_generation_jobs.py
28+
vi) --OPTIONAL-- the suffix of the OLD sdf file that contains all of the unique poses
29+
30+
OUTPUT:
31+
==Normal==
32+
i) the typesfile lines to add to generate the new types file
33+
ii) A SDF file containing all of the unique poses for a given ligand -- named LignameUNIQUE_SUFFIX
34+
iii) a ___.sdf file which will be the working file for obrms.
35+
'''
36+
37+
import argparse, re, subprocess, os, sys
38+
import pandas as pd
39+
from rdkit.Chem import AllChem as Chem
40+
41+
def check_exists(filename):
42+
if os.path.isfile(filename) and os.path.getsize(filename)>0:
43+
return True
44+
else:
45+
return False
46+
47+
def get_pocket_lines(filename,pocket):
48+
'''
49+
This function reads the lines from filename, and returns only the lines which contain pocket in them.
50+
'''
51+
all_lines=open(filename).readlines()
52+
lines=[x for x in all_lines if pocket in x]
53+
return lines
54+
55+
def calc_ligand_dic(lines,ligand_suffix):
56+
'''
57+
This function will parse the input list of lines and construct 2 dictionaries
58+
1) ligand name -> [docked files with that ligand]
59+
2) docked_filename -> crystal_file for that pose
60+
'''
61+
data={}
62+
docked_lookup={}
63+
for line in lines:
64+
#1) Getting the crystal ligand file
65+
ligfile=re.split('--autobox_ligand ',line)[1].split()[0]
66+
67+
#2) Getting the name of the ligand ** here we assume the ligfile is PATH/<PDBid>_<ligname><LIGSUFFIX>
68+
ligname=ligfile.split('/')[-1].split(ligand_suffix)[0].split('_')[1]
69+
70+
#3) Check if ligname in data
71+
if ligname not in data:
72+
data[ligname]=[]
73+
74+
#4) grabbing the docked files
75+
outfile=re.split('-o ',line)[1].split()[0]
76+
77+
#5) Adding these files to their corresponding places in the dictionary
78+
data[ligname].append(outfile)
79+
docked_lookup[outfile]=ligfile
80+
81+
return data, docked_lookup
82+
83+
def run_obrms(ligand_file,crystal_file):
84+
'''
85+
This function returns a list of rmsds of the docked ligand file to the crystal file. The list is in the order of the poses.
86+
'''
87+
88+
rmsds=subprocess.check_output(f'obrms {ligand_file} {crystal_file}',shell=True)
89+
rmsds=str(rmsds,'utf-8').rstrip().split('\n')
90+
rmsds=[float(x.split()[-1]) for x in rmsds]
91+
return rmsds
92+
93+
def get_lines_towrite(crystal_lookup,list_of_docked,affinity_lookup,crystal_suffix):
94+
'''
95+
This function will calculate the RMSD of every input pose, to the provided crystal pose.
96+
97+
returns a dictionary of lines --> 'docked pose filename':[lines to write]
98+
'''
99+
lines={}
100+
101+
for docked in list_of_docked:
102+
#Figure out affinity.
103+
affinity=0.0
104+
crystal=crystal_lookup[docked]
105+
cr_lookup=crystal.split(crystal_suffix)[0]
106+
if cr_lookup in affinity_lookup:
107+
affinity=affinity_lookup
108+
print(docked,crystal)
109+
rmsds=run_obrms(docked,crystal)
110+
counter=0
111+
lines[docked]=[]
112+
for r in rmsds:
113+
if r < 2:
114+
label='1'
115+
neg_aff=''
116+
else:
117+
label='0'
118+
neg_aff='-'
119+
120+
rec_gninatypes=docked.split('rec')[0]+'rec_0.gninatypes'
121+
lig_gninatypes=docked.replace('.sdf','_'+str(counter)+'.gninatypes')
122+
lines[docked].append(f'{label} {neg_aff}{affinity} {r} {rec_gninatypes} {lig_gninatypes}\n')
123+
counter+=1
124+
return lines
125+
126+
def run_obrms_cross(filename):
127+
'''
128+
This function returns a pandas dataframe of the RMSD between every pose and every other pose, which is generated using obrms -x
129+
'''
130+
131+
csv=subprocess.check_output('obrms -x '+filename,shell=True)
132+
csv=str(csv,'utf-8').rstrip().split('\n')
133+
data=pd.DataFrame([x.split(',')[1:] for x in csv],dtype=float)
134+
return data
135+
136+
137+
parser=argparse.ArgumentParser(description='Create lines to add to types files from counterexample generation. Assumes data file structure is ROOT/POCKET/FILES.')
138+
parser.add_argument('-p','--pocket',type=str,required=True,help='Name of the pocket that you will be generating the lines for.')
139+
parser.add_argument('-r','--root',type=str,required=True,help='PATH to the ROOT of the pockets.')
140+
parser.add_argument('-i','--input',type=str,required=True,help='File that is output from counterexample_generation_jobs.py')
141+
parser.add_argument('-cs','--crystal_suffix',default='_lig.pdb',help='Expresssion to glob the crystal ligand PDB. Defaults to _lig.pdb. Needs to match what was used with counterexample_generation_jobs.py')
142+
parser.add_argument('--old_unique_suffix',type=str,default=None,help='Suffix for the unique ligand sdf file from a previous run. If set we will load that in and add to it. Default behavior is to generate it from provided input file.')
143+
parser.add_argument('-us','--unique_suffix',type=str,default='_it1___.sdf',help='Suffix for the unique ligand sdf file for this run. Defaults to _it1___.sdf. One will be created for each ligand in the pocket.')
144+
parser.add_argument('--unique_threshold',default=0.25,help='RMSD threshold for unique poses. IE poses with RMSD > thresh are considered unique. Defaults to 0.25.')
145+
parser.add_argument('--lower_confusing_threshold',default=0.5,help='CNNscore threshold for identifying confusing good poses. Score < thresh & under 2RMSD is kept and labelled 1. 0<thresh<1. Default 0.5')
146+
parser.add_argument('--upper_confusing_threshold',default=0.9,help='CNNscore threshold for identifying confusing poor poses. If CNNscore > thresh & over 2RMSD pose is kept and labelled 0. lower<thresh<1. Default 0.9')
147+
parser.add_argument('-o','--outname',type=str,required=True,help='Name of the text file to write the new lines in. DO NOT WRITE THE FULL PATH!')
148+
parser.add_argument('-a','--affinity_lookup',default='pdbbind2017_affs.txt',help='File mapping the PDBid and ligname of the ligand to its pK value. Assmes space delimited "PDBid ligname pK". Defaults to pdbbind2017_affs.txt')
149+
args=parser.parse_args()
150+
151+
#Setting the myroot and root remove variable for use in the script
152+
myroot=os.path.join(args.root,args.pocket,'')
153+
root_remove=os.path.join(args.root,'')
154+
155+
156+
#sanity check threshold
157+
assert args.unique_threshold > 0, "Unique RMSD threshold needs to be positive"
158+
assert 0<args.lower_confusing_threshold <1, "Lower_confusing_threshold needs to be in (0,1)"
159+
assert args.lower_confusing_threshold<args.upper_confusing_threshold<1, "Upper_confusing_threshold needs to be in (lower_confusing_threshold,1)"
160+
161+
#generating our affinity lookup dictionary
162+
affinity_lookup={}
163+
with open(args.affinity_lookup) as infile:
164+
for line in infile:
165+
items=line.split()
166+
key=items[0]+'_'+items[1]
167+
val=items[2]
168+
affinity_lookup[key]=val
169+
170+
#first we will generate the dictionary for the ligand - poses we will use.
171+
tocheck=get_pocket_lines(args.input, args.pocket)
172+
datadic, docked_to_crystal_lookup=calc_ligand_dic(tocheck,args.crystal_suffix)
173+
174+
#main loop of the script
175+
with open(myroot+args.outname,'w') as outfile:
176+
#loop over the ligands
177+
for cr_name, list_o_ligs in datadic.items():
178+
if cr_name!='iqz':
179+
continue
180+
#0) Make sure that the working sdf is free.
181+
sdf_name=myroot+'___.sdf'
182+
sdf_tmp=myroot+'___tmp.sdf'
183+
#if this "___sdf" file already exists, we need to delete it and make a new one.
184+
if check_exists(sdf_name):
185+
os.remove(sdf_name)
186+
187+
#1) Figure out ALL of the lines to write
188+
line_dic=get_lines_towrite(crystal_lookup=docked_to_crystal_lookup,list_of_docked=list_o_ligs,affinity_lookup=affinity_lookup,crystal_suffix=args.crystal_suffix)
189+
190+
#2) Set up the 'working sdf' for the obrms -x calculations, consisting of the confusing examples + any possible previously generated examples
191+
# i) iterate over the possible lines for this ligand, keep only the confusing ones,
192+
# and write the confusing poses into the working sdf file.
193+
194+
w=Chem.SDWriter(sdf_name)
195+
keys=list(line_dic.keys())
196+
for key in keys:
197+
kept_lines=[]
198+
supply=Chem.SDMolSupplier(key,sanitize=False)
199+
for i,mol in enumerate(supply):
200+
curr_line=line_dic[key][i]
201+
score=mol.GetProp('CNNscore')
202+
label=curr_line.split()[0]
203+
#if scored "well", but was a bad pose
204+
if float(score) > args.upper_confusing_threshold and label=='0':
205+
kept_lines.append(curr_line)
206+
w.write(mol)
207+
#or if scored "poor", but was a good pose
208+
elif float(score) < args.lower_confusing_threshold and label=='1':
209+
kept_lines.append(curr_line)
210+
w.write(mol)
211+
#after the lines have been checked, we overwrite and only store the lines we kept.
212+
line_dic[key]=kept_lines
213+
w=None
214+
215+
# ii) Prepend ___.sdf with the previously existing unique poses sdf
216+
offset=0
217+
218+
if args.old_unique_suffix:
219+
print('Prepending existing similarity sdf to working sdf file')
220+
old_sdfname=myroot+cr_name+args.old_unique_suffix
221+
supply=Chem.SDMolSupplier(old_sdfname,sanitize=False)
222+
offset=len(supply)
223+
subprocess.check_call('mv %s %s'%(sdf_name,sdf_tmp),shell=True)
224+
subprocess.check_call('cat %s %s > %s'%(old_sdfname,sdf_tmp,sdf_name),shell=True)
225+
226+
#3) run obrms -x working_sdf to calculate the rmsd between each pose. This is the O(n^2) calculation
227+
unique_data=run_obrms_cross(sdf_name)
228+
229+
#4) determine the newly found "unique" poses
230+
assignments={}
231+
for (r,row) in unique_data.iterrows():
232+
if r not in assignments:
233+
for simi in row[row<args.unique_threshold].index:
234+
if simi not in assignments:
235+
assignments[simi]=r
236+
237+
to_remove=set([k for (k,v) in assignments.items() if k!=v])
238+
#5) write the remaining lines for the newly found "unique" poses.
239+
counter=offset
240+
for key in keys:
241+
for line in line_dic[key]:
242+
if counter not in to_remove:
243+
outfile.write(line.replace(root_remove,''))
244+
counter+=1
245+
246+
#6) Write out the new "uniques" sdf file to allow for easier future generation
247+
new_unique_sdfname=myroot+cr_name+args.unique_suffix
248+
w=Chem.SDWriter(new_unique_sdfname)
249+
supply=Chem.SDMolSupplier(sdf_name,sanitize=False)
250+
for i,mol in enumerate(supply):
251+
if i not in to_remove:
252+
w.write(mol)

0 commit comments

Comments
 (0)