Commit 90d66ce
authored
feat(FLCE): expose
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
This PR is a follow-up to #830, exposing `accum_dtype` option for monkey
patch functions.
All bf16 convergence tests related to fused linear cross entropy are
also enforced to run with `accum_dtype=torch.float32` for numerical
stability.
Related: #512, #742, #827, #850
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
Signed-off-by: Tcc0403 <[email protected]>accum_dtype for hf model monkey patch (#851)1 parent e6de786 commit 90d66ce
File tree
3 files changed
+4
-2
lines changed- src/liger_kernel/transformers/model
- test/convergence/bf16
3 files changed
+4
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
| 16 | + | |
16 | 17 | | |
17 | 18 | | |
18 | 19 | | |
| |||
23 | 24 | | |
24 | 25 | | |
25 | 26 | | |
| 27 | + | |
26 | 28 | | |
27 | 29 | | |
28 | 30 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
926 | 926 | | |
927 | 927 | | |
928 | 928 | | |
929 | | - | |
| 929 | + | |
930 | 930 | | |
931 | 931 | | |
932 | 932 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
860 | 860 | | |
861 | 861 | | |
862 | 862 | | |
863 | | - | |
| 863 | + | |
864 | 864 | | |
865 | 865 | | |
866 | 866 | | |
| |||
0 commit comments