-
Notifications
You must be signed in to change notification settings - Fork 622
Support skip scaling for input tensor for Triton rowwise FP8 kernel #4362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ Deploy Preview for pytorch-fbgemm-docs ready!
To edit notification comments on pull requests, go to your Netlify project configuration. |
This pull request was exported from Phabricator. Differential Revision: D76759999 |
This pull request was exported from Phabricator. Differential Revision: D76759999 |
1d9715b
to
3efacfc
Compare
…ytorch#4362) Summary: Pull Request resolved: pytorch#4362 X-link: facebookresearch/FBGEMM#1431 The scaling process (generating inputs to FP8 GEMM) adds non-trivial cost for FP8 quantization, and can offset the gain of FP8 GEMM, especially in memory bound case. By [re-designing the activation layer](https://fb.workplace.com/groups/1033540429995021/permalink/24472616132327454) of the model, most of the elements of input (more specifically, activation, corresponding to input `a` in this revision) tensor to FP8 GEMM could be within FP8 range. It is true that, even the output is within the FP8 range, the FP8 scaling is still helpful in accuracy, since the scaling process tries to use full FP8 bits to encode the information as much as possible. So evaluation on E2E model quality is needed per model basis. [A study](https://docs.google.com/document/d/1jEOVqOIn3cKe3PgFcHsxW9EKYi_-6QHnIuiPY8gLEwQ/edit?tab=t.0) has shown, by replacing the FP8 scaling/process with FP8 clamp on an Ads 500x inference model, the E2E throughput can further improve by **7%**, E2E QPS increasing from 21% (FP8) -> 28% (FP8 with skip scaling). We supported this case for Triton row-wise kernel in this revision. Reviewed By: jwfromm Differential Revision: D76759999
This pull request was exported from Phabricator. Differential Revision: D76759999 |
…ytorch#4362) Summary: Pull Request resolved: pytorch#4362 X-link: facebookresearch/FBGEMM#1431 The scaling process (generating inputs to FP8 GEMM) adds non-trivial cost for FP8 quantization, and can offset the gain of FP8 GEMM, especially in memory bound case. By [re-designing the activation layer](https://fb.workplace.com/groups/1033540429995021/permalink/24472616132327454) of the model, most of the elements of input (more specifically, activation, corresponding to input `a` in this revision) tensor to FP8 GEMM could be within FP8 range. It is true that, even the output is within the FP8 range, the FP8 scaling is still helpful in accuracy, since the scaling process tries to use full FP8 bits to encode the information as much as possible. So evaluation on E2E model quality is needed per model basis. [A study](https://docs.google.com/document/d/1jEOVqOIn3cKe3PgFcHsxW9EKYi_-6QHnIuiPY8gLEwQ/edit?tab=t.0) has shown, by replacing the FP8 scaling/process with FP8 clamp on an Ads 500x inference model, the E2E throughput can further improve by **7%**, E2E QPS increasing from 21% (FP8) -> 28% (FP8 with skip scaling). We supported this case for Triton row-wise kernel in this revision. Reviewed By: jwfromm Differential Revision: D76759999
3efacfc
to
6bf05c9
Compare
This pull request was exported from Phabricator. Differential Revision: D76759999 |
…ytorch#4362) Summary: Pull Request resolved: pytorch#4362 X-link: facebookresearch/FBGEMM#1431 The scaling process (generating inputs to FP8 GEMM) adds non-trivial cost for FP8 quantization, and can offset the gain of FP8 GEMM, especially in memory bound case. By [re-designing the activation layer](https://fb.workplace.com/groups/1033540429995021/permalink/24472616132327454) of the model, most of the elements of input (more specifically, activation, corresponding to input `a` in this revision) tensor to FP8 GEMM could be within FP8 range. It is true that, even the output is within the FP8 range, the FP8 scaling is still helpful in accuracy, since the scaling process tries to use full FP8 bits to encode the information as much as possible. So evaluation on E2E model quality is needed per model basis. [A study](https://docs.google.com/document/d/1jEOVqOIn3cKe3PgFcHsxW9EKYi_-6QHnIuiPY8gLEwQ/edit?tab=t.0) has shown, by replacing the FP8 scaling/process with FP8 clamp on an Ads 500x inference model, the E2E throughput can further improve by **7%**, E2E QPS increasing from 21% (FP8) -> 28% (FP8 with skip scaling). We supported this case for Triton row-wise kernel in this revision. Reviewed By: jwfromm Differential Revision: D76759999
6bf05c9
to
3f6f8c0
Compare
…ytorch#4362) Summary: Pull Request resolved: pytorch#4362 X-link: facebookresearch/FBGEMM#1431 The scaling process (generating inputs to FP8 GEMM) adds non-trivial cost for FP8 quantization, and can offset the gain of FP8 GEMM, especially in memory bound case. By [re-designing the activation layer](https://fb.workplace.com/groups/1033540429995021/permalink/24472616132327454) of the model, most of the elements of input (more specifically, activation, corresponding to input `a` in this revision) tensor to FP8 GEMM could be within FP8 range. It is true that, even the output is within the FP8 range, the FP8 scaling is still helpful in accuracy, since the scaling process tries to use full FP8 bits to encode the information as much as possible. So evaluation on E2E model quality is needed per model basis. [A study](https://docs.google.com/document/d/1jEOVqOIn3cKe3PgFcHsxW9EKYi_-6QHnIuiPY8gLEwQ/edit?tab=t.0) has shown, by replacing the FP8 scaling/process with FP8 clamp on an Ads 500x inference model, the E2E throughput can further improve by **7%**, E2E QPS increasing from 21% (FP8) -> 28% (FP8 with skip scaling). We supported this case for Triton row-wise kernel in this revision. Reviewed By: jwfromm Differential Revision: D76759999
This pull request was exported from Phabricator. Differential Revision: D76759999 |
3f6f8c0
to
733c0b5
Compare
This pull request has been merged in 6152f34. |
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/1431
The scaling process (generating inputs to FP8 GEMM) adds non-trivial
cost for FP8 quantization, and can offset the gain of FP8 GEMM, especially
in memory bound case.
By re-designing the activation layer of the model, most of the elements
of input (more specifically, activation, corresponding to input
a
in thisrevision) tensor to FP8 GEMM could be within FP8 range.
It is true that, even the output is within the FP8 range, the FP8 scaling is still
helpful in accuracy, since the scaling process tries to use full FP8 bits to
encode the information as much as possible. So evaluation on E2E model
quality is needed per model basis.
A study has shown, by replacing the FP8 scaling/process with FP8 clamp
on an Ads 500x inference model, the E2E throughput can further improve by 7%,
E2E QPS increasing from 21% (FP8) -> 28% (FP8 with skip scaling).
We supported this case for Triton row-wise kernel in this revision.
Reviewed By: y-x-c
Differential Revision: D76759999