Skip to content

Quantization Memory Requirements #1228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
sneha5gsm opened this issue Mar 5, 2025 · 4 comments
Open

Quantization Memory Requirements #1228

sneha5gsm opened this issue Mar 5, 2025 · 4 comments
Assignees
Labels
question Further information is requested

Comments

@sneha5gsm
Copy link

Hello!

I was trying the various quantization recipes for quantizing a 70B Llama 3 based model to FP8, INT8, INT4(A16) precisions as mentioned in the quantization docs by vLLM.

  1. Could you help me understand the memory requirements for the quantization recipes, i.e SmoothQuant (SmoothQuantModifier), GPTQ (GPTQModifier) and RTN (QuantizationModifier). A calculation/formula would help, for example, like the one we have for calculating kv cache:
memory in bytes for kv cache = 80 (layers) * 8 (kv heads) * 128 (head_dim) * 8192 (seq length) * 2 (k and v) * 2 (fp16)

I understand that the calculate_offload_device_map creates a custom device map by reserving memory for
GPTQ (reserve_for_hessians), but I would still like to understand the memory requirements to be able to utilize the GPU memory efficiently, to understand where all the GPU memory is consumed and to ensure that there are no bugs.

  1. Also, I understand that currently, for quantization of big models, the model is split in a pipeline parallel way on multiple GPUs available on the instance.
  • Since the GPU which is being used at any given time is the one which has the model layer that is being quantized at that time, would the time taken to quantize the model be similar to using a single GPU to quantize the model vs using multiple GPUs?
  • Is it possible to split the model in a tensor parallel way?
  • I understand that 'non-sequential GPTQ ' is deprecated, but how much memory is required for a non-sequential GPTQ? I think the above memory calculation would help. Also, how much speed up would we see using the non-sequential approach (compared to the sequential one)?

Thank you!

@dsikka dsikka added the question Further information is requested label Mar 5, 2025
@horheynm horheynm self-assigned this Mar 7, 2025
@horheynm
Copy link
Collaborator

horheynm commented Mar 7, 2025

HI @sneha5gsm

Let me take a look at the raw memory requirements. Let me see if there is an equation similar to kv-cache to approximate.

For 2.

  • It will be faster, because in a single GPU case, we will be using CPU offloading which introduces memory bound issues. The transfer time from CPU to GPU RAM will take lots of time
  • Yes it should, but no support currently. Something similar we are planning to introduce back fsdp back into oneshot.
  • To do non-sequential for GPTQ, we can do Hessian gradient accumulation in parallel, and also quantize layers in parallel given the activations for each calibration data. We will need to compute all the memory needed in the forward pass until the given layer for all targeted layers to compute the input activation and then approximate the hessian memory requirements based on the input activations.

@SnehaGhantasalaTR
Copy link

@horheynm

Thank you for taking the time and looking into the first query and for the answers to the second one!

Follow up questions for 2:

  • Is it possible to parallely quantize layers on all the GPUs but sequentially do it for the layers in a particular GPU? That way atleast all the GPUs are being used concurrently.
  • For :

It will be faster, because in a single GPU case, we will be using CPU offloading which introduces memory bound issues. The transfer time from CPU to GPU RAM will take lots of time

What if we have 2 GPUs? That way when one GPU is offloading the other is ready for computing?

Thanks

@horheynm
Copy link
Collaborator

horheynm commented Mar 7, 2025

Yes you can if you are not concerned about the error propagation from one layer to another.

GPTQ assumes that model are pretrained, which means that the activations of the model is relatively stable. This means that the prediction of a layer and the ground truth output is very close. This means we dont have to estimate the hessian using full backpropagation, so instead, look at the layer's input activation value to approximate the hessian. Hessian is used in gptq.

So given a calibration dataset, because of the above, you can quantize the layers in parallel if you have the activations. One reason we do them sequentially is because output from one layer is the in input of others, so doing it sequentially addresses the errors as the layers are quantized. Another reason is doing it sequentially is doing it in parallel very expensive. x - the input activation - in llama is in the shape of roughly [batch, seq, d_model], using fp16, its roughly 1 MB for seq of 1024, d_model of 4096 and 2 bytes per value for batch of 1. Just for the activations.

For offloading and computing to maximize usage that is something we can optimize. If you'd like, we can guide you how you can optimize this can contribute in the repo

@sneha5gsm
Copy link
Author

@horheynm
Did you get a chance to look at the raw memory requirements for the quantization recipes?

Also for:

For offloading and computing to maximize usage that is something we can optimize. If you'd like, we can guide you how you can optimize this can contribute in the repo

How would I go about contributing for the same?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants