Skip to content

update arbitrary-precision sub, add and multiplication of jaxite_word" #1538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_subdirectory(CGGI)
add_subdirectory(CKKS)
add_subdirectory(Comb)
add_subdirectory(Jaxite)
add_subdirectory(JaxiteWord)
add_subdirectory(LinAlg)
add_subdirectory(LWE)
add_subdirectory(ModArith)
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/JaxiteWord/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(IR)
74 changes: 74 additions & 0 deletions lib/Dialect/JaxiteWord/IR/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# JaxiteWord, an exit dialect to JaxiteWord API

load("@heir//lib/Dialect:dialect.bzl", "add_heir_dialect_library")
load("@llvm-project//mlir:tblgen.bzl", "td_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Dialect",
srcs = ["JaxiteWordDialect.cpp"],
hdrs = [
"JaxiteWordDialect.h",
"JaxiteWordOps.h",
"JaxiteWordTypes.h",
],
deps = [
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Support",
],
)

td_library(
name = "td_files",
srcs = [
"JaxiteWordDialect.td",
"JaxiteWordOps.td",
"JaxiteWordTypes.td",
],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
],
)

add_heir_dialect_library(
name = "dialect_inc_gen",
dialect = "JaxiteWord",
kind = "dialect",
td_file = "JaxiteWordDialect.td",
deps = [
":td_files",
],
)

add_heir_dialect_library(
name = "types_inc_gen",
dialect = "JaxiteWord",
kind = "type",
td_file = "JaxiteWordTypes.td",
deps = [
":td_files",
],
)

add_heir_dialect_library(
name = "ops_inc_gen",
dialect = "JaxiteWord",
kind = "op",
td_file = "JaxiteWordOps.td",
deps = [
":td_files",
"@heir//lib/Dialect/LWE/IR:td_files",
],
)
8 changes: 8 additions & 0 deletions lib/Dialect/JaxiteWord/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_heir_dialect(JaxiteWord jaxiteword)

add_mlir_dialect_library(HEIRJaxiteWord
JaxiteWordDialect.cpp

DEPENDS
HEIRJaxiteWordIncGen
)
42 changes: 42 additions & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h"

#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.cpp.inc"
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h"
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

#define GET_TYPEDEF_CLASSES
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.cpp.inc"
#define GET_OP_CLASSES
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.cpp.inc"

namespace mlir {
namespace heir {
namespace jaxiteword {

void JaxiteWordDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.cpp.inc"
>();
}

LogicalResult AddOp::verify() {
if (getModulusList().getType().getModulusList().size() !=
getValueA().getType().getTowers()) {
return emitOpError() << "Number of Towers of moudlus should match the "
"number of towers/limbs";
}
return success();
}

} // namespace jaxiteword
} // namespace heir
} // namespace mlir
13 changes: 13 additions & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_
#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_

#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project

// Generated headers (block clang-format from messing up order)
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.h.inc"

#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_
22 changes: 22 additions & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordDialect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_
#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_

include "mlir/IR/DialectBase.td"
include "mlir/IR/OpBase.td"

def JaxiteWord_Dialect : Dialect {
let name = "jaxiteword";

let description = [{
The `jaxiteword` dialect is an exit dialect for generating py code against the jaxiteword library API,
using the jaxiteword parameters and encoding scheme.

See https://github.com/google/jaxite/jaxite_word
}];

let cppNamespace = "::mlir::heir::jaxiteword";

let useDefaultTypePrinterParser = 1;
}

#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDDIALECT_H_
13 changes: 13 additions & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDOPS_H_
#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDOPS_H_

#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project

#define GET_OP_CLASSES
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordOps.h.inc"

#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDOPS_H_
38 changes: 38 additions & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDOPS_TD_
#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDOPS_TD_

include "JaxiteWordDialect.td"
include "JaxiteWordTypes.td"

include "lib/Dialect/LWE/IR/LWETypes.td"

include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"


class JaxiteWord_Op<string mnemonic, list<Trait> traits = []> :
Op<JaxiteWord_Dialect, mnemonic, traits> {
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
}];
let cppNamespace = "::mlir::heir::jaxiteword";
}


def AddOp : JaxiteWord_Op<"add", [AllTypesMatch<["value_a", "value_b", "result"]>,Commutative,Pure]> {
let description = [{
The operation computed by this function is homomorphic addition.
}];
let arguments = (ins JaxiteWord_Ciphertext:$value_a,
JaxiteWord_Ciphertext:$value_b,
JaxiteWord_ModulusList:$modulus_list
);
let results = (outs JaxiteWord_Ciphertext:$result);
let hasVerifier = 1;
}

#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDOPS_TD_
13 changes: 13 additions & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDTYPES_H_
#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDTYPES_H_

#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
// #include "mlir/include/mlir/IR/TypeSupport.h" // from @llvm-project
// #include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project

// Include LLVM ADT definitions.

#define GET_TYPEDEF_CLASSES
#include "lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.h.inc"

#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDTYPES_H_
38 changes: 38 additions & 0 deletions lib/Dialect/JaxiteWord/IR/JaxiteWordTypes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#ifndef LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDTYPES_TD_
#define LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDTYPES_TD_

include "JaxiteWordDialect.td"

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/DialectBase.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"

class JaxiteWord_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<JaxiteWord_Dialect, name, traits> {
let mnemonic = typeMnemonic;
let assemblyFormat = "`<` struct(params) `>`"; // print out all information in the arguments
}

def JaxiteWord_ModulusList : JaxiteWord_Type<"ModulusList", "modulus_list"> {
let description = [{
A list of modulus values.
}];
let parameters = (ins ArrayRefParameter<"::mlir::Type">:$modulus_list);
let assemblyFormat = "`<` $modulus_list `>`";
// jaxiteword.modulus_list<i32, i32, i32, i32>
}

def JaxiteWord_Ciphertext : JaxiteWord_Type<"Ciphertext", "ciphertext"> {
let description = [{
A ciphertext - a three dimensional array.
}];
let parameters = (ins "int":$polys,
"int":$towers,
"int":$degrees);
let assemblyFormat = "`<` $polys `,` $towers `,` $degrees `>`";
// jaxiteword.ciphertext<3, 4, 5>
}

#endif // LIB_DIALECT_JAXITEWORD_IR_JAXITEWORDTYPES_TD_
37 changes: 37 additions & 0 deletions lib/Dialect/LWE/Conversions/LWEToJaxiteWord/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "LWEToJaxiteWord",
srcs = ["LWEToJaxiteWord.cpp"],
hdrs = [
"LWEToJaxiteWord.h",
],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/BGV/IR:Dialect",
"@heir//lib/Dialect/CKKS/IR:Dialect",
"@heir//lib/Dialect/JaxiteWord/IR:Dialect",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Utils",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
],
)

add_heir_transforms(
header_filename = "LWEToJaxiteWord.h.inc",
pass_name = "LWEToJaxiteWord",
td_file = "LWEToJaxiteWord.td",
)
20 changes: 20 additions & 0 deletions lib/Dialect/LWE/Conversions/LWEToJaxiteWord/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
add_mlir_conversion_library(HEIRLWEToJaxiteWord
LWEToJaxiteWord.cpp

LINK_LIBS PUBLIC
HEIRConversionUtils
HEIRLWE
HEIRLWE
HEIRJaxiteWord

MLIRIR
MLIRPass
MLIRInferTypeOpInterface
MLIRArithDialect
MLIRFuncDialect
LLVMSupport
MLIRSupport
MLIRDialect
MLIRTransformUtils
MLIRIR
)
Loading
Loading