Skip to content

[BUG] EVO2 7b_arc_longcontext can't use context parallelism (2 GPUs) for long sequence prediction: shape mismatch #1268

@Oblynx

Description

@Oblynx

BioNeMo Framework Version

v2.7

Bug Description

I tried predicting with Evo2 on a fasta file with long sequences. The sequences go up to 500k tokens.
I first tried using 1 H200, which went OOM with batch size = 1. Then, I tried setting context parallelism to 2 and to use 2 H200.
This resulted in shape mismatch, which suggests that the context parallel split is wrong:

summary: `[rank0]: RuntimeError: shape '[1, 4, 20906]' is invalid for input of size 83626`

Note that 20906*4 = 83624. Does this mean that each sequence has to be trimmed to be divisible by 4?
Why?

Steps to Reproduce

predict_evo2 --fp8 --fasta <GENES.FASTA> --ckpt-dir evo2/nemo2_evo2_7b_1m --output-dir results --model-size 7b_arc_longcontext --tensor-parallel-size 1 --pipeline-model-parallel-size 1 --context-parallel-size 2 --output-log-prob-seqs

Error Messages and Logs

summary: `[rank0]: RuntimeError: shape '[1, 4, 20906]' is invalid for input of size 83626`

Details:


[NeMo I 2025-10-12 23:00:32 nemo_logging:393] Using byte-level tokenization
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] Rank 0 has data parallel group : [0]
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] Rank 0 has combined group of data parallel and context parallel : [0, 1]
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] All data parallel group ranks with context parallel combined: [[0, 1]]
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] Ranks 0 has data parallel rank: 0
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] Rank 0 has context parallel group: [0, 1]
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] All context parallel group ranks: [[0, 1]]
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] Ranks 0 has context parallel rank: 0
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] Rank 0 has model parallel group: [0]
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] All model parallel group ranks: [[0], [1]]
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] Rank 0 has tensor model parallel group: [0]
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] All tensor model parallel group ranks: [[0], [1]]
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] Rank 0 has tensor model parallel rank: 0
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] Rank 0 has pipeline model parallel group: [0]
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] Rank 0 has embedding group: [0]
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] All pipeline model parallel group ranks: [[0], [1]]
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] Rank 0 has pipeline model parallel rank 0
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] All embedding group ranks: [[0], [1]]
[NeMo I 2025-10-12 23:00:32 nemo_logging:393] Rank 0 has embedding rank: 0
Import of quick_gelu from megatron.core.fusions.fused_bias_geglu failed with: Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/nemo/utils/import_utils.py", line 319, in safe_import_from
    return getattr(imported_module, symbol), True
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'megatron.core.fusions.fused_bias_geglu' has no attribute 'quick_gelu'

INFO:nemo.utils.import_utils:Import of quick_gelu from megatron.core.fusions.fused_bias_geglu failed with: Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/nemo/utils/import_utils.py", line 319, in safe_import_from
    return getattr(imported_module, symbol), True
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'megatron.core.fusions.fused_bias_geglu' has no attribute 'quick_gelu'

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
[NeMo W 2025-10-12 23:00:32 nemo_logging:405] No version folders would be created under the log folder as 'resume_if_exists' is enabled.
[NeMo W 2025-10-12 23:00:32 nemo_logging:405] "update_logger_directory" is True. Overwriting tensorboard logger "save_dir" to /tmp/tmpoh8x6z5y
[W1012 23:00:32.300362536 socket.cpp:755] [c10d] The client socket cannot be initialized to connect to [localhost.localdomain]:53394 (errno: 97 - Address family not supported by protocol).
Import of quick_gelu from megatron.core.fusions.fused_bias_geglu failed with: Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/nemo/utils/import_utils.py", line 319, in safe_import_from
    return getattr(imported_module, symbol), True
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'megatron.core.fusions.fused_bias_geglu' has no attribute 'quick_gelu'

INFO:nemo.utils.import_utils:Import of quick_gelu from megatron.core.fusions.fused_bias_geglu failed with: Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/nemo/utils/import_utils.py", line 319, in safe_import_from
    return getattr(imported_module, symbol), True
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'megatron.core.fusions.fused_bias_geglu' has no attribute 'quick_gelu'

