diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index 9a13644401..b0ff57e63e 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -15,6 +15,7 @@ add_subdirectory(CGGI) add_subdirectory(CKKS) add_subdirectory(Comb) add_subdirectory(Jaxite) +add_subdirectory(JaxiteWord) add_subdirectory(LinAlg) add_subdirectory(LWE) add_subdirectory(ModArith) diff --git a/lib/Dialect/JaxiteWord/CMakeLists.txt b/lib/Dialect/JaxiteWord/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/lib/Dialect/JaxiteWord/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/lib/Dialect/JaxiteWord/IR/BUILD b/lib/Dialect/JaxiteWord/IR/BUILD new file mode 100644 index 0000000000..9dd4cd2ede --- /dev/null +++ b/lib/Dialect/JaxiteWord/IR/BUILD @@ -0,0 +1,74 @@ +# JaxiteWord, an exit dialect to JaxiteWord API + +load("@heir//lib/Dialect:dialect.bzl", "add_heir_dialect_library") +load("@llvm-project//mlir:tblgen.bzl", "td_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Dialect", + srcs = ["JaxiteWordDialect.cpp"], + hdrs = [ + "JaxiteWordDialect.h", + "JaxiteWordOps.h", + "JaxiteWordTypes.h", + ], + deps = [ + ":dialect_inc_gen", + ":ops_inc_gen", + ":types_inc_gen", + "@heir//lib/Dialect/LWE/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Support", + ], +) + +td_library( + name = "td_files", + srcs = [ + "JaxiteWordDialect.td", + "JaxiteWordOps.td", + "JaxiteWordTypes.td", + ], + deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +add_heir_dialect_library( + name = "dialect_inc_gen", + dialect = "JaxiteWord", + kind = "dialect", + td_file = "JaxiteWordDialect.td", + deps = [ + ":td_files", + ], +) + +add_heir_dialect_library( + name = "types_inc_gen", + dialect = "JaxiteWord", + kind = "type", + td_file = "JaxiteWordTypes.td", + deps = [ + ":td_files", + ], +) + +add_heir_dialect_library( + name = "ops_inc_gen", + dialect = "JaxiteWord", + kind = "op", + td_file = "JaxiteWordOps.td", + deps = [ + ":td_files", + "@heir//lib/Dialect/LWE/IR:td_files", + ], +) diff --git a/lib/Dialect/JaxiteWord/IR/CMakeLists.txt b/lib/Dialect/JaxiteWord/IR/CMakeLists.txt new file mode 100644 index 0000000000..6e1168085a --- /dev/null +++ b/lib/Dialect/JaxiteWord/IR/CMakeLists.txt @@ -0,0 +1,8 @@ +add_heir_dialect(JaxiteWord jaxiteword) + +add_mlir_dialect_library(HEIRJaxiteWord + JaxiteWordDialect.cpp + + DEPENDS + HEIRJaxiteWordIncGen +) diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp new file mode 100644 index 0000000000..0da3ab19a0 --- /dev/null +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp @@ -0,0 +1,42 @@ +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" + +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp.inc" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h" +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project + +#define GET_TYPEDEF_CLASSES +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.cpp.inc" +#define GET_OP_CLASSES +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.cpp.inc" + +namespace mlir { +namespace heir { +namespace jaxiteword { + +void JaxiteWordDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.cpp.inc" + >(); +} + +LogicalResult AddOp::verify() { + if (getModulusList().getType().getModulusList().size() != + getValueA().getType().getTowers()) { + return emitOpError() << "Number of Towers of moudlus should match the " + "number of towers/limbs"; + } + return success(); +} + +} // namespace jaxiteword +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h new file mode 100644 index 0000000000..1d54a31e8b --- /dev/null +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h @@ -0,0 +1,13 @@ +#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_ +#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_ + +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project + +// Generated headers (block clang-format from messing up order) +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h.inc" + +#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_ diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td new file mode 100644 index 0000000000..5837a68a27 --- /dev/null +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td @@ -0,0 +1,22 @@ +#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_ +#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_ + +include "mlir/IR/DialectBase.td" +include "mlir/IR/OpBase.td" + +def JaxiteWord_Dialect : Dialect { + let name = "jaxiteword"; + + let description = [{ + The `jaxiteword` dialect is an exit dialect for generating py code against the jaxiteword library API, + using the jaxiteword parameters and encoding scheme. + + See https://github.com/google/jaxite/jaxite_word + }]; + + let cppNamespace = "::mlir::heir::jaxiteword"; + + let useDefaultTypePrinterParser = 1; +} + +#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_ diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h b/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h new file mode 100644 index 0000000000..e0bd601740 --- /dev/null +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h @@ -0,0 +1,13 @@ +#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDOPS_H_ +#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDOPS_H_ + +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project + +#define GET_OP_CLASSES +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h.inc" + +#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDOPS_H_ diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td b/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td new file mode 100644 index 0000000000..566da51f08 --- /dev/null +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td @@ -0,0 +1,38 @@ +#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDOPS_TD_ +#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDOPS_TD_ + +include "JaxiteWordDialect.td" +include "JaxiteWordTypes.td" + +include "lib/Dialect/LWE/IR/LWETypes.td" + +include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/CommonTypeConstraints.td" +include "mlir/IR/CommonAttrConstraints.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + + +class JaxiteWord_Op traits = []> : + Op { + let assemblyFormat = [{ + operands attr-dict `:` functional-type(operands, results) + }]; + let cppNamespace = "::mlir::heir::jaxiteword"; +} + + +def AddOp : JaxiteWord_Op<"add", [AllTypesMatch<["value_a", "value_b", "result"]>,Commutative,Pure]> { + let description = [{ + The operation computed by this function is homomorphic addition. + }]; + let arguments = (ins JaxiteWord_Ciphertext:$value_a, + JaxiteWord_Ciphertext:$value_b, + JaxiteWord_ModulusList:$modulus_list + ); + let results = (outs JaxiteWord_Ciphertext:$result); + let hasVerifier = 1; +} + +#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDOPS_TD_ diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h b/lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h new file mode 100644 index 0000000000..624ea7be6a --- /dev/null +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h @@ -0,0 +1,13 @@ +#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDTYPES_H_ +#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDTYPES_H_ + +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +// #include "mlir/include/mlir/IR/TypeSupport.h" // from @llvm-project +// #include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project + +// Include LLVM ADT definitions. + +#define GET_TYPEDEF_CLASSES +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h.inc" + +#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDTYPES_H_ diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.td b/lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.td new file mode 100644 index 0000000000..bcb8a1d70f --- /dev/null +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.td @@ -0,0 +1,38 @@ +#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDTYPES_TD_ +#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDTYPES_TD_ + +include "JaxiteWordDialect.td" + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/CommonTypeConstraints.td" +include "mlir/IR/DialectBase.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" + +class JaxiteWord_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; + let assemblyFormat = "`<` struct(params) `>`"; // print out all information in the arguments +} + +def JaxiteWord_ModulusList : JaxiteWord_Type<"ModulusList", "modulus_list"> { + let description = [{ + A list of modulus values. + }]; + let parameters = (ins ArrayRefParameter<"::mlir::Type">:$modulus_list); + let assemblyFormat = "`<` $modulus_list `>`"; + // jaxiteword.modulus_list +} + +def JaxiteWord_Ciphertext : JaxiteWord_Type<"Ciphertext", "ciphertext"> { + let description = [{ + A ciphertext - a three dimensional array. + }]; + let parameters = (ins "int":$polys, + "int":$towers, + "int":$degrees); + let assemblyFormat = "`<` $polys `,` $towers `,` $degrees `>`"; + // jaxiteword.ciphertext<3, 4, 5> +} + +#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDTYPES_TD_ diff --git a/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/BUILD b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/BUILD new file mode 100644 index 0000000000..d6b6ec7411 --- /dev/null +++ b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/BUILD @@ -0,0 +1,37 @@ +load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "LWEToJaxiteWord", + srcs = ["LWEToJaxiteWord.cpp"], + hdrs = [ + "LWEToJaxiteWord.h", + ], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect/BGV/IR:Dialect", + "@heir//lib/Dialect/CKKS/IR:Dialect", + "@heir//lib/Dialect/JaxiteWord/IR:Dialect", + "@heir//lib/Dialect/LWE/IR:Dialect", + "@heir//lib/Utils", + "@heir//lib/Utils:ConversionUtils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + ], +) + +add_heir_transforms( + header_filename = "LWEToJaxiteWord.h.inc", + pass_name = "LWEToJaxiteWord", + td_file = "LWEToJaxiteWord.td", +) diff --git a/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/CMakeLists.txt b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/CMakeLists.txt new file mode 100644 index 0000000000..da7acfe136 --- /dev/null +++ b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_conversion_library(HEIRLWEToJaxiteWord + LWEToJaxiteWord.cpp + + LINK_LIBS PUBLIC + HEIRConversionUtils + HEIRLWE + HEIRLWE + HEIRJaxiteWord + + MLIRIR + MLIRPass + MLIRInferTypeOpInterface + MLIRArithDialect + MLIRFuncDialect + LLVMSupport + MLIRSupport + MLIRDialect + MLIRTransformUtils + MLIRIR +) diff --git a/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.cpp b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.cpp new file mode 100644 index 0000000000..3180ba00fb --- /dev/null +++ b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.cpp @@ -0,0 +1,394 @@ +#include "lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.h" + +#include +#include + +#include "lib/Dialect/BGV/IR/BGVDialect.h" +#include "lib/Dialect/BGV/IR/BGVOps.h" +#include "lib/Dialect/CKKS/IR/CKKSDialect.h" +#include "lib/Dialect/CKKS/IR/CKKSOps.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h" +#include "lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.h" +#include "lib/Dialect/LWE/IR/LWEAttributes.h" +#include "lib/Dialect/LWE/IR/LWEDialect.h" +#include "lib/Dialect/LWE/IR/LWEOps.h" +#include "lib/Dialect/LWE/IR/LWETypes.h" +#include "lib/Utils/ConversionUtils.h" +#include "lib/Utils/Utils.h" +#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::heir::lwe { + +#define GEN_PASS_DEF_LWETOJAXITEWORD +#include "lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.h.inc" + +// ToJaxiteWordTypeConverter::ToJaxiteWordTypeConverter(MLIRContext *ctx) { +// addConversion([](Type type) { return type; }); +// addConversion([ctx](lwe::RLWEPublicKeyType type) -> Type { +// return jaxiteword::PublicKeyType::get(ctx); +// }); +// addConversion([ctx](lwe::RLWESecretKeyType type) -> Type { +// return jaxiteword::PrivateKeyType::get(ctx); +// }); +// addConversion([ctx](lwe::NewLWEPublicKeyType type) -> Type { +// return jaxiteword::PublicKeyType::get(ctx); +// }); +// addConversion([ctx](lwe::NewLWESecretKeyType type) -> Type { +// return jaxiteword::PrivateKeyType::get(ctx); +// }); +// } + +// FailureOr getContextualCryptoContext(Operation *op) { +// auto result = getContextualArgFromFunc(op); +// if (failed(result)) { +// return op->emitOpError() +// << "Found LWE op in a function without a public key argument." +// " Did the AddCryptoContextArg pattern fail to run?"; +// } +// return result.value(); +// } + +namespace { +// NOTE: we can not use containsDialect +// for FuncOp declaration, which does not have a body +template +bool containsArgumentOfDialect(func::FuncOp funcOp) { + return llvm::any_of(funcOp.getArgumentTypes(), [&](Type argType) { + return DialectEqual()(&argType.getDialect()); + }); +} + +// struct AddCryptoContextArg : public OpConversionPattern { +// AddCryptoContextArg(mlir::MLIRContext *context) +// : OpConversionPattern(context, /* benefit= */ 2) {} + +// using OpConversionPattern::OpConversionPattern; + +// LogicalResult matchAndRewrite( +// func::FuncOp op, OpAdaptor adaptor, +// ConversionPatternRewriter &rewriter) const override { +// auto containsCryptoOps = +// containsDialects( +// op); +// auto containsCryptoArg = +// containsArgumentOfDialect(op); +// if (!(containsCryptoOps || containsCryptoArg)) { +// return failure(); +// } + +// auto cryptoContextType = +// jaxiteword::CryptoContextType::get(getContext()); +// rewriter.modifyOpInPlace(op, [&] { +// if (op.isDeclaration()) { +// auto newFuncType = op.getTypeWithArgsAndResults( +// ArrayRef{0}, ArrayRef{cryptoContextType}, {}, +// {}); +// op.setType(newFuncType); +// } else { +// op.insertArgument(0, cryptoContextType, nullptr, op.getLoc()); +// } +// }); + +// return success(); +// } +// }; + +struct ConvertFuncCallOp : public OpConversionPattern { + ConvertFuncCallOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + func::CallOp op, typename func::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = getContextualCryptoContext(op.getOperation()); + if (failed(result)) return result; + auto cryptoContext = result.value(); + + auto callee = op.getCallee(); + auto operands = adaptor.getOperands(); + auto resultTypes = op.getResultTypes(); + + SmallVector newOperands; + newOperands.push_back(cryptoContext); + for (auto operand : operands) { + newOperands.push_back(operand); + } + + rewriter.replaceOpWithNewOp(op, callee, resultTypes, + newOperands); + return success(); + } +}; + +// struct ConvertEncryptOp : public OpConversionPattern { +// ConvertEncryptOp(mlir::MLIRContext *context) +// : OpConversionPattern(context) {} + +// using OpConversionPattern::OpConversionPattern; + +// LogicalResult matchAndRewrite( +// lwe::RLWEEncryptOp op, typename lwe::RLWEEncryptOp::Adaptor adaptor, +// ConversionPatternRewriter &rewriter) const override { +// FailureOr result = getContextualCryptoContext(op.getOperation()); +// if (failed(result)) return result; + +// auto keyType = dyn_cast(op.getKey().getType()); +// if (!keyType) +// return op.emitError() +// << "OpenFHE only supports public key encryption for LWE."; + +// Value cryptoContext = result.value(); +// rewriter.replaceOp(op, +// rewriter.create( +// op.getLoc(), op.getOutput().getType(), +// cryptoContext, adaptor.getInput(), +// adaptor.getKey())); +// return success(); +// } +// }; + +// struct ConvertDecryptOp : public OpConversionPattern { +// ConvertDecryptOp(mlir::MLIRContext *context) +// : OpConversionPattern(context) {} + +// using OpConversionPattern::OpConversionPattern; + +// LogicalResult matchAndRewrite( +// RLWEDecryptOp op, RLWEDecryptOp::Adaptor adaptor, +// ConversionPatternRewriter &rewriter) const override { +// FailureOr result = getContextualCryptoContext(op.getOperation()); +// if (failed(result)) return result; + +// Value cryptoContext = result.value(); +// rewriter.replaceOp(op, +// rewriter.create( +// op.getLoc(), op.getOutput().getType(), +// cryptoContext, adaptor.getInput(), +// adaptor.getSecretKey())); +// return success(); +// } +// }; + +// struct ConvertEncodeOp : public OpConversionPattern { +// explicit ConvertEncodeOp(const mlir::TypeConverter &typeConverter, +// mlir::MLIRContext *context) +// : mlir::OpConversionPattern(typeConverter, context) +// {} + +// // OpenFHE has a convention that all inputs to MakePackedPlaintext are +// // std::vector, so we need to cast the input to that type. +// LogicalResult matchAndRewrite( +// lwe::RLWEEncodeOp op, OpAdaptor adaptor, +// ConversionPatternRewriter &rewriter) const override { +// FailureOr result = getContextualCryptoContext(op.getOperation()); +// if (failed(result)) return result; +// Value cryptoContext = result.value(); + +// Value input = adaptor.getInput(); +// auto elementTy = getElementTypeOrSelf(input.getType()); + +// auto tensorTy = mlir::dyn_cast(input.getType()); +// // Replicate scalar inputs into a splat tensor with shape matching +// // the ring dimension. +// if (!tensorTy) { +// auto ringDegree = +// op.getRing().getPolynomialModulus().getPolynomial().getDegree(); +// tensor::SplatOp splat = rewriter.create( +// op.getLoc(), RankedTensorType::get({ringDegree}, elementTy), +// input); +// input = splat.getResult(); +// tensorTy = splat.getType(); +// } + +// // Cast inputs to the correct types for OpenFHE API. +// if (auto intTy = mlir::dyn_cast(elementTy)) { +// if (intTy.getWidth() > 64) +// return op.emitError() << "No supported packing technique for integers +// " +// "bigger than 64 bits."; + +// if (intTy.getWidth() < 64) { +// // OpenFHE has a convention that all inputs to MakePackedPlaintext +// are +// // std::vector, so we need to cast the input to that type. +// auto int64Ty = rewriter.getIntegerType(64); +// auto newTensorTy = RankedTensorType::get(tensorTy.getShape(), +// int64Ty); input = +// rewriter.create(op.getLoc(), newTensorTy, input); +// } +// } else { +// auto floatTy = cast(elementTy); +// if (floatTy.getWidth() > 64) +// return op.emitError() << "No supported packing technique for floats " +// "bigger than 64 bits."; + +// if (floatTy.getWidth() < 64) { +// // OpenFHE has a convention that all inputs to +// MakeCKKSPackedPlaintext +// // are std::vector, so we need to cast the input to that +// type. auto f64Ty = rewriter.getF64Type(); auto newTensorTy = +// RankedTensorType::get(tensorTy.getShape(), f64Ty); input = +// rewriter.create(op.getLoc(), newTensorTy, input); +// } +// } + +// lwe::NewLWEPlaintextType plaintextType = lwe::NewLWEPlaintextType::get( +// op.getContext(), +// lwe::ApplicationDataAttr::get(adaptor.getInput().getType(), +// lwe::NoOverflowAttr::get(getContext())), +// lwe::PlaintextSpaceAttr::get(getContext(), op.getRing(), +// op.getEncoding())); + +// return llvm::TypeSwitch(op.getEncoding()) +// .Case([&](auto encoding) { +// rewriter.replaceOpWithNewOp( +// op, plaintextType, cryptoContext, input); +// return success(); +// }) +// .Case([&](auto encoding) { +// // TODO (#1192): support coefficient packing in +// `--lwe-to-jaxiteword` op.emitError() << "HEIR does not yet support +// coefficient encoding " +// " when targeting OpenFHE"; +// return failure(); +// }) +// .Case([&](auto encoding) { +// rewriter.replaceOpWithNewOp( +// op, plaintextType, cryptoContext, input); +// return success(); +// }) +// .Default([&](Attribute) -> LogicalResult { +// // encoding isn't support explicitly: +// op.emitError( +// "Unexpected encoding while targeting OpenFHE. " +// "If you expect this type of encoding to be supported " +// "for the OpenFHE backend, please file a bug report."); +// return failure(); +// }); +// } +// }; + +// struct ConvertBootstrapOp : public OpConversionPattern { +// ConvertBootstrapOp(mlir::MLIRContext *context) +// : OpConversionPattern(context) {} + +// using OpConversionPattern::OpConversionPattern; + +// LogicalResult matchAndRewrite( +// ckks::BootstrapOp op, ckks::BootstrapOp::Adaptor adaptor, +// ConversionPatternRewriter &rewriter) const override { +// FailureOr result = getContextualCryptoContext(op.getOperation()); +// if (failed(result)) return result; + +// Value cryptoContext = result.value(); +// rewriter.replaceOpWithNewOp( +// op, op.getOutput().getType(), cryptoContext, adaptor.getInput()); +// return success(); +// } +// }; +} // namespace + +struct LWEToJaxiteWord : public impl::LWEToJaxiteWordBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto *module = getOperation(); + ToJaxiteWordTypeConverter typeConverter(context); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addIllegalDialect(); + target.addIllegalDialect(); + target.addIllegalDialect(); + // We can keep the following ops, which the emitter can handle directly + target.addLegalOp(); + + RewritePatternSet patterns(context); + addStructuralConversionPatterns(typeConverter, patterns, target); + + patterns.add< + ///////////////////// + // LWE Op Patterns // + ///////////////////// + + // // Update Func Op Signature + // AddCryptoContextArg, + + // // Update Func CallOp Signature + // ConvertFuncCallOp, + + // // Handle LWE encode and en/decrypt + // // Note: `lwe.decode` is handled directly by the OpenFHE emitter + // ConvertEncodeOp, ConvertEncryptOp, ConvertDecryptOp, + + // Scheme-agnostic RLWE Arithmetic Ops: + ConvertLWEBinOp + // ConvertLWEBinOp, + // ConvertLWEBinOp, + // ConvertUnaryOp, + + // /////////////////////////////////// + // // Scheme-Specific Op Patterns // + // /////////////////////////////////// + // // The Add/(Sub)/Mul-Plain ops are not really scheme-specific, + // // but do not currently have an analogue in the LWE dialect. + // // TODO (#1193): Extend "common lwe" to support ctxt-ptxt ops + + // // AddPlain + // ConvertCiphertextPlaintextOp, + // ConvertCiphertextPlaintextOp, + + // // SubPlain + // ConvertCiphertextPlaintextOp, + // ConvertCiphertextPlaintextOp, + + // // MulPlain + // ConvertCiphertextPlaintextOp, + // ConvertCiphertextPlaintextOp, + + // // Rotate + // ConvertRotateOp, + // ConvertRotateOp, + // // Relin + // ConvertRelinOp, + // ConvertRelinOp, + // // Modulus Switch (BGV only) + // lwe::ConvertModulusSwitchOp, + // // Rescale (CKKS version of Modulus Switch) + // lwe::ConvertModulusSwitchOp, + // // Bootstrap (CKKS only) + // ConvertBootstrapOp + // End of Pattern List + >(typeConverter, context); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace mlir::heir::lwe diff --git a/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.h b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.h new file mode 100644 index 0000000000..f77d3dc6c9 --- /dev/null +++ b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.h @@ -0,0 +1,90 @@ +#ifndef LIB_DIALECT_LWE_CONVERSIONS_LWETOJAXITEWORD_LWETOJAXITEWORD_H_ +#define LIB_DIALECT_LWE_CONVERSIONS_LWETOJAXITEWORD_LWETOJAXITEWORD_H_ + +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h" +#include "lib/Dialect/LWE/IR/LWEOps.h" +#include "lib/Utils/ConversionUtils.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::heir::lwe { + +#define GEN_PASS_DECL +#include "lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.h.inc" + +class ToJaxiteWordTypeConverter : public TypeConverter { + public: + ToJaxiteWordTypeConverter(MLIRContext *ctx); +}; + +FailureOr getContextualCryptoContext(Operation *op); + +template +struct ConvertUnaryOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + UnaryOp op, typename UnaryOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = getContextualCryptoContext(op.getOperation()); + if (failed(result)) return result; + + Value cryptoContext = result.value(); + rewriter.replaceOp(op, rewriter.create( + op.getLoc(), cryptoContext, adaptor.getInput())); + return success(); + } +}; + +template +struct ConvertLWEBinOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + BinOp op, typename BinOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = getContextualCryptoContext(op.getOperation()); + if (failed(result)) return result; + + Value cryptoContext = result.value(); + rewriter.replaceOpWithNewOp(op, op.getOutput().getType(), + cryptoContext, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } +}; + +template +struct ConvertCiphertextPlaintextOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + BinOp op, typename BinOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = getContextualCryptoContext(op.getOperation()); + if (failed(result)) return result; + + Value cryptoContext = result.value(); + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType(), cryptoContext, + adaptor.getCiphertextInput(), adaptor.getPlaintextInput()); + return success(); + } +}; + +inline bool checkRelinToBasis(llvm::ArrayRef toBasis) { + if (toBasis.size() != 2) return false; + return toBasis[0] == 0 && toBasis[1] == 1; +} + +} // namespace mlir::heir::lwe + +#endif // LIB_DIALECT_LWE_CONVERSIONS_LWETOJAXITEWORD_LWETOJAXITEWORD_H_ diff --git a/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.td b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.td new file mode 100644 index 0000000000..b1548968c3 --- /dev/null +++ b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.td @@ -0,0 +1,22 @@ +#ifndef LIB_DIALECT_LWE_CONVERSIONS_LWETOJAXITEWORD_LWETOJAXITEWORD_TD_ +#define LIB_DIALECT_LWE_CONVERSIONS_LWETOJAXITEWORD_LWETOJAXITEWORD_TD_ + +include "mlir/Pass/PassBase.td" + +def LWEToJaxiteWord : Pass<"lwe-to-jaxiteword"> { + let summary = "Lower `lwe` to `jaxiteword` dialect."; + + let description = [{ + This pass lowers the `lwe` dialect to `JaxiteWord` dialect. + Currently, this also includes patterns that apply directly to `ckks` and `bgv` dialect operations. + TODO (#1193): investigate if the need for `ckks/bgv` patterns in `--lwe-to-jaxiteword` is permanent. + }]; + + let dependentDialects = [ + "mlir::heir::lwe::LWEDialect", + "mlir::heir::jaxiteword::JaxiteWordDialect", + "mlir::tensor::TensorDialect", + ]; +} + +#endif // LIB_DIALECT_LWE_CONVERSIONS_LWETOJAXITEWORD_LWETOJAXITEWORD_TD_ diff --git a/lib/Target/CMakeLists.txt b/lib/Target/CMakeLists.txt index 3185777dec..6be940bd79 100644 --- a/lib/Target/CMakeLists.txt +++ b/lib/Target/CMakeLists.txt @@ -2,6 +2,7 @@ add_library(HEIRTarget INTERFACE) add_subdirectory(Jaxite) +add_subdirectory(JaxiteWord) add_subdirectory(Metadata) add_subdirectory(OpenFhePke) add_subdirectory(TfheRust) diff --git a/lib/Target/JaxiteWord/BUILD b/lib/Target/JaxiteWord/BUILD new file mode 100644 index 0000000000..b678090fbd --- /dev/null +++ b/lib/Target/JaxiteWord/BUILD @@ -0,0 +1,29 @@ +# JaxiteWord Emitter + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "JaxiteWordEmitter", + srcs = ["JaxiteWordEmitter.cpp"], + hdrs = [ + "JaxiteWordEmitter.h", + "JaxiteWordTemplates.h", + ], + deps = [ + "@heir//lib/Analysis/SelectVariableNames", + "@heir//lib/Dialect/JaxiteWord/IR:Dialect", + "@heir//lib/Dialect/LWE/IR:Dialect", + "@heir//lib/Utils:TargetUtils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TranslateLib", + ], +) diff --git a/lib/Target/JaxiteWord/CMakeLists.txt b/lib/Target/JaxiteWord/CMakeLists.txt new file mode 100644 index 0000000000..9e8c3e2b33 --- /dev/null +++ b/lib/Target/JaxiteWord/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_library(HEIRJaxiteWordEmitter + JaxiteWordEmitter.cpp + + LINK_LIBS PUBLIC + HEIRSelectVariableNames + HEIRJaxiteWord + HEIRLWE + HEIRTargetUtils + LLVMSupport + MLIRArithDialect + MLIRAffineDialect + MLIRFuncDialect + MLIRIR + MLIRMemRefDialect + MLIRSupport + MLIRTensorDialect + MLIRTranslateLib +) +target_link_libraries(HEIRTarget INTERFACE HEIRJaxiteWordEmitter) diff --git a/lib/Target/JaxiteWord/JaxiteWordEmitter.cpp b/lib/Target/JaxiteWord/JaxiteWordEmitter.cpp new file mode 100644 index 0000000000..f45b2b8aa0 --- /dev/null +++ b/lib/Target/JaxiteWord/JaxiteWordEmitter.cpp @@ -0,0 +1,341 @@ +#include "lib/Target/JaxiteWord/JaxiteWordEmitter.h" + +#include +#include +#include +#include +#include + +#include "lib/Analysis/SelectVariableNames/SelectVariableNames.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h" +#include "lib/Dialect/LWE/IR/LWEDialect.h" +#include "lib/Dialect/LWE/IR/LWETypes.h" +#include "lib/Target/JaxiteWord/JaxiteWordTemplates.h" +#include "lib/Utils/TargetUtils.h" +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace jaxiteword { + +void registerToJaxiteWordTranslation() { + TranslateFromMLIRRegistration reg( + "emit-jaxiteword", + "translate the JaxiteWord dialect to python code for jaxiteword", + [](Operation *op, llvm::raw_ostream &output) { + return translateToJaxiteWord(op, output); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }); +} + +LogicalResult translateToJaxiteWord(Operation *op, llvm::raw_ostream &os) { + SelectVariableNames variableNames(op); + JaxiteWordEmitter emitter(os, &variableNames); + return emitter.translate(*op); +} + +LogicalResult JaxiteWordEmitter::translate(Operation &op) { + LogicalResult status = + llvm::TypeSwitch(op) + // Builtin ops + .Case([&](auto op) { return printOperation(op); }) + // Func ops + .Case( + [&](auto op) { return printOperation(op); }) + // JaxiteWord ops + .Case([&](auto op) { return printOperation(op); }) + // Tensor ops + .Case( + [&](auto op) { return printOperation(op); }) + // Memref ops + .Case( + [&](auto op) { return printOperation(op); }) + // Arith ops + .Case([&](auto op) { return success(); }) + .Default([&](Operation &) { + return op.emitOpError("unable to find printer for op"); + }); + + if (failed(status)) { + op.emitOpError(llvm::formatv("Failed to translate op {0}", op.getName())); + return failure(); + } + return success(); +} + +LogicalResult JaxiteWordEmitter::printOperation(ModuleOp moduleOp) { + os << kModulePrelude << "\n"; + for (Operation &op : moduleOp) { + if (failed(translate(op))) { + return failure(); + } + } + return success(); +} + +LogicalResult JaxiteWordEmitter::printOperation(func::FuncOp funcOp) { + os << "def " << funcOp.getName() << "(\n"; + os.indent(); + for (Value arg : funcOp.getArguments()) { + auto argName = variableNames->getNameForValue(arg); + os << argName << ": "; + if (failed(emitType(arg.getType()))) { + return funcOp.emitOpError() + << "Failed to emit JaxiteWord type " << arg.getType(); + } + os << ",\n"; + if (isa(arg.getType())) { + CiphertextArg_ = argName; + } + if (isa(arg.getType())) { + ModulusListArg_ = argName; + } + } + os.unindent(); + os << ")"; + + if (CiphertextArg_.empty() || ModulusListArg_.empty()) { + return funcOp.emitWarning() << "Missing server keyset or ModulusList"; + } + + if (funcOp.getNumResults() > 0) { + os << " -> "; + if (funcOp.getNumResults() == 1) { + Type result = funcOp.getResultTypes()[0]; + if (failed(emitType(result))) { + return funcOp.emitOpError() + << "Failed to emit JaxiteWord type " << result; + } + } else { + auto result = commaSeparatedTypes( + funcOp.getResultTypes(), [&](Type type) -> FailureOr { + auto result = convertType(type); + if (failed(result)) { + return funcOp.emitOpError() + << "Failed to emit JaxiteWord type " << type; + } + return result; + }); + os << "(" << result.value() << ")"; + } + } + + os << ":\n"; + os.indent(); + + for (Block &block : funcOp.getBlocks()) { + for (Operation &op : block.getOperations()) { + if (failed(translate(op))) { + return failure(); + } + } + } + + os.unindent(); + os << "\n"; + return success(); +} + +LogicalResult JaxiteWordEmitter::printOperation(func::ReturnOp op) { + std::function resultValue = [&](Value value) { + if (isa(value)) { + // Function arguments used as outputs. + return variableNames->getNameForValue(value); + } else { + return "temp_nodes[" + + std::to_string(variableNames->getIntForValue(value)) + "]"; + } + }; + if (op.getNumOperands() == 0) { + return success(); + } + if (op.getNumOperands() == 1) { + os << "return " << resultValue(op.getOperands()[0]) << "\n"; + return success(); + } else { + os << "return (" << commaSeparatedValues(op.getOperands(), resultValue) + << ")\n"; + return success(); + } + return failure(); +} + +LogicalResult JaxiteWordEmitter::printOperation(AddOp op) { + emitAssignPrefix(op.getResult()); + os << op.getOperationName() << "(" + << "v" << variableNames->getIntForValue(op.getValueA()) << ", " + << "v" << variableNames->getIntForValue(op.getValueB()) << ", " + << "v" << variableNames->getIntForValue(op.getModulusList()) << ")\n"; + return success(); +} + +void JaxiteWordEmitter::emitAssignPrefix(Value result) { + os << "temp_nodes[" << variableNames->getIntForValue(result) << "] = "; +} + +LogicalResult JaxiteWordEmitter::printOperation(tensor::ExtractOp op) { + emitAssignPrefix(op.getResult()); + if (isa(op.getTensor())) { + os << variableNames->getNameForValue(op.getTensor()); + } else { + os << "temp_nodes[" << variableNames->getIntForValue(op.getTensor()) << "]"; + } + os << "[" + << dyn_cast( + dyn_cast(op.getIndices()[0].getDefiningOp()) + .getValue()) + .getValue() + << "]\n"; + return success(); +} + +LogicalResult JaxiteWordEmitter::printOperation(tensor::FromElementsOp op) { + if (op.getNumOperands() == 0) { + return success(); + } + if (isa(op->getOperands()[0].getDefiningOp())) { + return success(); + } + emitAssignPrefix(op.getResult()); + os << "[" << commaSeparatedValues(op.getOperands(), [&](Value value) { + return "temp_nodes[" + + std::to_string(variableNames->getIntForValue(value)) + "]"; + }) << "]\n"; + return success(); +} + +// Loading variables. +// Example: temp_nodes[idx] = input[i] +LogicalResult JaxiteWordEmitter::printOperation(memref::LoadOp op) { + emitAssignPrefix(op.getResult()); + os << variableNames->getNameForValue(op.getMemref()); + if (isa(op.getMemref())) { + // We assume the arguments to the function are flattened. + // We assume here that the indices are SSA values (not integer attributes). + os << "[" + << flattenedIndex( + op.getMemRefType(), op.getIndices(), + [&](Value value) { + return dyn_cast( + dyn_cast(value.getDefiningOp()) + .getValue()) + .getValue() + .getSExtValue(); + }) + << "]"; + } else { + os << bracketEnclosedValues(op.getIndices(), [&](Value value) { + SmallString<16> idx_str; + dyn_cast( + dyn_cast(value.getDefiningOp()).getValue()) + .getValue() + .toStringUnsigned(idx_str); + return std::string(idx_str); + }); + } + os << "\n"; + return success(); +} + +// memref::AllocOp initializes a variable of a specific shape. Translation in +// JaxiteWord is to allocate a flattened array. +// Example: temp_nodes[idx] = jnp.full((ixj), None) +// Note: memref::AllocOp and memref::StoreOp need to be in sync on how the +// indices are processed +LogicalResult JaxiteWordEmitter::printOperation(memref::AllocOp op) { + emitAssignPrefix(op.getResult()); + os << "jnp.full((" + << std::accumulate(std::next(op.getMemref().getType().getShape().begin()), + op.getMemref().getType().getShape().end(), + std::to_string(op.getMemref().getType().getShape()[0]), + [&](const std::string &a, int64_t b) { + return a + "*" + std::to_string(b); + }) + << "), None)"; + os << "\n"; + return success(); +} + +// Assuming StoreOp is only used while storing results. +// Example: temp_nodes[result_idx][idx] = temp_nodes[i] +// Note: memref::AllocOp and memref::StoreOp need to be in sync on how the +// indices are processed. +LogicalResult JaxiteWordEmitter::printOperation(memref::StoreOp op) { + os << "temp_nodes[" << variableNames->getIntForValue(op.getMemref()) << "]"; + os << "[" + << flattenedIndex( + op.getMemRefType(), op.getIndices(), + [&](Value value) { + return dyn_cast( + dyn_cast(value.getDefiningOp()) + .getValue()) + .getValue() + .getSExtValue(); + }) + << "]"; + os << " = " << "temp_nodes[" + << variableNames->getIntForValue(op.getValueToStore()) << "]"; + os << "\n"; + return success(); +} + +FailureOr JaxiteWordEmitter::convertType(Type type) { + // Note: these are probably not the right type names to use exactly, and + // they will need to change to the right values once we try to compile it + // against a specific API version. + if (auto shapedType = dyn_cast(type)) { + // A lambda in a type switch statement can't return multiple types. + // FIXME: why can't both types be FailureOr? + auto elementTy = convertType(shapedType.getElementType()); + if (failed(elementTy)) return failure(); + + return std::string(std::string("list[") + elementTy.value() + "]"); + } + return llvm::TypeSwitch>(type) + .Case( + [&](auto type) { return std::string("jaxite_word.Ciphertext"); }) + .Case( + [&](auto type) { return std::string("jaxite_word.ModulusList"); }) + .Default([&](Type &) { return failure(); }); +} + +LogicalResult JaxiteWordEmitter::emitType(Type type) { + auto result = convertType(type); + if (failed(result)) { + return failure(); + } + os << result; + return success(); +} + +JaxiteWordEmitter::JaxiteWordEmitter(raw_ostream &os, + SelectVariableNames *variableNames) + : os(os), variableNames(variableNames) {} + +} // namespace jaxiteword +} // namespace heir +} // namespace mlir diff --git a/lib/Target/JaxiteWord/JaxiteWordEmitter.h b/lib/Target/JaxiteWord/JaxiteWordEmitter.h new file mode 100644 index 0000000000..f7369cd344 --- /dev/null +++ b/lib/Target/JaxiteWord/JaxiteWordEmitter.h @@ -0,0 +1,69 @@ +#ifndef INCLUDE_TARGET_JAXITEWORD_JAXITEWORDEMITTER_H_ +#define INCLUDE_TARGET_JAXITEWORD_JAXITEWORDEMITTER_H_ + +#include + +#include "lib/Analysis/SelectVariableNames/SelectVariableNames.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h" +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/IndentedOstream.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace jaxiteword { + +void registerToJaxiteWordTranslation(); + +/// Translates the given operation to Jaxire. +::mlir::LogicalResult translateToJaxiteWord(::mlir::Operation *op, + llvm::raw_ostream &os); + +class JaxiteWordEmitter { + public: + JaxiteWordEmitter(raw_ostream &os, SelectVariableNames *variableNames); + + LogicalResult translate(::mlir::Operation &operation); + + private: + // Output stream to emit to. + raw_indented_ostream os; + + // Pre-populated analysis selecting unique variable names for all the SSA + // values. + SelectVariableNames *variableNames; + + // ciphertext arg. + std::string CiphertextArg_; + + // A list of modulus to be used for the add operation. + std::string ModulusListArg_; + + LogicalResult printOperation(ModuleOp moduleOp); + LogicalResult printOperation(func::FuncOp funcOp); + LogicalResult printOperation(func::ReturnOp returnOp); + LogicalResult printOperation(AddOp op); + LogicalResult printOperation(tensor::ExtractOp op); + LogicalResult printOperation(tensor::FromElementsOp op); + LogicalResult printOperation(memref::AllocOp op); + LogicalResult printOperation(memref::LoadOp op); + LogicalResult printOperation(memref::StoreOp op); + LogicalResult emitType(Type type); + FailureOr convertType(Type type); + + void emitAssignPrefix(Value result); +}; + +} // namespace jaxiteword +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_TARGET_JAXITEWORD_JAXITEWORDEMITTER_H_ diff --git a/lib/Target/JaxiteWord/JaxiteWordTemplates.h b/lib/Target/JaxiteWord/JaxiteWordTemplates.h new file mode 100644 index 0000000000..f6db5a43d0 --- /dev/null +++ b/lib/Target/JaxiteWord/JaxiteWordTemplates.h @@ -0,0 +1,24 @@ +#ifndef LIB_TARGET_JAXITEWORD_JAXITEWORDTEMPLATES_H_ +#define LIB_TARGET_JAXITEWORD_JAXITEWORDTEMPLATES_H_ + +#include + +namespace mlir { +namespace heir { +namespace jaxiteword { + +constexpr std::string_view kModulePrelude = R"python( +import jax +import jax.numpy as jnp + +from typing import Dict, List + +from jaxite.jaxite_word import jaxite_word + +)python"; + +} // namespace jaxiteword +} // namespace heir +} // namespace mlir + +#endif // LIB_TARGET_JAXITEWORD_JAXITEWORDTEMPLATES_H_ diff --git a/tests/Dialect/JaxiteWord/Emitters/BUILD b/tests/Dialect/JaxiteWord/Emitters/BUILD new file mode 100644 index 0000000000..c571e6fc6d --- /dev/null +++ b/tests/Dialect/JaxiteWord/Emitters/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Dialect/JaxiteWord/Emitters/emit_jaxiteword.mlir b/tests/Dialect/JaxiteWord/Emitters/emit_jaxiteword.mlir new file mode 100644 index 0000000000..77fb230f92 --- /dev/null +++ b/tests/Dialect/JaxiteWord/Emitters/emit_jaxiteword.mlir @@ -0,0 +1,12 @@ +// RUN: heir-translate --emit-jaxite %s | FileCheck %s + +!ct = !jaxiteword.ciphertext<2, 3, 4> +!ml = !jaxiteword.modulus_list + +// CHECK-LABEL: func.func @test_add( +func.func @test_add(%ct1 : !ct, %ct2 : !ct, %modulus_list: !ml) -> !ct { + // find functions here: third_party/heir/tests/Dialect/Openfhe/IR/ops.mlir + // ToDo: How to create value for all inputs? + %out = jaxiteword.add %ct1, %ct2, %modulus_list: (!ct, !ct, !ml) -> !ct + return %out : !ct +} diff --git a/tests/Dialect/JaxiteWord/IR/BUILD b/tests/Dialect/JaxiteWord/IR/BUILD new file mode 100644 index 0000000000..c571e6fc6d --- /dev/null +++ b/tests/Dialect/JaxiteWord/IR/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Dialect/JaxiteWord/IR/ops.mlir b/tests/Dialect/JaxiteWord/IR/ops.mlir new file mode 100644 index 0000000000..8fc032fc26 --- /dev/null +++ b/tests/Dialect/JaxiteWord/IR/ops.mlir @@ -0,0 +1,11 @@ +// RUN: heir-opt %s | FileCheck %s + +!ct = !jaxiteword.ciphertext<2, 3, 4> +!ml = !jaxiteword.modulus_list + +// CHECK-LABEL: func.func @test_add( +func.func @test_add(%ct1 : !ct, %ct2 : !ct, %modulus_list: !ml) -> !ct { + // find functions here: third_party/heir/tests/Dialect/Openfhe/IR/ops.mlir + %out = jaxiteword.add %ct1, %ct2, %modulus_list: (!ct, !ct, !ml) -> !ct + return %out : !ct +} diff --git a/tools/BUILD b/tools/BUILD index ded8a6b58f..384ab81595 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -53,6 +53,7 @@ cc_binary( "@heir//lib/Dialect/CKKS/IR:Dialect", "@heir//lib/Dialect/Comb/IR:Dialect", "@heir//lib/Dialect/Jaxite/IR:Dialect", + "@heir//lib/Dialect/JaxiteWord/IR:Dialect", "@heir//lib/Dialect/LWE/Conversions/LWEToLattigo", "@heir//lib/Dialect/LWE/Conversions/LWEToOpenfhe", "@heir//lib/Dialect/LWE/Conversions/LWEToPolynomial", @@ -187,6 +188,7 @@ cc_binary( deps = [ "@heir//lib/Source/AutoHog:AutoHogImporter", "@heir//lib/Target/Jaxite:JaxiteEmitter", + "@heir//lib/Target/JaxiteWord:JaxiteWordEmitter", "@heir//lib/Target/Lattigo:LattigoEmitter", "@heir//lib/Target/Metadata:MetadataEmitter", "@heir//lib/Target/OpenFhePke:OpenFheRegistration", @@ -209,6 +211,7 @@ cc_binary( "@heir//lib/Dialect/CKKS/IR:Dialect", "@heir//lib/Dialect/Comb/IR:Dialect", "@heir//lib/Dialect/Jaxite/IR:Dialect", + "@heir//lib/Dialect/JaxiteWord/IR:Dialect", "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/Lattigo/IR:Dialect", "@heir//lib/Dialect/Mgmt/IR:Dialect", diff --git a/tools/heir-lsp.cpp b/tools/heir-lsp.cpp index 3795d861bd..147d80fc78 100644 --- a/tools/heir-lsp.cpp +++ b/tools/heir-lsp.cpp @@ -3,6 +3,7 @@ #include "lib/Dialect/CKKS/IR/CKKSDialect.h" #include "lib/Dialect/Comb/IR/CombDialect.h" #include "lib/Dialect/Jaxite/IR/JaxiteDialect.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" #include "lib/Dialect/LWE/IR/LWEDialect.h" #include "lib/Dialect/Lattigo/IR/LattigoDialect.h" #include "lib/Dialect/Mgmt/IR/MgmtDialect.h" @@ -42,6 +43,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 4d3dff96b9..61bd9e6ffb 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -17,6 +17,7 @@ #include "lib/Dialect/CKKS/IR/CKKSDialect.h" #include "lib/Dialect/Comb/IR/CombDialect.h" #include "lib/Dialect/Jaxite/IR/JaxiteDialect.h" +#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h" #include "lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.h" #include "lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h" #include "lib/Dialect/LWE/Conversions/LWEToPolynomial/LWEToPolynomial.h" @@ -143,6 +144,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/tools/heir-translate.cpp b/tools/heir-translate.cpp index 1d9faffe72..ba8bd05c55 100644 --- a/tools/heir-translate.cpp +++ b/tools/heir-translate.cpp @@ -1,5 +1,6 @@ #include "lib/Source/AutoHog/AutoHogImporter.h" #include "lib/Target/Jaxite/JaxiteEmitter.h" +#include "lib/Target/JaxiteWord/JaxiteWordEmitter.h" #include "lib/Target/Lattigo/LattigoEmitter.h" #include "lib/Target/Metadata/MetadataEmitter.h" #include "lib/Target/OpenFhePke/OpenFheTranslateRegistration.h" @@ -23,6 +24,7 @@ int main(int argc, char **argv) { // jaxite output mlir::heir::jaxite::registerToJaxiteTranslation(); + mlir::heir::jaxiteword::registerToJaxiteWordTranslation(); // OpenFHE mlir::heir::openfhe::registerTranslateOptions();