Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
2107403
Support HPU fp8 quantization
nirda7 Nov 25, 2024
cb7aa63
Refactor fp8 inc config and flow
nirda7 Nov 28, 2024
54035b8
adjust destructors and calling finish measurements through shutdown
nirda7 Dec 8, 2024
274755c
Add documentation changes
nirda7 Dec 10, 2024
af44fbd
fix CR comments
nirda7 Dec 16, 2024
4876f10
add more documentation changes
nirda7 Dec 19, 2024
cebdcef
some more CR fixes
nirda7 Dec 23, 2024
19e426e
remove gaudi-installation duplication
nirda7 Jan 2, 2025
1f66b75
change inc.rst to inc.md
nirda7 Jan 7, 2025
22440be
fix more CR comments
nirda7 Jan 9, 2025
67c9285
Add INC and Intel Gaudi to supported hardware table
nirda7 Jan 13, 2025
e8675b2
fix formatting
nirda7 Jan 14, 2025
3e9eb49
Fix weights load device use
nirda7 Jan 14, 2025
f11ff57
fix shutdown flow after executors refactor
nirda7 Jan 21, 2025
86785a9
fix shutdown flow
nirda7 Feb 6, 2025
651c372
add spdx header to inc.py
nirda7 Feb 6, 2025
5e9a52d
fix unsynced distructors calling to None
nirda7 Mar 20, 2025
bc6ac1e
Fix inc flow and remove weights_load_device - use cpu by default
nirda7 Mar 31, 2025
66c3513
fix get_name return type for inc.py
nirda7 May 6, 2025
d7a18af
fix md files
nirda7 May 27, 2025
e6e0829
fix CR comments and remove hpu worker
ulivne Jun 23, 2025
82c1bac
remvoe resolve_input method
ulivne Jun 24, 2025
d5b29df
undo more changes in linear.py
ulivne Jun 24, 2025
a04deb2
restore empty line
ulivne Jun 24, 2025
82af219
remove uneeded empty lines
ulivne Jun 24, 2025
deefe48
restore removed files from original state
ulivne Jun 24, 2025
e4dcfc3
Fix pre-commit
ulivne Jun 24, 2025
6c48bfd
additional pre-commit
ulivne Jun 24, 2025
08ae540
pre commit type fix
ulivne Jun 24, 2025
add32bf
Add moeConfig in inc.py
ulivne Jun 24, 2025
9f9dc69
Support hpu for v1 kv cache dtype validation
ulivne Jun 26, 2025
a4bc3cc
Merge remote-tracking branch 'upstream/main' into dev/hpu_fp8
ulivne Jul 8, 2025
664707b
Additional fix for CR comment
ulivne Jul 8, 2025
bddcef0
Merge remote-tracking branch 'upstream/main' into dev/hpu_fp8
ulivne Jul 15, 2025
f557ab6
restore function and pre-commit fixes
ulivne Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/features/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Contents:
- [BitBLAS](bitblas.md)
- [GGUF](gguf.md)
- [GPTQModel](gptqmodel.md)
- [Inc](inc.md)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the doc, seems like it should be INC

Suggested change
- [Inc](inc.md)
- [INC](inc.md)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

- [INT4 W4A16](int4.md)
- [INT8 W8A8](int8.md)
- [FP8 W8A8](fp8.md)
Expand Down
56 changes: 56 additions & 0 deletions docs/features/quantization/inc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
---
title: FP8 INC
---
[](){ #inc }

vLLM supports FP8 (8-bit floating point) weight and activation quantization using Intel® Neural Compressor (INC) on Intel® Gaudi® 2 and Intel® Gaudi® 3 AI accelerators.
Currently, quantization is validated only in Llama models.

Intel Gaudi supports quantization of various modules and functions, including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`. For more information, please refer to:
[Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules).

!!! note
Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package.

!!! note
`QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options).
The measurement configuration file is used during the calibration procedure to collect measurements for a given model. The quantization configuration is used during inference.

## Run Online Inference Using FP8

Once you've completed the model calibration process and collected the measurements, you can run FP8 inference with vLLM using the following command:

```bash
export QUANT_CONFIG=/path/to/quant/config/inc/meta-llama-3.1-405b-instruct/maxabs_measure_g3.json
vllm serve meta-llama/Llama-3.1-405B-Instruct --quantization inc --kv-cache-dtype fp8_inc --tensor_paralel_size 8
```

!!! tip
If you are just prototyping or testing your model with FP8, you can use the `VLLM_SKIP_WARMUP=true` environment variable to disable the warmup stage, which can take a long time. However, we do not recommend disabling this feature in production environments as it causes a significant performance drop.

!!! tip
When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this problem, you can use the below environment variables:
`VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes.
`VLLM_RPC_TIMEOUT` - to adjust the RPC protocol timeout used by the OpenAI-compatible API. This value is in microseconds, e.g., 600000 equals 10 minutes.

## Run Offline Inference Using FP8

To run offline inference (after completing the model calibration process):

* Set the "QUANT_CONFIG" environment variable to point to a JSON configuration file with QUANTIZE mode.
* Pass `quantization=inc` and `kv_cache_dtype=fp8_inc` as parameters to the `LLM` object.
* Call shutdown method of the model_executor at the end of the run.

```python
from vllm import LLM
llm = LLM("llama3.1/Meta-Llama-3.1-8B-Instruct", quantization="inc", kv_cache_dtype="fp8_inc")
...
# Call llm.generate on the required prompts and sampling params.
...
llm.llm_engine.model_executor.shutdown()
```

## Device for the Model's Weights Uploading

The unquantized weights are first loaded onto the CPU, then quantized and transferred to the target device (HPU) for model execution.
This reduces the device memory footprint of model weights, as only quantized weights are stored in the device memory.
25 changes: 13 additions & 12 deletions docs/features/quantization/supported_hardware.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@ title: Supported Hardware

The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM:

| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | AWS Neuron | Google TPU |
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-----------|------------------|--------------|
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ |
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ |
| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ |
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ❌ |
| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU |
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------|
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ |
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ |
| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ |
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ |
| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ |

- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
- ✅︎ indicates that the quantization method is supported on the specified hardware.
Expand Down
5 changes: 3 additions & 2 deletions docs/getting_started/installation/intel_gaudi.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ To verify that the Intel Gaudi software was correctly installed, run:
hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible
apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed
pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed
pip list | grep neural # verify that neural_compressor is installed
pip list | grep neural # verify that neural_compressor_pt is installed
```

Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade)
Expand Down Expand Up @@ -120,12 +120,13 @@ docker run \
- Inference with [HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html)
for accelerating low-batch latency and throughput
- Attention with Linear Biases (ALiBi)
- INC quantization

### Unsupported features

- Beam search
- LoRA adapters
- Quantization
- AWQ quantization
- Prefill chunking (mixed-batch inferencing)

### Supported configurations
Expand Down
13 changes: 8 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ def _verify_quantization(self) -> None:
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas"
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc"
]
if self.quantization is not None:
self.quantization = cast(QuantizationMethods, self.quantization)
Expand Down Expand Up @@ -1446,7 +1446,7 @@ def get_and_verify_max_len(self, max_model_len: int):


BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the comment here, can we remove the new cache dtype now? #12010 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the HPU worker for v1 is moved to plugin and v0 will be deprecated soon. We want to make the map "fp8_inc to fp8_e4m3" being more visible.

image

Alternatively, do you think we can update the mapping function above conditional like:

STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
    "fp8": torch.uint8,
    "fp8_e4m3": torch.uint8 if not current_platform.is_support_fp8_e4m3() else torch.float8_e4m3fn
    "fp8_e5m2": torch.uint8,
    "int8": torch.int8,
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay, let's just keep fp8_inc then

PrefixCachingHashAlgo = Literal["builtin", "sha256"]


Expand Down Expand Up @@ -1476,7 +1476,7 @@ class CacheConfig:
cache_dtype: CacheDType = "auto"
"""Data type for kv cache storage. If "auto", will use model data type.
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
fp8 (=fp8_e4m3)."""
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc)."""
is_attention_free: bool = False
"""Whether the model is attention-free. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
Expand Down Expand Up @@ -1566,7 +1566,7 @@ def _verify_cache_dtype(self) -> None:
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor")
"scaling factor.")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

