@@ -180,12 +180,18 @@ def calculate_token_position_to_id(block_position_indices, tokens_indices,
180180 group_indices = jnp .arange (G )[None , :, None , None ]
181181 group_indices = jnp .broadcast_to (group_indices , (O , G , num_tokens , E ))
182182
183+ # Clamp block_position_indices to prevent out-of-bounds access
184+ max_valid_index = num_blocks * block_size
185+ block_position_indices = jnp .clip (block_position_indices , 0 , max_valid_index )
186+
183187 token_position_to_id = jnp .zeros ((O , G , num_blocks * block_size + 1 ), dtype = jnp .int32 )
184188 token_position_to_id = token_position_to_id .at [batch_indices , group_indices , block_position_indices ].set (tokens_indices + 1 )
185189
186190 token_position_to_id = token_position_to_id [:, :, 1 :]
187191 token_position_to_id = token_position_to_id - 1
188192 token_position_to_id = jnp .where (token_position_to_id == - 1 , total_tokens , token_position_to_id )
193+ # Clamp final result to prevent out-of-bounds access
194+ token_position_to_id = jnp .clip (token_position_to_id , 0 , total_tokens )
189195 dest_output = dest_output .at [0 ].set (token_position_to_id )
190196 return dest_output
191197
@@ -215,6 +221,8 @@ def blockwise_mm_per_group_native(hidden_states, expert_affinities_masked, gate_
215221 def body_fun (b , carry ):
216222 output_jax = carry
217223 local_token_position_to_id = token_position_to_id [b , :]
224+ # Clamp indices to prevent out-of-bounds access on Neuron hardware
225+ local_token_position_to_id = jnp .clip (local_token_position_to_id , 0 , hidden_states .shape [0 ] - 1 )
218226 hidden_states_padded = hidden_states
219227 expert_affinities_padded = expert_affinities
220228 local_hidden_states = hidden_states_padded [local_token_position_to_id ].astype (jnp .float32 )
@@ -965,6 +973,9 @@ def compute_token_assignments(token_permutation_idx, num_experts, expert_capacit
965973 group_indices = group_indices .reshape (O , G , - 1 )
966974
967975 token_permutation_idx = token_permutation_idx .reshape (O , G , - 1 )
976+ # Clamp token_permutation_idx to prevent out-of-bounds scatter access
977+ max_valid_index = expert_capacity * num_experts
978+ token_permutation_idx = jnp .clip (token_permutation_idx , 0 , max_valid_index )
968979
969980 # Create scatter indices
970981 scatter_indices = jnp .stack (
@@ -1222,8 +1233,9 @@ def get_token_position_to_id(
12221233 group_indices = jnp .arange (G )[None , :, None , None ]
12231234 group_indices = jnp .broadcast_to (group_indices , (O , G , num_tokens , E ))
12241235
1225- # (O, G, S*top_k, E)
1226- # block_position_indices
1236+ # Clamp block_position_indices to prevent out-of-bounds scatter access
1237+ max_valid_index = num_blocks * block_size
1238+ block_position_indices = jnp .clip (block_position_indices , 0 , max_valid_index )
12271239
12281240 # Create scatter indices
12291241 scatter_indices = jnp .stack ([batch_indices , group_indices , block_position_indices ], axis = - 1 , dtype = jnp .int32 )
@@ -1247,6 +1259,8 @@ def get_token_position_to_id(
12471259
12481260 token_position_to_id = token_position_to_id - 1
12491261 token_position_to_id = jnp .where (token_position_to_id == - 1 , num_tokens ,token_position_to_id )
1262+ # Clamp final token_position_to_id to prevent out-of-bounds access
1263+ token_position_to_id = jnp .clip (token_position_to_id , 0 , num_tokens )
12501264 token_position_to_id = self ._remat_name (token_position_to_id , "blockwisegating.token_position_to_id" )
12511265 return token_position_to_id
12521266
@@ -1473,6 +1487,12 @@ def forward(self, logits):
14731487 check_rep = False
14741488 )
14751489 token_position_to_id = token_position_to_id_sm (expert_capacity , block_position_indices , local_num_experts )
1490+ # Clamp token_position_to_id indices
1491+ token_position_to_id = jnp .clip (token_position_to_id , 0 , S - 1 )
1492+
1493+ # Clamp block_to_expert indices
1494+ block_to_expert = jnp .clip (block_to_expert , 0 , cfg .num_experts - 1 )
1495+
14761496 router_z_loss = _router_z_loss (logits )
14771497
14781498 return self .Output (
0 commit comments