You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently we are not saving random number generator status. For model like Flux, the inputs are randomly generated. If we saved checkpoint at step=x, and load at "step=x", the rng states will not be the same. So the generated noise (part of Flux input) is not the same, and the loss is not deterministic.
Example Implementation
In train.py
def state_dict(self) -> dict[str, Any]:
# Save training step and RNG states for reproducibility
device_module = utils.device_module
self.device_rng_state = device_module.get_rng_state()
self.cpu_rng_state = torch.get_rng_state()
def comput_rng_hash(state: torch.ByteTensor) -> float:
"""Compute a hash for the given state dictionary."""
return int.from_bytes(state.cpu().numpy().tobytes()[0:32])
logger.info(
f"In trainer.state_dict(), Read State dict RNG states at step {self.step}: CPU {comput_rng_hash(self.cpu_rng_state)} device {comput_rng_hash(self.device_rng_state)}"
)
return {
"step": self.step,
"device_rng_states": self.device_rng_state,
"cpu_rng_states": self.cpu_rng_state,
}
def load_state_dict(self, state_dict: dict[str, Any]):
self.step = state_dict["step"]
self.device_rng_state = state_dict["device_rng_states"]
self.cpu_rng_state = state_dict["cpu_rng_states"]
# Restore RNG states if they exist in the state_dict
device_module = utils.device_module
device_module.set_rng_state(self.device_rng_state)
torch.set_rng_state(self.cpu_rng_state)
def comput_rng_hash(state: torch.ByteTensor) -> float:
"""Compute a hash for the given state dictionary."""
return int.from_bytes(state.cpu().numpy().tobytes()[0:32])
logger.info(
f"Loaded State dict RNG states at step {self.step}: CPU {comput_rng_hash(self.cpu_rng_state)} device {comput_rng_hash(self.device_rng_state)}"
)
TODOs
The example implementation above is not "correct", because the rng state saved and loaded are not the same based on testing.
The text was updated successfully, but these errors were encountered:
## Context:
1. Change flux-dev / flux-schnell model training to be ~30000 step based
on current MAST training results
2. Enable checkpointing. We enabled final_layer reshard_after_forward to
solve issue described
[here](#1167 (comment))
## Test
If we run following 2 runs, the training loss curve should be identical
with `deterministic = True`:
1. Without checkpoint save and load, total step=10
2. Save checkpoint at step 5, and load checkpoint at step 5, continue
training
Currently issue #1194 makes the training loss not strictly identical. To
exclude the influence of #1194, we reset the seeds (by calling
`set_deterministic()` at the beginning of step 6. Then the checkpoint
save/load makes the training loss identical.
<img width="1675" alt="Screenshot 2025-05-14 at 2 06 23 PM"
src="https://github.com/user-attachments/assets/22882b71-378c-44fa-bd48-8a8f238aa1b0"
/>
## Context:
1. Change flux-dev / flux-schnell model training to be ~30000 step based
on current MAST training results
2. Enable checkpointing. We enabled final_layer reshard_after_forward to
solve issue described
[here](#1167 (comment))
## Test
If we run following 2 runs, the training loss curve should be identical
with `deterministic = True`:
1. Without checkpoint save and load, total step=10
2. Save checkpoint at step 5, and load checkpoint at step 5, continue
training
Currently issue #1194 makes the training loss not strictly identical. To
exclude the influence of #1194, we reset the seeds (by calling
`set_deterministic()` at the beginning of step 6. Then the checkpoint
save/load makes the training loss identical.
<img width="1675" alt="Screenshot 2025-05-14 at 2 06 23 PM"
src="https://github.com/user-attachments/assets/22882b71-378c-44fa-bd48-8a8f238aa1b0"
/>
Context
Currently we are not saving random number generator status. For model like Flux, the inputs are randomly generated. If we saved checkpoint at
step=x
, and load at "step=x", the rng states will not be the same. So the generated noise (part of Flux input) is not the same, and the loss is not deterministic.Example Implementation
In
train.py
TODOs
The example implementation above is not "correct", because the rng state saved and loaded are not the same based on testing.
The text was updated successfully, but these errors were encountered: