diff --git a/ai_edge_torch/generative/README.md b/ai_edge_torch/generative/README.md index fd1312a3..902b0d83 100644 --- a/ai_edge_torch/generative/README.md +++ b/ai_edge_torch/generative/README.md @@ -46,8 +46,9 @@ Once converted, you will get a quantized `.tflite` model which will be ready for In the current release, the following schemes are supported: -* Dynamic range quantization with FP32 activations and INT8 weights for linear ops -* FP16 quantization with FP16 weights and FP32 activations and computation for all ops +* Dynamic quantization with FP32 activations and INT8 weights for linear ops +* FP16 quantization with FP16 weights and FP32 activations and computation for all opss +* Dynamic INT4 blockwise quantization: FP32 activations, INT4 weights, and integer computation, block size must be multiple of 32 These correspond to the available recipes in `quant_recipes.py`
diff --git a/ai_edge_torch/generative/examples/gemma3/README.md b/ai_edge_torch/generative/examples/gemma3/README.md index 9c55e339..9a2a8057 100644 --- a/ai_edge_torch/generative/examples/gemma3/README.md +++ b/ai_edge_torch/generative/examples/gemma3/README.md @@ -6,6 +6,41 @@ The Gemma 3 is the latest model in the Gemma family of open weights models. The Gemma 3 Tokenizer is available in the Gemma PyTorch repo [here](https://github.com/google/gemma_pytorch). The reauthored models here are compatible with that tokenizer. +## Convert & Quantize Gemma 3 to TFlite + +Convert and quantize Gemma 3 model to various quantization schemes can be done using the following command: + +```bash +python convert_gemma3_to_tflite.py --quantize= \ + --checkpoint_path= \ + --output_path= \ + --prefill_seq_lens= \ + --kv_cache_max_len= \ + --mask_as_input=True +``` + +For example, the following command was used to create the dynamic int4 block32 models + +```bash +python convert_gemma3_to_tflite.py --quantize="dynamic_int4_block32" \ + --checkpoint_path=/tmp/gemma-3-pytorch-gemma-3-1b-pt-v1 --output_path="/tmp/" \ + --prefill_seq_lens=2048 --kv_cache_max_len=4096 --mask_as_input=True +``` + +All ready to use quantization schemes can be found in [here](https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/utilities/converter.py#L46) + + +Comparison of various quantization schemes are shared below + +| quantization scheme | Model Size | PIQA score | CPU Prefill Speed | CPU Decode Speed | Peak Memory Usage| +| ------------------- | ---------- | ---------- | ----------------- | ---------------- |------------------| +| Dynamic INT8 | 973 MB | 73.61 | 172.65 tokens/s | 34.97 tokens/s | 1.63 GB | +| Dynamic INT4 Block32 | 711 MB | 72.9 | 124.24 tokens/s | 41.06 tokens/s | 1.41 GB | +| Dynamic INT4 Block128| 650 MB | 71.6 | 146.22 tokens/s | 42.78 tokens/s | 1.31 GB | + +Note: All speed & memory usage are benchmarked on Sanpdragon 8 elite device, performance may vary from device to device + + ## Gemma 3 Task File Creation Creation of a Task file is needed to use the converted model and tokenizer in the [LLM Inference API](https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference). To create the `.task` file for Gemma 3, pip install the [mediapipe](https://pypi.org/project/mediapipe/) Python package and then execute the following Python code: diff --git a/ai_edge_torch/generative/quantize/README.md b/ai_edge_torch/generative/quantize/README.md index 7a9f1fb2..588af979 100644 --- a/ai_edge_torch/generative/quantize/README.md +++ b/ai_edge_torch/generative/quantize/README.md @@ -18,11 +18,12 @@ Once converted, you will get a quantized `.tflite` model which will be ready for In the current release, the following schemes are supported: -* Dynamic range quantization: FP32 activations, INT8 weights, and integer computation -* Weight-only quantization: FP32 activations, INT8 weights, and floating point computation +* Dynamic INT8 quantization: FP32 activations, INT8 weights, and integer computation +* Weight-only INT8 quantization: FP32 activations, INT8 weights, and floating point computation * FP16 quantization: FP16 weights, FP32 activations and floating point computation for all ops +* Dynamic INT4 blockwise quantization: FP32 activations, INT4 weights, and integer computation, block size must be multiple of 32 -These correspond to the available recipes in `quant_recipes.py`. +Preset recipes to the available recipes in `quant_recipes.py`. ## Advanced usage @@ -36,8 +37,8 @@ def custom_selective_quantization_recipe() -> quant_config.QuantConfig: generative_recipe=quant_recipe.GenerativeQuantRecipe( default=create_layer_quant_fp16(), embedding=create_layer_quant_int8_dynamic(), - attention=create_layer_quant_int8_weight_only(), - feedforward=create_layer_quant_int8_dynamic(), + attention=create_layer_quant_int4_block(32), + feedforward=create_layer_quant_int4_block(256), ) ) ``` diff --git a/ai_edge_torch/generative/quantize/quant_recipe.py b/ai_edge_torch/generative/quantize/quant_recipe.py index c4906b20..a287bfc2 100644 --- a/ai_edge_torch/generative/quantize/quant_recipe.py +++ b/ai_edge_torch/generative/quantize/quant_recipe.py @@ -16,9 +16,12 @@ from dataclasses import dataclass from typing import Optional, Union +from ai_edge_torch.generative.layers import model_config from ai_edge_torch.generative.quantize import quant_attrs from ai_edge_torch.generative.quantize import supported_schemes +ModelConfig = model_config.ModelConfig + @dataclass class LayerQuantRecipe: @@ -52,7 +55,7 @@ def __str__(self): f'w:{self.weight_dtype.name}, ' f'{self.mode.name}, ' f'{self.algorithm.name}, ' - f'{self.granularity.name}' + f'{self.granularity.name}, ' f'{self.block_size}' ) return f'{base_str})' @@ -133,6 +136,7 @@ class GenerativeQuantRecipe: feedforward: Union[ Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]] ] = None + _model_config: Optional[ModelConfig] = None def __str__(self): return f"""GenerativeQuantRecipe( diff --git a/ai_edge_torch/generative/quantize/quant_recipes.py b/ai_edge_torch/generative/quantize/quant_recipes.py index 1ae7c431..4f392108 100644 --- a/ai_edge_torch/generative/quantize/quant_recipes.py +++ b/ai_edge_torch/generative/quantize/quant_recipes.py @@ -63,6 +63,7 @@ def all_supported_int4_dynamic_block_recipe( generative_recipe=quant_recipe.GenerativeQuantRecipe( default=quant_recipe_utils.create_layer_quant_int4_dynamic_block( block_size - ) + ), + embedding=quant_recipe_utils.create_layer_quant_int8_dynamic(), ) ) diff --git a/ai_edge_torch/generative/test/test_quantize.py b/ai_edge_torch/generative/test/test_quantize.py index acb467bd..c2539bdd 100644 --- a/ai_edge_torch/generative/test/test_quantize.py +++ b/ai_edge_torch/generative/test/test_quantize.py @@ -14,7 +14,6 @@ # ============================================================================== import ai_edge_torch -from ai_edge_torch import config from ai_edge_torch.generative.examples.test_models import toy_model # NOQA from ai_edge_torch.generative.quantize import quant_recipe from ai_edge_torch.generative.quantize import quant_recipe_utils diff --git a/ai_edge_torch/generative/tools/batch_convert.py b/ai_edge_torch/generative/tools/batch_convert.py index f106b9db..82d03fb4 100644 --- a/ai_edge_torch/generative/tools/batch_convert.py +++ b/ai_edge_torch/generative/tools/batch_convert.py @@ -282,9 +282,12 @@ def convert_models(conversion_configs: Sequence[ConversionConfig]) -> None: ) converter.convert_to_tflite( pytorch_model, - tflite_path=os.path.join(config.tflite_output_path, output_filename), + output_path=config.tflite_output_path, + output_name_prefix=output_filename, prefill_seq_len=config.prefill_seq_lens, - quantize=True if precision == ExportPrecision.INT8 else False, + quantize=converter.QuantizationName.DYNAMIC_INT8 + if precision == ExportPrecision.INT8 + else converter.QuantizationName.NONE, export_config=ExportConfig(), ) logging.info("Successfully converted model: %s", output_filename) diff --git a/ai_edge_torch/generative/utilities/converter.py b/ai_edge_torch/generative/utilities/converter.py index ee57fa0e..21a78e87 100644 --- a/ai_edge_torch/generative/utilities/converter.py +++ b/ai_edge_torch/generative/utilities/converter.py @@ -15,6 +15,7 @@ """Common utility functions for model conversion.""" +import enum import os import pathlib from typing import Optional, Union @@ -42,6 +43,27 @@ def forward(self, *export_args, **export_kwargs): return self.module(*export_args, **full_kwargs) +class QuantizationName(str, enum.Enum): + """Strings for all supported quantization recipes. + + none: No quantization. + dynamic_int8: Dynamic range quantization with int8 weights. + weight_only_int8: Weight only quantization with int8 weights. + fp16: Float16 quantization. + dynamic_int4_block32: Dynamic range quantization with int4 weights and block + size of 32, better model quality but slower inference. + dynamic_int4_block128: Dynamic range quantization with int4 weights and block + size of 128, faster inference but worse model quality. + """ + + NONE = 'none' + DYNAMIC_INT8 = 'dynamic_int8' + WEIGHT_ONLY_INT8 = 'weight_only_int8' + FP16 = 'fp16' + DYNAMIC_INT4_BLOCK32 = 'dynamic_int4_block32' + DYNAMIC_INT4_BLOCK128 = 'dynamic_int4_block128' + + def define_conversion_flags( model_name: str, default_mask_as_input: bool = False, @@ -74,10 +96,10 @@ def define_conversion_flags( 1280, 'The maximum size of KV cache buffer, including both prefill and decode.', ) - flags.DEFINE_bool( + flags.DEFINE_string( 'quantize', - True, - 'Whether the model should be quantized.', + 'dynamic_int8', + 'How the model should be quantized.', ) flags.DEFINE_multi_integer( 'lora_ranks', @@ -99,6 +121,66 @@ def define_conversion_flags( return flags +def get_quant_recipe_from_flag( + quantize: str, +) -> Optional[quant_recipes.QuantizationRecipe]: + """Processes the quantization flag and returns the corresponding recipe. + + Args: + quantize: The quantization type. + + Returns: + The quantization recipe, or None if no quantization is needed. + + Raises: + ValueError: If the quantization type is not supported. + """ + match quantize: + case QuantizationName.NONE: + return None + case QuantizationName.DYNAMIC_INT8: + return quant_recipes.full_int8_dynamic_recipe() + case QuantizationName.WEIGHT_ONLY_INT8: + return quant_recipes.full_int8_weight_only_recipe() + case QuantizationName.FP16: + return quant_recipes.full_fp16_recipe() + case QuantizationName.DYNAMIC_INT4_BLOCK32: + return quant_recipes.full_int4_dynamic_block_recipe(32) + case QuantizationName.DYNAMIC_INT4_BLOCK128: + return quant_recipes.full_int4_dynamic_block_recipe(128) + case _: + raise ValueError(f'Unsupported quantization flag: {quantize}') + + +def create_quantize_suffix(quantize: str) -> str: + """Creates a suffix for the output file name based on the quantization type. + + Args: + quantize: The quantization type. + + Returns: + A string representing the quantization suffix. + + Raises: + ValueError: If the quantization type is not supported. + """ + match quantize: + case QuantizationName.NONE: + return 'f32' + case QuantizationName.DYNAMIC_INT8: + return 'q8' + case QuantizationName.WEIGHT_ONLY_INT8: + return 'q8_wo' + case QuantizationName.FP16: + return 'fp16' + case QuantizationName.DYNAMIC_INT4_BLOCK32: + return 'q4_block32' + case QuantizationName.DYNAMIC_INT4_BLOCK128: + return 'q4_block128' + case _: + raise ValueError(f'Unsupported quantization flag: {quantize}') + + def _build_mask(mask_len, kv_cache_max_len, causal_mask_value) -> torch.Tensor: if isinstance(mask_len, list): return [ @@ -118,7 +200,7 @@ def convert_to_tflite( prefill_seq_len: Union[int, list[int]], pixel_values_size: torch.Size = None, pixel_seq_len: int = 0, - quantize: bool = True, + quantize: str = 'dynamic_int8', config: cfg.ModelConfig = None, lora_ranks: Optional[list[int]] = None, export_config: ExportConfig = None, @@ -164,8 +246,8 @@ def convert_to_tflite( embeddings generated by the image encoder with pixel values. The actual length of prefill_seq_len will be added by pixel_seq_len when pixel values are passed. - quantize (bool, optional): Whether the model should be quanized. Defaults - to True. + quantize (str, optional): The quantization type. Defaults to + 'dynamic_int8'. config (cfg.ModelConfig, optional): The model config used to configure KV cache. If None, it uses the config of the pytorch_model. lora_ranks (list[int], optional): The ranks of the LORA layers. If None, @@ -186,7 +268,7 @@ def convert_to_tflite( lora = lora_utils.LoRA.zeros(rank, config) loras.append(lora) - quant_suffix = 'q8' if quantize else 'f32' + quant_suffix = create_quantize_suffix(quantize) kv_size = config.kv_cache_max_len lora_suffix = ( '' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}' @@ -220,7 +302,7 @@ def _export_helper( prefill_seq_lens: list[int], pixel_values_size: torch.Size, pixel_seq_len: int, - quantize: bool, + quantize: str, config: cfg.ModelConfig, loras: list[None | lora_utils.LoRA], export_config: ExportConfig, @@ -269,7 +351,8 @@ def _export_helper( kv_layout=export_config.kvcache_layout, ) - quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None + quant_config = get_quant_recipe_from_flag(quantize) + quant_config._model_config = config # For export, we create a module that captures any non-exportable, # arugments, e.g. the generation config object. @@ -334,5 +417,7 @@ def _export_helper( sample_kwargs=sample_kwargs, ) - edge_model = converter.convert(quant_config=quant_config) + edge_model = converter.convert( + quant_config=quant_config, + ) edge_model.export(output_file) diff --git a/ai_edge_torch/lowertools/_shim.py b/ai_edge_torch/lowertools/_shim.py index 972d9efd..9d1d680f 100644 --- a/ai_edge_torch/lowertools/_shim.py +++ b/ai_edge_torch/lowertools/_shim.py @@ -50,7 +50,7 @@ def exported_programs_to_tflite( *, quant_config: Optional[qcfg.QuantConfig] = None, _tfl_converter_flags: Optional[dict[str, Any]] = None, - _saved_model_dir: Optional[str] = None + _saved_model_dir: Optional[str] = None, ): """Converts a list of ExportedProgram to a TFLite model. diff --git a/ai_edge_torch/lowertools/translate_recipe.py b/ai_edge_torch/lowertools/translate_recipe.py index 6c7fa37e..bb694748 100644 --- a/ai_edge_torch/lowertools/translate_recipe.py +++ b/ai_edge_torch/lowertools/translate_recipe.py @@ -29,6 +29,8 @@ _ATTENTION_REGEX_STR = 'ai_edge_torch.generative.layers.attention' _FEEDFORWARD_REGEX_STR = 'ai_edge_torch.generative.layers.feed_forward' _EMBEDDING_REGEX_STR = 'Embedding_tok_embedding' +# TODO: b/415833584 - Improve the regex for pre-softmax layer. +_DECODE_LOGITS_REGEX_STR = 'StatefulPartitionedCall' _ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}' @@ -95,10 +97,11 @@ def _set_quant_config( rm: quantizer.recipe_manager.RecipeManager, layer_recipe: quant_recipe.LayerQuantRecipe, regex: str, + operation_name: _OpName = _OpName.ALL_SUPPORTED, ): rm.add_quantization_config( regex=regex, - operation_name=_OpName.ALL_SUPPORTED, + operation_name=operation_name, op_config=_OpQuantConfig( weight_tensor_config=_TensorQuantConfig( num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype), @@ -126,6 +129,16 @@ def translate_to_ai_edge_recipe( if recipe.embedding is not None: _set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR) + if ( + recipe._model_config is not None + and recipe._model_config.lm_head_share_weight_with_embedding + ): + _set_quant_config( + rm, + recipe.embedding, + _DECODE_LOGITS_REGEX_STR, + _OpName.FULLY_CONNECTED, + ) if recipe.attention is not None: if isinstance(recipe.attention, dict):