[NeMo W 2025-10-12 23:00:44 nemo_logging:405] No version folders would be created under the log folder as 'resume_if_exists' is enabled.
[NeMo W 2025-10-12 23:00:44 nemo_logging:405] "update_logger_directory" is True. Overwriting tensorboard logger "save_dir" to /tmp/tmpb6c0k_2g
[W1012 23:00:45.323335656 socket.cpp:755] [c10d] The client socket cannot be initialized to connect to [localhost.localdomain]:53394 (errno: 97 - Address family not supported by protocol).
INFO:pytorch_lightning.utilities.rank_zero:----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

[NeMo W 2025-10-12 23:00:46 nemo_logging:405] Multi-GPU predictions could result in shuffled inputs. Verify that the original indices are included in the model's predictions as outputs are not ordered and batch indices do not track input order.
[NeMo W 2025-10-12 23:00:46 nemo_logging:405] Multi-GPU predictions could result in shuffled inputs. Verify that the original indices are included in the model's predictions as outputs are not ordered and batch indices do not track input order.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
[NeMo W 2025-10-12 23:00:46 nemo_logging:405] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored.
[NeMo W 2025-10-12 23:00:46 nemo_logging:405] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at evo2/nemo2_evo2_7b_1m/weights
[NeMo W 2025-10-12 23:00:47 serialization:184] DEPRECATED: Passing 'checkpoint_dir' as a Path object in load_common_state_dict will no longer be supported in a future release. Please pass it as a string instead.
[NeMo W 2025-10-12 23:00:47 serialization:184] DEPRECATED: Passing 'checkpoint_dir' as a Path object in load_common_state_dict will no longer be supported in a future release. Please pass it as a string instead.
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at evo2/nemo2_evo2_7b_1m/weights
[NeMo I 2025-10-12 23:00:44 nemo_logging:393] Experiments will be logged at /tmp/tmpb6c0k_2g/default
[NeMo I 2025-10-12 23:00:44 nemo_logging:393] Using byte-level tokenization
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] Rank 1 has data parallel group : [1]
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] Rank 1 has combined group of data parallel and context parallel : [0, 1]
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] All data parallel group ranks with context parallel combined: [[0, 1]]
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] Ranks 1 has data parallel rank: 0
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] Rank 1 has context parallel group: [0, 1]
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] All context parallel group ranks: [[0, 1]]
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] Ranks 1 has context parallel rank: 1
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] Rank 1 has model parallel group: [1]
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] All model parallel group ranks: [[0], [1]]
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] Rank 1 has tensor model parallel group: [1]
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] All tensor model parallel group ranks: [[0], [1]]
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] Rank 1 has tensor model parallel rank: 0
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] Rank 1 has pipeline model parallel group: [1]
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] Rank 1 has embedding group: [1]
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] All pipeline model parallel group ranks: [[0], [1]]
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] Rank 1 has pipeline model parallel rank 0
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] All embedding group ranks: [[0], [1]]
[NeMo I 2025-10-12 23:00:45 nemo_logging:393] Rank 1 has embedding rank: 0
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[NeMo I 2025-10-12 23:00:45 num_microbatches_calculator:228] setting number of microbatches to constant 1
[NeMo I 2025-10-12 23:00:46 nemo_logging:393] Padded vocab_size: 512, original vocab_size: 512, dummy tokens: 0.
[NeMo I 2025-10-12 23:00:46 nemo_logging:393] Padded vocab_size: 512, original vocab_size: 512, dummy tokens: 0.
[NeMo I 2025-10-12 23:00:46 nemo_logging:393]  > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 6582312704
[NeMo I 2025-10-12 23:00:46 nemo_logging:393]  > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 6582312704
[NeMo I 2025-10-12 23:00:46 utils:661] Setting up DistributedDataParallel with config DistributedDataParallelConfig(grad_reduce_in_fp32=False, overlap_grad_reduce=False, overlap_param_gather=False, align_param_gather=False, use_distributed_optimizer=False, num_distributed_optimizer_instances=1, check_for_nan_in_grad=True, check_for_large_grads=False, bucket_size=None, pad_buckets_for_high_nccl_busbw=False, average_in_collective=False, fp8_param_gather=False, reuse_grad_buf_for_mxfp8_param_ag=False, use_custom_fsdp=False, data_parallel_sharding_strategy='no_shard', gradient_reduce_div_fusion=True, suggested_communication_unit_size=None, preserve_fp32_weights=True, keep_fp8_transpose_cache_when_using_custom_fsdp=False, nccl_ub=False, fsdp_double_buffer=False)
[NeMo I 2025-10-12 23:00:46 utils:682] Number of buckets for gradient all-reduce / reduce-scatter: 1
    Params for bucket 1 (6582312704 elements, 6582312704 padded size):
        module.decoder.layers.16.mlp.linear_fc2.weight
        module.decoder.layers.28.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.13.mixer.mixer.filter.R
        module.decoder.layers.24.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.10.self_attention.linear_proj.weight
        module.decoder.layers.4.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.0.mlp.linear_fc1.weight
        module.decoder.layers.24.self_attention.linear_proj.bias
        module.decoder.layers.26.mixer.dense.weight
        module.decoder.layers.20.mlp.linear_fc1.weight
        module.decoder.layers.3.self_attention.linear_proj.bias
        module.decoder.layers.22.mlp.linear_fc2.weight
        module.decoder.layers.2.mixer.mixer.filter.gamma
        module.decoder.layers.28.mixer.dense_projection.weight
        module.decoder.layers.13.mixer.mixer.filter.p
        module.decoder.layers.25.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.19.mixer.dense_projection.weight
        module.decoder.layers.4.mixer.dense_projection.weight
        module.decoder.layers.30.mixer.dense_projection.weight
        module.decoder.layers.16.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.26.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.20.mlp.linear_fc2.weight
        module.decoder.layers.11.mlp.linear_fc1.weight
        module.decoder.layers.10.self_attention.linear_proj.bias
        module.decoder.layers.6.mixer.dense_projection.weight
        module.decoder.layers.2.mixer.mixer.filter.R
        module.decoder.layers.23.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.8.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.28.mixer.mixer.short_conv.short_conv_weight
        module.decoder.layers.14.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.24.mlp.linear_fc1.weight
        module.decoder.layers.9.mixer.dense.weight
        module.decoder.layers.4.mixer.mixer.short_conv.short_conv_weight
        module.decoder.layers.30.mixer.mixer.conv_bias
        module.decoder.layers.21.mixer.dense.bias
        module.decoder.layers.11.mlp.linear_fc2.weight
        module.decoder.layers.6.mixer.mixer.conv_bias
        module.decoder.layers.19.mixer.mixer.conv_bias
        module.decoder.layers.8.mixer.mixer.conv_bias
        module.decoder.layers.2.mixer.mixer.filter.p
        module.decoder.layers.28.mixer.dense.weight
        module.decoder.layers.4.mixer.dense.weight
        module.decoder.layers.24.mlp.linear_fc2.weight
        module.decoder.layers.20.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.9.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.16.mixer.dense.bias
        module.decoder.layers.0.mlp.linear_fc2.weight
        module.decoder.layers.26.mlp.linear_fc1.weight
        module.decoder.layers.12.mixer.dense.bias
        module.decoder.layers.23.mixer.dense.bias
        module.decoder.layers.8.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.5.mixer.mixer.conv_bias
        module.decoder.layers.29.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.13.mixer.dense.weight
        module.decoder.layers.10.self_attention.linear_qkv.layer_norm_weight
        module.decoder.layers.5.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.30.mixer.mixer.filter.gamma
        module.decoder.layers.15.mixer.dense.weight
        module.decoder.layers.1.mixer.dense.bias
        module.decoder.layers.26.mlp.linear_fc2.weight
        module.decoder.layers.21.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.6.mixer.mixer.filter.gamma
        module.decoder.layers.17.self_attention.linear_qkv.weight
        module.decoder.layers.29.mixer.mixer.filter.h
        module.decoder.layers.13.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.20.mixer.dense.bias
        module.decoder.layers.9.mlp.linear_fc1.weight
        module.decoder.layers.5.mixer.mixer.filter.h
        module.decoder.layers.30.mixer.mixer.filter.R
        module.decoder.layers.15.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.21.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.12.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.6.mixer.mixer.filter.R
        module.decoder.layers.2.mixer.dense.weight
        module.decoder.layers.23.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.17.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.8.mixer.dense_projection.weight
        module.decoder.layers.25.mixer.dense.bias
        module.decoder.layers.19.mixer.dense.weight
        module.decoder.layers.9.mlp.linear_fc2.weight
        module.decoder.layers.30.mixer.mixer.filter.p
        module.decoder.layers.1.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.21.mixer.dense_projection.weight
        module.decoder.layers.6.mixer.mixer.filter.p
        module.decoder.final_norm.weight
        module.decoder.layers.18.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.9.mixer.dense.bias
        module.decoder.layers.2.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.28.mlp.linear_fc1.weight
        module.decoder.layers.13.mlp.linear_fc1.weight
        module.decoder.layers.25.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.19.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.4.mlp.linear_fc1.weight
        module.decoder.layers.31.self_attention.linear_proj.weight
        module.decoder.layers.15.mlp.linear_fc1.weight
        module.decoder.layers.27.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.21.mixer.mixer.short_conv.short_conv_weight
        module.decoder.layers.12.mixer.dense_projection.weight
        module.decoder.layers.7.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.23.mixer.dense_projection.weight
        module.decoder.layers.17.mlp.linear_fc1.weight
        module.decoder.layers.9.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.28.mlp.linear_fc2.weight
        module.decoder.layers.13.mlp.linear_fc2.weight
        module.decoder.layers.31.self_attention.linear_proj.bias
        module.decoder.layers.25.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.4.mlp.linear_fc2.weight
        module.decoder.layers.1.mixer.dense_projection.weight
        module.decoder.layers.15.mlp.linear_fc2.weight
        module.decoder.layers.29.mixer.mixer.conv_bias
        module.decoder.layers.21.mixer.dense.weight
        module.decoder.layers.23.mixer.mixer.conv_bias
        module.decoder.layers.17.mlp.linear_fc2.weight
        module.decoder.layers.8.mixer.dense.bias
        module.decoder.layers.2.mlp.linear_fc1.weight
        module.decoder.layers.29.mixer.dense.bias
        module.decoder.layers.14.mixer.dense.bias
        module.decoder.layers.25.mixer.dense_projection.weight
        module.decoder.layers.19.mlp.linear_fc1.weight
        module.decoder.layers.12.mixer.mixer.conv_bias
        module.decoder.layers.5.mixer.dense.bias
        module.decoder.layers.30.mixer.dense.weight
        module.decoder.layers.2.mixer.dense.bias
        module.decoder.layers.27.mixer.dense_projection.weight
        module.decoder.layers.22.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.13.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.6.mixer.dense.weight
        module.decoder.layers.2.mlp.linear_fc2.weight
        module.decoder.layers.25.mixer.mixer.short_conv.short_conv_weight
        module.decoder.layers.19.mlp.linear_fc2.weight
        module.decoder.layers.30.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.27.mixer.mixer.conv_bias
        module.decoder.layers.22.mixer.mixer.filter.h
        module.decoder.layers.6.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.23.mixer.mixer.filter.gamma
        module.decoder.layers.8.mixer.dense.weight
        module.decoder.layers.3.self_attention.linear_proj.weight
        module.decoder.layers.29.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.14.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.25.mixer.dense.weight
        module.decoder.layers.10.self_attention.linear_qkv.weight
        module.decoder.layers.5.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.31.self_attention.linear_qkv.layer_norm_weight
        module.decoder.layers.16.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.13.mixer.dense.bias
        module.decoder.layers.1.mlp.linear_fc1.weight
        module.decoder.layers.23.mixer.mixer.filter.R
        module.decoder.layers.18.mixer.dense.bias
        module.decoder.layers.8.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.14.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.26.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.10.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.30.mlp.linear_fc1.weight
        module.decoder.layers.27.mixer.mixer.filter.gamma
        module.decoder.layers.21.mlp.linear_fc1.weight
        module.decoder.layers.12.mixer.dense.weight
        module.decoder.layers.6.mlp.linear_fc1.weight
        module.decoder.layers.23.mixer.mixer.filter.p
        module.decoder.layers.18.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.3.self_attention.linear_qkv.layer_norm_weight
        module.decoder.layers.29.mixer.dense_projection.weight
        module.decoder.layers.14.mixer.dense_projection.weight
        module.decoder.layers.1.mixer.mixer.filter.h
        module.decoder.layers.26.mixer.mixer.filter.h
        module.decoder.layers.20.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.11.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.5.mixer.dense_projection.weight
        module.decoder.layers.30.mlp.linear_fc2.weight
        module.decoder.layers.16.mixer.dense_projection.weight
        module.decoder.layers.1.mixer.dense.weight
        module.decoder.layers.27.mixer.mixer.filter.R
        module.decoder.layers.21.mlp.linear_fc2.weight
        module.decoder.layers.12.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.6.mlp.linear_fc2.weight
        module.decoder.layers.24.self_attention.linear_proj.weight
        module.decoder.layers.18.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.8.mlp.linear_fc1.weight
        module.decoder.layers.14.mixer.mixer.short_conv.short_conv_weight
        module.decoder.layers.22.mixer.mixer.conv_bias
        module.decoder.layers.10.mlp.linear_fc1.weight
        module.decoder.layers.16.mixer.mixer.conv_bias
        module.decoder.layers.27.mixer.mixer.filter.p
        module.decoder.layers.22.mixer.dense.bias
        module.decoder.layers.7.mixer.dense.bias
        module.decoder.layers.18.mixer.dense_projection.weight
        module.decoder.layers.8.mlp.linear_fc2.weight
        module.decoder.layers.30.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.14.mixer.dense.weight
        module.decoder.layers.25.mlp.linear_fc1.weight
        module.decoder.layers.20.mixer.dense_projection.weight
        module.decoder.layers.10.mlp.linear_fc2.weight
        module.decoder.layers.6.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.28.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.12.mlp.linear_fc1.weight
        module.decoder.layers.23.mixer.dense.weight
        module.decoder.layers.18.mixer.mixer.short_conv.short_conv_weight
        module.decoder.layers.1.mixer.mixer.conv_bias
        module.decoder.layers.15.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.25.mlp.linear_fc2.weight
        module.decoder.layers.20.mixer.mixer.conv_bias
        module.decoder.layers.16.mixer.mixer.filter.gamma
        module.decoder.layers.1.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.22.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.12.mlp.linear_fc2.weight
        module.decoder.layers.7.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.23.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.18.mixer.dense.weight
        module.decoder.layers.3.self_attention.linear_qkv.weight
        module.decoder.layers.30.mixer.dense.bias
        module.decoder.layers.15.mixer.mixer.filter.h
        module.decoder.layers.0.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.26.mixer.dense.bias
        module.decoder.layers.6.mixer.dense.bias
        module.decoder.layers.16.mixer.mixer.filter.R
        module.decoder.layers.2.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.27.mixer.dense.weight
        module.decoder.layers.7.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.24.self_attention.linear_qkv.layer_norm_weight
        module.decoder.layers.19.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.9.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.3.mlp.linear_fc1.layer_norm_weight
        module.embedding.word_embeddings.weight
        module.decoder.layers.29.mixer.dense.weight
        module.decoder.layers.0.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.5.mixer.dense.weight
        module.decoder.layers.20.mixer.mixer.filter.gamma
        module.decoder.layers.11.mixer.dense.bias
        module.decoder.layers.1.mlp.linear_fc2.weight
        module.decoder.layers.31.self_attention.linear_qkv.weight
        module.decoder.layers.16.mixer.mixer.filter.p
        module.decoder.layers.27.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.22.mixer.dense_projection.weight
        module.decoder.layers.7.mixer.dense_projection.weight
        module.decoder.layers.23.mlp.linear_fc1.weight
        module.decoder.layers.19.mixer.mixer.filter.h
        module.decoder.layers.4.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.29.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.14.mlp.linear_fc1.weight
        module.decoder.layers.26.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.20.mixer.mixer.filter.R
        module.decoder.layers.11.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.5.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.31.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.17.self_attention.linear_proj.weight
        module.decoder.layers.13.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.7.mixer.mixer.short_conv.short_conv_weight
        module.decoder.layers.23.mlp.linear_fc2.weight
        module.decoder.layers.9.mixer.dense_projection.weight
        module.decoder.layers.3.mlp.linear_fc1.weight
        module.decoder.layers.14.mlp.linear_fc2.weight
        module.decoder.layers.0.mixer.dense.bias
        module.decoder.layers.20.mixer.mixer.filter.p
        module.decoder.layers.11.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.17.self_attention.linear_proj.bias
        module.decoder.layers.7.mixer.dense.weight
        module.decoder.layers.27.mlp.linear_fc1.weight
        module.decoder.layers.22.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.15.mixer.mixer.conv_bias
        module.decoder.layers.3.mlp.linear_fc2.weight
        module.decoder.layers.18.mlp.linear_fc1.weight
        module.decoder.layers.9.mixer.mixer.conv_bias
        module.decoder.layers.29.mlp.linear_fc1.weight
        module.decoder.layers.15.mixer.dense.bias
        module.decoder.layers.0.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.26.mixer.dense_projection.weight
        module.decoder.layers.21.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.11.mixer.dense_projection.weight
        module.decoder.layers.5.mlp.linear_fc1.weight
        module.decoder.layers.31.mlp.linear_fc1.weight
        module.decoder.layers.16.mixer.dense.weight
        module.decoder.layers.2.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.27.mlp.linear_fc2.weight
        module.decoder.layers.13.mixer.dense_projection.weight
        module.decoder.layers.7.mlp.linear_fc1.weight
        module.decoder.layers.18.mlp.linear_fc2.weight
        module.decoder.layers.29.mlp.linear_fc2.weight
        module.decoder.layers.1.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.11.mixer.mixer.short_conv.short_conv_weight
        module.decoder.layers.5.mlp.linear_fc2.weight
        module.decoder.layers.31.mlp.linear_fc2.weight
        module.decoder.layers.16.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.28.mixer.dense.bias
        module.decoder.layers.13.mixer.mixer.conv_bias
        module.decoder.layers.8.mixer.mixer.filter.h
        module.decoder.layers.26.mixer.mixer.conv_bias
        module.decoder.layers.19.mixer.dense.bias
        module.decoder.layers.9.mixer.mixer.filter.gamma
        module.decoder.layers.15.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.0.mixer.dense_projection.weight
        module.decoder.layers.27.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.20.mixer.dense.weight
        module.decoder.layers.11.mixer.dense.weight
        module.decoder.layers.17.self_attention.linear_qkv.layer_norm_weight
        module.decoder.layers.2.mixer.dense_projection.weight
        module.decoder.layers.22.mixer.dense.weight
        module.decoder.layers.7.mlp.linear_fc2.weight
        module.decoder.layers.9.mixer.mixer.filter.R
        module.decoder.layers.4.mixer.dense.bias
        module.decoder.layers.0.mixer.mixer.short_conv.short_conv_weight
        module.decoder.layers.20.mlp.linear_fc1.layer_norm_weight
        module.decoder.layers.12.mixer.hyena_proj_conv.short_conv_weight
        module.decoder.layers.16.mlp.linear_fc1.weight
        module.decoder.layers.2.mixer.mixer.conv_bias
        module.decoder.layers.28.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.22.mlp.linear_fc1.weight
        module.decoder.layers.13.mixer.mixer.filter.gamma
        module.decoder.layers.24.self_attention.linear_qkv.weight
        module.decoder.layers.19.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.9.mixer.mixer.filter.p
        module.decoder.layers.4.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.30.mixer.dense_projection.layer_norm_weight
        module.decoder.layers.15.mixer.dense_projection.weight
        module.decoder.layers.0.mixer.dense.weight
        module.decoder.layers.27.mixer.dense.bias
        module.decoder.layers.12.mixer.mixer.filter.h
        module.decoder.layers.6.mixer.dense_projection.layer_norm_weight
