diff --git a/utils/inference_utils.py b/utils/inference_utils.py index dc5d09e71..ac11ec857 100644 --- a/utils/inference_utils.py +++ b/utils/inference_utils.py @@ -1,48 +1,39 @@ -import copy import os -import pickle - -import torch -from Bio.PDB import PDBParser from esm import FastaBatchedDataset, pretrained from rdkit.Chem import AddHs, MolFromSmiles from torch_geometric.data import Dataset, HeteroData +import numpy as np +import torch +import prody as pr import esm -from datasets.constants import three_to_one from datasets.process_mols import generate_conformer, read_molecule, get_lig_graph_with_matching, moad_extract_receptor_structure +from datasets.parse_chi import aa_idx2aa_short, get_onehot_sequence def get_sequences_from_pdbfile(file_path): - biopython_parser = PDBParser() - structure = biopython_parser.get_structure('random_id', file_path) - structure = structure[0] sequence = None - for i, chain in enumerate(structure): - seq = '' - for res_idx, residue in enumerate(chain): - if residue.get_resname() == 'HOH': - continue - residue_coords = [] - c_alpha, n, c = None, None, None - for atom in residue: - if atom.name == 'CA': - c_alpha = list(atom.get_vector()) - if atom.name == 'N': - n = list(atom.get_vector()) - if atom.name == 'C': - c = list(atom.get_vector()) - if c_alpha != None and n != None and c != None: # only append residue if it is an amino acid - try: - seq += three_to_one[residue.get_resname()] - except Exception as e: - seq += '-' - print("encountered unknown AA: ", residue.get_resname(), ' in the complex. Replacing it with a dash - .') + + pdb = pr.parsePDB(file_path) + seq = pdb.ca.getSequence() + one_hot = get_onehot_sequence(seq) + + chain_ids = np.zeros(len(one_hot)) + res_chain_ids = pdb.ca.getChids() + res_seg_ids = pdb.ca.getSegnames() + res_chain_ids = np.asarray([s + c for s, c in zip(res_seg_ids, res_chain_ids)]) + ids = np.unique(res_chain_ids) + + for i, id in enumerate(ids): + chain_ids[res_chain_ids == id] = i + + s_temp = np.argmax(one_hot[res_chain_ids == id], axis=1) + s = ''.join([aa_idx2aa_short[aa_idx] for aa_idx in s_temp]) if sequence is None: - sequence = seq + sequence = s else: - sequence += (":" + seq) + sequence += (":" + s) return sequence