diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index a07bf0f7b..0dad02d25 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -295,7 +295,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_local_tokens_per_expert, self.experts.num_experts, 1, - token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M, ALIGN_SIZE_M, ) token_indices = torch.vstack(