Skip to content

[WIP] Add Support for msamp FP4 Quantization and Unit Test #203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

Mr-Philo
Copy link

@Mr-Philo Mr-Philo commented May 22, 2025

Work in Progress, Please donot Merge

This PR is for new features regarding to FP4 quantization for MS-AMP library.

Working items:

  • Custom FP4 quant function in CUDA
  • Custom Differentiable Gradient Estimation for weight update
  • Forward and Backward pass for simulated FP4 quantization
  • Unit Test for FP4 quantization

)
],
cmdclass={'build_ext': cpp_extension.BuildExtension}
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add newline at the end of the file

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same for other files

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

__nv_bfloat16* output_data = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>());

const int threadsPerBlock = HIP_GET_NUM_THREADS(size); // 512
// const int blocks = HIP_GET_BLOCKS(size, threadsPerBlock); // max grid num: HIP_MAX_GRID_NUM = 65535
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove all unused code like this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@@ -15,24 +15,29 @@ class Floating:
qfp_max: dict = {}

@staticmethod
def _get_fp_max(exp, man, inf_existed=True):
def _get_fp_max(exp, man, inf_existed=True, nan_existed=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we don't need to revise this file, because parameter nan_existed is never set to False in new code

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted

@@ -13,6 +13,14 @@
from msamp.common.tensor import ScalingTensor
from msamp.operators.gemm import Gemm

import os

USE_W_SIMU_FP4 = bool(int(os.getenv('USE_W_SIMU_FP4', 0)))
Copy link
Collaborator

@yzygitzh yzygitzh Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we make the naming clearer? Like using WEIGHT and ACTIVATION instead of W and A, and use SIMULATION/SIMULATE instead of SIMU. Also we need to add MSAMP_ prefix otherwise it's not bounded in MS-AMP project.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed

Makefile Outdated
@@ -24,4 +24,5 @@ lint: cpplint mdlint
postinstall:
cd msamp/operators/dist_op && bash build.sh && cd -
cd msamp/operators/arithmetic && pip install -v . && cd -
cd msamp/operators/fp4_quant && pip install -v . && cd -
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please just use quantize or quantization for naming.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -175,6 +196,9 @@ def backward(ctx, grad_output):
wgrad_qtype,
use_split_accumulator=True,
)
if USE_W_DIFFERENTIABLE_GRADIENT_ESTIMATOR:
scaled_w = ctx.saved_tensors[0]
grad_weight.mul_(FP4_QUANT.apply_DGE_item(scaled_w, k=5.0, power_clamp_max=3.0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just eliminate explicit argument assignments for k and power_clamp_max here, if their default values are already provided.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eliminated

class FP4_QUANT:
"""FP4 Quantization operator."""
@staticmethod
def apply_DGE_item(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does DGE mean here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add full term in function comment

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added



@staticmethod
def quantize_simu_fp4_in_bf16(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simulate or simulation instead of simu?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add MIT license head

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants