Skip to content

Commit 1e4dc37

Browse files
frontend: allow use of MLIR Types as "casts"
* allow cleartext backend for CGGI * allow types.Boolean where types.Integer is expected * add support for `>>`/arith.shrsi * better support for `MLIRType` when running vanilla Python
1 parent b0cf72d commit 1e4dc37

File tree

9 files changed

+332
-37
lines changed

9 files changed

+332
-37
lines changed

MODULE.bazel.lock

Lines changed: 25 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

frontend/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,24 @@ frontend_test(
8888
],
8989
)
9090

91+
frontend_test(
92+
name = "cggi_test",
93+
srcs = ["cggi_test.py"],
94+
tags = [
95+
# copybara: manual
96+
"notap",
97+
],
98+
)
99+
100+
frontend_test(
101+
name = "cast_test",
102+
srcs = ["cast_test.py"],
103+
tags = [
104+
# copybara: manual
105+
"notap",
106+
],
107+
)
108+
91109
bzl_library(
92110
name = "testing_bzl",
93111
srcs = ["testing.bzl"],

frontend/cast_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from heir import compile
2+
from heir.mlir import I1, I8, Secret
3+
from heir.backends.cleartext import CleartextBackend
4+
5+
6+
from absl.testing import absltest # fmt: skip
7+
class EndToEndTest(absltest.TestCase):
8+
9+
def test_cggi_cast(self):
10+
11+
@compile(
12+
scheme="cggi",
13+
backend=CleartextBackend(),
14+
debug=True,
15+
)
16+
def foo(x: Secret[I8]):
17+
x0 = I1((x >> 7) & 1)
18+
return x0
19+
20+
# Test cleartext functionality
21+
self.assertEqual(1, foo.original(255))
22+
self.assertEqual(0, foo.original(16))
23+
24+
# Test FHE functionality
25+
self.assertEqual(1, foo(255))
26+
self.assertEqual(0, foo(16))
27+
28+
29+
if __name__ == "__main__":
30+
absltest.main()

frontend/cggi_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from heir import compile
2-
from heir.mlir import I8, Secret
2+
from heir.mlir import I1, I8, Secret
3+
from heir.backends.cleartext import CleartextBackend
34

45

56
from absl.testing import absltest # fmt: skip
67
class EndToEndTest(absltest.TestCase):
78

8-
def test_simple_arithmetic(self):
9+
def test_simple_cggi_arithmetic(self):
910

1011
@compile(
1112
scheme="cggi",
13+
backend=CleartextBackend(),
1214
debug="True",
1315
)
1416
def foo(a: Secret[I8], b: Secret[I8]):

frontend/heir/mlir/types.py

Lines changed: 129 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,106 @@
44
from typing import Generic, Self, TypeVar, TypeVarTuple, get_args, get_origin
55
from numba.core.types import Type as NumbaType
66
from numba.core.types import boolean, int8, int16, int32, int64, float32, float64
7+
from numba.extending import typeof_impl, type_callable
78

89
T = TypeVar("T")
910
Ts = TypeVarTuple("Ts")
1011

11-
operator_error_message = "MLIRType should only be used for annotations."
12+
# List of all MLIR types we define here, for use in other parts of the compiler
13+
MLIR_TYPES = [] # populated via MLIRType's __init_subclass__
14+
15+
16+
def check_for_value(a: "MLIRType"):
17+
if not hasattr(a, "value"):
18+
raise RuntimeError(
19+
"Trying to use an operator on an MLIRType without a value."
20+
)
1221

1322

1423
class MLIRType(ABC):
1524

25+
def __init__(self, value: int):
26+
self.value = value
27+
28+
def __int__(self):
29+
check_for_value(self)
30+
return int(self.value)
31+
32+
def __index__(self):
33+
check_for_value(self)
34+
return int(self.value)
35+
36+
def __str__(self):
37+
check_for_value(self)
38+
return str(self.value)
39+
40+
def __repr__(self):
41+
check_for_value(self)
42+
return str(self.value)
43+
44+
def __eq__(self, other):
45+
check_for_value(self)
46+
if isinstance(other, MLIRType):
47+
check_for_value(other)
48+
return self.value == other.value
49+
return self.value == other
50+
51+
def __ne__(self, other):
52+
return not self.__eq__(other)
53+
54+
def __init_subclass__(cls, **kwargs):
55+
super().__init_subclass__(**kwargs)
56+
MLIR_TYPES.append(cls)
57+
1658
@staticmethod
1759
@abstractmethod
1860
def numba_type() -> NumbaType:
1961
raise NotImplementedError("No numba type exists for a generic MLIRType")
2062

21-
def __add__(self, other) -> Self:
22-
raise RuntimeError(operator_error_message)
63+
@staticmethod
64+
@abstractmethod
65+
def mlir_type() -> str:
66+
raise NotImplementedError("No mlir type exists for a generic MLIRType")
67+
68+
def __add__(self, other):
69+
check_for_value(self)
70+
return self.value + other
71+
72+
def __radd__(self, other):
73+
check_for_value(self)
74+
return other + self.value
75+
76+
def __sub__(self, other):
77+
check_for_value(self)
78+
return self.value - other
79+
80+
def __rsub__(self, other):
81+
check_for_value(self)
82+
return other - self.value
83+
84+
def __mul__(self, other):
85+
check_for_value(self)
86+
return self.value * other
87+
88+
def __rmul__(self, other):
89+
check_for_value(self)
90+
return other * self.value
91+
92+
def __rshift__(self, other):
93+
check_for_value(self)
94+
return self.value >> other
95+
96+
def __rrshift__(self, other):
97+
check_for_value(self)
98+
return other >> self.value
2399

24-
def __sub__(self, other) -> Self:
25-
raise RuntimeError(operator_error_message)
100+
def __lshift__(self, other):
101+
check_for_value(self)
102+
return self.value << other
26103

27-
def __mul__(self, other) -> Self:
28-
raise RuntimeError(operator_error_message)
104+
def __rlshift__(self, other):
105+
check_for_value(self)
106+
return other << self.value
29107

30108

31109
class Secret(Generic[T], MLIRType):
@@ -34,62 +112,106 @@ class Secret(Generic[T], MLIRType):
34112
def numba_type() -> NumbaType:
35113
raise NotImplementedError("No numba type exists for a generic Secret")
36114

115+
@staticmethod
116+
def mlir_type() -> str:
117+
raise NotImplementedError("No mlir type exists for a generic Secret")
118+
37119

38120
class Tensor(Generic[*Ts], MLIRType):
39121

40122
@staticmethod
41123
def numba_type() -> NumbaType:
42124
raise NotImplementedError("No numba type exists for a generic Tensor")
43125

126+
@staticmethod
127+
def mlir_type() -> str:
128+
raise NotImplementedError("No mlir type exists for a generic Tensor")
129+
44130

45131
class F32(MLIRType):
46132
# TODO (#1162): For CKKS/Float: allow specifying actual intended precision/scale and warn/error if not achievable @staticmethod
47133
@staticmethod
48134
def numba_type() -> NumbaType:
49135
return float32
50136

137+
@staticmethod
138+
def mlir_type() -> str:
139+
return "f32"
140+
51141

52142
class F64(MLIRType):
53143
# TODO (#1162): For CKKS/Float: allow specifying actual intended precision/scale and warn/error if not achievable @staticmethod
54144
@staticmethod
55145
def numba_type() -> NumbaType:
56146
return float64
57147

148+
@staticmethod
149+
def mlir_type() -> str:
150+
return "f64"
151+
58152

59153
class I1(MLIRType):
60154

61155
@staticmethod
62156
def numba_type() -> NumbaType:
63157
return boolean
64158

159+
@staticmethod
160+
def mlir_type() -> str:
161+
return "i1"
162+
65163

66164
class I8(MLIRType):
67165

68166
@staticmethod
69167
def numba_type() -> NumbaType:
70168
return int8
71169

170+
@staticmethod
171+
def mlir_type() -> str:
172+
return "i8"
173+
72174

73175
class I16(MLIRType):
74176

75177
@staticmethod
76178
def numba_type() -> NumbaType:
77179
return int16
78180

181+
@staticmethod
182+
def mlir_type() -> str:
183+
return "i16"
184+
79185

80186
class I32(MLIRType):
81187

82188
@staticmethod
83189
def numba_type() -> NumbaType:
84190
return int32
85191

192+
@staticmethod
193+
def mlir_type() -> str:
194+
return "i32"
195+
86196

87197
class I64(MLIRType):
88198

89199
@staticmethod
90200
def numba_type() -> NumbaType:
91201
return int64
92202

203+
@staticmethod
204+
def mlir_type() -> str:
205+
return "i64"
206+
207+
208+
# Register the types defined above with Numba
209+
for typ in [I8, I16, I32, I64, I1, F32, F64]:
210+
211+
@type_callable(typ)
212+
def build_typer_function(context, typ=typ):
213+
return lambda value: typ.numba_type()
214+
93215

94216
# Helper functions
95217

0 commit comments

Comments
 (0)