Skip to content

Commit

Permalink
vectorize cast function; type annotation changes; add to exclusion list
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Yao <[email protected]>
  • Loading branch information
yuanyao-nv committed Aug 9, 2024
1 parent 26c05d4 commit 9ba8a86
Show file tree
Hide file tree
Showing 16 changed files with 63 additions and 31 deletions.
4 changes: 2 additions & 2 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3718,7 +3718,7 @@ for from_type, to_type in test_cases:
"-INF",
"-4",
"0.01",
"-1000000",
"-0.0",
],
dtype=np.float32,
)
Expand Down Expand Up @@ -3746,7 +3746,7 @@ for from_type, to_type in test_cases:
raise ValueError(
f"Conversion from {from_type} to {to_type} is not tested."
)
expected = evaluate_float4e2m1_from_bits(
expected = unpacked_float4e2m1_to_float32(
subbyte.float32_to_float4e2m1_unpacked(np_fp32)
)
output = make_tensor(
Expand Down
4 changes: 2 additions & 2 deletions docs/TestCoverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -2587,7 +2587,7 @@ for from_type, to_type in test_cases:
"-INF",
"-4",
"0.01",
"-1000000",
"-0.0",
],
dtype=np.float32,
)
Expand Down Expand Up @@ -2615,7 +2615,7 @@ for from_type, to_type in test_cases:
raise ValueError(
f"Conversion from {from_type} to {to_type} is not tested."
)
expected = evaluate_float4e2m1_from_bits(
expected = unpacked_float4e2m1_to_float32(
subbyte.float32_to_float4e2m1_unpacked(np_fp32)
)
output = make_tensor(
Expand Down
6 changes: 3 additions & 3 deletions onnx/backend/test/case/node/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
tensor_dtype_to_field,
)
from onnx.numpy_helper import (
evaluate_float4e2m1_from_bits,
float8e4m3_to_float32,
float8e5m2_to_float32,
unpacked_float4e2m1_to_float32,
)


Expand Down Expand Up @@ -303,7 +303,7 @@ def export() -> None:
"-INF",
"-4",
"0.01",
"-1000000",
"-0.0",
],
dtype=np.float32,
)
Expand Down Expand Up @@ -331,7 +331,7 @@ def export() -> None:
raise ValueError(
f"Conversion from {from_type} to {to_type} is not tested."
)
expected = evaluate_float4e2m1_from_bits(
expected = unpacked_float4e2m1_to_float32(
subbyte.float32_to_float4e2m1_unpacked(np_fp32)
)
output = make_tensor(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@

*'�o�h�x�������������������B��Bx
*'�o�h�x�������������������B��Bx
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
*
�w�By
�w�By
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
*
�w�Bx
�w�Bx
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
*
�w�Bx
�w�Bx
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
*
�w�By
�w�By
26 changes: 15 additions & 11 deletions onnx/numpy_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,22 +221,26 @@ def unpack_int4(
return res


def evaluate_float4e2m1_from_bits(x: np.ndarray[np.uint8]) -> np.ndarray[np.float32]:
"""Evaluate the numerical value of a single float4e2m1 element represented as uint8
def unpacked_float4e2m1_to_float32(x: np.ndarray) -> np.ndarray:
"""Evaluate the numerical value of an array of unpacked float4e2m1 values (as uint8)
See :ref:`onnx-detail-int4` for technical details.
Args:
x: a uint8 element representing a float4e2m1 (using the 4 LSB)
x: an array of uint8 elements representing a float4e2m1 (using the 4 LSB)
Returns:
A float32 element representing the value of the float4e2m1 input.
An array of float32 elements representing the values of the float4e2m1 input.
"""
# x is stored in 4 LSB of int
S = np.where(np.bitwise_and(x, 0x08), -1, 1)
M = x & 0x01
E = (x & 0x06) >> 1

val = np.where(E==0, S*(M/2.0), S*(1.0+M/2.0) *2.0 **(E-1)) # denormalized, normalized
sign = np.where(np.bitwise_and(x, 0x08), -1, 1)
mantissa = x & 0x01
exponent = (x & 0x06) >> 1

val = np.where(
exponent == 0,
sign * (mantissa / 2.0),
sign * (1.0 + mantissa / 2.0) * 2.0 ** (exponent - 1),
) # denormalized, normalized
return val


Expand All @@ -258,8 +262,8 @@ def unpack_float4e2m1(
res_high, res_low = subbyte.unpack_single_4bitx2(data.ravel(), False)
res = np.empty((res_high.size + res_low.size,), dtype=np.float32)

res[0::2] = evaluate_float4e2m1_from_bits(res_high)
res[1::2] = evaluate_float4e2m1_from_bits(res_low)
res[0::2] = unpacked_float4e2m1_to_float32(res_high)
res[1::2] = unpacked_float4e2m1_to_float32(res_low)

if (
res.size == np.prod(dims) + 1
Expand Down
4 changes: 2 additions & 2 deletions onnx/reference/ops/op_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
)
from onnx.numpy_helper import (
bfloat16_to_float32,
evaluate_float4e2m1_from_bits,
float8e4m3_to_float32,
float8e5m2_to_float32,
unpacked_float4e2m1_to_float32,
)
from onnx.onnx_pb import TensorProto
from onnx.reference.op_run import OpRun
Expand Down Expand Up @@ -131,7 +131,7 @@ def cast_to(x, to, saturate): # noqa: PLR0911
if x.dtype == float4e2m1 and x.dtype.descr[0][0] == "float4e2m1":
if to == TensorProto.FLOAT4E2M1:
return x
res = evaluate_float4e2m1_from_bits(x)
res = unpacked_float4e2m1_to_float32(x)
if to == TensorProto.FLOAT:
return res.astype(np.float32)
elif to == TensorProto.FLOAT16:
Expand Down
4 changes: 2 additions & 2 deletions onnx/reference/ops/op_dequantize_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
)
from onnx.helper import np_dtype_to_tensor_dtype
from onnx.numpy_helper import (
evaluate_float4e2m1_from_bits,
float8e4m3_to_float32,
float8e5m2_to_float32,
unpacked_float4e2m1_to_float32,
)
from onnx.reference.op_run import OpRun
from onnx.reference.ops.op_quantize_linear import reshape_input
Expand Down Expand Up @@ -93,7 +93,7 @@ def _run(
elif x_type == TensorProto.FLOAT8E5M2FNUZ:
dx = float8e5m2_to_float32(x, fn=True, uz=True)
elif x_type == TensorProto.FLOAT4E2M1:
dx = evaluate_float4e2m1_from_bits(x)
dx = unpacked_float4e2m1_to_float32(x)
else:
dx = x.astype(np.float32)
y = dx * reshape_input(x_scale, x.shape, axis, block_size)
Expand Down
32 changes: 28 additions & 4 deletions onnx/subbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def unpack_single_4bitx2(
return (x_low.astype(dtype), x_high.astype(dtype))


def float32_to_float4e2m1_unpacked(x: np.ndarray | np.dtype) -> np.ndarray:
def float32_to_float4e2m1_unpacked_slow(x: np.ndarray | np.dtype) -> np.ndarray:
"""Cast float32 to float4e2m1 (without packing).
Args:
Expand All @@ -85,7 +85,7 @@ def float32_to_float4e2m1_unpacked(x: np.ndarray | np.dtype) -> np.ndarray:
def float32_to_float4e2m1(value):
if np.isnan(value):
return 0x7
s = 0x0 if value >= 0 else 0x8
s = 0x8 if np.signbit(value) else 0x0
magnitude = np.abs(value)
if np.isinf(magnitude):
ret = 0x7
Expand Down Expand Up @@ -116,14 +116,38 @@ def float32_to_float4e2m1(value):
return y.astype(np.uint8) # type: ignore[no-any-return]


def float32x2_to_float4e2m1x2(val_low: np.dtype, val_high: np.dtype) -> np.ndarray:
def float32_to_float4e2m1_unpacked(values: np.ndarray) -> np.ndarray:
"""Cast float32 to float4e2m1 (without packing).
Args:
values: element or array to be converted
Returns:
An ndarray with unpacked float4e2m1 elements (as uint8)
"""
sign = np.where(np.signbit(values), 0x8, 0x0).astype(np.uint8)
magnitude = np.abs(values)
res = np.zeros(values.shape, dtype=np.uint8)
res[(magnitude > 0.25) & (magnitude < 0.75)] = 0x1
res[(magnitude >= 0.75) & (magnitude <= 1.25)] = 0x2
res[(magnitude > 1.25) & (magnitude < 1.75)] = 0x3
res[(magnitude >= 1.75) & (magnitude <= 2.5)] = 0x4
res[(magnitude > 2.5) & (magnitude < 3.5)] = 0x5
res[(magnitude >= 3.5) & (magnitude <= 5.0)] = 0x6
res[magnitude > 5.0] = 0x7
res |= sign
res[np.isnan(values)] = 0x7
return res


def float32x2_to_float4e2m1x2(val_low: np.ndarray, val_high: np.ndarray) -> np.ndarray:
"""Cast two elements to float4e2m1 and pack to a single byte
Args:
val_low: element to be packed in the 4 LSB
val_high: element to be packed in the 4 MSB
Returns:
An ndarray with a single uint8 element, containing both float4e2m1 elements
An ndarray with uint8 elements, containing both float4e2m1 elements
"""
i8_high = float32_to_float4e2m1_unpacked(val_high)
i8_low = float32_to_float4e2m1_unpacked(val_low)
Expand Down
4 changes: 4 additions & 0 deletions onnx/test/test_backend_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
"|test_cast_UINT4_to_FLOAT16" # No corresponding Numpy type for Tensor Type.
"|test_cast_INT4_to_FLOAT16" # No corresponding Numpy type for Tensor Type.
"|test_maxpool_2d_ceil_output_size_reduce_by_one" # TODO: remove after https://github.com/microsoft/onnxruntime/pull/18377 in Ort release.
"|test_quantizeLinear_float4e2m1" # No corresponding Numpy type for Tensor Type.
"|test_dequantizeLinear_float4e2m1" # No corresponding Numpy type for Tensor Type.
"|cast_float4e2m1" # No corresponding Numpy type for Tensor Type.
"|to_float4e2m1" # No corresponding Numpy type for Tensor Type.
")"
)

Expand Down

0 comments on commit 9ba8a86

Please sign in to comment.