Skip to content

frontend: support "casting" between MLIR types #1825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions MODULE.bazel.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions frontend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,24 @@ frontend_test(
],
)

frontend_test(
name = "cggi_test",
srcs = ["cggi_test.py"],
tags = [
# copybara: manual
"notap",
],
)

frontend_test(
name = "cast_test",
srcs = ["cast_test.py"],
tags = [
# copybara: manual
"notap",
],
)

bzl_library(
name = "testing_bzl",
srcs = ["testing.bzl"],
Expand Down
30 changes: 30 additions & 0 deletions frontend/cast_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from heir import compile
from heir.mlir import I1, I8, Secret
from heir.backends.cleartext import CleartextBackend


from absl.testing import absltest # fmt: skip
class EndToEndTest(absltest.TestCase):

def test_cggi_cast(self):

@compile(
scheme="cggi",
backend=CleartextBackend(),
debug=True,
)
def foo(x: Secret[I8]):
x0 = I1((x >> 7) & 1)
return x0

# Test cleartext functionality
self.assertEqual(1, foo.original(255))
self.assertEqual(0, foo.original(16))

# Test FHE functionality
self.assertEqual(1, foo(255))
self.assertEqual(0, foo(16))


if __name__ == "__main__":
absltest.main()
4 changes: 3 additions & 1 deletion frontend/cggi_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from heir import compile
from heir.mlir import I8, Secret
from heir.backends.cleartext import CleartextBackend


from absl.testing import absltest # fmt: skip
class EndToEndTest(absltest.TestCase):

def test_simple_arithmetic(self):
def test_simple_cggi_arithmetic(self):

@compile(
scheme="cggi",
backend=CleartextBackend(),
debug="True",
)
def foo(a: Secret[I8], b: Secret[I8]):
Expand Down
136 changes: 129 additions & 7 deletions frontend/heir/mlir/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,106 @@
from typing import Generic, Self, TypeVar, TypeVarTuple, get_args, get_origin
from numba.core.types import Type as NumbaType
from numba.core.types import boolean, int8, int16, int32, int64, float32, float64
from numba.extending import typeof_impl, type_callable

T = TypeVar("T")
Ts = TypeVarTuple("Ts")

operator_error_message = "MLIRType should only be used for annotations."
# List of all MLIR types we define here, for use in other parts of the compiler
MLIR_TYPES = [] # populated via MLIRType's __init_subclass__


def check_for_value(a: "MLIRType"):
if not hasattr(a, "value"):
raise RuntimeError(
"Trying to use an operator on an MLIRType without a value."
)


class MLIRType(ABC):

def __init__(self, value: int):
self.value = value

def __int__(self):
check_for_value(self)
return int(self.value)

def __index__(self):
check_for_value(self)
return int(self.value)

def __str__(self):
check_for_value(self)
return str(self.value)

def __repr__(self):
check_for_value(self)
return str(self.value)

def __eq__(self, other):
check_for_value(self)
if isinstance(other, MLIRType):
check_for_value(other)
return self.value == other.value
return self.value == other

def __ne__(self, other):
return not self.__eq__(other)

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
MLIR_TYPES.append(cls)

@staticmethod
@abstractmethod
def numba_type() -> NumbaType:
raise NotImplementedError("No numba type exists for a generic MLIRType")

def __add__(self, other) -> Self:
raise RuntimeError(operator_error_message)
@staticmethod
@abstractmethod
def mlir_type() -> str:
raise NotImplementedError("No mlir type exists for a generic MLIRType")

def __add__(self, other):
check_for_value(self)
return self.value + other

def __radd__(self, other):
check_for_value(self)
return other + self.value

def __sub__(self, other):
check_for_value(self)
return self.value - other

def __rsub__(self, other):
check_for_value(self)
return other - self.value

def __mul__(self, other):
check_for_value(self)
return self.value * other

def __rmul__(self, other):
check_for_value(self)
return other * self.value

def __rshift__(self, other):
check_for_value(self)
return self.value >> other

def __rrshift__(self, other):
check_for_value(self)
return other >> self.value

def __sub__(self, other) -> Self:
raise RuntimeError(operator_error_message)
def __lshift__(self, other):
check_for_value(self)
return self.value << other

def __mul__(self, other) -> Self:
raise RuntimeError(operator_error_message)
def __rlshift__(self, other):
check_for_value(self)
return other << self.value


class Secret(Generic[T], MLIRType):
Expand All @@ -34,62 +112,106 @@ class Secret(Generic[T], MLIRType):
def numba_type() -> NumbaType:
raise NotImplementedError("No numba type exists for a generic Secret")

@staticmethod
def mlir_type() -> str:
raise NotImplementedError("No mlir type exists for a generic Secret")


class Tensor(Generic[*Ts], MLIRType):

@staticmethod
def numba_type() -> NumbaType:
raise NotImplementedError("No numba type exists for a generic Tensor")

@staticmethod
def mlir_type() -> str:
raise NotImplementedError("No mlir type exists for a generic Tensor")


class F32(MLIRType):
# TODO (#1162): For CKKS/Float: allow specifying actual intended precision/scale and warn/error if not achievable @staticmethod
@staticmethod
def numba_type() -> NumbaType:
return float32

@staticmethod
def mlir_type() -> str:
return "f32"


class F64(MLIRType):
# TODO (#1162): For CKKS/Float: allow specifying actual intended precision/scale and warn/error if not achievable @staticmethod
@staticmethod
def numba_type() -> NumbaType:
return float64

@staticmethod
def mlir_type() -> str:
return "f64"


class I1(MLIRType):

@staticmethod
def numba_type() -> NumbaType:
return boolean

@staticmethod
def mlir_type() -> str:
return "i1"


class I8(MLIRType):

@staticmethod
def numba_type() -> NumbaType:
return int8

@staticmethod
def mlir_type() -> str:
return "i8"


class I16(MLIRType):

@staticmethod
def numba_type() -> NumbaType:
return int16

@staticmethod
def mlir_type() -> str:
return "i16"


class I32(MLIRType):

@staticmethod
def numba_type() -> NumbaType:
return int32

@staticmethod
def mlir_type() -> str:
return "i32"


class I64(MLIRType):

@staticmethod
def numba_type() -> NumbaType:
return int64

@staticmethod
def mlir_type() -> str:
return "i64"


# Register the types defined above with Numba
for typ in [I8, I16, I32, I64, I1, F32, F64]:

@type_callable(typ)
def build_typer_function(context, typ=typ):
return lambda value: typ.numba_type()


# Helper functions

Expand Down
Loading
Loading