Skip to content

WIP: create PISA dialect (+ emitter and passes) #1046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ repos:
- id: codespell
args: ["-L", "crate, fpt"]

# The PISA dialect contains operation names that look like misspellings.
exclude: >
(?x)^(
.*\/pisa\/.*\.mlir|
.*\/PISA\/.*\.td|
.*\/PISA\/.*\.cpp
)$

# Changes tabs to spaces
- repo: https://github.com/Lucas-C/pre-commit-hooks
Expand Down
110 changes: 110 additions & 0 deletions lib/Dialect/PISA/IR/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# PISA dialect implementation

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Dialect",
srcs = [
"PISADialect.cpp",
],
hdrs = [
"PISADialect.h",
"PISAOps.h",
],
deps = [
"dialect_inc_gen",
"ops_inc_gen",
":PISAOps",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "PISAOps",
srcs = [
"PISAOps.cpp",
],
hdrs = [
"PISADialect.h",
"PISAOps.h",
],
deps = [
":dialect_inc_gen",
":ops_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Support",
],
)

td_library(
name = "td_files",
srcs = [
"PISADialect.td",
"PISAOps.td",
],
# include from the heir - root to enable fully - qualified include - paths
includes = ["../../../.."],
deps = [
"@heir//lib/Utils/DRR",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
],
)

