-
Notifications
You must be signed in to change notification settings - Fork 385
Seeing - "Recomputed values for the following tensors have different metadata than during the forward pass." #1117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
could you share the config to reproduce? I think the complaint is from activation checkpointing -- I saw it when exploring possible ways to register hooks for load balancing updates #1114 |
The toml file follows. ` [profiling] [metrics] [model] [optimizer] [lr_scheduler] [training] [parallelism] [checkpoint] [activation_checkpoint] [float8] |
Hmm I'm not sure if I can reproduce. Some more questions:
|
@tianyu-l , please see the answers below. what model configs are you using? - The tolml file is above, if that is what you are asking. Otherwise, the source is not changed. |
Looks like checkpoint recompute has issue/s - this looks funny .
One interesting observation is that if I make the following change, training proceeds , but hits NaN after some time.
The argss[0] is an empty dictionary, and args[1/2] are tensors. Checkpointing code looks complicated and difficult to debug. It will be good to be able to relate the model layer to the saved-recompute mismatch . |
cc @soulitzer in case you have context |
Would it be possible to pass |
@soulitzer you can access the log here with debug=True, llama4_17bx16e N=2 PPN=8 TP=8 FSDP=2 https://github.com/ratnampa/misc_uploads/blob/main/torchtitan/llama4_17bx16e/N2_PPN8_TP8_FSDP2_llama4_17bx16e.log I have also added log only for rank 1, might be easier to inspect. |
Thanks for the logs, it looks like there's a split_with_sizes that received slightly different inputs between the original and recompute. original:
recompute:
Any idea why that is the case? |
Interesting pointer. Both the original and recompute in the above sum to 1024, but differ in 2 split locations (99 vs 98 a and 105 vs 106). What does the position in the following refer to ?
|
position 91 means it is the 91st tensor saved in the checkpointed region |
How do I map the position to a layer in the model ? Also, what is the code that decides the split ? |
Ah I wouldn't look at the position number here. I'd just search for The way its structured is basically:
|
@soulitzer , thanks. Added a PyTorch PR for adding a layer identification to checkpoint discrepancies. |
Thanks for the PR, I added a comment here pytorch/pytorch#153021 (review). |
Few questions.
|
Not a lot. What type of information are you looking for? There's a code comment on some AC internals here https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L602.
Yes
Your forward logic depending on global state, e.g. are you explicitly branching on any globals, are there modes TorchDispatchModes/TorchFunctionModes
Don't think it's possible in TorchTitan through the config (@tianyu-l correct me if I'm wrong)
It always saves some "compute intensive ops" except every other matmul.
|
I believe you can always define your own |
Not sure. It is the Llama4 model OOB. |
Seeing the following with the llama4_17bx16e model.
rank11: File ".../lib/python3.10/site-packages/torch/utils/checkpoint.py", line 902, in check_recomputed_tensors_match
rank11: raise CheckpointError(
rank11: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
rank11: tensor at position 46:
rank11: saved metadata: {'shape': torch.Size([965, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=3)}
rank11: recomputed metadata: {'shape': torch.Size([964, 5120]),[rank13]: Traceback (most recent call last):
The text was updated successfully, but these errors were encountered: