Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Yao <[email protected]>
  • Loading branch information
yuanyao-nv committed Aug 24, 2024
1 parent d45ad99 commit 74e11de
Show file tree
Hide file tree
Showing 83 changed files with 507 additions and 52 deletions.
2 changes: 1 addition & 1 deletion docs/IR.md
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ It is common to represent a tensor as a nested list. This generally works fine,

|Group|Types|Description|
|---|---|---|
Floating Point Types|float16, float32, float64, bfloat16, float8e4m3fn, float8e5m2, float8e4m3fnuz, float8e5m2fnuz|Values adhering to the IEEE 754-2008 standard representation of floating-point data or defined in papers [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433) and [8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915)
Floating Point Types|float16, float32, float64, bfloat16, float8e4m3fn, float8e5m2, float8e4m3fnuz, float8e5m2fnuz, float4e2m1|Values adhering to the IEEE 754-2008 standard representation of floating-point data or defined in papers [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433), [8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915), and the [Open Compute Project](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
Signed Integer Types|int4, int8, int16, int32, int64|Signed integers are supported for 4-64 bit widths.
Unsigned Integer Types|uint4, uint8, uint16, uint32, uint64|Unsigned integers are supported for 4-64 bit widths.
Complex Types|complex64, complex128|A complex number with either 32- or 64-bit real and imaginary parts.
Expand Down
2 changes: 1 addition & 1 deletion docs/docsgen/source/api/numpy_helper.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
.. autofunction:: onnx.numpy_helper.to_array
```

As numpy does not support all the types defined in ONNX (float 8 types, blofat16, int4, uint4),
As numpy does not support all the types defined in ONNX (float 8 types, blofat16, int4, uint4, float4e2m1),
these two functions use a custom dtype defined in :mod:`onnx._custom_element_types`.

## sequence
Expand Down
118 changes: 118 additions & 0 deletions docs/docsgen/source/technical/float4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
<!--
Copyright (c) ONNX Project Contributors
SPDX-License-Identifier: Apache-2.0
-->

(onnx-detail-float4)=

# Float stored in 4 bits

## Papers

4 bit floating point formats have emerged as a solution to the
rising cost and deployment challenges of large language models.
The S1E2M1 format has been part of the [Open Compute Project (OCP)](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
standard.

As a result, a new data type was introduced in `onnx==1.18.0`
to support a limited set of operators to enable computation
with float4.

- `FLOAT4E2M1`: 1 bit for the sign, 2 bits for the exponents, and 1 bit for the mantissa.
No nan or infinities.

## E2M1

$S$ stands for the sign. $10_2$ describe a number base 2.

```{eval-rst}
.. list-table:: Float4 type
:widths: 10 10
:header-rows: 1
* -
- E2M1
* - Exponent bias
- 1
* - Infinities
-
* - NaN
-
* - Zeros
- :math:`S.00.0_2`
* - Max
- :math:`S.11.1_2`
* - Min
- :math:`S.00.1_2 = 2^{-1}`
```

Let's denote the bit representation as $S.b_2 b_1 b_0$.
The float value is defined by the following expressions:

```{eval-rst}
.. list-table:: Float4 type values
:widths: 10 10
:header-rows: 1
* -
- E2M1
* - exponent :math:`\neq` 0
- :math:`(-1)^S 2^{\sum_{i=1}^2 b_i 2^{i-1} - 1} \left( 1 + b_0 2^{-1} \right)`
* - exponent :math:`=` 0
- :math:`(-1)^S b_0 2^{-1}`
```

The following table lists all the representable values by float4 E2M1, ignoring the sign bit:
```{eval-rst}
.. list-table:: Float4 type values
:widths: 10 10
:header-rows: 1
* - bits (ignoring sign bit)
- E2M1
* - 000
- 0
* - 001
- 0.5
* - 010
- 1
* - 011
- 1.5
* - 100
- 2
* - 101
- 3
* - 110
- 4
* - 111
- 6
```

## Cast

Upcasting from float4 to float32, float16, bfloat16, and float8 is exact.
The behavior for downcasting to float 4 is summarized below

| x | E2M1 |
| ----------------- | ------------------------------------------------- |
| -6<=x<=6 | E2M1 converted value of x. Round to nearest even. |
| x=+/-0 | +/-0 |
| x>6 | 6 |
| x<-6 | -6 |
| +Inf | 6 |
| -Inf | -6 |
| NaN | 6 |

## Packing and Unpacking

Float4 is stored as 2x4bit in a single byte.
The first element is stored in the 4 LSB and the second element is stored in the 4 MSB,
i.e. for elements `x` and `y` that are consecutive elements in the array:
```
pack(x,y): y << 4 | x & 0x0F
unpack(z): x = z & 0x0F, y = z >> 4
```
In case the total number of elements is odd, padding of 4 bits will be appended.
The storage size of a 4 bit tensor of size `N` is `ceil(N/2)`.
1 change: 1 addition & 0 deletions docs/docsgen/source/technical/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ deeper than the code documentation.
float8
int4
float4
```
2 changes: 2 additions & 0 deletions onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"IR_VERSION_2020_5_8",
"IR_VERSION_2021_7_30",
"IR_VERSION_2023_5_5",
"IR_VERSION_2024_3_25",
"EXPERIMENTAL",
"STABLE",
# Modules
Expand Down Expand Up @@ -95,6 +96,7 @@
IR_VERSION_2020_5_8,
IR_VERSION_2021_7_30,
IR_VERSION_2023_5_5,
IR_VERSION_2024_3_25,
ModelProto,
NodeProto,
OperatorSetIdProto,
Expand Down
6 changes: 6 additions & 0 deletions onnx/_custom_element_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@
#: than its onnx size.
int4 = np.dtype((np.int8, {"int4": (np.int8, 0)}))

#: Defines float 4 e2m1 type, see See :ref:`onnx-detail-float4` for technical details.
#: Do note that one integer is stored using a byte and therefore is twice bigger
#: than its onnx size.
float4e2m1 = np.dtype((np.uint8, {"float4e2m1": (np.uint8, 0)}))

mapping_name_to_data_type = {
"bfloat16": onnx.TensorProto.BFLOAT16,
"e4m3fn": onnx.TensorProto.FLOAT8E4M3FN,
Expand All @@ -60,4 +65,5 @@
"e5m2fnuz": onnx.TensorProto.FLOAT8E5M2FNUZ,
"int4": onnx.TensorProto.INT4,
"uint4": onnx.TensorProto.UINT4,
"float4e2m1": onnx.TensorProto.FLOAT4E2M1,
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ByJ���@D{,@�#@��@�8@��&@w�@��8@[m=@ @�22@*�@�@>;@���?ZT�?~/?�5@�81@Վ7@�P>@&�2@��@�d1@\<�?�N&@}2�?�G<@�6@�@���?��0@�@��@��?j[$@��#@sK$@>9<@o�)@��@��@+@V�?4�(@a�(@յ�?'ճ?�V@�k@��@\@��>@�^�?��?O��?�n'@��?<<@YZ�?
ByJ���@D{,@�#@��@�8@��&@w�@��8@Zm=@ @�22@*�@�@>;@���?ZT�?~/?�5@�81@Վ7@�P>@%�2@��@�d1@\<�?�N&@~2�?�G<@�6@�@���?��0@�@��@��?j[$@��#@sK$@>9<@o�)@��@��@+@V�?5�(@a�(@յ�?(ճ?�V@�k@��@\@��>@�^�?��?O��?�n'@��?<<@YZ�?
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ByJ�ø?OL?(�%?�?���>��3?J��>���?8s�?l{�>��i?�? �?�_�?���=ت�=�<��{?"Hd?��?���?fm?C��>E@e?[��=��1?M>`H�?T� ?���>��>Z�b?%��>$�?c�<tm*?�(?3*?��?B@?{H�>]��>k�E?9�v=:�:?p-<?�Y>�c>N�>X��>�S?/x�>L��?)Z�=�yW>��%>�06?�>a�>��|>
ByJ�ĸ?PL?*�%?�?���>��3?J��>���?8s�?l{�>��i?�? �?�_�?���=ت�=�<��{?#Hd?��?���?fm?C��>F@e?[��=��1?M>`H�?T� ?���>��>Z�b?%��>$�?c�<um*?�(?3*?��?D@?{H�>]��>k�E?9�v=;�:?q-<?�Y>�c>N�>X��>�S?/x�>L��?)Z�=�yW>��%>�06?�>a�>��|>
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ByJ�ɚ�?���>%�]?,��?(��?�~]��{X?ng��ӽ'��>��>���?A�3?b��=�c�>�ѧ><��?��P�ȝ>!<F��Wտ�??>H?��/��q�?�����\;=�>�V=�?��?� >AG�>�L�����\����t>!��?M=�?���v����Pj��$��P馿��?����J�پ�P��� 7?LϠ�D<X��4N�4u�>cQ��-s����漎�>� �=���>���ߵ�
ByJ�ɚ�?���>%�]?,��?)��?�~]��{X?ng��ӽ'��>��>���?A�3?b��=�c�>�ѧ><��?��P�ȝ>!<F��Wտ�??>H?��/��q�?�����\;=�>�W=�?��?� >@G�>�L�����\����t>!��?M=�?���v����Pj��$��Q馿��?����I�پ�P��� 7?LϠ�D<X��4N�4u�>bQ��-s����漎�>� �=���>���ߵ�
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ByJ�#�?��>�OF?V�?�"�?�F�܂B?~��p�ҽ�z�>4~>|�w?��&?���=���>��>� {?�4O��W�>=�4��F���3?{6?�r#�z�?m�w��K;=�{=�}#~?�%y?j>��>��9�98���l����>�fc?0�`?�3���N��� O�u��4���m�?jS�fӾH�e�c.)?���d�V�A�:�(�>Q��&8^����s3�>��=�b�>ٵ�t*��
ByJ�#�?~��>�OF?V�?�"�?�F�܂B?~��p�ҽ�z�>4~>|�w?��&?���=���>��>� {?�4O��W�>=�4��F���3?{6?�r#�z�?m�w��K;=�{=�}#~?�%y?j>��>��9�:8���l����>�fc?0�`?�3���N��� O�u��4���m�?jS�fӾH�e�c.)?���d�V�B�:�(�>Q��'8^����s3�>��=�b�>ٵ�t*��
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
ByJ���?q�e?̍2?eo?�~�>+�D?�@�>B�?�W�?���>���?�?�%?/9�?0��=��=զ�<�*�?5�?���?C�@�R�?���>��?w]�=�B?��>
��??1?���>���>��?!�>;,%?��<Ξ8?W6?�T8?ܸ�?;U?��>���>��\?��v=�N?��O?��Z>h�>'3�>�&�>)�%?��>��$@���=�Y>�&>��G?��>T^?�v>
ByJ���?q�e?̍2?eo?�~�>+�D?�@�>B�?�W�?���>���?�?�%?/9�?0��=��=Ԧ�<�*�?5�?���?C�@�R�?���>��?w]�=�B?��>
��??1?���>���>��?!�>;,%?��<Ϟ8?W6?�T8?۸�?;U?��>���>��\?��v=�N?��O?��Z>g�>(3�>�&�>)�%?��>��$@���=�Y>�&>��G?��>T^?�v>
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ByJ��<@@�b�?B\�?!!�@'T@�%�?�?�?x�?���?9�?�T�?gz@1�?��?;ь?)1�?��@��?�R�?>��?���@LT�?��?]�?^|�@w}@P"�?�?�?U@�r@%��?C�?ص?�]l@@Ӈ?S��?@��?q9�?ظ�?��?��?� @�6@G�e@���?�z�?�J�?��?2'@Q�?^׶?T��?��?��?�?/�?�H�?l�?_��?���?
ByJ��<@@�b�?C\�?!!�@'T@�%�?�?�?x�?���?9�?�T�?gz@1�?��?<ь?)1�?��@��?�R�??��?���@LT�?��?]�?_|�@w}@P"�?�?�?U@�r@&��?C�?ص?�]l@@Ӈ?S��?@��?q9�?ٸ�?��?��?� @�6@H�e@���?�z�?�J�?��?2'@P�?^׶?T��?��?��?�?/�?�H�?l�?_��?���?
Expand Down
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_dft/test_data_set_0/output_0.pb
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
ByJ��?�K�>��Q?A�@t��?
V$���I?�W��uB���>�d�=��?VQ?V��=q��>�~W>�Q�?�x�G>e�+�8^_�$��>�o2?�.�ޓ@�3ٽxD�<+7����?п�?24�=I�z>kL*��'A������=D�?D�?��
�|��
��B⽁���KR�?h@�:U�Z����?S��;%��<�)����>�f�����a�J�>�s=��?>��*�S �
ByJ��?�K�>��Q?A�@t��? V$���I?�W��uB���>�d�=��?VQ?V��=q��>�~W>�Q�?�x�G>d�+�8^_�$��>�o2?�.�ޓ@�3ٽxD�<+7����?п�?24�=I�z>jL*��'A������=D�?D�?��
�|��
��B����KR�?h@�:U�Z����?T��;%��;�)����>�f�����a�J�>�s=��?>��*�S �
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
ByJ��?�K�>��Q?A�@t��?
V$���I?�W��uB���>�d�=��?VQ?V��=q��>�~W>�Q�?�x�G>e�+�8^_�$��>�o2?�.�ޓ@�3ٽxD�<+7����?п�?24�=I�z>kL*��'A������=D�?D�?��
�|��
��B⽁���KR�?h@�:U�Z����?S��;%��<�)����>�f�����a�J�>�s=��?>��*�S �
ByJ��?�K�>��Q?A�@t��? V$���I?�W��uB���>�d�=��?VQ?V��=q��>�~W>�Q�?�x�G>d�+�8^_�$��>�o2?�.�ޓ@�3ٽxD�<+7����?п�?24�=I�z>jL*��'A������=D�?D�?��
�|��
��B����KR�?h@�:U�Z����?T��;%��;�)����>�f�����a�J�>�s=��?>��*�S �
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@

ByJ(C�=�iF>)��>��E?��x?��x?��E?$��>�iF>C�=
ByJ(@�=�iF>*��>��E?��x?��x?��E?"��>�iF>@�=
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@

ByJ(C�=�iF>)��>��E?��x?��x?��E?$��>�iF>C�=
ByJ(@�=�iF>*��>��E?��x?��x?��E?"��>�iF>@�=
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
ByJ����?��?aK@��&@���?�q�?̌s?;�)>vU�>�5�>5�?;��?8C?���>���>�?3h�?B�>7.^?�g%@�L$@%sy?TK�?�@i@')�?�@>�P�?k/�?�&�?���>|w@���?U]�>���? /�?˙�?���>GG�?Dv�?/�?X@�,�?�?���?�;�?��?���?qBf?�>k?�z?]�?"�?�S�>P��>�i�>(?��+?�4?I�4?��U?�h�?}t�>���?�>�?rr?mHh?p�:?��:?��? ��?ܟ?�p:?�h~?��r?��?hӟ>%�?�E�?w�p?�r�?��?Cd
@�^@��?�E�?6`�?tp�?|R�?��?E�z?2~?,O<?
ByJ����?��?aK@��&@���?�q�?̌s?;�)>wU�>�5�>5�?;��?8C?���>���>�?3h�?B�>8.^?�g%@�L$@&sy?TK�?�@i@')�?�@>�P�?k/�?�&�?���>|w@���?V]�>���?/�?˙�?���>GG�?Dv�?/�?X@�,�?�?���?�;�?��?���?qBf?�>k?�z?]�?"�?�S�>Q��>�i�>(?��+?�4?I�4?��U?�h�?~t�>���?�>�?rr?mHh?p�:?��:?��? ��?ܟ?�p:?�h~?��r?��?hӟ>%�?�E�?w�p?�r�?��?Cd
@�^@��?�E�?6`�?tp�?|R�?��?E�z?2~?,O<?
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
BYJl�d�?�;�>:�ſr����d���>D��>6nQ?��[?:���#rc�vyH�E3U?> ����,?�UD?�Pi?�ҿ�o����?��>�2�?�ϗ?��m����=msþ���
BYJl�d�?�;�>:�ſr����d���>D��>6nQ?��[?:���#rc�vyH�G3U?? ����,?�UD?�Pi?�ҿ�o����?��>�2�?�ϗ?��m����=msþ���
Original file line number Diff line number Diff line change
@@ -1 +1 @@
BYJl�d�?�;�>:�ſr����d���>D��>6nQ?��[?:���#rc�vyH�E3U?> ����,?�UD?�Pi?�ҿ�o����?��>�2�?�ϗ?��m����=msþ���
BYJl�d�?�;�>:�ſr����d���>D��>6nQ?��[?:���#rc�vyH�G3U?? ����,?�UD?�Pi?�ҿ�o����?��>�2�?�ϗ?��m����=msþ���
Original file line number Diff line number Diff line change
@@ -1 +1 @@
BYJl�d�?�;�>:�ſr����d���>D��>6nQ?��[?:���#rc�vyH�E3U?> ����,?�UD?�Pi?�ҿ�o����?��>�2�?�ϗ?��m����=msþ���
BYJl�d�?�;�>:�ſr����d���>D��>6nQ?��[?:���#rc�vyH�G3U?? ����,?�UD?�Pi?�ҿ�o����?��>�2�?�ϗ?��m����=msþ���
Binary file modified onnx/backend/test/data/node/test_pow/test_data_set_0/output_0.pb
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
�1�?
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
�1�?
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
�1�?
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
�1�?
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
�1�?
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
 BreducedJ0�eMIz�@�&~X@{��U��忎c;��^@n6���"@��
 BreducedJ0�eMIz�@�&~X@}��U��忎c;��^@n6���"@��
�1�?
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ByJ�%E5@hd�>�B�?͹�@�,J@����6��?����ӽ�.�>�>��@./V?F��=E��>w�>�b@�S��>�7v�H��a3?��y?6�O��-�@��~;=,�@�uH@W�@DM>cD�>�)����c����"� >0R�?_��?�N˾�%��`2���O��ru*��\@f1����ݸͿ��[?�W�6~[�\�����>��¼�>����>�S�=�<�>V-��ҽ�
ByJ�$E5@gd�>�B�?̹�@�,J@����6��?����ӽ�.�>�>��@//V?F��=E��>w�>�b@�S��>�7v�H��a3?��y?6�O��-�@��~;=,�@�uH@W�@DM>cD�>�)����c����"� >0R�?^��?�N˾�%��`2���O��ru*��\@g1����ݸͿ��[?�W�5~[�\�����>��¼�>����>�S�=�<�>V-��ҽ�
Expand Down
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ByJ�R������>�R�?=����JQ�f�����?-/�k%ԽF��>��>��A[�s?pm�=�u�>5z�>&PAzU����>s꒿��*?�D?!�?��j��I�����̎;=]�A��T�A�0A��>�h�>�P��P@@!���6k!>��4@��%@��оD����e޿������@�> �7���RmB���{?Y~�Ap1]��Ο�̛�>pw�ԟ������>�m�=�ğ>xZ<�~
ByJ�R������>�R�?=����JQ�f�����?-/�j%ԽF��>�>��A[�s?pm�=�u�>5z�>&PAzU����>s꒿��*?�D? �?��j��I�����̎;=]�A��T�A�0A��>�h�>�P��Q@@!���6k!>��4@��%@��оD����e޿������@�> �7���RmB���{?Y~�Ap1]��Ο�͛�>pw�՟������>�m�=�ğ>xZ<�
Expand Down
1 change: 1 addition & 0 deletions onnx/checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ void check_tensor(const TensorProto& tensor, const CheckerContext& ctx) {
case TensorProto::FLOAT8E5M2FNUZ:
case TensorProto::UINT4:
case TensorProto::INT4:
case TensorProto::FLOAT4E2M1:
check_field(int32_data);
break;

Expand Down
3 changes: 2 additions & 1 deletion onnx/common/ir_pb_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ Tensor tensorProtoToTensor(const ONNX_NAMESPACE::TensorProto& tp) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: {
ret.int32s().reserve(tp.int32_data_size());
for (int i = 0; i < tp.int32_data_size(); i++) {
ret.int32s().push_back(tp.int32_data(i));
Expand Down
1 change: 1 addition & 0 deletions onnx/defs/data_type_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ TypesWrapper::TypesWrapper() {
type_str_to_tensor_data_type_["float8e5m2fnuz"] = TensorProto_DataType_FLOAT8E5M2FNUZ;
type_str_to_tensor_data_type_["uint4"] = TensorProto_DataType_UINT4;
type_str_to_tensor_data_type_["int4"] = TensorProto_DataType_INT4;
type_str_to_tensor_data_type_["float4e2m1"] = TensorProto_DataType_FLOAT4E2M1;

for (auto& str_type_pair : type_str_to_tensor_data_type_) {
tensor_data_type_to_type_str_[str_type_pair.second] = str_type_pair.first;
Expand Down
1 change: 1 addition & 0 deletions onnx/defs/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ Status OnnxParser::Parse(TensorProto& tensorProto, const TypeProto& tensorTypePr
case TensorProto::DataType::TensorProto_DataType_FLOAT8E5M2:
case TensorProto::DataType::TensorProto_DataType_FLOAT8E5M2FNUZ:
case TensorProto::DataType::TensorProto_DataType_BOOL:
case TensorProto::DataType::TensorProto_DataType_FLOAT4E2M1:
PARSE_TOKEN(intval);
// TODO: check values are in the correct range.
tensorProto.add_int32_data(intval);
Expand Down
1 change: 1 addition & 0 deletions onnx/defs/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class PrimitiveTypeNameMap : public StringIntMap<PrimitiveTypeNameMap> {
map_["float8e5m2fnuz"] = TensorProto_DataType_FLOAT8E5M2FNUZ;
map_["uint4"] = TensorProto_DataType_UINT4;
map_["int4"] = TensorProto_DataType_INT4;
map_["float4e2m1"] = TensorProto_DataType_FLOAT4E2M1;
}

static bool IsTypeName(const std::string& dtype) {
Expand Down
Loading

0 comments on commit 74e11de

Please sign in to comment.