Skip to content

Commit 3225b24

Browse files
author
Akshi22
committed
clamped indices to detangle from OOB
1 parent 314b68d commit 3225b24

File tree

10 files changed

+50
-18
lines changed

10 files changed

+50
-18
lines changed

axlearn/common/mixture_of_experts.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,18 @@ def calculate_token_position_to_id(block_position_indices, tokens_indices,
180180
group_indices = jnp.arange(G)[None, :, None, None]
181181
group_indices = jnp.broadcast_to(group_indices, (O, G, num_tokens, E))
182182

183+
# Clamp block_position_indices to prevent out-of-bounds access
184+
max_valid_index = num_blocks * block_size
185+
block_position_indices = jnp.clip(block_position_indices, 0, max_valid_index)
186+
183187
token_position_to_id = jnp.zeros((O, G, num_blocks * block_size + 1), dtype=jnp.int32)
184188
token_position_to_id = token_position_to_id.at[batch_indices, group_indices, block_position_indices].set(tokens_indices+1)
185189

186190
token_position_to_id = token_position_to_id[:, :, 1:]
187191
token_position_to_id = token_position_to_id - 1
188192
token_position_to_id = jnp.where(token_position_to_id==-1, total_tokens, token_position_to_id)
193+
# Clamp final result to prevent out-of-bounds access
194+
token_position_to_id = jnp.clip(token_position_to_id, 0, total_tokens)
189195
dest_output = dest_output.at[0].set(token_position_to_id)
190196
return dest_output
191197

@@ -215,6 +221,8 @@ def blockwise_mm_per_group_native(hidden_states, expert_affinities_masked, gate_
215221
def body_fun(b, carry):
216222
output_jax = carry
217223
local_token_position_to_id = token_position_to_id[b, :]
224+
# Clamp indices to prevent out-of-bounds access on Neuron hardware
225+
local_token_position_to_id = jnp.clip(local_token_position_to_id, 0, hidden_states.shape[0] - 1)
218226
hidden_states_padded = hidden_states
219227
expert_affinities_padded = expert_affinities
220228
local_hidden_states = hidden_states_padded[local_token_position_to_id].astype(jnp.float32)
@@ -965,6 +973,9 @@ def compute_token_assignments(token_permutation_idx, num_experts, expert_capacit
965973
group_indices = group_indices.reshape(O, G, -1)
966974

967975
token_permutation_idx = token_permutation_idx.reshape(O, G, -1)
976+
# Clamp token_permutation_idx to prevent out-of-bounds scatter access
977+
max_valid_index = expert_capacity * num_experts
978+
token_permutation_idx = jnp.clip(token_permutation_idx, 0, max_valid_index)
968979

969980
# Create scatter indices
970981
scatter_indices = jnp.stack(
@@ -1222,8 +1233,9 @@ def get_token_position_to_id(
12221233
group_indices = jnp.arange(G)[None, :, None, None]
12231234
group_indices = jnp.broadcast_to(group_indices, (O, G, num_tokens, E))
12241235

1225-
# (O, G, S*top_k, E)
1226-
# block_position_indices
1236+
# Clamp block_position_indices to prevent out-of-bounds scatter access
1237+
max_valid_index = num_blocks * block_size
1238+
block_position_indices = jnp.clip(block_position_indices, 0, max_valid_index)
12271239

12281240
# Create scatter indices
12291241
scatter_indices = jnp.stack([batch_indices, group_indices, block_position_indices], axis=-1, dtype=jnp.int32)
@@ -1247,6 +1259,8 @@ def get_token_position_to_id(
12471259

12481260
token_position_to_id = token_position_to_id - 1
12491261
token_position_to_id = jnp.where(token_position_to_id==-1, num_tokens,token_position_to_id)
1262+
# Clamp final token_position_to_id to prevent out-of-bounds access
1263+
token_position_to_id = jnp.clip(token_position_to_id, 0, num_tokens)
12501264
token_position_to_id = self._remat_name(token_position_to_id, "blockwisegating.token_position_to_id")
12511265
return token_position_to_id
12521266

@@ -1473,6 +1487,12 @@ def forward(self, logits):
14731487
check_rep=False
14741488
)
14751489
token_position_to_id = token_position_to_id_sm(expert_capacity, block_position_indices, local_num_experts)
1490+
# Clamp token_position_to_id indices
1491+
token_position_to_id = jnp.clip(token_position_to_id, 0, S - 1)
1492+
1493+
# Clamp block_to_expert indices
1494+
block_to_expert = jnp.clip(block_to_expert, 0, cfg.num_experts - 1)
1495+
14761496
router_z_loss = _router_z_loss(logits)
14771497

14781498
return self.Output(

axlearn/experiments/text/gpt/envy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797

9898
MAX_SEQUENCE_LENGTH = {
9999
"test": 8192,
100-
"Switch-Base": 8192,
100+
"Switch-Base": 2048,
101101
"Switch-Large": 8192,
102102
"Switch-XXL": 8192,
103103
"Mistral-toy": 256,

profile.slurm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ export PROFILE_JOB_NAME=$2
1212
export PROFILE_JOB_ID=$1
1313
export AXLEARN_PROFILE_MODE=capture
1414

15-
srun -l setup_node.sh ../may-artifacts/
15+
srun -l setup_node.sh /fsx/akshiaws/jul-end-artifacts
1616
srun -l runner.sh

runner.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ export NEURON_CC_FLAGS="${NEURON_CC_FLAGS} --dump=${NEURON_DUMP_PATH}"
140140

141141
# use to add debug logging at module level in xla
142142
export TF_CPP_MIN_LOG_LEVEL=0
143-
export TF_CPP_VMODULE="neuron_token_threading=2"
143+
export TF_CPP_VMODULE="neuron_token_threading=2,neuron_repeated_dus_to_concat=3"
144144

145145
# JAX Cache
146146
# export JAX_COMPILATION_CACHE_DIR="cache/"

runners/full_convergence_16x10b.slurm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ export RENAME_JOB=true
2828
export RENAME_JOB_PREFIX=rh
2929

3030
if [ ${1:-1} = "1" ]; then
31-
srun -l ./setup_node.sh ../may-artifacts/
31+
srun -l ./setup_node.sh /fsx/akshiaws/jul-end-artifacts
3232
else
3333
echo "Skip installing"
3434
fi

runners/full_convergence_8x20b.slurm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ export RENAME_JOB=true
3535
export RENAME_JOB_PREFIX=rh
3636

3737
if [ ${1:-1} = "1" ]; then
38-
srun -l ./setup_node.sh /fsx/aahila/jul-artifacts/
38+
srun -l ./setup_node.sh /fsx/akshiaws/jul-end-artifacts
3939
else
4040
echo "Skip installing"
4141
fi

runners/full_convergence_8x7b.slurm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ export RENAME_JOB=true
3535
export RENAME_JOB_PREFIX=rh
3636

3737
if [ ${1:-1} = "1" ]; then
38-
srun -l ./setup_node.sh ../may-artifacts/
38+
srun -l ./setup_node.sh /fsx/akshiaws/jul-end-artifacts
3939
else
4040
echo "Skip installing"
4141
fi

runners/run.slurm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ elif [ $mode = "repeated" ]; then
2929
echo "Using repeated"
3030
export AXLEARN_REPEATED=1
3131
export VENV_NAME=jaxmoe
32-
# Set to use repeated, make sure to use /fsx/huilgolr/may-artifacts/repeated/libneuronxla-2.2.20250521+7e624b6.dev-py3-none-linux_x86_64
32+
# Set to use repeated, make sure to use /fsx/huilgolr/jul-end-artifacts/repeated/libneuronxla-2.2.20250521+7e624b6.dev-py3-none-linux_x86_64
3333
fi
3434

3535
if [ ${2:-1} = "1" ]; then
3636
echo "Installing"
37-
srun -l ./setup_node.sh ../may-artifacts/
37+
srun -l ./setup_node.sh /fsx/akshiaws/jul-end-artifacts
3838
else
3939
echo "Skip installing"
4040
fi

runners/run_full.slurm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ elif [ $mode = "repeated" ]; then
2929
echo "Using repeated"
3030
export AXLEARN_REPEATED=1
3131
export VENV_NAME=jaxmoe
32-
# Set to use repeated, make sure to use /fsx/huilgolr/may-artifacts/repeated/libneuronxla-2.2.20250521+7e624b6.dev-py3-none-linux_x86_64
32+
# Set to use repeated, make sure to use /fsx/huilgolr/jul-end-artifacts/repeated/libneuronxla-2.2.20250521+7e624b6.dev-py3-none-linux_x86_64
3333
fi
3434

3535
if [ ${2:-1} = "1" ]; then
3636
echo "Installing"
37-
srun -l ./setup_node.sh ../may-artifacts/
37+
srun -l ./setup_node.sh /fsx/akshiaws/jul-end-artifacts
3838
else
3939
echo "Skip installing"
4040
fi

runners/switch_source.sh

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
#!/bin/bash
2+
#export AXLEARN_JAX_BACKEND="cpu"
3+
24
# comment this out to run full model
3-
export AXLEARN_NUM_LAYERS=2
5+
export AXLEARN_NUM_LAYERS=4
46
export AXLEARN_REMAT_LAYER=selective
57
export AXLEARN_MODEL_NAME="envy-Switch-Base"
68
export AXLEARN_TP_DEGREE=4
79
export AXLEARN_EP_DEGREE=4
810
export AXLEARN_SEQ_DEGREE=4
11+
export AXLEARN_FSDP_DEGREE=1
912
export AXLEARN_TRAIN_BATCH_SIZE=4
13+
1014
# use v2 index calc
1115
export AXLEARN_USE_BLOCKWISE=2
1216

17+
# 0: dense, 1: sparse, 2: alternating
18+
export AXLEARN_MOE_LAYER_FREQ=2
19+
20+
export EP_WITHIN_NODE=1
21+
# export AXLEARN_PROFILE_MODE="tracerun"
22+
# export NEURON_RT_LOCAL_CORE_DUMP_DIRECTORY=""
23+
1324
if [ "${AXLEARN_SEQ_DEGREE:-0}" -gt 1 ]; then
1425
export AXLEARN_FLASH_ATTENTION=0
1526
else
@@ -18,11 +29,12 @@ fi
1829

1930
export AXLEARN_REPEATED=1
2031
# it expects the env to be at ../$VENV_NAME
21-
export VENV_NAME=jaxmoe
22-
32+
export VENV_NAME=akshiaws/jaxmoe
2333
# to simulate slurm job run
2434
export SLURM_PROCID=0
2535
# to output artifacts at this path ./artifacts/JOB_ID/
26-
export JOB_ID=switchbaseep64
27-
rm -rf ./artifacts/$JOB_ID/
28-
bash runner.sh 2>&1 | tee log_$JOB_ID.out
36+
export JOB_ID=dummy
37+
38+
rm -rf /fsx/akshiaws/artifacts/$JOB_ID/
39+
bash /fsx/akshiaws/axlearn/runner.sh
40+
2>&1 | tee log_$JOB_ID.out

0 commit comments

Comments
 (0)