Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion lib/Dialect/ModArith/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ cc_library(
srcs = [
"ModArithDialect.cpp",
"ModArithOps.cpp",
"ModArithTypes.cpp",
],
hdrs = [
"ModArithDialect.h",
Expand All @@ -24,6 +25,7 @@ cc_library(
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
"@heir//lib/Dialect/RNS/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:CommonFolders",
Expand All @@ -44,7 +46,9 @@ td_library(
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@heir//lib/Dialect/RNS/IR:td_files",
"@heir//lib/Utils/DRR",
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
Expand Down Expand Up @@ -89,6 +93,5 @@ gentbl_cc_library(
td_file = "ModArithCanonicalization.td",
deps = [
":td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
],
)
1 change: 1 addition & 0 deletions lib/Dialect/ModArith/IR/ModArithDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def ModArith_Dialect : Dialect {

let dependentDialects = [
"arith::ArithDialect",
"mlir::heir::rns::RNSDialect"
];
}

Expand Down
39 changes: 2 additions & 37 deletions lib/Dialect/ModArith/IR/ModArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cstdint>
#include <vector>

#include "lib/Dialect/RNS/IR/RNSTypes.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
Expand Down Expand Up @@ -32,20 +33,6 @@ namespace mlir {
namespace heir {
namespace mod_arith {

/// Ensures that the underlying integer type is wide enough for the coefficient
template <typename OpType>
LogicalResult verifyModArithType(OpType op, ModArithType type) {
APInt modulus = type.getModulus().getValue();
unsigned bitWidth = modulus.getBitWidth();
unsigned modWidth = modulus.getActiveBits();
if (modWidth > bitWidth - 1)
return op.emitOpError()
<< "underlying type's bitwidth must be 1 bit larger than "
<< "the modulus bitwidth, but got " << bitWidth
<< " while modulus requires width " << modWidth << ".";
return success();
}

template <typename OpType>
LogicalResult verifySameWidth(OpType op, ModArithType modArithType,
IntegerType integerType) {
Expand All @@ -62,29 +49,7 @@ LogicalResult verifySameWidth(OpType op, ModArithType modArithType,
LogicalResult ExtractOp::verify() {
auto modArithType = getOperandModArithType(*this);
auto integerType = getResultIntegerType(*this);
auto result = verifySameWidth(*this, modArithType, integerType);
if (result.failed()) return result;
return verifyModArithType(*this, modArithType);
}

LogicalResult ReduceOp::verify() {
return verifyModArithType(*this, getResultModArithType(*this));
}

LogicalResult AddOp::verify() {
return verifyModArithType(*this, getResultModArithType(*this));
}

LogicalResult SubOp::verify() {
return verifyModArithType(*this, getResultModArithType(*this));
}

LogicalResult MulOp::verify() {
return verifyModArithType(*this, getResultModArithType(*this));
}

LogicalResult MacOp::verify() {
return verifyModArithType(*this, getResultModArithType(*this));
return verifySameWidth(*this, modArithType, integerType);
}

LogicalResult BarrettReduceOp::verify() {
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/ModArith/IR/ModArithOps.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef LIB_DIALECT_MODARITH_IR_MODARITHOPS_H_
#define LIB_DIALECT_MODARITH_IR_MODARITHOPS_H_

#include "lib/Dialect/RNS/IR/RNSTypes.h"

// NOLINTBEGIN(misc-include-cleaner): Required to define ModArithOps
#include "lib/Dialect/ModArith/IR/ModArithDialect.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
Expand Down
10 changes: 5 additions & 5 deletions lib/Dialect/ModArith/IR/ModArithOps.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef LIB_DIALECT_MODARITH_IR_MODARITHOPS_TD_
#define LIB_DIALECT_MODARITH_IR_MODARITHOPS_TD_

include "lib/Dialect/RNS/IR/RNSTypes.td"
include "lib/Dialect/ModArith/IR/ModArithDialect.td"
include "lib/Dialect/ModArith/IR/ModArithTypes.td"
include "mlir/IR/BuiltinAttributes.td"
Expand All @@ -9,6 +10,8 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

def ModArithOrRNSLike: TypeOrValueSemanticsContainer<AnyTypeOf<[ModArith_ModArithType, RNS]>,
"mod_arith_or_rns-like">;

class ModArith_Op<string mnemonic, list<Trait> traits = [Pure]> :
Op<ModArith_Dialect, mnemonic, traits> {
Expand Down Expand Up @@ -121,15 +124,13 @@ def ModArith_ReduceOp : ModArith_Op<"reduce", [Pure, ElementwiseMappable, SameOp
ModArithLike:$input
);
let results = (outs ModArithLike:$output);
let hasVerifier = 1;
let assemblyFormat = "operands attr-dict `:` type($output)";
}

class ModArith_BinaryOp<string mnemonic, list<Trait> traits = []> :
ModArith_Op<mnemonic, traits # [SameOperandsAndResultType, Pure, ElementwiseMappable]>,
Arguments<(ins ModArithLike:$lhs, ModArithLike:$rhs)>,
Results<(outs ModArithLike:$output)> {
let hasVerifier = 1;
Arguments<(ins ModArithOrRNSLike:$lhs, ModArithOrRNSLike:$rhs)>,
Results<(outs ModArithOrRNSLike:$output)> {
let assemblyFormat ="operands attr-dict `:` type($output)";
}

Expand Down Expand Up @@ -180,7 +181,6 @@ def ModArith_MacOp : ModArith_Op<"mac", [SameOperandsAndResultType, Pure, Elemen
}];
let arguments = (ins ModArithLike:$lhs, ModArithLike:$rhs, ModArithLike:$acc);
let results = (outs ModArithLike:$output);
let hasVerifier = 1;
let assemblyFormat = "operands attr-dict `:` type($output)";
}

Expand Down
27 changes: 27 additions & 0 deletions lib/Dialect/ModArith/IR/ModArithTypes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"

#include "llvm/include/llvm/ADT/STLFunctionalExtras.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace mod_arith {

LogicalResult ModArithType::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
::mlir::IntegerAttr modulus) {
APInt value = modulus.getValue();
unsigned bitWidth = value.getBitWidth();
unsigned modWidth = value.getActiveBits();
if (modWidth > bitWidth - 1)
return emitError()
<< "underlying type's bitwidth must be 1 bit larger than "
<< "the modulus bitwidth, but got " << bitWidth
<< " while modulus requires width " << modWidth << ".";
return success();
}

} // namespace mod_arith
} // namespace heir
} // namespace mlir
2 changes: 2 additions & 0 deletions lib/Dialect/ModArith/IR/ModArithTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def ModArith_ModArithType : ModArith_Type<"ModArith", "int", [MemRefElementTypeI
return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias;
}
}];

let genVerifyDecl = 1;
}

def ModArithLike: TypeOrValueSemanticsContainer<ModArith_ModArithType, "mod_arith-like">;
Expand Down
5 changes: 2 additions & 3 deletions lib/Dialect/Polynomial/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ td_library(
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@heir//lib/Dialect/ModArith/IR:td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
Expand Down Expand Up @@ -82,7 +84,6 @@ add_heir_dialect_library(
td_file = "PolynomialOps.td",
deps = [
":td_files",
"@heir//lib/Dialect/ModArith/IR:td_files",
],
)

Expand All @@ -93,7 +94,5 @@ gentbl_cc_library(
td_file = "PolynomialCanonicalization.td",
deps = [
":td_files",
"@heir//lib/Dialect/ModArith/IR:td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
],
)
1 change: 1 addition & 0 deletions tests/Dialect/ModArith/IR/invalid-ops.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: heir-opt --verify-diagnostics --split-input-file %s | FileCheck %s

// expected-error@+1 {{underlying type's bitwidth must be 1 bit larger than the modulus bitwidth, but got 8 while modulus requires width 8}}
!Zp = !mod_arith.int<255 : i8>

// -----
Expand Down
10 changes: 10 additions & 0 deletions tests/Dialect/ModArith/IR/syntax.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// RUN: heir-opt --mlir-print-local-scope %s | FileCheck %s

!Zp = !mod_arith.int<17 : i10>
!Zp2 = !mod_arith.int<257 : i10>
!Zp_vec = tensor<4x!Zp>
!rns = !rns.rns<!Zp, !Zp2>

// CHECK: @test_arith_syntax
func.func @test_arith_syntax() {
Expand Down Expand Up @@ -83,3 +85,11 @@ func.func @test_arith_syntax() {

return
}

// CHECK: @test_rns_syntax
func.func @test_rns_syntax(%arg0: !rns, %arg1: !rns) -> !rns {
// CHECK: mod_arith.add
// CHECK-SAME: rns
%result = mod_arith.add %arg0, %arg1 : !rns
return %result : !rns
}
Loading