gentbl_cc_library(
name = "dialect_inc_gen",
tbl_outs = [
(
[
"-gen-dialect-decls",
],
"PISADialect.h.inc",
),
(
[
"-gen-dialect-defs",
],
"PISADialect.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "PISADialect.td",
deps = [
":td_files",
],
)

gentbl_cc_library(
name = "ops_inc_gen",
tbl_outs = [
(
["-gen-op-decls"],
"PISAOps.h.inc",
),
(
["-gen-op-defs"],
"PISAOps.cpp.inc",
),
(
["-gen-op-doc"],
"PISAOps.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "PISAOps.td",
deps = [
":dialect_inc_gen",
":td_files",
],
)
28 changes: 28 additions & 0 deletions lib/Dialect/PISA/IR/PISADialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "lib/Dialect/PISA/IR/PISADialect.h"

#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project

// NOLINTNEXTLINE(misc-include-cleaner): Required to define PISAOps

#include "lib/Dialect/PISA/IR/PISAOps.h"

// Generated definitions
#include "lib/Dialect/PISA/IR/PISADialect.cpp.inc"

#define GET_OP_CLASSES
#include "lib/Dialect/PISA/IR/PISAOps.cpp.inc"

namespace mlir {
namespace heir {
namespace pisa {

void PISADialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "lib/Dialect/PISA/IR/PISAOps.cpp.inc"
>();
}

} // namespace pisa
} // namespace heir
} // namespace mlir
10 changes: 10 additions & 0 deletions lib/Dialect/PISA/IR/PISADialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef LIB_DIALECT_PISA_IR_PISADIALECT_H_
#define LIB_DIALECT_PISA_IR_PISADIALECT_H_

#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project

// Generated headers (block clang-format from messing up order)
#include "lib/Dialect/PISA/IR/PISADialect.h.inc"

#endif // LIB_DIALECT_PISA_IR_PISADIALECT_H_
16 changes: 16 additions & 0 deletions lib/Dialect/PISA/IR/PISADialect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef LIB_DIALECT_PISA_IR_PISADIALECT_TD_
#define LIB_DIALECT_PISA_IR_PISADIALECT_TD_

include "mlir/IR/DialectBase.td"

def PISA_Dialect : Dialect {
let name = "pisa";
let description = [{
// FIXME: add documentation
The `pisa` dialect is ...
}];

let cppNamespace = "::mlir::heir::pisa";
}

#endif // LIB_DIALECT_PISA_IR_PISADIALECT_TD_
7 changes: 7 additions & 0 deletions lib/Dialect/PISA/IR/PISAOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include "lib/Dialect/PISA/IR/PISAOps.h"

namespace mlir {
namespace heir {
namespace pisa {} // namespace pisa
} // namespace heir
} // namespace mlir
12 changes: 12 additions & 0 deletions lib/Dialect/PISA/IR/PISAOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef LIB_DIALECT_PISA_IR_PISAOPS_H_
#define LIB_DIALECT_PISA_IR_PISAOPS_H_

#include "lib/Dialect/ModArith/IR/ModArithTypes.h" // required for the type predicate we use
#include "lib/Dialect/PISA/IR/PISADialect.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project

#define GET_OP_CLASSES
#include "lib/Dialect/PISA/IR/PISAOps.h.inc"

#endif // LIB_DIALECT_PISA_IR_PISAOPS_H_
106 changes: 106 additions & 0 deletions lib/Dialect/PISA/IR/PISAOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#ifndef LIB_DIALECT_PISA_IR_PISAOPS_TD_
#define LIB_DIALECT_PISA_IR_PISAOPS_TD_

include "lib/Dialect/PISA/IR/PISADialect.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

// We only accept tensors of mod_arith with 32-bit typed moduli.
// Note that we do NOT allow moduli that are concretely less than 32 bits but have a larger type (e.g., I64)
// as those allow the compiler to emit code that relies on temporarily using up to 64 bits before mod-reducing.
def Tensor8192I32 : TypeConstraint<CPred<[{
mlir::isa<mlir::RankedTensorType>($_self) &&
mlir::cast<mlir::RankedTensorType>($_self).getRank() == 1 &&
mlir::cast<mlir::RankedTensorType>($_self).getDimSize(0) == 8192 &&
llvm::isa<mlir::heir::mod_arith::ModArithType>(mlir::cast<mlir::RankedTensorType>($_self).getElementType()) &&
mlir::cast<mlir::heir::mod_arith::ModArithType>(mlir::cast<mlir::RankedTensorType>($_self).getElementType()).getModulus().getType().isInteger(32)
}]>, "tensor<8192xmod_arith.int< ... : i32>>">;

class PISA_Op<string mnemonic, list<Trait> traits = [Pure]> :
Op<PISA_Dialect, mnemonic, traits> {
let cppNamespace = "::mlir::heir::pisa";
}

class PISA_BinaryOp<string mnemonic, list<Trait> traits = []> :
PISA_Op<mnemonic, traits # [SameOperandsAndResultType]>,
Arguments<(ins Tensor8192I32:$lhs, Tensor8192I32:$rhs, I32Attr:$q, I32Attr:$i)>,
Results<(outs Tensor8192I32:$output)> {
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($output))";
}

def PISA_AddOp : PISA_BinaryOp<"add", [Commutative]> {
let summary = "addition operation";
let description = [{
Computes addition of two polynomials (irrespective of ntt/coefficient representation).
}];
}

def PISA_SubOp : PISA_BinaryOp<"sub", []> {
let summary = "subtraction operation";
let description = [{
Computes subtraction of two polynomials (irrespective of ntt/coefficient representation).
}];
}

def PISA_MulOp : PISA_BinaryOp<"mul", [Commutative]> {
let summary = "multiplication operation";
let description = [{
Computes addition of two polynomials (in ntt representation).
}];
}

def PISA_MuliOp : PISA_Op<"muli", [SameOperandsAndResultType]> {
let summary = "multiplication-with-immediate operation";
let description = [{
Computes multiplication of a polynomial (in ntt representation) with a constant.
}];
let arguments = (ins Tensor8192I32:$lhs, I32Attr:$q, I32Attr:$i, I32Attr:$imm);
let results = (outs Tensor8192I32:$output);
let assemblyFormat = "$lhs attr-dict `:` qualified(type($output))";
}

def PISA_MacOp : PISA_Op<"mac", [SameOperandsAndResultType]> {
let summary = "multiply-and-accumulate operation";
let description = [{
Computes multiplication of two polynomials (in ntt representation) and adds the result to a third polynomial.
}];
let arguments = (ins Tensor8192I32:$lhs, Tensor8192I32:$rhs, Tensor8192I32:$acc, I32Attr:$q, I32Attr:$i);
let results = (outs Tensor8192I32:$output);
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` qualified(type($output))";
}

def PISA_MaciOp : PISA_Op<"maci", [SameOperandsAndResultType]> {
let summary = "multiply-and-accumulate-with-immediate operation";
let description = [{
Computes multiplication of a polynomial (in ntt representation) with a constant and adds the result to a third polynomial.
}];
let arguments = (ins Tensor8192I32:$lhs, Tensor8192I32:$acc, I32Attr:$q, I32Attr:$i, I32Attr:$imm);
let results = (outs Tensor8192I32:$output);
let assemblyFormat = "$lhs `,` $acc attr-dict `:` qualified(type($output))";
}

def PISA_NTTOp : PISA_Op<"ntt", [SameOperandsAndResultType]> {
let summary = "number-theoretic-transform operation";
let description = [{
Computes number-theoretic-transform of a polynomial.
}];
let arguments = (ins Tensor8192I32:$poly, Tensor8192I32:$w, I32Attr:$q, I32Attr:$i);
let results = (outs Tensor8192I32:$output);
let assemblyFormat = "$poly `,` $w attr-dict `:` qualified(type($output))";
}

def PISA_INTTOp : PISA_Op<"intt", [SameOperandsAndResultType]> {
let summary = "inverse number-theoretic-transform operation";
let description = [{
Computes inverse number-theoretic-transform of a polynomial.
}];
let arguments = (ins Tensor8192I32:$poly, Tensor8192I32:$w, I32Attr:$q, I32Attr:$i);
let results = (outs Tensor8192I32:$output);
let assemblyFormat = "$poly `,` $w attr-dict `:` qualified(type($output))";
}


#endif // LIB_DIALECT_PISA_IR_PISAOPS_TD_
1 change: 1 addition & 0 deletions lib/Dialect/Polynomial/Conversions/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(PolynomialToModArith)
add_subdirectory(PolynomialToPISA)
45 changes: 45 additions & 0 deletions lib/Dialect/Polynomial/Conversions/PolynomialToPISA/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "PolynomialToPISA",
srcs = ["PolynomialToPISA.cpp"],
hdrs = ["PolynomialToPISA.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Dialect/PISA/IR:Dialect",
"@heir//lib/Dialect/Polynomial/IR:Dialect",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=PolynomialToPISA",
],
"PolynomialToPISA.h.inc",
),
(
["-gen-pass-doc"],
"PolynomialToPISA.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "PolynomialToPISA.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
Empty file.
Loading
Loading