diff --git a/MODULE.bazel.lock b/MODULE.bazel.lock index 9336d77a58..c1d26c8ddd 100644 --- a/MODULE.bazel.lock +++ b/MODULE.bazel.lock @@ -408,12 +408,11 @@ "https://bcr.bazel.build/modules/rules_java/7.2.0/MODULE.bazel": "06c0334c9be61e6cef2c8c84a7800cef502063269a5af25ceb100b192453d4ab", "https://bcr.bazel.build/modules/rules_java/7.3.2/MODULE.bazel": "50dece891cfdf1741ea230d001aa9c14398062f2b7c066470accace78e412bc2", "https://bcr.bazel.build/modules/rules_java/7.6.1/MODULE.bazel": "2f14b7e8a1aa2f67ae92bc69d1ec0fa8d9f827c4e17ff5e5f02e91caa3b2d0fe", - "https://bcr.bazel.build/modules/rules_java/8.11.0/MODULE.bazel": "c3d280bc5ff1038dcb3bacb95d3f6b83da8dd27bba57820ec89ea4085da767ad", - "https://bcr.bazel.build/modules/rules_java/8.11.0/source.json": "302b52a39259a85aa06ca3addb9787864ca3e03b432a5f964ea68244397e7544", "https://bcr.bazel.build/modules/rules_java/8.3.2/MODULE.bazel": "7336d5511ad5af0b8615fdc7477535a2e4e723a357b6713af439fe8cf0195017", "https://bcr.bazel.build/modules/rules_java/8.5.1/MODULE.bazel": "d8a9e38cc5228881f7055a6079f6f7821a073df3744d441978e7a43e20226939", "https://bcr.bazel.build/modules/rules_java/8.6.1/MODULE.bazel": "f4808e2ab5b0197f094cabce9f4b006a27766beb6a9975931da07099560ca9c2", "https://bcr.bazel.build/modules/rules_java/8.6.3/MODULE.bazel": "e90505b7a931d194245ffcfb6ff4ca8ef9d46b4e830d12e64817752e0198e2ed", + "https://bcr.bazel.build/modules/rules_java/8.6.3/source.json": "8330cc5d277085bbcf93e9f1c85c24d06975585606a1215df4faf886a8d3cc9e", "https://bcr.bazel.build/modules/rules_jvm_external/4.4.2/MODULE.bazel": "a56b85e418c83eb1839819f0b515c431010160383306d13ec21959ac412d2fe7", "https://bcr.bazel.build/modules/rules_jvm_external/5.1/MODULE.bazel": "33f6f999e03183f7d088c9be518a63467dfd0be94a11d0055fe2d210f89aa909", "https://bcr.bazel.build/modules/rules_jvm_external/5.2/MODULE.bazel": "d9351ba35217ad0de03816ef3ed63f89d411349353077348a45348b096615036", @@ -2242,6 +2241,28 @@ ] } }, + "@@rules_java+//java:rules_java_deps.bzl%compatibility_proxy": { + "general": { + "bzlTransitiveDigest": "9SPkp75wN6dP9sKiOgU1uOCTNNA08v8PVBMYs+SZ27s=", + "usagesDigest": "ZsMvJZXTm/u0lYuq0P/yTF2h6FdynULVT3mwKA1SE5k=", + "recordedFileInputs": {}, + "recordedDirentsInputs": {}, + "envVariables": {}, + "generatedRepoSpecs": { + "compatibility_proxy": { + "repoRuleId": "@@rules_java+//java:rules_java_deps.bzl%_compatibility_proxy_repo_rule", + "attributes": {} + } + }, + "recordedRepoMappingEntries": [ + [ + "rules_java+", + "bazel_tools", + "bazel_tools" + ] + ] + } + }, "@@rules_kotlin+//src/main/starlark/core/repositories:bzlmod_setup.bzl%rules_kotlin_extensions": { "general": { "bzlTransitiveDigest": "sFhcgPbDQehmbD1EOXzX4H1q/CD5df8zwG4kp4jbvr8=", @@ -2430,7 +2451,7 @@ }, "@@rules_python+//python/extensions:pip.bzl%pip": { "general": { - "bzlTransitiveDigest": "pBYZfL5VYF4BT5jARY3p7l3d7Uknixz0stBBKWj2iaw=", + "bzlTransitiveDigest": "NLNmCO7BV8DxB9vHDXiQ7ZU1Kdtl0Sr3I2f3bKDuOCc=", "usagesDigest": "fJqnUFr/GgZxd/VmNa3e4q3P8xv3wi51f6kXAa9+tsg=", "recordedFileInputs": { "@@//requirements-dev.txt": "54fd907ca5b52c0522f9c479b4fe17cf8af6843d1d5b58f62a14705a0552c692", @@ -13686,7 +13707,7 @@ }, "@@rules_rust+//rust:extensions.bzl%rust": { "general": { - "bzlTransitiveDigest": "0XHSLimPR0yIi1HIZgsMrr67WhwiJ9YwNk9n4/D118s=", + "bzlTransitiveDigest": "u0AbcREuyXinWoLmuEjo7UWBx+Y+ROXvSITNXzSqOvo=", "usagesDigest": "ozx08ZbgRXTJw0zCaO/xtMUzgGLvwaQkZGnUo6tlyHM=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, diff --git a/frontend/BUILD b/frontend/BUILD index 2429fca897..4e51de535e 100644 --- a/frontend/BUILD +++ b/frontend/BUILD @@ -88,6 +88,24 @@ frontend_test( ], ) +frontend_test( + name = "cggi_test", + srcs = ["cggi_test.py"], + tags = [ + # copybara: manual + "notap", + ], +) + +frontend_test( + name = "cast_test", + srcs = ["cast_test.py"], + tags = [ + # copybara: manual + "notap", + ], +) + bzl_library( name = "testing_bzl", srcs = ["testing.bzl"], diff --git a/frontend/cast_test.py b/frontend/cast_test.py new file mode 100644 index 0000000000..3bd197698e --- /dev/null +++ b/frontend/cast_test.py @@ -0,0 +1,30 @@ +from heir import compile +from heir.mlir import I1, I8, Secret +from heir.backends.cleartext import CleartextBackend + + +from absl.testing import absltest # fmt: skip +class EndToEndTest(absltest.TestCase): + + def test_cggi_cast(self): + + @compile( + scheme="cggi", + backend=CleartextBackend(), + debug=True, + ) + def foo(x: Secret[I8]): + x0 = I1((x >> 7) & 1) + return x0 + + # Test cleartext functionality + self.assertEqual(1, foo.original(255)) + self.assertEqual(0, foo.original(16)) + + # Test FHE functionality + self.assertEqual(1, foo(255)) + self.assertEqual(0, foo(16)) + + +if __name__ == "__main__": + absltest.main() diff --git a/frontend/cggi_test.py b/frontend/cggi_test.py index d89d5dba94..13a0beb08f 100644 --- a/frontend/cggi_test.py +++ b/frontend/cggi_test.py @@ -1,14 +1,16 @@ from heir import compile from heir.mlir import I8, Secret +from heir.backends.cleartext import CleartextBackend from absl.testing import absltest # fmt: skip class EndToEndTest(absltest.TestCase): - def test_simple_arithmetic(self): + def test_simple_cggi_arithmetic(self): @compile( scheme="cggi", + backend=CleartextBackend(), debug="True", ) def foo(a: Secret[I8], b: Secret[I8]): diff --git a/frontend/heir/mlir/types.py b/frontend/heir/mlir/types.py index 06e94bb08d..2714f9b08a 100644 --- a/frontend/heir/mlir/types.py +++ b/frontend/heir/mlir/types.py @@ -4,28 +4,106 @@ from typing import Generic, Self, TypeVar, TypeVarTuple, get_args, get_origin from numba.core.types import Type as NumbaType from numba.core.types import boolean, int8, int16, int32, int64, float32, float64 +from numba.extending import typeof_impl, type_callable T = TypeVar("T") Ts = TypeVarTuple("Ts") -operator_error_message = "MLIRType should only be used for annotations." +# List of all MLIR types we define here, for use in other parts of the compiler +MLIR_TYPES = [] # populated via MLIRType's __init_subclass__ + + +def check_for_value(a: "MLIRType"): + if not hasattr(a, "value"): + raise RuntimeError( + "Trying to use an operator on an MLIRType without a value." + ) class MLIRType(ABC): + def __init__(self, value: int): + self.value = value + + def __int__(self): + check_for_value(self) + return int(self.value) + + def __index__(self): + check_for_value(self) + return int(self.value) + + def __str__(self): + check_for_value(self) + return str(self.value) + + def __repr__(self): + check_for_value(self) + return str(self.value) + + def __eq__(self, other): + check_for_value(self) + if isinstance(other, MLIRType): + check_for_value(other) + return self.value == other.value + return self.value == other + + def __ne__(self, other): + return not self.__eq__(other) + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + MLIR_TYPES.append(cls) + @staticmethod @abstractmethod def numba_type() -> NumbaType: raise NotImplementedError("No numba type exists for a generic MLIRType") - def __add__(self, other) -> Self: - raise RuntimeError(operator_error_message) + @staticmethod + @abstractmethod + def mlir_type() -> str: + raise NotImplementedError("No mlir type exists for a generic MLIRType") + + def __add__(self, other): + check_for_value(self) + return self.value + other + + def __radd__(self, other): + check_for_value(self) + return other + self.value + + def __sub__(self, other): + check_for_value(self) + return self.value - other + + def __rsub__(self, other): + check_for_value(self) + return other - self.value + + def __mul__(self, other): + check_for_value(self) + return self.value * other + + def __rmul__(self, other): + check_for_value(self) + return other * self.value + + def __rshift__(self, other): + check_for_value(self) + return self.value >> other + + def __rrshift__(self, other): + check_for_value(self) + return other >> self.value - def __sub__(self, other) -> Self: - raise RuntimeError(operator_error_message) + def __lshift__(self, other): + check_for_value(self) + return self.value << other - def __mul__(self, other) -> Self: - raise RuntimeError(operator_error_message) + def __rlshift__(self, other): + check_for_value(self) + return other << self.value class Secret(Generic[T], MLIRType): @@ -34,6 +112,10 @@ class Secret(Generic[T], MLIRType): def numba_type() -> NumbaType: raise NotImplementedError("No numba type exists for a generic Secret") + @staticmethod + def mlir_type() -> str: + raise NotImplementedError("No mlir type exists for a generic Secret") + class Tensor(Generic[*Ts], MLIRType): @@ -41,6 +123,10 @@ class Tensor(Generic[*Ts], MLIRType): def numba_type() -> NumbaType: raise NotImplementedError("No numba type exists for a generic Tensor") + @staticmethod + def mlir_type() -> str: + raise NotImplementedError("No mlir type exists for a generic Tensor") + class F32(MLIRType): # TODO (#1162): For CKKS/Float: allow specifying actual intended precision/scale and warn/error if not achievable @staticmethod @@ -48,6 +134,10 @@ class F32(MLIRType): def numba_type() -> NumbaType: return float32 + @staticmethod + def mlir_type() -> str: + return "f32" + class F64(MLIRType): # TODO (#1162): For CKKS/Float: allow specifying actual intended precision/scale and warn/error if not achievable @staticmethod @@ -55,6 +145,10 @@ class F64(MLIRType): def numba_type() -> NumbaType: return float64 + @staticmethod + def mlir_type() -> str: + return "f64" + class I1(MLIRType): @@ -62,6 +156,10 @@ class I1(MLIRType): def numba_type() -> NumbaType: return boolean + @staticmethod + def mlir_type() -> str: + return "i1" + class I8(MLIRType): @@ -69,6 +167,10 @@ class I8(MLIRType): def numba_type() -> NumbaType: return int8 + @staticmethod + def mlir_type() -> str: + return "i8" + class I16(MLIRType): @@ -76,6 +178,10 @@ class I16(MLIRType): def numba_type() -> NumbaType: return int16 + @staticmethod + def mlir_type() -> str: + return "i16" + class I32(MLIRType): @@ -83,6 +189,10 @@ class I32(MLIRType): def numba_type() -> NumbaType: return int32 + @staticmethod + def mlir_type() -> str: + return "i32" + class I64(MLIRType): @@ -90,6 +200,18 @@ class I64(MLIRType): def numba_type() -> NumbaType: return int64 + @staticmethod + def mlir_type() -> str: + return "i64" + + +# Register the types defined above with Numba +for typ in [I8, I16, I32, I64, I1, F32, F64]: + + @type_callable(typ) + def build_typer_function(context, typ=typ): + return lambda value: typ.numba_type() + # Helper functions diff --git a/frontend/heir/mlir_emitter.py b/frontend/heir/mlir_emitter.py index caedf3940a..9311d9cfe8 100644 --- a/frontend/heir/mlir_emitter.py +++ b/frontend/heir/mlir_emitter.py @@ -12,7 +12,8 @@ from numba.core import controlflow from numba.core.types import Type as NumbaType -from heir.interfaces import InternalCompilerError +from heir.mlir.types import MLIRType, MLIR_TYPES, I1, I8, I16, I32, I64, F32, F64 +from heir.interfaces import CompilerError, DebugMessage, InternalCompilerError def mlirType(numba_type: NumbaType) -> str: @@ -36,6 +37,73 @@ def mlirType(numba_type: NumbaType) -> str: raise InternalCompilerError("Unsupported type: " + str(numba_type)) +def isIntegerLike(typ: NumbaType | MLIRType) -> bool: + if isinstance(typ, type) and issubclass(typ, MLIRType): + return typ in {I1, I8, I16, I32, I64} + if isinstance(typ, NumbaType): + return isinstance(typ, types.Integer) or isinstance(typ, types.Boolean) + raise InternalCompilerError(f"Encountered unexpected type {typ}") + + +def isFloatLike(typ: NumbaType | MLIRType) -> bool: + if isinstance(typ, type) and issubclass(type, MLIRType): + return typ in {F32, F64} + if isinstance(typ, NumbaType): + return isinstance(typ, types.Float) + raise InternalCompilerError(f"Encountered unexpected type {typ}") + + +# Needed because, e.g. Boolean doesn't have a bitwidth +def getBitwidth(typ: NumbaType | MLIRType) -> int: + if isinstance(typ, type) and issubclass(typ, MLIRType): + if typ in {I1, I8, I16, I32, I64}: + # e.g., .__name__ -> "I32" -> "32" + return int(typ.__name__[1:]) + if isinstance(typ, types.Integer): + return typ.bitwidth + if isinstance(typ, types.Boolean): + return 1 + raise InternalCompilerError(f"unexpected type {typ} ({type(typ)})") + + +def mlirCastOp( + from_type: NumbaType, to_type: MLIRType, value: str, loc: ir.Loc +) -> str: + if isIntegerLike(from_type) and isIntegerLike(to_type): + from_width = getBitwidth(from_type) + to_width = getBitwidth(to_type.numba_type()) + if from_width == to_width: + raise CompilerError( + f"Cannot create cast of {value} from {from_type} to {to_type} as they" + " have the same bitwidth", + loc, + ) + if from_width > to_width: + return ( + f"arith.trunci {value} : {mlirType(from_type)} to" + f" {to_type.mlir_type()} {mlirLoc(loc)}" + ) + if from_width < to_width: + return ( + f"arith.extsi {value} : {mlirType(from_type)} to" + f" {to_type.mlir_type()} {mlirLoc(loc)}" + ) + if isFloatLike(from_type) and isIntegerLike(to_type): + return ( + f"arith.fptosi {value} : {mlirType(from_type)} to" + f" {mlirType(to_type)} {mlirLoc(loc)}" + ) + if isIntegerLike(from_type) and isFloatLike(to_type): + return ( + f"arith.sitofp {value} : {mlirType(from_type)} to" + f" {mlirType(to_type)} {mlirLoc(loc)}" + ) + raise CompilerError( + f"Encountered unsupported cast of {value} from {from_type} to {to_type}", + loc, + ) + + def mlirLoc(loc: ir.Loc) -> str: return ( f"loc(\"{loc.filename or ''}\":{loc.line or 0}:{loc.col or 0})" @@ -246,6 +314,9 @@ def __init__( self.return_types = (return_types,) self.temp_var_id = 0 self.numba_names_to_ssa_var_names = {} + # The globals_map maps the numba-assigned name for a global (e.g. '$4load_global.0') + # to a tuple of (name, value) where name is the "pretty" name (e.g., 'foo') + # and value is the actual Python object referenced (the underlying function/module/class/object/etc) self.globals_map = {} self.loops = {} self.cfa = self.get_control_flow() @@ -419,12 +490,39 @@ def emit_assign(self, assign): func = assign.value.func # if assert fails, variable was undefined assert func.name in self.globals_map - if self.globals_map[func.name] == "bool": + name, global_ = self.globals_map[func.name] + if name == "bool": # nothing to do, forward the name to the arg of bool() self.forward_name(from_var=assign.target, to_var=assign.value.args[0]) return "" + if global_ in MLIR_TYPES: + if len(assign.value.args) != 1: + raise CompilerError( + "MLIR type cast requires exactly one argument", assign.value.loc + ) + value = assign.value.args[0].name + if ( + mlirType(self.typemap.get(assign.target.name)) + != global_.mlir_type() + ): + raise InternalCompilerError( + f"MLIR type cast of {value} from" + f" {mlirType(self.typemap.get(value))} to" + f" {global_.mlir_type()} is not correctly reflected in types" + " inferred for the assignment, which expects" + f" {mlirType(self.typemap.get(assign.target.name))}" + ) + target_ssa = self.get_or_create_name(assign.target) + ssa_id = self.get_or_create_name(assign.value.args[0]) + cast = mlirCastOp( + self.typemap.get(value), + global_, + ssa_id, + assign.loc, + ) + return f"{target_ssa} = {cast}" else: - raise InternalCompilerError("Unknown global " + func.name) + raise InternalCompilerError("Call to unknown function " + name) case ir.Expr(op="cast"): # not sure what to do here. maybe will be needed for type conversions # when interfacing with C @@ -446,7 +544,10 @@ def emit_assign(self, assign): self.forward_name_to_id(assign.target, name.strip("%")) return const_str case ir.Global(): - self.globals_map[assign.target.name] = assign.value.name + self.globals_map[assign.target.name] = ( + assign.value.name, + assign.value.value, + ) return "" case ir.Var(): # Sometimes we need this to be assigned? @@ -463,24 +564,26 @@ def emit_ext_if_needed(self, lhs, rhs): return self.get_name(lhs), self.get_name(rhs), "", lhs_type # types aren't integer types - if not isinstance(lhs_type, types.Integer) or not isinstance( - rhs_type, types.Integer - ): + if not isIntegerLike(lhs_type) or not isIntegerLike(rhs_type): raise InternalCompilerError( "Extension handling for non-integer (e.g., floats, tensors) types" " is not yet supported. Please ensure (inferred) bit-widths match." + f" Failed to extend {lhs_type} and {rhs_type} types." ) # TODO (#1162): Support bitwidth extension for float types # (this probably requires adding support for local variable type hints, # such as `b : F16 = 1.0` as there is no clear "natural" bitwidth for literals) # TODO (#1162): Support bitwidth extension for non-scalar types (e.g., tensors) - if lhs_type.bitwidth == rhs_type.bitwidth: + lhs_bitwidth = getBitwidth(lhs_type) + rhs_bitwidth = getBitwidth(rhs_type) + + if lhs_bitwidth == rhs_bitwidth: return self.get_name(lhs), self.get_name(rhs), "", lhs_type # time to emit some extensions! short, long = lhs, rhs - if lhs_type.bitwidth > rhs_type.bitwidth: + if lhs_bitwidth > rhs_bitwidth: short, long = rhs, lhs tmp = self.get_next_name() @@ -491,7 +594,7 @@ def emit_ext_if_needed(self, lhs, rhs): f"{mlirLoc(short.loc)}\n" ) - if lhs_type.bitwidth > rhs_type.bitwidth: + if lhs_bitwidth > rhs_bitwidth: return self.get_name(lhs), tmp, ext, lhs_type return tmp, self.get_name(rhs), ext, rhs_type @@ -518,6 +621,10 @@ def emit_binop(self, binop): return f"arith.sub{suffix} {lhs_ssa}, {rhs_ssa}", ext, ty case operator.lshift: return f"arith.shl{suffix} {lhs_ssa}, {rhs_ssa}", ext, ty + case operator.rshift: + # Used signed semantics when integer types + suffix = "si" if suffix == "i" else suffix + return f"arith.shr{suffix} {lhs_ssa}, {rhs_ssa}", ext, ty case operator.and_: return f"arith.and{suffix} {lhs_ssa}, {rhs_ssa}", ext, ty case operator.xor: diff --git a/frontend/heir/numba_nbep1_reverter/old_builtins.py b/frontend/heir/numba_nbep1_reverter/old_builtins.py index f6c7cacd3d..2a87e4ea5b 100644 --- a/frontend/heir/numba_nbep1_reverter/old_builtins.py +++ b/frontend/heir/numba_nbep1_reverter/old_builtins.py @@ -1,8 +1,9 @@ """This file is a near-verbatim copy of numba.core.typing.old_builtins.py, i.e., numba.core.typing.builtins if config.USE_LEGACY_TYPE_SYSTEM is set -with only one change: we override integer_binop_cases to stop numba -from upcasting, e.g., int8 + int8 to int64 or int32 (intp in numba). +with two changes: we override (1) integer_binop_cases and +(2) BitwiseShiftOperation to stop numba from upcasting, +e.g., int8 + int8 to int64 or int32 (intp in numba) Numba decided to do this for simplicity, see the explanation in NBEP1: https://numba.readthedocs.io/en/stable/proposals/integer-typing.html @@ -332,22 +333,15 @@ class PowerBuiltin(BinOpPower): class BitwiseShiftOperation(ConcreteTemplate): - # For bitshifts, only the first operand's signedness matters - # to choose the operation's signedness (the second operand - # should always be positive but will generally be considered - # signed anyway, since it's often a constant integer). - # (also, see issue #1995 for right-shifts) - - # The RHS type is fixed to 64-bit signed/unsigned ints. - # The implementation will always cast the operands to the width of the - # result type, which is the widest between the LHS type and (u)intp. + # For bitshifts, only the first operand's type matters for the result type. + # The result type should always match the LHS type (op), not upcast. cases = [ - signature(max(op, types.intp), op, op2) + signature(op, op, op2) for op in sorted(types.signed_domain) for op2 in [types.uint64, types.int64] ] cases += [ - signature(max(op, types.uintp), op, op2) + signature(op, op, op2) for op in sorted(types.unsigned_domain) for op2 in [types.uint64, types.int64] ] diff --git a/frontend/heir/pipeline.py b/frontend/heir/pipeline.py index 1333fca029..6930ffb3d4 100644 --- a/frontend/heir/pipeline.py +++ b/frontend/heir/pipeline.py @@ -190,7 +190,9 @@ def run_pipeline( ) # Run backend (which will call heir_translate and other tools, e.g., clang, as needed) - if "--mlir-to-cggi" in heir_opt_options: + if "--mlir-to-cggi" in heir_opt_options and not isinstance( + backend, CleartextBackend + ): raise NotImplementedError( "Backend compilation is unsupported for CGGI scheme, check CGGI" f" output at {mlirpath}" diff --git a/frontend/testing.bzl b/frontend/testing.bzl index 367eaf8dcc..b3b2d07957 100644 --- a/frontend/testing.bzl +++ b/frontend/testing.bzl @@ -28,6 +28,7 @@ def frontend_test(name, srcs, deps = [], data = [], tags = []): deps = deps + [ ":frontend", "@com_google_absl_py//absl/testing:absltest", + "@edu_berkeley_abc//:abc", ], imports = ["."], data = data, @@ -42,6 +43,8 @@ def frontend_test(name, srcs, deps = [], data = [], tags = []): "HEIR_OPT_PATH": "tools/heir-opt", "HEIR_TRANSLATE_PATH": "tools/heir-translate", "PYBIND11_INCLUDE_PATH": "pybind11/include", + "HEIR_YOSYS_SCRIPTS_DIR": "lib/Transforms/YosysOptimizer/yosys", + "HEIR_ABC_BINARY": "$(rootpath @edu_berkeley_abc//:abc)", "NUMBA_USE_LEGACY_TYPE_SYSTEM": "1", }, )