Skip to content

Commit cc4b804

Browse files
committed
ModArith: Support RNS type in ops
1 parent b59fe98 commit cc4b804

File tree

7 files changed

+48
-7
lines changed

7 files changed

+48
-7
lines changed

lib/Dialect/ModArith/IR/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
":dialect_inc_gen",
2525
":ops_inc_gen",
2626
":types_inc_gen",
27+
"@heir//lib/Dialect/RNS/IR:Dialect",
2728
"@llvm-project//llvm:Support",
2829
"@llvm-project//mlir:ArithDialect",
2930
"@llvm-project//mlir:CommonFolders",
@@ -44,7 +45,9 @@ td_library(
4445
# include from the heir-root to enable fully-qualified include-paths
4546
includes = ["../../../.."],
4647
deps = [
48+
"@heir//lib/Dialect/RNS/IR:td_files",
4749
"@heir//lib/Utils/DRR",
50+
"@llvm-project//mlir:ArithOpsTdFiles",
4851
"@llvm-project//mlir:BuiltinDialectTdFiles",
4952
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
5053
"@llvm-project//mlir:OpBaseTdFiles",
@@ -89,6 +92,5 @@ gentbl_cc_library(
8992
td_file = "ModArithCanonicalization.td",
9093
deps = [
9194
":td_files",
92-
"@llvm-project//mlir:ArithOpsTdFiles",
9395
],
9496
)

lib/Dialect/ModArith/IR/ModArithDialect.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def ModArith_Dialect : Dialect {
1616

1717
let dependentDialects = [
1818
"arith::ArithDialect",
19+
"mlir::heir::rns::RNSDialect"
1920
];
2021
}
2122

lib/Dialect/ModArith/IR/ModArithOps.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <cstdint>
66
#include <vector>
77

8+
#include "lib/Dialect/RNS/IR/RNSTypes.h"
89
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
910
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
1011
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
@@ -46,6 +47,30 @@ LogicalResult verifyModArithType(OpType op, ModArithType type) {
4647
return success();
4748
}
4849

50+
/// Ensures that the underlying integer type is wide enough for the coefficient
51+
template <typename OpType>
52+
LogicalResult verifyRNSType(OpType op, rns::RNSType type) {
53+
for (auto basisType : type.getBasisTypes()) {
54+
if (auto modArithType = dyn_cast<ModArithType>(basisType)) {
55+
if (failed(verifyModArithType(op, modArithType))) {
56+
return op.emitOpError()
57+
<< "Every basis type in the RNS type must be a valid ModArith.";
58+
}
59+
} else {
60+
return op.emitOpError() << "Unsupported RNS type.";
61+
}
62+
}
63+
return success();
64+
}
65+
66+
template <typename OpType>
67+
LogicalResult verifyModArithOrRNSType(OpType op, Type type) {
68+
if (auto rnsType = dyn_cast<rns::RNSType>(type)) {
69+
return verifyRNSType(op, rnsType);
70+
}
71+
return verifyModArithType(op, cast<ModArithType>(getElementTypeOrSelf(type)));
72+
}
73+
4974
template <typename OpType>
5075
LogicalResult verifySameWidth(OpType op, ModArithType modArithType,
5176
IntegerType integerType) {
@@ -72,7 +97,7 @@ LogicalResult ReduceOp::verify() {
7297
}
7398

7499
LogicalResult AddOp::verify() {
75-
return verifyModArithType(*this, getResultModArithType(*this));
100+
return verifyModArithOrRNSType(*this, this->getResult().getType());
76101
}
77102

78103
LogicalResult SubOp::verify() {

lib/Dialect/ModArith/IR/ModArithOps.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef LIB_DIALECT_MODARITH_IR_MODARITHOPS_H_
22
#define LIB_DIALECT_MODARITH_IR_MODARITHOPS_H_
33

4+
#include "lib/Dialect/RNS/IR/RNSTypes.h"
5+
46
// NOLINTBEGIN(misc-include-cleaner): Required to define ModArithOps
57
#include "lib/Dialect/ModArith/IR/ModArithDialect.h"
68
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"

lib/Dialect/ModArith/IR/ModArithOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef LIB_DIALECT_MODARITH_IR_MODARITHOPS_TD_
22
#define LIB_DIALECT_MODARITH_IR_MODARITHOPS_TD_
33

4+
include "lib/Dialect/RNS/IR/RNSTypes.td"
45
include "lib/Dialect/ModArith/IR/ModArithDialect.td"
56
include "lib/Dialect/ModArith/IR/ModArithTypes.td"
67
include "mlir/IR/BuiltinAttributes.td"
@@ -9,6 +10,7 @@ include "mlir/IR/OpBase.td"
910
include "mlir/Interfaces/InferTypeOpInterface.td"
1011
include "mlir/Interfaces/SideEffectInterfaces.td"
1112

13+
def ModArithOrRNSLike: TypeOrValueSemanticsContainer<AnyTypeOf<[ModArith_ModArithType, RNS]>, "mod_arith_or_rns-like">;
1214

1315
class ModArith_Op<string mnemonic, list<Trait> traits = [Pure]> :
1416
Op<ModArith_Dialect, mnemonic, traits> {
@@ -128,8 +130,8 @@ def ModArith_ReduceOp : ModArith_Op<"reduce", [Pure, ElementwiseMappable, SameOp
128130

129131
class ModArith_BinaryOp<string mnemonic, list<Trait> traits = []> :
130132
ModArith_Op<mnemonic, traits # [SameOperandsAndResultType, Pure, ElementwiseMappable]>,
131-
Arguments<(ins ModArithLike:$lhs, ModArithLike:$rhs)>,
132-
Results<(outs ModArithLike:$output)> {
133+
Arguments<(ins ModArithOrRNSLike:$lhs, ModArithOrRNSLike:$rhs)>,
134+
Results<(outs ModArithOrRNSLike:$output)> {
133135
let hasVerifier = 1;
134136
let assemblyFormat ="operands attr-dict `:` type($output)";
135137
}

lib/Dialect/Polynomial/IR/BUILD

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ td_library(
3838
# include from the heir-root to enable fully-qualified include-paths
3939
includes = ["../../../.."],
4040
deps = [
41+
"@heir//lib/Dialect/ModArith/IR:td_files",
42+
"@llvm-project//mlir:ArithOpsTdFiles",
4143
"@llvm-project//mlir:BuiltinDialectTdFiles",
4244
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
4345
"@llvm-project//mlir:OpBaseTdFiles",
@@ -82,7 +84,6 @@ add_heir_dialect_library(
8284
td_file = "PolynomialOps.td",
8385
deps = [
8486
":td_files",
85-
"@heir//lib/Dialect/ModArith/IR:td_files",
8687
],
8788
)
8889

@@ -93,7 +94,5 @@ gentbl_cc_library(
9394
td_file = "PolynomialCanonicalization.td",
9495
deps = [
9596
":td_files",
96-
"@heir//lib/Dialect/ModArith/IR:td_files",
97-
"@llvm-project//mlir:ArithOpsTdFiles",
9897
],
9998
)

tests/Dialect/ModArith/IR/syntax.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// RUN: heir-opt --mlir-print-local-scope %s | FileCheck %s
22

33
!Zp = !mod_arith.int<17 : i10>
4+
!Zp2 = !mod_arith.int<257 : i10>
45
!Zp_vec = tensor<4x!Zp>
6+
!rns = !rns.rns<!Zp, !Zp2>
57

68
// CHECK: @test_arith_syntax
79
func.func @test_arith_syntax() {
@@ -83,3 +85,11 @@ func.func @test_arith_syntax() {
8385

8486
return
8587
}
88+
89+
// CHECK: @test_rns_syntax
90+
func.func @test_rns_syntax(%arg0: !rns, %arg1: !rns) -> !rns {
91+
// CHECK: mod_arith.add
92+
// CHECK-SAME: rns
93+
%result = mod_arith.add %arg0, %arg1 : !rns
94+
return %result : !rns
95+
}

0 commit comments

Comments
 (0)