You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[PyTorch] Miscellanous fixes for FP8 DPA module (#804)
* initialize tp_group for FP8 DPA
Signed-off-by: Charlene Yang <[email protected]>
* fix cuDNN version in unit tests for cuDNN v9
Signed-off-by: Charlene Yang <[email protected]>
* add hook to ignore missing fused_attn._extra_states if training from old checkpoints
Signed-off-by: Charlene Yang <[email protected]>
* remove test and redundant implementation from last commit
Signed-off-by: Charlene Yang <[email protected]>
* remove warning message and replace with docstring
Signed-off-by: Charlene Yang <[email protected]>
* remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group
Signed-off-by: Charlene Yang <[email protected]>
* move core_attention.fused_attention._extra_state to core_attention._extra_state
Signed-off-by: Charlene Yang <[email protected]>
* simplify post_state_dict_hooks between FU and DPA
Signed-off-by: Charlene Yang <[email protected]>
* add temporary test
Signed-off-by: Charlene Yang <[email protected]>
* remove previous attempts to move core_attention.fused_attention to core_attention; keep the test
Signed-off-by: Charlene Yang <[email protected]>
* remove the test
Signed-off-by: Charlene Yang <[email protected]>
* disable pylint self arg for hook which is required by hook
Signed-off-by: Charlene Yang <[email protected]>
---------
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: cyanguwa <[email protected]>
0 commit comments