Skip to content

[Solved] Error in function._extend when running ESM #260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
duongttr opened this issue Jun 3, 2025 · 0 comments
Open

[Solved] Error in function._extend when running ESM #260

duongttr opened this issue Jun 3, 2025 · 0 comments

Comments

@duongttr
Copy link

duongttr commented Jun 3, 2025

I don't know if anyone else has faced the same issue as I have, but I want to share this simple solution if anyone is stuck at this step.

The issued code:

from torchdrug import data, models
esm = models.ESM(path='./pretrained', model="ESM-2-650M")

sequence = "MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP"
protein = data.Protein.from_sequence(sequence)

output = esm(protein, None)
print(output)

And the error is:

Traceback (most recent call last):
  File "/data/ProtTranslator/test_esm.py", line 22, in <module>
    output = esm(protein, None)
  File "/home/cbbl2/.conda/envs/esm_gearnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/cbbl2/.conda/envs/esm_gearnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/cbbl2/.conda/envs/esm_gearnet/lib/python3.10/site-packages/torchdrug/models/esm.py", line 151, in forward
    input, size_ext = functional._extend(bos, torch.ones_like(size_ext), input, size_ext)
  File "/home/cbbl2/.conda/envs/esm_gearnet/lib/python3.10/site-packages/torchdrug/layers/functional/functional.py", line 153, in _extend
    new_data = torch.zeros(new_cum_size[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
IndexError: index -1 is out of bounds for dimension 0 with size 0

The issue comes from the incorrect types of num_atom, num_bond, and num_residue. If you look at the error, it's IndexError, which means they should be a tensor or an array. To solve this, you need to modify the type of them:

import torch

protein.num_atom = torch.tensor([protein.num_atom])
protein.num_residue = torch.tensor([protein.num_residue])
protein.num_bond = torch.tensor([protein.num_bond])

One more thing, when you work with a batch of proteins, it changes from num_atom, num_bond, and num_residue to num_atoms, num_bonds, and num_residues:

from torchdrug import data, models
esm = models.ESM(path='./pretrained', model="ESM-2-650M")

sequence = "MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP"
protein = data.Protein.from_sequence(sequence)

protein_pack = data.Protein.pack([protein, protein])
print(protein_pack) # PackedProtein(batch_size=2, num_atoms=[321, 321], num_bonds=[662, 662], num_residues=[37, 37])
output = esm(protein_pack, None)
print(output)

The solution is simple, just do it like this:

protein_pack.num_atom = protein_pack.num_atoms
protein_pack.num_residue = protein_pack.num_residues
protein_pack.num_bond = protein_pack.num_bonds

Hope this helps!

@duongttr duongttr changed the title [Solved] Error of in function._extend when running ESM [Solved] Error in function._extend when running ESM Jun 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant