-
Notifications
You must be signed in to change notification settings - Fork 49
[WIP] Add Support for msamp FP4 Quantization and Unit Test #203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
msamp/operators/fp4_quant/setup.py
Outdated
) | ||
], | ||
cmdclass={'build_ext': cpp_extension.BuildExtension} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add newline at the end of the file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same for other files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
msamp/operators/fp4_quant/quant.cu
Outdated
__nv_bfloat16* output_data = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>()); | ||
|
||
const int threadsPerBlock = HIP_GET_NUM_THREADS(size); // 512 | ||
// const int blocks = HIP_GET_BLOCKS(size, threadsPerBlock); // max grid num: HIP_MAX_GRID_NUM = 65535 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove all unused code like this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
msamp/common/dtype/floating.py
Outdated
@@ -15,24 +15,29 @@ class Floating: | |||
qfp_max: dict = {} | |||
|
|||
@staticmethod | |||
def _get_fp_max(exp, man, inf_existed=True): | |||
def _get_fp_max(exp, man, inf_existed=True, nan_existed=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems we don't need to revise this file, because parameter nan_existed
is never set to False
in new code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted
msamp/megatron/layers.py
Outdated
@@ -13,6 +13,14 @@ | |||
from msamp.common.tensor import ScalingTensor | |||
from msamp.operators.gemm import Gemm | |||
|
|||
import os | |||
|
|||
USE_W_SIMU_FP4 = bool(int(os.getenv('USE_W_SIMU_FP4', 0))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we make the naming clearer? Like using WEIGHT
and ACTIVATION
instead of W
and A
, and use SIMULATION
/SIMULATE
instead of SIMU
. Also we need to add MSAMP_
prefix otherwise it's not bounded in MS-AMP project.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed
Makefile
Outdated
@@ -24,4 +24,5 @@ lint: cpplint mdlint | |||
postinstall: | |||
cd msamp/operators/dist_op && bash build.sh && cd - | |||
cd msamp/operators/arithmetic && pip install -v . && cd - | |||
cd msamp/operators/fp4_quant && pip install -v . && cd - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please just use quantize
or quantization
for naming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
msamp/megatron/layers.py
Outdated
@@ -175,6 +196,9 @@ def backward(ctx, grad_output): | |||
wgrad_qtype, | |||
use_split_accumulator=True, | |||
) | |||
if USE_W_DIFFERENTIABLE_GRADIENT_ESTIMATOR: | |||
scaled_w = ctx.saved_tensors[0] | |||
grad_weight.mul_(FP4_QUANT.apply_DGE_item(scaled_w, k=5.0, power_clamp_max=3.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just eliminate explicit argument assignments for k
and power_clamp_max
here, if their default values are already provided.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eliminated
class FP4_QUANT: | ||
"""FP4 Quantization operator.""" | ||
@staticmethod | ||
def apply_DGE_item( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does DGE
mean here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add full term in function comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
|
||
|
||
@staticmethod | ||
def quantize_simu_fp4_in_bf16( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simulate
or simulation
instead of simu
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add MIT license head
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
Work in Progress, Please donot Merge
This PR is for new features regarding to FP4 quantization for MS-AMP library.
Working items: