diff --git a/ai_edge_torch/generative/README.md b/ai_edge_torch/generative/README.md
index fd1312a3..902b0d83 100644
--- a/ai_edge_torch/generative/README.md
+++ b/ai_edge_torch/generative/README.md
@@ -46,8 +46,9 @@ Once converted, you will get a quantized `.tflite` model which will be ready for
In the current release, the following schemes are supported:
-* Dynamic range quantization with FP32 activations and INT8 weights for linear ops
-* FP16 quantization with FP16 weights and FP32 activations and computation for all ops
+* Dynamic quantization with FP32 activations and INT8 weights for linear ops
+* FP16 quantization with FP16 weights and FP32 activations and computation for all opss
+* Dynamic INT4 blockwise quantization: FP32 activations, INT4 weights, and integer computation, block size must be multiple of 32
These correspond to the available recipes in `quant_recipes.py`
diff --git a/ai_edge_torch/generative/examples/gemma3/README.md b/ai_edge_torch/generative/examples/gemma3/README.md
index 9c55e339..9a2a8057 100644
--- a/ai_edge_torch/generative/examples/gemma3/README.md
+++ b/ai_edge_torch/generative/examples/gemma3/README.md
@@ -6,6 +6,41 @@ The Gemma 3 is the latest model in the Gemma family of open weights models.
The Gemma 3 Tokenizer is available in the Gemma PyTorch repo [here](https://github.com/google/gemma_pytorch). The reauthored models here are compatible with that tokenizer.
+## Convert & Quantize Gemma 3 to TFlite
+
+Convert and quantize Gemma 3 model to various quantization schemes can be done using the following command:
+
+```bash
+python convert_gemma3_to_tflite.py --quantize= \
+ --checkpoint_path= \
+ --output_path= \
+ --prefill_seq_lens= \
+ --kv_cache_max_len= \
+ --mask_as_input=True
+```
+
+For example, the following command was used to create the dynamic int4 block32 models
+
+```bash
+python convert_gemma3_to_tflite.py --quantize="dynamic_int4_block32" \
+ --checkpoint_path=/tmp/gemma-3-pytorch-gemma-3-1b-pt-v1 --output_path="/tmp/" \
+ --prefill_seq_lens=2048 --kv_cache_max_len=4096 --mask_as_input=True
+```
+
+All ready to use quantization schemes can be found in [here](https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/utilities/converter.py#L46)
+
+
+Comparison of various quantization schemes are shared below
+
+| quantization scheme | Model Size | PIQA score | CPU Prefill Speed | CPU Decode Speed | Peak Memory Usage|
+| ------------------- | ---------- | ---------- | ----------------- | ---------------- |------------------|
+| Dynamic INT8 | 973 MB | 73.61 | 172.65 tokens/s | 34.97 tokens/s | 1.63 GB |
+| Dynamic INT4 Block32 | 711 MB | 72.9 | 124.24 tokens/s | 41.06 tokens/s | 1.41 GB |
+| Dynamic INT4 Block128| 650 MB | 71.6 | 146.22 tokens/s | 42.78 tokens/s | 1.31 GB |
+
+Note: All speed & memory usage are benchmarked on Sanpdragon 8 elite device, performance may vary from device to device
+
+
## Gemma 3 Task File Creation
Creation of a Task file is needed to use the converted model and tokenizer in the [LLM Inference API](https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference). To create the `.task` file for Gemma 3, pip install the [mediapipe](https://pypi.org/project/mediapipe/) Python package and then execute the following Python code:
diff --git a/ai_edge_torch/generative/quantize/README.md b/ai_edge_torch/generative/quantize/README.md
index 7a9f1fb2..588af979 100644
--- a/ai_edge_torch/generative/quantize/README.md
+++ b/ai_edge_torch/generative/quantize/README.md
@@ -18,11 +18,12 @@ Once converted, you will get a quantized `.tflite` model which will be ready for
In the current release, the following schemes are supported:
-* Dynamic range quantization: FP32 activations, INT8 weights, and integer computation
-* Weight-only quantization: FP32 activations, INT8 weights, and floating point computation
+* Dynamic INT8 quantization: FP32 activations, INT8 weights, and integer computation
+* Weight-only INT8 quantization: FP32 activations, INT8 weights, and floating point computation
* FP16 quantization: FP16 weights, FP32 activations and floating point computation for all ops
+* Dynamic INT4 blockwise quantization: FP32 activations, INT4 weights, and integer computation, block size must be multiple of 32
-These correspond to the available recipes in `quant_recipes.py`.
+Preset recipes to the available recipes in `quant_recipes.py`.
## Advanced usage
@@ -36,8 +37,8 @@ def custom_selective_quantization_recipe() -> quant_config.QuantConfig:
generative_recipe=quant_recipe.GenerativeQuantRecipe(
default=create_layer_quant_fp16(),
embedding=create_layer_quant_int8_dynamic(),
- attention=create_layer_quant_int8_weight_only(),
- feedforward=create_layer_quant_int8_dynamic(),
+ attention=create_layer_quant_int4_block(32),
+ feedforward=create_layer_quant_int4_block(256),
)
)
```
diff --git a/ai_edge_torch/generative/quantize/quant_recipe.py b/ai_edge_torch/generative/quantize/quant_recipe.py
index c4906b20..a287bfc2 100644
--- a/ai_edge_torch/generative/quantize/quant_recipe.py
+++ b/ai_edge_torch/generative/quantize/quant_recipe.py
@@ -16,9 +16,12 @@
from dataclasses import dataclass
from typing import Optional, Union
+from ai_edge_torch.generative.layers import model_config
from ai_edge_torch.generative.quantize import quant_attrs
from ai_edge_torch.generative.quantize import supported_schemes
+ModelConfig = model_config.ModelConfig
+
@dataclass
class LayerQuantRecipe:
@@ -52,7 +55,7 @@ def __str__(self):
f'w:{self.weight_dtype.name}, '
f'{self.mode.name}, '
f'{self.algorithm.name}, '
- f'{self.granularity.name}'
+ f'{self.granularity.name}, '
f'{self.block_size}'
)
return f'{base_str})'
@@ -133,6 +136,7 @@ class GenerativeQuantRecipe:
feedforward: Union[
Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
] = None
+ _model_config: Optional[ModelConfig] = None
def __str__(self):
return f"""GenerativeQuantRecipe(
diff --git a/ai_edge_torch/generative/quantize/quant_recipes.py b/ai_edge_torch/generative/quantize/quant_recipes.py
index 1ae7c431..4f392108 100644
--- a/ai_edge_torch/generative/quantize/quant_recipes.py
+++ b/ai_edge_torch/generative/quantize/quant_recipes.py
@@ -63,6 +63,7 @@ def all_supported_int4_dynamic_block_recipe(
generative_recipe=quant_recipe.GenerativeQuantRecipe(
default=quant_recipe_utils.create_layer_quant_int4_dynamic_block(
block_size
- )
+ ),
+ embedding=quant_recipe_utils.create_layer_quant_int8_dynamic(),
)
)
diff --git a/ai_edge_torch/generative/test/test_quantize.py b/ai_edge_torch/generative/test/test_quantize.py
index acb467bd..c2539bdd 100644
--- a/ai_edge_torch/generative/test/test_quantize.py
+++ b/ai_edge_torch/generative/test/test_quantize.py
@@ -14,7 +14,6 @@
# ==============================================================================
import ai_edge_torch
-from ai_edge_torch import config
from ai_edge_torch.generative.examples.test_models import toy_model # NOQA
from ai_edge_torch.generative.quantize import quant_recipe
from ai_edge_torch.generative.quantize import quant_recipe_utils
diff --git a/ai_edge_torch/generative/tools/batch_convert.py b/ai_edge_torch/generative/tools/batch_convert.py
index f106b9db..82d03fb4 100644
--- a/ai_edge_torch/generative/tools/batch_convert.py
+++ b/ai_edge_torch/generative/tools/batch_convert.py
@@ -282,9 +282,12 @@ def convert_models(conversion_configs: Sequence[ConversionConfig]) -> None:
)
converter.convert_to_tflite(
pytorch_model,
- tflite_path=os.path.join(config.tflite_output_path, output_filename),
+ output_path=config.tflite_output_path,
+ output_name_prefix=output_filename,
prefill_seq_len=config.prefill_seq_lens,
- quantize=True if precision == ExportPrecision.INT8 else False,
+ quantize=converter.QuantizationName.DYNAMIC_INT8
+ if precision == ExportPrecision.INT8
+ else converter.QuantizationName.NONE,
export_config=ExportConfig(),
)
logging.info("Successfully converted model: %s", output_filename)
diff --git a/ai_edge_torch/generative/utilities/converter.py b/ai_edge_torch/generative/utilities/converter.py
index ee57fa0e..21a78e87 100644
--- a/ai_edge_torch/generative/utilities/converter.py
+++ b/ai_edge_torch/generative/utilities/converter.py
@@ -15,6 +15,7 @@
"""Common utility functions for model conversion."""
+import enum
import os
import pathlib
from typing import Optional, Union
@@ -42,6 +43,27 @@ def forward(self, *export_args, **export_kwargs):
return self.module(*export_args, **full_kwargs)
+class QuantizationName(str, enum.Enum):
+ """Strings for all supported quantization recipes.
+
+ none: No quantization.
+ dynamic_int8: Dynamic range quantization with int8 weights.
+ weight_only_int8: Weight only quantization with int8 weights.
+ fp16: Float16 quantization.
+ dynamic_int4_block32: Dynamic range quantization with int4 weights and block
+ size of 32, better model quality but slower inference.
+ dynamic_int4_block128: Dynamic range quantization with int4 weights and block
+ size of 128, faster inference but worse model quality.
+ """
+
+ NONE = 'none'
+ DYNAMIC_INT8 = 'dynamic_int8'
+ WEIGHT_ONLY_INT8 = 'weight_only_int8'
+ FP16 = 'fp16'
+ DYNAMIC_INT4_BLOCK32 = 'dynamic_int4_block32'
+ DYNAMIC_INT4_BLOCK128 = 'dynamic_int4_block128'
+
+
def define_conversion_flags(
model_name: str,
default_mask_as_input: bool = False,
@@ -74,10 +96,10 @@ def define_conversion_flags(
1280,
'The maximum size of KV cache buffer, including both prefill and decode.',
)
- flags.DEFINE_bool(
+ flags.DEFINE_string(
'quantize',
- True,
- 'Whether the model should be quantized.',
+ 'dynamic_int8',
+ 'How the model should be quantized.',
)
flags.DEFINE_multi_integer(
'lora_ranks',
@@ -99,6 +121,66 @@ def define_conversion_flags(
return flags
+def get_quant_recipe_from_flag(
+ quantize: str,
+) -> Optional[quant_recipes.QuantizationRecipe]:
+ """Processes the quantization flag and returns the corresponding recipe.
+
+ Args:
+ quantize: The quantization type.
+
+ Returns:
+ The quantization recipe, or None if no quantization is needed.
+
+ Raises:
+ ValueError: If the quantization type is not supported.
+ """
+ match quantize:
+ case QuantizationName.NONE:
+ return None
+ case QuantizationName.DYNAMIC_INT8:
+ return quant_recipes.full_int8_dynamic_recipe()
+ case QuantizationName.WEIGHT_ONLY_INT8:
+ return quant_recipes.full_int8_weight_only_recipe()
+ case QuantizationName.FP16:
+ return quant_recipes.full_fp16_recipe()
+ case QuantizationName.DYNAMIC_INT4_BLOCK32:
+ return quant_recipes.full_int4_dynamic_block_recipe(32)
+ case QuantizationName.DYNAMIC_INT4_BLOCK128:
+ return quant_recipes.full_int4_dynamic_block_recipe(128)
+ case _:
+ raise ValueError(f'Unsupported quantization flag: {quantize}')
+
+
+def create_quantize_suffix(quantize: str) -> str:
+ """Creates a suffix for the output file name based on the quantization type.
+
+ Args:
+ quantize: The quantization type.
+
+ Returns:
+ A string representing the quantization suffix.
+
+ Raises:
+ ValueError: If the quantization type is not supported.
+ """
+ match quantize:
+ case QuantizationName.NONE:
+ return 'f32'
+ case QuantizationName.DYNAMIC_INT8:
+ return 'q8'
+ case QuantizationName.WEIGHT_ONLY_INT8:
+ return 'q8_wo'
+ case QuantizationName.FP16:
+ return 'fp16'
+ case QuantizationName.DYNAMIC_INT4_BLOCK32:
+ return 'q4_block32'
+ case QuantizationName.DYNAMIC_INT4_BLOCK128:
+ return 'q4_block128'
+ case _:
+ raise ValueError(f'Unsupported quantization flag: {quantize}')
+
+
def _build_mask(mask_len, kv_cache_max_len, causal_mask_value) -> torch.Tensor:
if isinstance(mask_len, list):
return [
@@ -118,7 +200,7 @@ def convert_to_tflite(
prefill_seq_len: Union[int, list[int]],
pixel_values_size: torch.Size = None,
pixel_seq_len: int = 0,
- quantize: bool = True,
+ quantize: str = 'dynamic_int8',
config: cfg.ModelConfig = None,
lora_ranks: Optional[list[int]] = None,
export_config: ExportConfig = None,
@@ -164,8 +246,8 @@ def convert_to_tflite(
embeddings generated by the image encoder with pixel values. The actual
length of prefill_seq_len will be added by pixel_seq_len when pixel
values are passed.
- quantize (bool, optional): Whether the model should be quanized. Defaults
- to True.
+ quantize (str, optional): The quantization type. Defaults to
+ 'dynamic_int8'.
config (cfg.ModelConfig, optional): The model config used to configure KV
cache. If None, it uses the config of the pytorch_model.
lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
@@ -186,7 +268,7 @@ def convert_to_tflite(
lora = lora_utils.LoRA.zeros(rank, config)
loras.append(lora)
- quant_suffix = 'q8' if quantize else 'f32'
+ quant_suffix = create_quantize_suffix(quantize)
kv_size = config.kv_cache_max_len
lora_suffix = (
'' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
@@ -220,7 +302,7 @@ def _export_helper(
prefill_seq_lens: list[int],
pixel_values_size: torch.Size,
pixel_seq_len: int,
- quantize: bool,
+ quantize: str,
config: cfg.ModelConfig,
loras: list[None | lora_utils.LoRA],
export_config: ExportConfig,
@@ -269,7 +351,8 @@ def _export_helper(
kv_layout=export_config.kvcache_layout,
)
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
+ quant_config = get_quant_recipe_from_flag(quantize)
+ quant_config._model_config = config
# For export, we create a module that captures any non-exportable,
# arugments, e.g. the generation config object.
@@ -334,5 +417,7 @@ def _export_helper(
sample_kwargs=sample_kwargs,
)
- edge_model = converter.convert(quant_config=quant_config)
+ edge_model = converter.convert(
+ quant_config=quant_config,
+ )
edge_model.export(output_file)
diff --git a/ai_edge_torch/lowertools/_shim.py b/ai_edge_torch/lowertools/_shim.py
index 972d9efd..9d1d680f 100644
--- a/ai_edge_torch/lowertools/_shim.py
+++ b/ai_edge_torch/lowertools/_shim.py
@@ -50,7 +50,7 @@ def exported_programs_to_tflite(
*,
quant_config: Optional[qcfg.QuantConfig] = None,
_tfl_converter_flags: Optional[dict[str, Any]] = None,
- _saved_model_dir: Optional[str] = None
+ _saved_model_dir: Optional[str] = None,
):
"""Converts a list of ExportedProgram to a TFLite model.
diff --git a/ai_edge_torch/lowertools/translate_recipe.py b/ai_edge_torch/lowertools/translate_recipe.py
index 6c7fa37e..bb694748 100644
--- a/ai_edge_torch/lowertools/translate_recipe.py
+++ b/ai_edge_torch/lowertools/translate_recipe.py
@@ -29,6 +29,8 @@
_ATTENTION_REGEX_STR = 'ai_edge_torch.generative.layers.attention'
_FEEDFORWARD_REGEX_STR = 'ai_edge_torch.generative.layers.feed_forward'
_EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
+# TODO: b/415833584 - Improve the regex for pre-softmax layer.
+_DECODE_LOGITS_REGEX_STR = 'StatefulPartitionedCall'
_ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'
@@ -95,10 +97,11 @@ def _set_quant_config(
rm: quantizer.recipe_manager.RecipeManager,
layer_recipe: quant_recipe.LayerQuantRecipe,
regex: str,
+ operation_name: _OpName = _OpName.ALL_SUPPORTED,
):
rm.add_quantization_config(
regex=regex,
- operation_name=_OpName.ALL_SUPPORTED,
+ operation_name=operation_name,
op_config=_OpQuantConfig(
weight_tensor_config=_TensorQuantConfig(
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
@@ -126,6 +129,16 @@ def translate_to_ai_edge_recipe(
if recipe.embedding is not None:
_set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR)
+ if (
+ recipe._model_config is not None
+ and recipe._model_config.lm_head_share_weight_with_embedding
+ ):
+ _set_quant_config(
+ rm,
+ recipe.embedding,
+ _DECODE_LOGITS_REGEX_STR,
+ _OpName.FULLY_CONNECTED,
+ )
if recipe.attention is not None:
if isinstance(recipe.attention, dict):