Expand Down Expand Up @@ -1685,6 +1685,9 @@ class LoadConfig:
default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format."""
device: Optional[str] = None
"""Device to which model weights will be loaded, default to
device_config.device"""
ignore_patterns: Optional[Union[list[str], str]] = None
"""The list of patterns to ignore when loading the model. Default to
"original/**/*" to avoid repeated loading of llama's checkpoints."""
Expand Down Expand Up @@ -1792,7 +1795,7 @@ class ParallelConfig:
or equal to the number of GPUs available, "mp" will be used to
keep processing on a single host. Otherwise, this will default
to "ray" if Ray is installed and fail otherwise. Note that tpu
and hpu only support Ray for distributed inference."""
only support Ray for distributed inference."""

worker_cls: str = "auto"
"""The full name of the worker class to use. If "auto", the worker class
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
return type_hints


def is_online_quantization(quantization: Any) -> bool:
return quantization in ["inc"]


@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
Expand Down Expand Up @@ -973,6 +977,8 @@ def create_load_config(self) -> LoadConfig:
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
device="cpu"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If weights are firstly loaded to cpu, in which step it will be moved to hpu?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are moved by underline INC logic, after quantization to fp8

if is_online_quantization(self.quantization) else None,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
use_tqdm_on_load=self.use_tqdm_on_load,
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"moe_wna16",
"torchao",
"auto-round",
"inc",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))

Expand Down Expand Up @@ -103,6 +104,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .gptq_marlin import GPTQMarlinConfig
from .gptq_marlin_24 import GPTQMarlin24Config
from .hqq_marlin import HQQMarlinConfig
from .inc import INCConfig
from .ipex_quant import IPEXConfig
from .marlin import MarlinConfig
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
Expand Down Expand Up @@ -142,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"moe_wna16": MoeWNA16Config,
"torchao": TorchAOConfig,
"auto-round": AutoRoundConfig,
"inc": INCConfig,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
Expand Down
45 changes: 45 additions & 0 deletions vllm/model_executor/layers/quantization/inc.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of adding this quantization class if it seems to do nothing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to be compatible with the API - we have some definitions there complying with INC quantization method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a header to this file explaining its purpose? It is rather confusing otherwise and this is a good place to define how this quant method works

Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Optional

import torch

from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)


class INCConfig(QuantizationConfig):
"""Config class for FP8 using Intel Neural Compressor."""

@classmethod
def get_name(cls) -> QuantizationMethods:
return "inc"

@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]

@classmethod
def from_config(cls, config: dict[str, Any]) -> "INCConfig":
raise AssertionError

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return UnquantizedFusedMoEMethod(layer.moe_config)
return None

@classmethod
def get_min_capability(cls) -> int:
raise AssertionError

@staticmethod
def get_config_filenames() -> list[str]:
return []
10 changes: 9 additions & 1 deletion vllm/model_executor/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import torch.nn as nn

from vllm.config import LoadConfig, ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)

logger = init_logger(__name__)


class BaseModelLoader(ABC):
"""Base class for model loaders."""
Expand All @@ -32,11 +35,16 @@ def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
"""Load a model with the given configurations."""
device_config = vllm_config.device_config
target_device = torch.device(device_config.device)
load_config = vllm_config.load_config
load_device = device_config.device if load_config.device is None else \
load_config.device
target_device = torch.device(load_device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)

logger.info("Loading weights on %s ...", load_device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this debug

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def get_quant_config(model_config: ModelConfig,
quant_cls = get_quantization_config(model_config.quantization)

# GGUF doesn't have config file
if model_config.quantization == "gguf":
return quant_cls.from_config({})
if model_config.quantization in ("gguf", "inc"):
return quant_cls()
Comment on lines 154 to +156
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this valid for gguf ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
Expand Down
1 change: 1 addition & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
"int8": torch.int8,
"fp8_inc": torch.float8_e4m3fn,
}

TORCH_DTYPE_TO_NUMPY_DTYPE = {
Expand Down
Loading