Skip to content

Save RNG states during checkpointing for deterministic debugging #1194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wwwjn opened this issue May 14, 2025 · 0 comments
Open

Save RNG states during checkpointing for deterministic debugging #1194

wwwjn opened this issue May 14, 2025 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@wwwjn
Copy link
Contributor

wwwjn commented May 14, 2025

Context

Currently we are not saving random number generator status. For model like Flux, the inputs are randomly generated. If we saved checkpoint at step=x, and load at "step=x", the rng states will not be the same. So the generated noise (part of Flux input) is not the same, and the loss is not deterministic.

Example Implementation

In train.py

def state_dict(self) -> dict[str, Any]:
        # Save training step and RNG states for reproducibility

        device_module = utils.device_module
        self.device_rng_state = device_module.get_rng_state()
        self.cpu_rng_state = torch.get_rng_state()

        def comput_rng_hash(state: torch.ByteTensor) -> float:
            """Compute a hash for the given state dictionary."""
            return int.from_bytes(state.cpu().numpy().tobytes()[0:32])

        logger.info(
            f"In trainer.state_dict(), Read State dict RNG states at step {self.step}: CPU {comput_rng_hash(self.cpu_rng_state)} device {comput_rng_hash(self.device_rng_state)}"
        )

        return {
            "step": self.step,
            "device_rng_states": self.device_rng_state,
            "cpu_rng_states": self.cpu_rng_state,
        }

    def load_state_dict(self, state_dict: dict[str, Any]):
        self.step = state_dict["step"]
        self.device_rng_state = state_dict["device_rng_states"]
        self.cpu_rng_state = state_dict["cpu_rng_states"]

        # Restore RNG states if they exist in the state_dict
        device_module = utils.device_module
        device_module.set_rng_state(self.device_rng_state)
        torch.set_rng_state(self.cpu_rng_state)

        def comput_rng_hash(state: torch.ByteTensor) -> float:
            """Compute a hash for the given state dictionary."""
            return int.from_bytes(state.cpu().numpy().tobytes()[0:32])

        logger.info(
            f"Loaded State dict RNG states at step {self.step}: CPU {comput_rng_hash(self.cpu_rng_state)} device {comput_rng_hash(self.device_rng_state)}"
        )

TODOs

The example implementation above is not "correct", because the rng state saved and loaded are not the same based on testing.

@wwwjn wwwjn self-assigned this May 14, 2025
@wwwjn wwwjn added the enhancement New feature or request label May 14, 2025
wwwjn added a commit that referenced this issue May 15, 2025
## Context:
1. Change flux-dev / flux-schnell model training to be ~30000 step based
on current MAST training results
2. Enable checkpointing. We enabled final_layer reshard_after_forward to
solve issue described
[here](#1167 (comment))

## Test
If we run following 2 runs, the training loss curve should be identical
with `deterministic = True`:
1. Without checkpoint save and load, total step=10
2. Save checkpoint at step 5, and load checkpoint at step 5, continue
training

Currently issue #1194 makes the training loss not strictly identical. To
exclude the influence of #1194, we reset the seeds (by calling
`set_deterministic()` at the beginning of step 6. Then the checkpoint
save/load makes the training loss identical.

<img width="1675" alt="Screenshot 2025-05-14 at 2 06 23 PM"
src="https://github.com/user-attachments/assets/22882b71-378c-44fa-bd48-8a8f238aa1b0"
/>
wwwjn added a commit that referenced this issue May 16, 2025
## Context:
1. Change flux-dev / flux-schnell model training to be ~30000 step based
on current MAST training results
2. Enable checkpointing. We enabled final_layer reshard_after_forward to
solve issue described
[here](#1167 (comment))

## Test
If we run following 2 runs, the training loss curve should be identical
with `deterministic = True`:
1. Without checkpoint save and load, total step=10
2. Save checkpoint at step 5, and load checkpoint at step 5, continue
training

Currently issue #1194 makes the training loss not strictly identical. To
exclude the influence of #1194, we reset the seeds (by calling
`set_deterministic()` at the beginning of step 6. Then the checkpoint
save/load makes the training loss identical.

<img width="1675" alt="Screenshot 2025-05-14 at 2 06 23 PM"
src="https://github.com/user-attachments/assets/22882b71-378c-44fa-bd48-8a8f238aa1b0"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant