Skip to content

Commit

Permalink
refactor coerce and constant
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Aug 7, 2023
1 parent 7a38d5d commit 1a8a998
Show file tree
Hide file tree
Showing 12 changed files with 835 additions and 701 deletions.
103 changes: 89 additions & 14 deletions mlir_utils/dialects/ext/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
from mlir.dialects import arith as arith_dialect
from mlir.dialects import complex as complex_dialect
from mlir.dialects._arith_ops_ext import _is_integer_like_type
from mlir.dialects._ods_common import get_op_result_or_value
from mlir.dialects.linalg.opdsl.lang.emitter import (
Expand All @@ -19,6 +20,7 @@
Context,
DenseElementsAttr,
IndexType,
InsertionPoint,
IntegerAttr,
IntegerType,
Location,
Expand All @@ -28,9 +30,20 @@
Type,
Value,
register_attribute_builder,
ComplexType,
BF16Type,
F16Type,
F32Type,
F64Type,
FloatAttr,
)

from mlir_utils.util import get_result_or_results, maybe_cast, get_user_code_loc
from mlir_utils.util import (
get_result_or_results,
maybe_cast,
get_user_code_loc,
register_value_caster,
)

try:
from mlir_utils.dialects.arith import *
Expand All @@ -46,7 +59,8 @@ def constant(
index: Optional[bool] = None,
*,
loc: Location = None,
) -> arith_dialect.ConstantOp:
ip: InsertionPoint = None,
) -> Value:
"""Instantiate arith.constant with value `value`.
Args:
Expand All @@ -67,21 +81,62 @@ def constant(
type = IndexType.get()
if type is None:
type = infer_mlir_type(value)
elif RankedTensorType.isinstance(type) and isinstance(value, (int, float, bool)):

assert type is not None

if _is_complex_type(type):
value = complex(value)
return maybe_cast(
get_result_or_results(
complex_dialect.ConstantOp(
type,
list(
map(
lambda x: FloatAttr.get(type.element_type, x),
[value.real, value.imag],
)
),
loc=loc,
ip=ip,
)
)
)

if _is_floating_point_type(type) and not isinstance(value, np.ndarray):
value = float(value)

if RankedTensorType.isinstance(type) and isinstance(value, (int, float, bool)):
ranked_tensor_type = RankedTensorType(type)
value = np.ones(
value = np.full(
ranked_tensor_type.shape,
value,
dtype=mlir_type_to_np_dtype(ranked_tensor_type.element_type),
)
assert type is not None

if isinstance(value, np.ndarray):
value = DenseElementsAttr.get(
value,
type=type,
)

return maybe_cast(
get_result_or_results(arith_dialect.ConstantOp(type, value, loc=loc))
get_result_or_results(arith_dialect.ConstantOp(type, value, loc=loc, ip=ip))
)


def index_cast(
value: Value,
*,
to: Type = None,
loc: Location = None,
ip: InsertionPoint = None,
) -> Value:
if loc is None:
loc = get_user_code_loc()
if to is None:
to = IndexType.get()
return maybe_cast(
get_result_or_results(arith_dialect.IndexCastOp(to, value, loc=loc, ip=ip))
)


Expand Down Expand Up @@ -231,6 +286,7 @@ def _binary_op(
rhs: "ArithValue",
op: str,
predicate: str = None,
signedness: str = None,
*,
loc: Location = None,
) -> "ArithValue":
Expand All @@ -247,12 +303,15 @@ def _binary_op(
"""
if loc is None:
loc = get_user_code_loc()
if not isinstance(rhs, lhs.__class__):
if (
isinstance(rhs, Value)
and lhs.type != rhs.type
or isinstance(rhs, (float, int, bool, np.ndarray))
):
lhs, rhs = lhs.coerce(rhs)
if lhs.type != rhs.type:
raise ValueError(f"{lhs=} {rhs=} must have the same type.")
assert lhs.type == rhs.type, f"{lhs=} {rhs=} must have the same type."

assert op in {"add", "sub", "mul", "cmp", "truediv", "floordiv", "mod"}
assert op in {"add", "and", "or", "sub", "mul", "cmp", "truediv", "floordiv", "mod"}

if op == "cmp":
assert predicate is not None
Expand Down Expand Up @@ -301,15 +360,20 @@ def _binary_op(
elif _is_integer_like_type(lhs.dtype):
# eq, ne signs don't matter
if predicate not in {"eq", "ne"}:
if lhs.dtype.is_signed:
predicate = "s" + predicate
if signedness is not None:
predicate = signedness + predicate
else:
predicate = "u" + predicate
if lhs.dtype.is_signed:
predicate = "s" + predicate
else:
predicate = "u" + predicate
return lhs.__class__(op(predicate, lhs, rhs, loc=loc), dtype=lhs.dtype)
else:
return lhs.__class__(op(lhs, rhs, loc=loc), dtype=lhs.dtype)


# TODO(max): these could be generic in the dtype
# TODO(max): hit .verify() before constructing (maybe)
class ArithValue(Value, metaclass=ArithValueMeta):
"""Class for functionality shared by Value subclasses that support
arithmetic operations.
Expand Down Expand Up @@ -363,6 +427,9 @@ def __repr__(self):
__rsub__ = partialmethod(_binary_op, op="sub")
__rmul__ = partialmethod(_binary_op, op="mul")

__and__ = partialmethod(_binary_op, op="and")
__or__ = partialmethod(_binary_op, op="or")

def __eq__(self, other):
if not isinstance(other, self.__class__):
try:
Expand Down Expand Up @@ -435,6 +502,14 @@ def __float__(self):
def coerce(self, other) -> tuple["Scalar", "Scalar"]:
if isinstance(other, (int, float, bool)):
other = Scalar(other, dtype=self.dtype)
elif isinstance(other, Scalar) and _is_index_type(self.type):
other = index_cast(other)
elif isinstance(other, Scalar) and _is_index_type(other.type):
other = index_cast(other, to=self.type)
else:
raise ValueError(f"can't coerce {other=} to Scalar")
raise ValueError(f"can't coerce {other=} to {self=}")
return self, other


for t in [BF16Type, F16Type, F32Type, F64Type, IndexType, IntegerType, ComplexType]:
register_value_caster(t.static_typeid)(Scalar)
19 changes: 11 additions & 8 deletions mlir_utils/dialects/ext/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional, Sequence

from bytecode import ConcreteBytecode, ConcreteInstr
from mlir.dialects.linalg.opdsl.lang.emitter import _is_index_type
from mlir.dialects.scf import IfOp, ForOp
from mlir.ir import InsertionPoint, Value, OpResultList, OpResult

Expand All @@ -16,7 +17,7 @@
OpCode,
)
from mlir_utils.ast.util import ast_call, set_lineno
from mlir_utils.dialects.ext.arith import constant
from mlir_utils.dialects.ext.arith import constant, index_cast
from mlir_utils.dialects.scf import yield_ as yield__
from mlir_utils.util import (
region_op,
Expand All @@ -43,15 +44,17 @@ def _for(
if stop is None:
stop = start
start = 0
if isinstance(start, int):
start = constant(start, index=True)
if isinstance(stop, int):
stop = constant(stop, index=True)
if isinstance(step, int):
step = constant(step, index=True)
params = [start, stop, step]
for i, p in enumerate(params):
if isinstance(p, int):
p = constant(p, index=True)
if not _is_index_type(p.type):
p = index_cast(p)
params[i] = p

if loc is None:
loc = get_user_code_loc()
return ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
return ForOp(*params, iter_args, loc=loc, ip=ip)


for_ = region_op(_for, terminator=yield__)
Expand Down
14 changes: 8 additions & 6 deletions mlir_utils/dialects/ext/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __getitem__(self, idx: tuple) -> "Tensor":
if isinstance(idx, tuple) and all(i == slice(None) for i in idx):
return self
if idx is None:
return _expand_dims(self, (0,))
return expand_dims(self, (0,))

idx = list((idx,) if isinstance(idx, int) else idx)
for i, d in enumerate(idx):
Expand Down Expand Up @@ -198,8 +198,10 @@ def coerce(self, other) -> tuple["Tensor", "Tensor"]:
if isinstance(other, (int, float)):
other = Tensor(np.full(self.shape, other), dtype=self.dtype)
return self, other
elif _is_scalar(other):
other = tensor.splat(self.type, other)
elif isinstance(other, Scalar):
other = tensor.splat(
RankedTensorType.get(self.shape, other.dtype), other
)
return self, other

raise ValueError(f"can't coerce unknown {other=}")
Expand Down Expand Up @@ -256,7 +258,7 @@ def static_strides(self):
return tuple(strides)


def _expand_dims(inp, newaxis_dims) -> Tensor:
def expand_dims(inp, newaxis_dims) -> Tensor:
"""Expand the shape of a tensor.
Insert a new axis that will appear at the `axis` position in the expanded
Expand Down Expand Up @@ -514,7 +516,7 @@ def _extract_slice(
raise ValueError(f"non-constant indices not supported {indexer}")

# This adds newaxis/None dimensions.
return _expand_dims(out, indexer.newaxis_dims)
return expand_dims(out, indexer.newaxis_dims)


def _insert_slice(
Expand All @@ -523,7 +525,7 @@ def _insert_slice(
idx,
):
if isinstance(source, Scalar):
source = _expand_dims(source, (0,))
source = expand_dims(source, (0,))

indexer = _indices_to_indexer(idx, dest.shape)

Expand Down
53 changes: 44 additions & 9 deletions mlir_utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,24 @@
import numpy as np
from mlir.ir import (
Attribute,
BF16Type,
ComplexType,
F16Type,
F32Type,
F64Type,
Float8E5M2Type,
Float8E4M3FNType,
Float8E4M3B11FNUZType,
IndexType,
IntegerType,
MemRefType,
NoneType,
OpaqueType,
RankedTensorType,
Type,
UnrankedMemRefType,
UnrankedTensorType,
VectorType,
BF16Type,
OpaqueType,
)

_index_t = lambda: IndexType.get()
Expand All @@ -43,6 +48,16 @@
_f64_t = lambda: F64Type.get()
_bf16_t = lambda: BF16Type.get()

_f8e5m2_t = lambda: Float8E5M2Type.get()
_f8e4m3_t = lambda: Float8E4M3FNType.get()
_f8e4m3b11fnuz_t = lambda: Float8E4M3B11FNUZType.get()

_cmp16_t = lambda: ComplexType.get(_f16_t())
_cmp32_t = lambda: ComplexType.get(_f32_t())
_cmp64_t = lambda: ComplexType.get(_f64_t())

_none_t = lambda: NoneType.get()

opaque_t = lambda dialect_namespace, buffer: OpaqueType.get(dialect_namespace, buffer)


Expand All @@ -53,26 +68,29 @@ def _placeholder_opaque_t():
_name_to_type = {
"index_t": _index_t,
"bool_t": _bool_t,

"i8_t": _i8_t,
"i16_t": _i16_t,
"i32_t": _i32_t,
"i64_t": _i64_t,

"si8_t": _si8_t,
"si16_t": _si16_t,
"si32_t": _si32_t,
"si64_t": _si64_t,

"ui8_t": _ui8_t,
"ui16_t": _ui16_t,
"ui32_t": _ui32_t,
"ui64_t": _ui64_t,

"f16_t": _f16_t,
"f32_t": _f32_t,
"f64_t": _f64_t,
"bf16_t": _bf16_t,
"f8e5m2_t": _f8e5m2_t,
"f8e4m3_t": _f8e4m3_t,
"f8e4m3b11fnuz_t": _f8e4m3b11fnuz_t,
"cmp16_t": _cmp16_t,
"cmp32_t": _cmp32_t,
"cmp64_t": _cmp64_t,
"none_t": _none_t,
}


Expand Down Expand Up @@ -115,7 +133,7 @@ def mlir_type_to_np_dtype(mlir_type):

def infer_mlir_type(
py_val: Union[int, float, bool, np.ndarray]
) -> Union[IntegerType, F64Type, RankedTensorType]:
) -> Union[IntegerType, F32Type, F64Type, RankedTensorType]:
"""Infer MLIR type (`ir.Type`) from supported python values.
Note ints and floats are mapped to 64-bit types.
Expand All @@ -129,9 +147,26 @@ def infer_mlir_type(
if isinstance(py_val, bool):
return _bool_t()
elif isinstance(py_val, int):
return _i64_t()
if -(2 ** 31) <= py_val < 2 ** 31:
return _i32_t()
elif 2 ** 31 <= py_val < 2 ** 32:
return _ui32_t()
elif -(2 ** 63) <= py_val < 2 ** 63:
return _i64_t()
elif 2 ** 63 <= py_val < 2 ** 64:
return _ui64_t()
else:
raise RuntimeError(f"Nonrepresentable integer {py_val}.")
elif isinstance(py_val, float):
return _f64_t()
if (
abs(py_val) == float("inf")
or abs(py_val) == 0.0
or py_val != py_val # NaN
or np.finfo(np.float32).min <= abs(py_val) <= np.finfo(np.float32).max
):
return _f32_t()
else:
return _f64_t()
elif isinstance(py_val, np.ndarray):
dtype = np_dtype_to_mlir_type(py_val.dtype.type)
return RankedTensorType.get(py_val.shape, dtype)
Expand Down
Loading

0 comments on commit 1a8a998

Please sign in to comment.