Skip to content

Seeing - "Recomputed values for the following tensors have different metadata than during the forward pass." #1117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
githubsgi opened this issue Apr 17, 2025 · 19 comments

Comments

@githubsgi
Copy link
Contributor

Seeing the following with the llama4_17bx16e model.

rank11: File ".../lib/python3.10/site-packages/torch/utils/checkpoint.py", line 902, in check_recomputed_tensors_match
rank11: raise CheckpointError(
rank11: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
rank11: tensor at position 46:
rank11: saved metadata: {'shape': torch.Size([965, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=3)}
rank11: recomputed metadata: {'shape': torch.Size([964, 5120]),[rank13]: Traceback (most recent call last):

@tianyu-l
Copy link
Contributor

could you share the config to reproduce? I think the complaint is from activation checkpointing -- I saw it when exploring possible ways to register hooks for load balancing updates #1114

@githubsgi
Copy link
Contributor Author

The toml file follows.

`
[job]
dump_folder = "./outputs_llama4_17bx16e"
description = "Llama 4 Scout 17Bx16E training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama4"
flavor = "17bx16e"
tokenizer_path = "./assets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 4e-3
eps = 1e-15

[lr_scheduler]
warmup_steps = 600
lr_min = 0.1

[training]
batch_size = 1
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 3000
compile = false
dataset = "c4"
dataset_path = "./data/hf/c4"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 8
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
context_parallel_degree = 1

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'full' # ['none', 'selective', 'full']

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = "output,router.gate"
`

@tianyu-l
Copy link
Contributor

Hmm I'm not sure if I can reproduce. Some more questions:

  • what model configs are you using?
  • are you using Grouped MM or for-loop implementation for MoE? could be depending on your hardware
  • are you running the load balancing I recently added?

@githubsgi
Copy link
Contributor Author

@tianyu-l , please see the answers below.

what model configs are you using? - The tolml file is above, if that is what you are asking. Otherwise, the source is not changed.
are you using Grouped MM or for-loop implementation for MoE? - Not using GroupedMM.
are you running the load balancing I recently added? - just tried with the latest commit. Same issue.

@githubsgi
Copy link
Contributor Author

githubsgi commented Apr 25, 2025

Looks like checkpoint recompute has issue/s - this looks funny .

                try:
                    with _recomputation_hook(
                        weakref.ref(frame), gid
                    ), torch.autograd.enable_grad():
                        frame.recompute_fn(*args)
                except _StopRecomputationError:
                    pass
                frame.is_recomputed[gid] = True
                frame.check_recomputed_tensors_match(gid)

One interesting observation is that if I make the following change, training proceeds , but hits NaN after some time.

                try:
                    #print (f"frame.recompute_fn {frame.recompute_fn}")
                    with _recomputation_hook(
                        weakref.ref(frame), gid
                    ), torch.autograd.enable_grad():
                        frame.recompute_fn(*args)
                        print (f" frame.recompute_fn(*args) {args}")
                except _StopRecomputationError :
                    print (f"_StopRecomputationError  {len(args)} {args[0]} {args[1].shape} {args[2].shape} ")
                    pass

The argss[0] is an empty dictionary, and args[1/2] are tensors.

Checkpointing code looks complicated and difficult to debug. It will be good to be able to relate the model layer to the saved-recompute mismatch .

@tianyu-l
Copy link
Contributor

cc @soulitzer in case you have context

@soulitzer
Copy link

It will be good to be able to relate the model layer to the saved-recompute mismatch .

Would it be possible to pass debug=True to checkpoint? With this flag enabled, the error message would contain the list of operators that ran during forward/compute and captured stack traces.

@ratnampa
Copy link

@soulitzer you can access the log here with debug=True, llama4_17bx16e N=2 PPN=8 TP=8 FSDP=2 https://github.com/ratnampa/misc_uploads/blob/main/torchtitan/llama4_17bx16e/N2_PPN8_TP8_FSDP2_llama4_17bx16e.log

I have also added log only for rank 1, might be easier to inspect.
https://github.com/ratnampa/misc_uploads/blob/main/torchtitan/llama4_17bx16e/rank1.log

@soulitzer
Copy link

Thanks for the logs, it looks like there's a split_with_sizes that received slightly different inputs between the original and recompute.

original:

torch._ops.aten.split_with_sizes.default($181, ['14', '99', '71', '38', '66', '105', '81', '78', '86', '43', '33', '118', '22', '35', '83', '52'])   

recompute:

torch._ops.aten.split_with_sizes.default($245, ['14', '98', '71', '38', '66', '106', '81', '78', '86', '43', '33', '118', '22', '35', '83', '52'])

Any idea why that is the case?

@githubsgi
Copy link
Contributor Author

Interesting pointer. Both the original and recompute in the above sum to 1024, but differ in 2 split locations (99 vs 98 a and 105 vs 106). What does the position in the following refer to ?

[rank1]: saved metadata: {'shape': torch.Size([99, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 49:
[rank1]: saved metadata: {'shape': torch.Size([99, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 51:
[rank1]: saved metadata: {'shape': torch.Size([99, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 52:
[rank1]: saved metadata: {'shape': torch.Size([99, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 53:
[rank1]: saved metadata: {'shape': torch.Size([99ug=True` to `torch.utils.checkpoint.checkpoint()`.
[rank1]: recomputed metadata: {'shape': torch.Size([98, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 55:
[rank1]: saved metadata: {'shape': torch.Size([99, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 84:
[rank1]: saved metadata: {'shape': torch.Size([105, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 85:
[rank1]: saved metadata: {'shape': torch.Size([105, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 87:
[rank1]: saved metadata: {'shape': torch.Size([105, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 88:
[rank1]: saved metadata: {'shape': torch.Size([105, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 89:
[rank1]: saved metadata: {'shape': torch.Size([105, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 91:

@soulitzer
Copy link

position 91 means it is the 91st tensor saved in the checkpointed region

@githubsgi
Copy link
Contributor Author

githubsgi commented Apr 30, 2025

How do I map the position to a layer in the model ? Also, what is the code that decides the split ?

@soulitzer
Copy link

Ah I wouldn't look at the position number here. I'd just search for split_with_sizes and below that you'd find the python and cpp stack traces which should have the module information. In this case, what you're looking for should be /home/ratnampa/torchtitan/torchtitan/experiments/llama4/model/moe.py:40:forward

The way its structured is basically:

op1
stack trace for op1
op2
stack trace for op2
...

@githubsgi
Copy link
Contributor Author

@soulitzer , thanks. Added a PyTorch PR for adding a layer identification to checkpoint discrepancies.

@soulitzer
Copy link

Thanks for the PR, I added a comment here pytorch/pytorch#153021 (review).

@githubsgi
Copy link
Contributor Author

Few questions.

  1. Is there any more design/rfc docs on activation checkpointing other than this ?
  2. The ac metadata is stored in CPU ? I guess the saved activations are left in the accelerators ?
  3. I do see differences in the input to layers (e.g. x) between forward and recompute. Where could that come from ? Could the RNG state play a role here ?
  4. What is the best way to not do ac on specific layers ?
  5. I see the selective_ac_option, what is an example of using the "op" option ?

@soulitzer
Copy link

Is there any more design/rfc docs on activation checkpointing other than

Not a lot. What type of information are you looking for? There's a code comment on some AC internals here https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L602.

The ac metadata is stored in CPU ? I guess the saved activations are left in the accelerators ?

Yes

I do see differences in the input to layers (e.g. x) between forward and recompute. Where could that come from ? Could the RNG state play a role here ?

Your forward logic depending on global state, e.g. are you explicitly branching on any globals, are there modes TorchDispatchModes/TorchFunctionModes

What is the best way to not do ac on specific layers ?

Don't think it's possible in TorchTitan through the config (@tianyu-l correct me if I'm wrong)

I see the selective_ac_option, what is an example of using the "op" option ?

It always saves some "compute intensive ops" except every other matmul.

use_op_sac = ac_config.selective_ac_option == "op"

@tianyu-l
Copy link
Contributor

What is the best way to not do ac on specific layers ?

I believe you can always define your own apply_ac method
https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/parallelize_llama.py#L292

@githubsgi
Copy link
Contributor Author

@tianyu-l , @soulitzer

I do see differences in the input to layers (e.g. x) between forward and recompute. Where could that come from ? Could the RNG state play a role here ?

Your forward logic depending on global state, e.g. are you explicitly branching on any globals, are there modes TorchDispatchModes/TorchFunctionModes

Not sure. It is the Llama4 model OOB.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants