Skip to content

Numpy Serialization is not Unique for FP8 dtypes #324

@jon-chuang

Description

@jon-chuang
import jax.numpy as jnp
import numpy as np

>>> np.arange(10).astype(ml_dtypes.float8_e4m3b11fnuz).dtype.str
'<V1'
>>> np.arange(10).astype(ml_dtypes.float8_e4m3fn).dtype.str
'<V1'
>>> np.arange(10).astype(ml_dtypes.float8_e4m3).dtype.str
'<V1'
>>> np.asarray(jnp.arange(10).astype(jnp.float8_e3m4)).dtype.str
'<V1'

I just want to point out how big of a footgun this is, and it seems fundamentally misguided by mapping all the above dtypes to have the same kNpyDescrKind:

static constexpr char kNpyDescrKind = 'V'; // Void

Though I doubt we would want to change this now because of backwards compatibility.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions