Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA op getrows fails for long sequences #11189

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

milot-mirdita
Copy link

I have integrated the ProstT5 protein language into Foldseek. Thanks a lot for the great library! I am upstreaming a few fixes for issues I found in ggml during the integration. I hope that it's okay to push the changes here and that they get synced at some point to the main ggml repo.

The T5 encoder has a square input pos tensor (llm_build_pos_bucket(cause = false)) which quickly exceeds the 65k limit (on most GPUs?) of the CUDA GET_ROWS op.

lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);

I have implemented this only for the _float op and I don't feel very confident in CUDA programming. I tested this one specifically for my use-case against the reference implementation of my model, but don't have models ready to test for quantized versions.

T5 embeddings have a square input pos tensor which quickly exceeds the 65k limit of getrows

Implemented only for _float, need other implementations
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jan 11, 2025
Comment on lines +122 to +123
static const int64_t MAX_GRID_Y = 65535;
for (int64_t startY = 0; startY < ne10; startY += MAX_GRID_Y) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is incorrect. The grid y dimension uses 16 bits and ranges from 0 to 65535 (inclusive). So the correct stride would be 65536. With this code two threads per grid write to the same address (though this should result in identical results). The correct way to fix this would be to modify the CUDA kernel and have it iterate with a stride of 65536 over the y dimension. This will also avoid issues with the number of nodes in a CUDA graph varying depending on input parameters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants