Skip to content

moe_lb_loss should be divided by gradient_accumulation_steps for reporting. #1483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
bzantium opened this issue Mar 26, 2025 · 2 comments
Open
Assignees

Comments

@bzantium
Copy link

bzantium commented Mar 26, 2025

moe_lb_loss = aux["moe_lb_loss"]

moe_lb_loss should be divided by gradient_accumulation_steps for reporting.

  moe_lb_loss = aux["moe_lb_loss"] / config.gradient_accumulation_steps
@bzantium bzantium changed the title moe_lb_loss should be divided by config.gradient_accumulation_steps for reporting. moe_lb_loss should be divided by gradient_accumulation_steps for reporting. Mar 26, 2025
@RissyRan RissyRan self-assigned this Mar 28, 2025
@RissyRan
Copy link
Collaborator

Thanks for reaching out! I think the gradient_accumlation_steps is handled here? Inside of if config.gradient_accumulation_steps > 1.

maxtext/MaxText/train.py

Lines 451 to 454 in db89bbb

loss = (
grad_and_loss["loss"] / grad_and_loss["total_weights"]
+ grad_and_loss["moe_lb_loss"] / config.gradient_accumulation_steps
)

@RissyRan RissyRan assigned bzantium and unassigned RissyRan Mar 31, 2025
@bzantium
Copy link
Author

bzantium commented Mar 31, 2025

you are right for loss calculation but not for moe lb loss logging on tensorboard.

maxtext/MaxText/train.py

Lines 468 to 486 in db89bbb

moe_lb_loss = aux["moe_lb_loss"]
if config.gradient_clipping_threshold > 0:
grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold)
else:
grads = raw_grads
if config.optimizer_memory_host_offload:
state = state.replace(
opt_state=jax.device_put(
state.opt_state,
jax.tree_util.tree_map(lambda x: x.with_memory_kind(kind="device"), state_mesh_shardings.opt_state),
)
)
new_state = state.apply_gradients(grads=grads)
scalar_metrics = {
"learning/loss": loss,
"learning/moe_lb_loss": moe_lb_loss,
"learning/total_weights": total_weights,

in L485, current code provides moe_lb_loss which is summed through gradient accumulation if config.gradient_accumulation_steps>1 following:

aux = jax.tree_map(lambda x: jnp.sum(x, axis=0), aux)

so need to fixed as suggested above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants