Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Yao <[email protected]>
  • Loading branch information
yuanyao-nv committed Aug 7, 2024
1 parent f4ca510 commit ba82f37
Show file tree
Hide file tree
Showing 10 changed files with 16 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/IR.md
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ It is common to represent a tensor as a nested list. This generally works fine,

|Group|Types|Description|
|---|---|---|
Floating Point Types|float16, float32, float64, bfloat16, float8e4m3fn, float8e5m2, float8e4m3fnuz, float8e5m2fnuz|Values adhering to the IEEE 754-2008 standard representation of floating-point data or defined in papers [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433) and [8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915)
Floating Point Types|float16, float32, float64, bfloat16, float8e4m3fn, float8e5m2, float8e4m3fnuz, float8e5m2fnuz, float4e2m1|Values adhering to the IEEE 754-2008 standard representation of floating-point data or defined in papers [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433) and [8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915)
Signed Integer Types|int4, int8, int16, int32, int64|Signed integers are supported for 4-64 bit widths.
Unsigned Integer Types|uint4, uint8, uint16, uint32, uint64|Unsigned integers are supported for 4-64 bit widths.
Complex Types|complex64, complex128|A complex number with either 32- or 64-bit real and imaginary parts.
Expand Down
2 changes: 1 addition & 1 deletion docs/docsgen/source/api/numpy_helper.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
.. autofunction:: onnx.numpy_helper.to_array
```

As numpy does not support all the types defined in ONNX (float 8 types, blofat16, int4, uint4),
As numpy does not support all the types defined in ONNX (float 8 types, blofat16, int4, uint4, float4e2m1),
these two functions use a custom dtype defined in :mod:`onnx._custom_element_types`.

## sequence
Expand Down
3 changes: 2 additions & 1 deletion onnx/common/ir_pb_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ Tensor tensorProtoToTensor(const ONNX_NAMESPACE::TensorProto& tp) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: {
ret.int32s().reserve(tp.int32_data_size());
for (int i = 0; i < tp.int32_data_size(); i++) {
ret.int32s().push_back(tp.int32_data(i));
Expand Down
1 change: 1 addition & 0 deletions onnx/defs/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ Status OnnxParser::Parse(TensorProto& tensorProto, const TypeProto& tensorTypePr
case TensorProto::DataType::TensorProto_DataType_FLOAT8E5M2:
case TensorProto::DataType::TensorProto_DataType_FLOAT8E5M2FNUZ:
case TensorProto::DataType::TensorProto_DataType_BOOL:
case TensorProto::DataType::TensorProto_DataType_FLOAT4E2M1:
PARSE_TOKEN(intval);
// TODO: check values are in the correct range.
tensorProto.add_int32_data(intval);
Expand Down
1 change: 1 addition & 0 deletions onnx/defs/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class PrimitiveTypeNameMap : public StringIntMap<PrimitiveTypeNameMap> {
map_["float8e5m2fnuz"] = TensorProto_DataType_FLOAT8E5M2FNUZ;
map_["uint4"] = TensorProto_DataType_UINT4;
map_["int4"] = TensorProto_DataType_INT4;
map_["float4e2m1"] = TensorProto_DataType_FLOAT4E2M1;
}

static bool IsTypeName(const std::string& dtype) {
Expand Down
7 changes: 5 additions & 2 deletions onnx/numpy_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def evaluate_float4e2m1_from_bits(x: np.uint8) -> np.float32:
x: a uint8 element representing a float4e2m1 (using the 4 LSB)
Returns:
Packed array with size `ceil(farray.size/2)` (single dimension).
A float32 element representing the value of the float4e2m1 input.
"""
# x is stored in 4 LSB of int
S = -1 if bool(x & 0x08) else 1
Expand Down Expand Up @@ -619,10 +619,13 @@ def from_array(tensor: np.ndarray, name: str | None = None) -> TensorProto:
elif dt == custom_np_types.uint4 and dt.descr[0][0] == "uint4":
to = TensorProto.UINT4
dt_to = np.uint8 # type: ignore[assignment]
elif dt == custom_np_types.float4e2m1 and dt.descr[0][0] == "float4e2m1":
to = TensorProto.FLOAT4E2M1
dt_to = np.uint8
else:
return _from_array(tensor, name)

if to in (TensorProto.UINT4, TensorProto.INT4):
if to in (TensorProto.UINT4, TensorProto.INT4, TensorProto.FLOAT4E2M1):
value = tensor.astype(dt_to).ravel()
if value.size % 2 == 1:
raise ValueError(
Expand Down
2 changes: 2 additions & 0 deletions onnx/reference/custom_element_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from onnx._custom_element_types import (
bfloat16,
float4e2m1,
float8e4m3fn,
float8e4m3fnuz,
float8e5m2,
Expand All @@ -22,6 +23,7 @@

_supported_types = [
(bfloat16, "bfloat16", "bfloat16"),
(float4e2m1, "float4e2m1", "float4_e2m1"),
(float8e4m3fn, "e4m3fn", "float8_e4m3fn"),
(float8e4m3fnuz, "e4m3fnuz", "float8_e4m3fnuz"),
(float8e5m2, "e5m2", "float8_e5m2"),
Expand Down
2 changes: 1 addition & 1 deletion onnx/reference/ops/op_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def cast_to(x, to, saturate): # noqa: PLR0911

if to == TensorProto.FLOAT4E2M1:
xf = x.astype(np.float32)
y = subbyte.float32_to_float4e2m1_unpacked(xf)
y = subbyte.float32_to_float4e2m1_unpacked(xf).astype(float4e2m1)
return y.reshape(x.shape)

if to == TensorProto.STRING:
Expand Down
1 change: 1 addition & 0 deletions onnx/test/numpy_helper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ def test_to_array_from_array(self, att):
def test_to_array_from_array_subtype(self):
self._to_array_from_array(onnx.TensorProto.INT4)
self._to_array_from_array(onnx.TensorProto.UINT4)
self._to_array_from_array(onnx.TensorProto.FLOAT4E2M1)

def test_to_array_from_array_string(self):
self._to_array_from_array(onnx.TensorProto.STRING, False)
Expand Down
1 change: 1 addition & 0 deletions onnx/test/parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def test_parse_various_float_values(self, test_literal, expect_exception):
("uint16", TensorProto.UINT16),
("uint32", TensorProto.UINT32),
("uint64", TensorProto.UINT64),
("float4e2m1", TensorProto.FLOAT4E2M1),
]
)
def test_parse_graph_types(self, name, itype) -> None:
Expand Down

0 comments on commit ba82f37

Please sign in to comment.