-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrun_modelinfer_pytorch_jit_multimer.py
161 lines (145 loc) · 6.37 KB
/
run_modelinfer_pytorch_jit_multimer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from absl import logging
from absl import flags
from absl import app
import pathlib
from typing import Dict
import numpy as np
import pickle
import os
import torch
from time import time
from alphafold.common.residue_constants import atom_type_num
from alphafold_pytorch_jit.net import RunModel
from alphafold.common import protein
from alphafold.model import config
from runners.saver import load_feature_dict_if_exist
from runners.timmer import Timmers
import jax
import intel_extension_for_pytorch as ipex
bf16 = (os.environ.get('AF2_BF16') == '1')
print("bf16 variable: ", bf16)
logging.set_verbosity(logging.INFO)
flags.DEFINE_list(
'fasta_paths', None, 'Paths to FASTA files, suffix should be .fa. '
'each containing a prediction target that will be folded one after another. '
'If a FASTA file contains multiple sequences, then it will be folded as a multimer. '
'Paths should be separated by commas. All FASTA paths must have a unique basename as the '
'basename is used to name the output directories for each prediction.')
flags.DEFINE_string('output_dir', None, 'Path to a directory that will '
'read inputs and store results.')
flags.DEFINE_string('model_names', None, 'names of multimer model to use')
flags.DEFINE_string('root_params', None, 'root directory of model parameters') ### updated
flags.DEFINE_integer('random_seed', 123, 'The random seed for the data '
'pipeline. By default, this is randomly generated. Note '
'that even if this is set, Alphafold may still not be '
'deterministic, because processes like GPU inference are '
'nondeterministic.')
flags.DEFINE_integer('num_multimer_predictions_per_model', 1, 'How many '
'predictions (each with a different random seed) will be '
'generated per model. E.g. if this is 2 and there are 5 '
'models then there will be 10 predictions per input. '
'Note: this FLAG only applies in multimer mode')
FLAGS = flags.FLAGS
try:
from alphafold_pytorch_jit.basics import GatingAttention
from tpp_pytorch_extension.alphafold.Alpha_Attention import GatingAttentionOpti_forward
GatingAttention.forward = GatingAttentionOpti_forward
from alphafold_pytorch_jit.backbones_multimer import FusedTriangleMultiplication
from tpp_pytorch_extension.alphafold.Alpha_FusedTriangleMultiplication import FusedTriangleMultiplicationOpti_forward
FusedTriangleMultiplication.forward = FusedTriangleMultiplicationOpti_forward
is_tpp = True
print('Running with Intel Optimizations. TPP extension detected.')
except:
is_tpp = False
print('[warning] No TPP extension detected, will fallback to imperative mode')
def run_model_inference(
fasta_name:str,
output_dir_base:str,
model_runners: Dict[str, RunModel]
):
fasta_name = fasta_name.rstrip('.fa').rstrip('.fasta')
logging.info('run model prediction of {}'.format(fasta_name))
output_dir = os.path.join(output_dir_base, fasta_name)
assert os.path.exists(output_dir)
preproc_dir = os.path.join(output_dir, 'intermediates')
fp_features = os.path.join(preproc_dir, 'features.npz')
assert os.path.isdir(preproc_dir) and os.path.isfile(fp_features)
# load features
df_features = load_feature_dict_if_exist(fp_features)
# run model inference
durations = {}
ranking_confidences = {}
unrelaxed_proteins = {}
unrelaxed_pdbs = {}
plddts = {}
for model_idx, (model_name, model_runner) in enumerate(
model_runners.items()):
logging.info('use model {} on {}'.format(
model_name, fasta_name))
t0 = time()
with torch.inference_mode():
with torch.cpu.amp.autocast(enabled=bf16):
prediction_result = model_runner(df_features)
dt = time() - t0
durations['predict_and_compile_{}'.format(model_name)] = dt
logging.info('complete model {} inference with duration = {}'.format(
model_name, dt))
plddts[model_name] = prediction_result['plddt']
print("plddts score = ", np.mean(plddts[model_name]))
ranking_confidences[model_name] = prediction_result['ranking_confidence']
fp_output = os.path.join(output_dir, f'result_{model_name}.pkl')
with open(fp_output, 'wb') as h:
pickle.dump(prediction_result, h, protocol=4)
plddt_b_factors = np.repeat(
plddts[model_name][:, None], atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(
jax.tree_map(lambda x:x.detach().numpy(),df_features),
prediction_result,
plddt_b_factors,
remove_leading_feature_dimension=False)
unrelaxed_proteins[model_name] = unrelaxed_protein
unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein)
# sort by plddt ranks
ranked_order = [
model_name for model_name, confidence in
sorted(ranking_confidences.items(), key=lambda x:x[1], reverse=True)]
for idx, model_name in enumerate(ranked_order):
unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}_rank{idx}.pdb')
with open(unrelaxed_pdb_path, 'w') as h:
h.write(unrelaxed_pdbs[model_name]) # save unrelaxed pdb
def main(argv):
num_ensemble = 1
num_prediction_per_model = FLAGS.num_multimer_predictions_per_model
model_names = FLAGS.model_names # config.MODEL_PRESETS['multimer']
root_params = FLAGS.root_params
torch.manual_seed(FLAGS.random_seed)
if isinstance(model_names, str):
model_names = [model_names]
fasta_names = [pathlib.Path(p).stem for p in FLAGS.fasta_paths]
if len(fasta_names) != len(set(fasta_names)):
raise ValueError('All FASTA paths must have a unique basename.')
model_runners = {}
model_list = FLAGS.model_names.strip('[]').split(',')
print(model_list)
for model_name in model_list:
root_params = FLAGS.root_params + model_name
model_config = config.model_config(model_name)
model_config['model']['num_ensemble_eval'] = num_ensemble
fp_timmer = os.path.join(FLAGS.output_dir, f'timmers_{model_name}.txt')
h_timmer = Timmers(fp_timmer)
for i in range(num_prediction_per_model):
model_runners[f'{model_name}_pred_{i}'] = RunModel(
model_config, root_params, h_timmer, FLAGS.random_seed)
for fasta_name in fasta_names:
run_model_inference(
fasta_name,
FLAGS.output_dir,
model_runners)
if __name__ == '__main__':
flags.mark_flags_as_required([
'fasta_paths',
'output_dir',
'root_params',
'model_names'
])
app.run(main)