[NeMo I 2025-10-12 23:00:47 nemo_logging:393] Loaded sharded_state_dict_metadata from checkpoint: {'distrib_optim_sharding_type': 'fully_sharded_model_space'}
[NeMo I 2025-10-12 23:00:47 nemo_logging:393] Loaded sharded_state_dict_metadata from checkpoint: {'distrib_optim_sharding_type': 'fully_sharded_model_space'}
[NeMo I 2025-10-12 23:00:47 nemo_logging:393] Using <megatron.core.dist_checkpointing.strategies.fully_parallel.FullyParallelLoadStrategyWrapper object at 0x746c99862720> dist-ckpt load strategy.
[NeMo I 2025-10-12 23:00:47 nemo_logging:393] Using <megatron.core.dist_checkpointing.strategies.fully_parallel.FullyParallelLoadStrategyWrapper object at 0x7c410c187dd0> dist-ckpt load strategy.
[NeMo I 2025-10-12 23:00:51 nemo_logging:393] Global Checkpoint Load : Rank : 1 : Start time : 1760302847.604s : Time spent in load_checkpoint: 3.940s
[NeMo I 2025-10-12 23:00:51 nemo_logging:393] Global Checkpoint Load : Rank : 0 : Start time : 1760302847.607s : Time spent in load_checkpoint: 3.938s
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:bionemo.llm.data.collate:Extra keys in batch that will not be padded: {'seq_idx'}. Missing keys in batch: set()
WARNING:DotProductAttention:flash-attn v3 may provide important feature support or performance improvement. Please install flash-attn v3 by 
(1) git clone https://github.com/Dao-AILab/flash-attention.git
(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(4) mkdir -p $python_path/flash_attn_3
(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py
[rank1]: Traceback (most recent call last):
[rank1]:   File "/usr/local/bin/predict_evo2", line 10, in <module>
[rank1]:     sys.exit(main())
[rank1]:              ^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/evo2/run/predict.py", line 552, in main
[rank1]:     predict(
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/evo2/run/predict.py", line 543, in predict
[rank1]:     trainer.predict(model, datamodule=datamodule)  # TODO return_predictions=False
[rank1]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 858, in predict
[rank1]:     return call._call_and_handle_interrupt(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
[rank1]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank1]:     return function(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 897, in _predict_impl
[rank1]:     results = self._run(model, ckpt_path=ckpt_path)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
[rank1]:     results = self._run_stage()
[rank1]:               ^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 1020, in _run_stage
[rank1]:     return self.predict_loop.run()
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/utilities.py", line 178, in _decorator
[rank1]:     return loop_run(self, *args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/prediction_loop.py", line 124, in run
[rank1]:     self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter)
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/prediction_loop.py", line 253, in _predict_step
[rank1]:     predictions = call._call_strategy_hook(trainer, "predict_step", *step_args)
[rank1]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
[rank1]:     output = fn(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/nemo/lightning/pytorch/strategies/megatron_strategy.py", line 896, in predict_step
[rank1]:     return self.model.predict_step(dataloader_iter, *args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/nemo/lightning/megatron_parallel.py", line 423, in predict_step
[rank1]:     return self._step(
[rank1]:            ^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/nemo/lightning/megatron_parallel.py", line 458, in _step
[rank1]:     return self.forward(
[rank1]:            ^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/nemo/lightning/megatron_parallel.py", line 308, in forward
[rank1]:     microbatch_outputs = step()
[rank1]:                          ^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/nemo/lightning/megatron_parallel.py", line 1248, in __call__
[rank1]:     return self.forward_backward_func(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/pipeline_parallel/schedules.py", line 518, in forward_backward_no_pipelining
[rank1]:     output_tensor, num_tokens = forward_step(
[rank1]:                                 ^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/pipeline_parallel/schedules.py", line 289, in forward_step
[rank1]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank1]:                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/nemo/lightning/megatron_parallel.py", line 498, in wrapped_forward_step_func
[rank1]:     batch = _data_step(dataloader_iter)
[rank1]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/nemo/collections/llm/gpt/model/base.py", line 661, in data_step
[rank1]:     return self.config.data_step_fn(dataloader_iter)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/evo2/run/predict.py", line 337, in hyena_predict_data_step
[rank1]:     output = get_batch_on_this_cp_rank(_batch_required_keys)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/utils.py", line 1850, in get_batch_on_this_cp_rank
[rank1]:     val = val.view(
[rank1]:           ^^^^^^^^^
[rank1]: RuntimeError: shape '[1, 4, 20906]' is invalid for input of size 83626
[rank0]: Traceback (most recent call last):
[rank0]:   File "/usr/local/bin/predict_evo2", line 10, in <module>
[rank0]:     sys.exit(main())
[rank0]:              ^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/evo2/run/predict.py", line 552, in main
[rank0]:     predict(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/evo2/run/predict.py", line 543, in predict
[rank0]:     trainer.predict(model, datamodule=datamodule)  # TODO return_predictions=False
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 858, in predict
[rank0]:     return call._call_and_handle_interrupt(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
[rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]:     return function(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 897, in _predict_impl
[rank0]:     results = self._run(model, ckpt_path=ckpt_path)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
[rank0]:     results = self._run_stage()
[rank0]:               ^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 1020, in _run_stage
[rank0]:     return self.predict_loop.run()
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/utilities.py", line 178, in _decorator
[rank0]:     return loop_run(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/prediction_loop.py", line 124, in run
[rank0]:     self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter)
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/prediction_loop.py", line 253, in _predict_step
[rank0]:     predictions = call._call_strategy_hook(trainer, "predict_step", *step_args)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/nemo/lightning/pytorch/strategies/megatron_strategy.py", line 896, in predict_step
[rank0]:     return self.model.predict_step(dataloader_iter, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/nemo/lightning/megatron_parallel.py", line 423, in predict_step
[rank0]:     return self._step(
[rank0]:            ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/nemo/lightning/megatron_parallel.py", line 458, in _step
[rank0]:     return self.forward(
[rank0]:            ^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/nemo/lightning/megatron_parallel.py", line 308, in forward
[rank0]:     microbatch_outputs = step()
[rank0]:                          ^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/nemo/lightning/megatron_parallel.py", line 1248, in __call__
[rank0]:     return self.forward_backward_func(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/pipeline_parallel/schedules.py", line 518, in forward_backward_no_pipelining
[rank0]:     output_tensor, num_tokens = forward_step(
[rank0]:                                 ^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/pipeline_parallel/schedules.py", line 289, in forward_step
[rank0]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank0]:                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/nemo/lightning/megatron_parallel.py", line 498, in wrapped_forward_step_func
[rank0]:     batch = _data_step(dataloader_iter)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/nemo/collections/llm/gpt/model/base.py", line 661, in data_step
[rank0]:     return self.config.data_step_fn(dataloader_iter)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/evo2/run/predict.py", line 337, in hyena_predict_data_step
[rank0]:     output = get_batch_on_this_cp_rank(_batch_required_keys)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/utils.py", line 1850, in get_batch_on_this_cp_rank
[rank0]:     val = val.view(
[rank0]:           ^^^^^^^^^
[rank0]: RuntimeError: shape '[1, 4, 20906]' is invalid for input of size 83626

Docker Image

SIF container built by me from nvcr.io/nvidia/clara/bionemo-framework:2.7

System Information

GPU Details:

  • GPU Model: Nvidia H200

Additional Context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions