Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PyTorch] Miscellanous fixes for FP8 DPA module (#804)
* initialize tp_group for FP8 DPA Signed-off-by: Charlene Yang <[email protected]> * fix cuDNN version in unit tests for cuDNN v9 Signed-off-by: Charlene Yang <[email protected]> * add hook to ignore missing fused_attn._extra_states if training from old checkpoints Signed-off-by: Charlene Yang <[email protected]> * remove test and redundant implementation from last commit Signed-off-by: Charlene Yang <[email protected]> * remove warning message and replace with docstring Signed-off-by: Charlene Yang <[email protected]> * remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group Signed-off-by: Charlene Yang <[email protected]> * move core_attention.fused_attention._extra_state to core_attention._extra_state Signed-off-by: Charlene Yang <[email protected]> * simplify post_state_dict_hooks between FU and DPA Signed-off-by: Charlene Yang <[email protected]> * add temporary test Signed-off-by: Charlene Yang <[email protected]> * remove previous attempts to move core_attention.fused_attention to core_attention; keep the test Signed-off-by: Charlene Yang <[email protected]> * remove the test Signed-off-by: Charlene Yang <[email protected]> * disable pylint self arg for hook which is required by hook Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: cyanguwa <[email protected]>
- Loading branch information