Skip to content

Support trivial encryption in secret dialect and noise analysis #1887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions lib/Analysis/NoiseAnalysis/BFV/NoiseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,21 @@ LogicalResult NoiseAnalysis<NoiseModel>::visitOperation(

auto res =
llvm::TypeSwitch<Operation &, LogicalResult>(*op)
.template Case<secret::RevealOp, secret::ConcealOp>([&](auto op) {
// Reveal outputs are not secret, so no noise. Conceal outputs are
// a fresh encryption. Both are handled properly by setToEntryState
// based on the type of the result.
//
// TODO(#1875): support trivial encryptions which have zero noise.
.template Case<secret::RevealOp>([&](auto revealOp) {
// Reveal outputs are not secret, so no noise.
for (auto result : results) {
setToEntryState(result);
}
return success();
})
.template Case<secret::ConcealOp>([&](auto concealOp) {
// Conceal outputs have the noise of a fresh encryption, unless
// they are trivial encryptions, in which case there is zero noise.
if (concealOp.getTrivial()) {
propagate(concealOp.getResult(), NoiseState::of(0.0));
return success();
}

for (auto result : results) {
setToEntryState(result);
}
Expand Down
21 changes: 15 additions & 6 deletions lib/Analysis/NoiseAnalysis/BGV/NoiseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,21 @@ LogicalResult NoiseAnalysis<NoiseModel>::visitOperation(

auto res =
llvm::TypeSwitch<Operation &, LogicalResult>(*op)
.template Case<secret::RevealOp, secret::ConcealOp>([&](auto op) {
// Reveal outputs are not secret, so no noise. Conceal outputs are
// a fresh encryption. Both are handled properly by setToEntryState
// based on the type of the result.
//
// TODO(#1875): support trivial encryptions which have zero noise.
.template Case<secret::RevealOp>([&](auto revealOp) {
// Reveal outputs are not secret, so no noise.
for (auto result : results) {
setToEntryState(result);
}
return success();
})
.template Case<secret::ConcealOp>([&](auto concealOp) {
// Conceal outputs have the noise of a fresh encryption, unless
// they are trivial encryptions, in which case there is zero noise.
if (concealOp.getTrivial()) {
propagate(concealOp.getResult(), NoiseState::of(0.0));
return success();
}

for (auto result : results) {
setToEntryState(result);
}
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/LWE/IR/LWEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ LogicalResult ReinterpretApplicationDataOp::verify() {
return success();
}

// Verification for RLWE_EncryptOp
LogicalResult RLWEEncryptOp::verify() {
Type keyType = getKey().getType();
auto keyRing =
Expand Down
11 changes: 11 additions & 0 deletions lib/Dialect/LWE/IR/LWEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,17 @@ def LWE_RLWEEncryptOp : LWE_Op<"rlwe_encrypt", [
let hasVerifier = 1;
}

def LWE_RLWETrivialEncryptOp : LWE_Op<"rlwe_trivial_encrypt", [
NewEncodingsMatch<"input", "NewLWEPlaintextType", "output", "NewLWECiphertextType">]> {
let summary = "Trivially encrypt an RLWE plaintext to a RLWE ciphertext";
let description = [{
The ciphertext is valid for any secret key.
}];

let arguments = (ins NewLWEPlaintext:$input);
let results = (outs NewLWECiphertext:$output);
}

def LWE_RLWEDecryptOp : LWE_Op<"rlwe_decrypt", [
NewEncodingsMatch<"input", "NewLWECiphertextType", "output", "NewLWEPlaintextType">,
KeyAndCiphertextMatch<"secret_key", "NewLWESecretKeyType", "input", "NewLWECiphertextType">]> {
Expand Down
46 changes: 39 additions & 7 deletions lib/Dialect/Secret/Conversions/Patterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ LogicalResult ConvertClientConceal::matchAndRewrite(
ContextAwareConversionPatternRewriter &rewriter) const {
func::FuncOp parentFunc = op->getParentOfType<func::FuncOp>();
if (!parentFunc || !parentFunc->hasAttr(kClientEncFuncAttrName)) {
return op->emitError() << "expected to be inside a function with attribute "
<< kClientEncFuncAttrName;
return failure();
}
if (op.getTrivial()) {
return failure();
}

// The encryption func encrypts a single value, so it must have a single
Expand All @@ -51,11 +53,7 @@ LogicalResult ConvertClientConceal::matchAndRewrite(
auto resultCtTy =
dyn_cast<lwe::NewLWECiphertextType>(parentFunc.getResultTypes()[0]);
if (!resultCtTy) {
return parentFunc->emitError()
<< "expected secret.conceal op to be inside a function with a "
"single LWE ciphertext return type; it may be that "
"the type converter failed to run on this func "
"because the mgmt attribute is missing.";
return failure();
}

if (resultCtTy.getCiphertextSpace()
Expand Down Expand Up @@ -96,6 +94,40 @@ LogicalResult ConvertClientConceal::matchAndRewrite(
return success();
}

LogicalResult ConvertTrivialConceal::matchAndRewrite(
secret::ConcealOp op, OpAdaptor adaptor,
ContextAwareConversionPatternRewriter &rewriter) const {
if (!op.getTrivial()) {
return failure();
}

SmallVector<Type> resultTypes;
if (failed(typeConverter->convertType(op.getResult().getType(),
op.getResult(), resultTypes))) {
return op->emitError() << "failed to convert result type for trivial "
"conceal op";
}
lwe::NewLWECiphertextType resultCtTy =
dyn_cast<lwe::NewLWECiphertextType>(resultTypes[0]);
auto plaintextTy = lwe::NewLWEPlaintextType::get(
op.getContext(), resultCtTy.getApplicationData(),
resultCtTy.getPlaintextSpace());
auto encoded = rewriter.create<lwe::RLWEEncodeOp>(
op.getLoc(), plaintextTy, adaptor.getCleartext(),
resultCtTy.getPlaintextSpace().getEncoding(),
resultCtTy.getPlaintextSpace().getRing());

auto encryptOp = rewriter.create<lwe::RLWETrivialEncryptOp>(
op.getLoc(), resultCtTy, encoded.getResult());

// Copy attributes from the original op to preserve any mgmt attrs needed by
// dialect conversion from secret to scheme.
encryptOp->setAttrs(op->getAttrs());

rewriter.replaceOp(op, encryptOp);
return success();
}

LogicalResult ConvertClientReveal::matchAndRewrite(
secret::RevealOp op, OpAdaptor adaptor,
ContextAwareConversionPatternRewriter &rewriter) const {
Expand Down
11 changes: 10 additions & 1 deletion lib/Dialect/Secret/Conversions/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ namespace heir {
// Lower a client encryption function's secret.conceal op to lwe.rlwe_encode +
// lwe.rlwe_encrypt. Modifies the containing function to add new secret key
// material args.
// TODO(#1875): support trivial encryptions
struct ConvertClientConceal
: public ContextAwareOpConversionPattern<secret::ConcealOp> {
ConvertClientConceal(const ContextAwareTypeConverter &typeConverter,
Expand All @@ -36,6 +35,16 @@ struct ConvertClientConceal
polynomial::RingAttr ring;
};

// Lower a trivial secret.conceal .
struct ConvertTrivialConceal
: public ContextAwareOpConversionPattern<secret::ConcealOp> {
using ContextAwareOpConversionPattern::ContextAwareOpConversionPattern;

LogicalResult matchAndRewrite(
secret::ConcealOp op, OpAdaptor adaptor,
ContextAwareConversionPatternRewriter &rewriter) const override;
};

// Lower a client decryption function's secret.reveal op to lwe.rlwe_decrypt +
// lwe.rlwe_decode. Modifies the containing function to add new secret key
// material args.
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
SecretGenericOpCipherPlainConversion<arith::AddIOp, bgv::AddPlainOp>,
SecretGenericOpCipherPlainConversion<arith::SubIOp, bgv::SubPlainOp>,
SecretGenericOpCipherPlainConversion<arith::MulIOp, bgv::MulPlainOp>,
SecretGenericFuncCallConversion>(typeConverter, context);
SecretGenericFuncCallConversion, ConvertTrivialConceal>(typeConverter,
context);

patterns.add<ConvertClientConceal>(typeConverter, context, usePublicKey,
rlweRing.value());
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ struct SecretToCKKS : public impl::SecretToCKKSBase<SecretToCKKS> {
ConvertAnyContextAware<tensor::ExtractSliceOp>,
ConvertAnyContextAware<tensor::ExtractOp>,
ConvertAnyContextAware<tensor::InsertOp>,
SecretGenericFuncCallConversion>(typeConverter, context);
SecretGenericFuncCallConversion, ConvertTrivialConceal>(typeConverter,
context);

patterns.add<ConvertClientConceal>(typeConverter, context, usePublicKey,
rlweRing.value());
Expand Down
11 changes: 10 additions & 1 deletion lib/Dialect/Secret/IR/SecretOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,18 @@ LogicalResult GenericOp::verify() {
}

void ConcealOp::build(OpBuilder &builder, OperationState &result,
Value cleartextValue) {
Value cleartextValue, std::optional<UnitAttr> trivial) {
Type resultType = SecretType::get(cleartextValue.getType());
build(builder, result, resultType, cleartextValue);
if (trivial.has_value()) {
result.addAttribute("trivial", builder.getUnitAttr());
}
}

void ConcealOp::build(OpBuilder &builder, OperationState &result,
Value cleartextValue) {
build(builder, result, cleartextValue.getType(), cleartextValue,
std::nullopt);
}

void RevealOp::build(OpBuilder &builder, OperationState &result,
Expand Down
13 changes: 10 additions & 3 deletions lib/Dialect/Secret/IR/SecretOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,27 @@ def Secret_ConcealOp : Secret_Op<"conceal", [Pure]> {
op is also useful for type materialization in the dialect conversion
framework.

If the trivial attribute is set, the conceal operation corresponds to a
"trivial" encryption for backends that support it. In particular, most LWE
schemes support encrypting a value with respect to any secret key, by
choosing a zero-valued LWE sample and zero noise.

Examples:

```mlir
%Y = secret.conceal %value : i32 -> !secret.secret<i32>
%0 = secret.conceal %value : i32 -> !secret.secret<i32>
%1 = secret.conceal %value {trivial} : i32 -> !secret.secret<i32>
```
}];

let arguments = (ins AnyType:$cleartext);
let arguments = (ins AnyType:$cleartext, UnitAttr:$trivial);
let results = (outs Secret:$output);
let assemblyFormat = "$cleartext attr-dict `:` qualified(type($cleartext)) `->` qualified(type($output))";

let builders = [
// Builder to infer output type from the input type
OpBuilder<(ins "Value":$cleartext)>
OpBuilder<(ins "Value":$cleartext, "std::optional<UnitAttr>":$trivial)>,
OpBuilder<(ins "Value":$cleartext)>,
];
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: heir-opt --secret-to-bgv %s | FileCheck %s

!ct = !secret.secret<tensor<1024xi1>>
#mgmt = #mgmt.mgmt<level = 0, dimension = 2>

module attributes {bgv.schemeParam = #bgv.scheme_param<logN = 14, Q = [67239937, 17179967489, 17180262401, 17180295169, 17180393473, 70368744210433], P = [70368744570881, 70368744701953], plaintextModulus = 65537>} {
// CHECK: func @test_arith_ops
func.func @test_arith_ops(%arg0 : tensor<1024xi1>) -> (!ct {mgmt.mgmt = #mgmt}) {
// CHECK: lwe.rlwe_encode
// CHECK: lwe.rlwe_trivial_encrypt
%0 = secret.conceal %arg0 {trivial, mgmt.mgmt = #mgmt} : tensor<1024xi1> -> !ct
// CHECK: return
return %0 : !ct
}
}
10 changes: 8 additions & 2 deletions tests/Dialect/Secret/IR/syntax.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: heir-opt %s > %t
// RUN: FileCheck %s < %t
// RUN: heir-opt %s | FileCheck %s

// This simply tests for syntax.

Expand All @@ -18,4 +17,11 @@ module {
} -> (!secret.secret<memref<1x16xi8>>)
func.return %Z : !secret.secret<memref<1x16xi8>>
}

// CHECK: conceal_trivial
func.func @conceal_trivial() -> !secret.secret<i8> {
%c7 = arith.constant 7 : i8
%0 = secret.conceal %c7 {trivial} : i8 -> !secret.secret<i8>
func.return %0 : !secret.secret<i8>
}
}
Loading