4
4
from typing import Generic , Self , TypeVar , TypeVarTuple , get_args , get_origin
5
5
from numba .core .types import Type as NumbaType
6
6
from numba .core .types import boolean , int8 , int16 , int32 , int64 , float32 , float64
7
+ from numba .extending import typeof_impl , type_callable
7
8
8
9
T = TypeVar ("T" )
9
10
Ts = TypeVarTuple ("Ts" )
10
11
11
- operator_error_message = "MLIRType should only be used for annotations."
12
+ # List of all MLIR types we define here, for use in other parts of the compiler
13
+ MLIR_TYPES = [] # populated via MLIRType's __init_subclass__
14
+
15
+
16
+ def check_for_value (a : "MLIRType" ):
17
+ if not hasattr (a , "value" ):
18
+ raise RuntimeError (
19
+ "Trying to use an operator on an MLIRType without a value."
20
+ )
12
21
13
22
14
23
class MLIRType (ABC ):
15
24
25
+ def __init__ (self , value : int ):
26
+ self .value = value
27
+
28
+ def __int__ (self ):
29
+ check_for_value (self )
30
+ return int (self .value )
31
+
32
+ def __index__ (self ):
33
+ check_for_value (self )
34
+ return int (self .value )
35
+
36
+ def __str__ (self ):
37
+ check_for_value (self )
38
+ return str (self .value )
39
+
40
+ def __repr__ (self ):
41
+ check_for_value (self )
42
+ return str (self .value )
43
+
44
+ def __eq__ (self , other ):
45
+ check_for_value (self )
46
+ if isinstance (other , MLIRType ):
47
+ check_for_value (other )
48
+ return self .value == other .value
49
+ return self .value == other
50
+
51
+ def __ne__ (self , other ):
52
+ return not self .__eq__ (other )
53
+
54
+ def __init_subclass__ (cls , ** kwargs ):
55
+ super ().__init_subclass__ (** kwargs )
56
+ MLIR_TYPES .append (cls )
57
+
16
58
@staticmethod
17
59
@abstractmethod
18
60
def numba_type () -> NumbaType :
19
61
raise NotImplementedError ("No numba type exists for a generic MLIRType" )
20
62
21
- def __add__ (self , other ) -> Self :
22
- raise RuntimeError (operator_error_message )
63
+ @staticmethod
64
+ @abstractmethod
65
+ def mlir_type () -> str :
66
+ raise NotImplementedError ("No mlir type exists for a generic MLIRType" )
67
+
68
+ def __add__ (self , other ):
69
+ check_for_value (self )
70
+ return self .value + other
71
+
72
+ def __radd__ (self , other ):
73
+ check_for_value (self )
74
+ return other + self .value
75
+
76
+ def __sub__ (self , other ):
77
+ check_for_value (self )
78
+ return self .value - other
79
+
80
+ def __rsub__ (self , other ):
81
+ check_for_value (self )
82
+ return other - self .value
83
+
84
+ def __mul__ (self , other ):
85
+ check_for_value (self )
86
+ return self .value * other
87
+
88
+ def __rmul__ (self , other ):
89
+ check_for_value (self )
90
+ return other * self .value
91
+
92
+ def __rshift__ (self , other ):
93
+ check_for_value (self )
94
+ return self .value >> other
95
+
96
+ def __rrshift__ (self , other ):
97
+ check_for_value (self )
98
+ return other >> self .value
23
99
24
- def __sub__ (self , other ) -> Self :
25
- raise RuntimeError (operator_error_message )
100
+ def __lshift__ (self , other ):
101
+ check_for_value (self )
102
+ return self .value << other
26
103
27
- def __mul__ (self , other ) -> Self :
28
- raise RuntimeError (operator_error_message )
104
+ def __rlshift__ (self , other ):
105
+ check_for_value (self )
106
+ return other << self .value
29
107
30
108
31
109
class Secret (Generic [T ], MLIRType ):
@@ -34,62 +112,106 @@ class Secret(Generic[T], MLIRType):
34
112
def numba_type () -> NumbaType :
35
113
raise NotImplementedError ("No numba type exists for a generic Secret" )
36
114
115
+ @staticmethod
116
+ def mlir_type () -> str :
117
+ raise NotImplementedError ("No mlir type exists for a generic Secret" )
118
+
37
119
38
120
class Tensor (Generic [* Ts ], MLIRType ):
39
121
40
122
@staticmethod
41
123
def numba_type () -> NumbaType :
42
124
raise NotImplementedError ("No numba type exists for a generic Tensor" )
43
125
126
+ @staticmethod
127
+ def mlir_type () -> str :
128
+ raise NotImplementedError ("No mlir type exists for a generic Tensor" )
129
+
44
130
45
131
class F32 (MLIRType ):
46
132
# TODO (#1162): For CKKS/Float: allow specifying actual intended precision/scale and warn/error if not achievable @staticmethod
47
133
@staticmethod
48
134
def numba_type () -> NumbaType :
49
135
return float32
50
136
137
+ @staticmethod
138
+ def mlir_type () -> str :
139
+ return "f32"
140
+
51
141
52
142
class F64 (MLIRType ):
53
143
# TODO (#1162): For CKKS/Float: allow specifying actual intended precision/scale and warn/error if not achievable @staticmethod
54
144
@staticmethod
55
145
def numba_type () -> NumbaType :
56
146
return float64
57
147
148
+ @staticmethod
149
+ def mlir_type () -> str :
150
+ return "f64"
151
+
58
152
59
153
class I1 (MLIRType ):
60
154
61
155
@staticmethod
62
156
def numba_type () -> NumbaType :
63
157
return boolean
64
158
159
+ @staticmethod
160
+ def mlir_type () -> str :
161
+ return "i1"
162
+
65
163
66
164
class I8 (MLIRType ):
67
165
68
166
@staticmethod
69
167
def numba_type () -> NumbaType :
70
168
return int8
71
169
170
+ @staticmethod
171
+ def mlir_type () -> str :
172
+ return "i8"
173
+
72
174
73
175
class I16 (MLIRType ):
74
176
75
177
@staticmethod
76
178
def numba_type () -> NumbaType :
77
179
return int16
78
180
181
+ @staticmethod
182
+ def mlir_type () -> str :
183
+ return "i16"
184
+
79
185
80
186
class I32 (MLIRType ):
81
187
82
188
@staticmethod
83
189
def numba_type () -> NumbaType :
84
190
return int32
85
191
192
+ @staticmethod
193
+ def mlir_type () -> str :
194
+ return "i32"
195
+
86
196
87
197
class I64 (MLIRType ):
88
198
89
199
@staticmethod
90
200
def numba_type () -> NumbaType :
91
201
return int64
92
202
203
+ @staticmethod
204
+ def mlir_type () -> str :
205
+ return "i64"
206
+
207
+
208
+ # Register the types defined above with Numba
209
+ for typ in [I8 , I16 , I32 , I64 , I1 , F32 , F64 ]:
210
+
211
+ @type_callable (typ )
212
+ def build_typer_function (context , typ = typ ):
213
+ return lambda value : typ .numba_type ()
214
+
93
215
94
216
# Helper functions
95
217
0 commit comments