Skip to content

update quantization documentations for int4 blockwise. #641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ai_edge_torch/generative/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ Once converted, you will get a quantized `.tflite` model which will be ready for

In the current release, the following schemes are supported:

* Dynamic range quantization with FP32 activations and INT8 weights for linear ops
* FP16 quantization with FP16 weights and FP32 activations and computation for all ops
* Dynamic quantization with FP32 activations and INT8 weights for linear ops
* FP16 quantization with FP16 weights and FP32 activations and computation for all opss
* Dynamic INT4 blockwise quantization: FP32 activations, INT4 weights, and integer computation, block size must be multiple of 32

These correspond to the available recipes in `quant_recipes.py`
<br/>
Expand Down
35 changes: 35 additions & 0 deletions ai_edge_torch/generative/examples/gemma3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,41 @@ The Gemma 3 is the latest model in the Gemma family of open weights models.

The Gemma 3 Tokenizer is available in the Gemma PyTorch repo [here](https://github.com/google/gemma_pytorch). The reauthored models here are compatible with that tokenizer.

## Convert & Quantize Gemma 3 to TFlite

Convert and quantize Gemma 3 model to various quantization schemes can be done using the following command:

```bash
python convert_gemma3_to_tflite.py --quantize=<string for the desired quantization schemes> \
--checkpoint_path=<path to torch safetensor directory> \
--output_path=<path the directory where the tflite file to be saved> \
--prefill_seq_lens=<maximum length of supported input> \
--kv_cache_max_len=<maximum of prefill + decode context length> \
--mask_as_input=True
```

For example, the following command was used to create the dynamic int4 block32 models

```bash
python convert_gemma3_to_tflite.py --quantize="dynamic_int4_block32" \
--checkpoint_path=/tmp/gemma-3-pytorch-gemma-3-1b-pt-v1 --output_path="/tmp/" \
--prefill_seq_lens=2048 --kv_cache_max_len=4096 --mask_as_input=True
```

All ready to use quantization schemes can be found in [here](https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/utilities/converter.py#L46)


Comparison of various quantization schemes are shared below

| quantization scheme | Model Size | PIQA score | CPU Prefill Speed | CPU Decode Speed | Peak Memory Usage|
| ------------------- | ---------- | ---------- | ----------------- | ---------------- |------------------|
| Dynamic INT8 | 973 MB | 73.61 | 172.65 tokens/s | 34.97 tokens/s | 1.63 GB |
| Dynamic INT4 Block32 | 711 MB | 72.9 | 124.24 tokens/s | 41.06 tokens/s | 1.41 GB |
| Dynamic INT4 Block128| 650 MB | 71.6 | 146.22 tokens/s | 42.78 tokens/s | 1.31 GB |

Note: All speed & memory usage are benchmarked on Sanpdragon 8 elite device, performance may vary from device to device


## Gemma 3 Task File Creation

Creation of a Task file is needed to use the converted model and tokenizer in the [LLM Inference API](https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference). To create the `.task` file for Gemma 3, pip install the [mediapipe](https://pypi.org/project/mediapipe/) Python package and then execute the following Python code:
Expand Down
11 changes: 6 additions & 5 deletions ai_edge_torch/generative/quantize/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ Once converted, you will get a quantized `.tflite` model which will be ready for

In the current release, the following schemes are supported:

* Dynamic range quantization: FP32 activations, INT8 weights, and integer computation
* Weight-only quantization: FP32 activations, INT8 weights, and floating point computation
* Dynamic INT8 quantization: FP32 activations, INT8 weights, and integer computation
* Weight-only INT8 quantization: FP32 activations, INT8 weights, and floating point computation
* FP16 quantization: FP16 weights, FP32 activations and floating point computation for all ops
* Dynamic INT4 blockwise quantization: FP32 activations, INT4 weights, and integer computation, block size must be multiple of 32

These correspond to the available recipes in `quant_recipes.py`.
Preset recipes to the available recipes in `quant_recipes.py`.

## Advanced usage

Expand All @@ -36,8 +37,8 @@ def custom_selective_quantization_recipe() -> quant_config.QuantConfig:
generative_recipe=quant_recipe.GenerativeQuantRecipe(
default=create_layer_quant_fp16(),
embedding=create_layer_quant_int8_dynamic(),
attention=create_layer_quant_int8_weight_only(),
feedforward=create_layer_quant_int8_dynamic(),
attention=create_layer_quant_int4_block(32),
feedforward=create_layer_quant_int4_block(256),
)
)
```
Expand Down
6 changes: 5 additions & 1 deletion ai_edge_torch/generative/quantize/quant_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
from dataclasses import dataclass
from typing import Optional, Union

from ai_edge_torch.generative.layers import model_config
from ai_edge_torch.generative.quantize import quant_attrs
from ai_edge_torch.generative.quantize import supported_schemes

ModelConfig = model_config.ModelConfig


@dataclass
class LayerQuantRecipe:
Expand Down Expand Up @@ -52,7 +55,7 @@ def __str__(self):
f'w:{self.weight_dtype.name}, '
f'{self.mode.name}, '
f'{self.algorithm.name}, '
f'{self.granularity.name}'
f'{self.granularity.name}, '
f'{self.block_size}'
)
return f'{base_str})'
Expand Down Expand Up @@ -133,6 +136,7 @@ class GenerativeQuantRecipe:
feedforward: Union[
Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
] = None
_model_config: Optional[ModelConfig] = None

def __str__(self):
return f"""GenerativeQuantRecipe(
Expand Down
3 changes: 2 additions & 1 deletion ai_edge_torch/generative/quantize/quant_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def all_supported_int4_dynamic_block_recipe(
generative_recipe=quant_recipe.GenerativeQuantRecipe(
default=quant_recipe_utils.create_layer_quant_int4_dynamic_block(
block_size
)
),
embedding=quant_recipe_utils.create_layer_quant_int8_dynamic(),
)
)
1 change: 0 additions & 1 deletion ai_edge_torch/generative/test/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================

import ai_edge_torch
from ai_edge_torch import config
from ai_edge_torch.generative.examples.test_models import toy_model # NOQA
from ai_edge_torch.generative.quantize import quant_recipe
from ai_edge_torch.generative.quantize import quant_recipe_utils
Expand Down
7 changes: 5 additions & 2 deletions ai_edge_torch/generative/tools/batch_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,12 @@ def convert_models(conversion_configs: Sequence[ConversionConfig]) -> None:
)
converter.convert_to_tflite(
pytorch_model,
tflite_path=os.path.join(config.tflite_output_path, output_filename),
output_path=config.tflite_output_path,
output_name_prefix=output_filename,
prefill_seq_len=config.prefill_seq_lens,
quantize=True if precision == ExportPrecision.INT8 else False,
quantize=converter.QuantizationName.DYNAMIC_INT8
if precision == ExportPrecision.INT8
else converter.QuantizationName.NONE,
export_config=ExportConfig(),
)
logging.info("Successfully converted model: %s", output_filename)
Expand Down
105 changes: 95 additions & 10 deletions ai_edge_torch/generative/utilities/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Common utility functions for model conversion."""

import enum
import os
import pathlib
from typing import Optional, Union
Expand Down Expand Up @@ -42,6 +43,27 @@ def forward(self, *export_args, **export_kwargs):
return self.module(*export_args, **full_kwargs)


class QuantizationName(str, enum.Enum):
"""Strings for all supported quantization recipes.

none: No quantization.
dynamic_int8: Dynamic range quantization with int8 weights.
weight_only_int8: Weight only quantization with int8 weights.
fp16: Float16 quantization.
dynamic_int4_block32: Dynamic range quantization with int4 weights and block
size of 32, better model quality but slower inference.
dynamic_int4_block128: Dynamic range quantization with int4 weights and block
size of 128, faster inference but worse model quality.
"""

NONE = 'none'
DYNAMIC_INT8 = 'dynamic_int8'
WEIGHT_ONLY_INT8 = 'weight_only_int8'
FP16 = 'fp16'
DYNAMIC_INT4_BLOCK32 = 'dynamic_int4_block32'
DYNAMIC_INT4_BLOCK128 = 'dynamic_int4_block128'


def define_conversion_flags(
model_name: str,
default_mask_as_input: bool = False,
Expand Down Expand Up @@ -74,10 +96,10 @@ def define_conversion_flags(
1280,
'The maximum size of KV cache buffer, including both prefill and decode.',
)
flags.DEFINE_bool(
flags.DEFINE_string(
'quantize',
True,
'Whether the model should be quantized.',
'dynamic_int8',
'How the model should be quantized.',
)
flags.DEFINE_multi_integer(
'lora_ranks',
Expand All @@ -99,6 +121,66 @@ def define_conversion_flags(
return flags


def get_quant_recipe_from_flag(
quantize: str,
) -> Optional[quant_recipes.QuantizationRecipe]:
"""Processes the quantization flag and returns the corresponding recipe.

Args:
quantize: The quantization type.

Returns:
The quantization recipe, or None if no quantization is needed.

Raises:
ValueError: If the quantization type is not supported.
"""
match quantize:
case QuantizationName.NONE:
return None
case QuantizationName.DYNAMIC_INT8:
return quant_recipes.full_int8_dynamic_recipe()
case QuantizationName.WEIGHT_ONLY_INT8:
return quant_recipes.full_int8_weight_only_recipe()
case QuantizationName.FP16:
return quant_recipes.full_fp16_recipe()
case QuantizationName.DYNAMIC_INT4_BLOCK32:
return quant_recipes.full_int4_dynamic_block_recipe(32)
case QuantizationName.DYNAMIC_INT4_BLOCK128:
return quant_recipes.full_int4_dynamic_block_recipe(128)
case _:
raise ValueError(f'Unsupported quantization flag: {quantize}')


def create_quantize_suffix(quantize: str) -> str:
"""Creates a suffix for the output file name based on the quantization type.

Args:
quantize: The quantization type.

Returns:
A string representing the quantization suffix.

Raises:
ValueError: If the quantization type is not supported.
"""
match quantize:
case QuantizationName.NONE:
return 'f32'
case QuantizationName.DYNAMIC_INT8:
return 'q8'
case QuantizationName.WEIGHT_ONLY_INT8:
return 'q8_wo'
case QuantizationName.FP16:
return 'fp16'
case QuantizationName.DYNAMIC_INT4_BLOCK32:
return 'q4_block32'
case QuantizationName.DYNAMIC_INT4_BLOCK128:
return 'q4_block128'
case _:
raise ValueError(f'Unsupported quantization flag: {quantize}')


def _build_mask(mask_len, kv_cache_max_len, causal_mask_value) -> torch.Tensor:
if isinstance(mask_len, list):
return [
Expand All @@ -118,7 +200,7 @@ def convert_to_tflite(
prefill_seq_len: Union[int, list[int]],
pixel_values_size: torch.Size = None,
pixel_seq_len: int = 0,
quantize: bool = True,
quantize: str = 'dynamic_int8',
config: cfg.ModelConfig = None,
lora_ranks: Optional[list[int]] = None,
export_config: ExportConfig = None,
Expand Down Expand Up @@ -164,8 +246,8 @@ def convert_to_tflite(
embeddings generated by the image encoder with pixel values. The actual
length of prefill_seq_len will be added by pixel_seq_len when pixel
values are passed.
quantize (bool, optional): Whether the model should be quanized. Defaults
to True.
quantize (str, optional): The quantization type. Defaults to
'dynamic_int8'.
config (cfg.ModelConfig, optional): The model config used to configure KV
cache. If None, it uses the config of the pytorch_model.
lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
Expand All @@ -186,7 +268,7 @@ def convert_to_tflite(
lora = lora_utils.LoRA.zeros(rank, config)
loras.append(lora)

quant_suffix = 'q8' if quantize else 'f32'
quant_suffix = create_quantize_suffix(quantize)
kv_size = config.kv_cache_max_len
lora_suffix = (
'' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
Expand Down Expand Up @@ -220,7 +302,7 @@ def _export_helper(
prefill_seq_lens: list[int],
pixel_values_size: torch.Size,
pixel_seq_len: int,
quantize: bool,
quantize: str,
config: cfg.ModelConfig,
loras: list[None | lora_utils.LoRA],
export_config: ExportConfig,
Expand Down Expand Up @@ -269,7 +351,8 @@ def _export_helper(
kv_layout=export_config.kvcache_layout,
)

quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
quant_config = get_quant_recipe_from_flag(quantize)
quant_config._model_config = config

# For export, we create a module that captures any non-exportable,
# arugments, e.g. the generation config object.
Expand Down Expand Up @@ -334,5 +417,7 @@ def _export_helper(
sample_kwargs=sample_kwargs,
)

edge_model = converter.convert(quant_config=quant_config)
edge_model = converter.convert(
quant_config=quant_config,
)
edge_model.export(output_file)
2 changes: 1 addition & 1 deletion ai_edge_torch/lowertools/_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def exported_programs_to_tflite(
*,
quant_config: Optional[qcfg.QuantConfig] = None,
_tfl_converter_flags: Optional[dict[str, Any]] = None,
_saved_model_dir: Optional[str] = None
_saved_model_dir: Optional[str] = None,
):
"""Converts a list of ExportedProgram to a TFLite model.

Expand Down
15 changes: 14 additions & 1 deletion ai_edge_torch/lowertools/translate_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
_ATTENTION_REGEX_STR = 'ai_edge_torch.generative.layers.attention'
_FEEDFORWARD_REGEX_STR = 'ai_edge_torch.generative.layers.feed_forward'
_EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
# TODO: b/415833584 - Improve the regex for pre-softmax layer.
_DECODE_LOGITS_REGEX_STR = 'StatefulPartitionedCall'
_ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'


Expand Down Expand Up @@ -95,10 +97,11 @@ def _set_quant_config(
rm: quantizer.recipe_manager.RecipeManager,
layer_recipe: quant_recipe.LayerQuantRecipe,
regex: str,
operation_name: _OpName = _OpName.ALL_SUPPORTED,
):
rm.add_quantization_config(
regex=regex,
operation_name=_OpName.ALL_SUPPORTED,
operation_name=operation_name,
op_config=_OpQuantConfig(
weight_tensor_config=_TensorQuantConfig(
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
Expand Down Expand Up @@ -126,6 +129,16 @@ def translate_to_ai_edge_recipe(

if recipe.embedding is not None:
_set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR)
if (
recipe._model_config is not None
and recipe._model_config.lm_head_share_weight_with_embedding
):
_set_quant_config(
rm,
recipe.embedding,
_DECODE_LOGITS_REGEX_STR,
_OpName.FULLY_CONNECTED,
)

if recipe.attention is not None:
if isinstance(recipe.attention, dict):
Expand Down
Loading