Skip to content

demo: add AES demo for tfhe_rust CPU #1765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,3 @@ jobs:
- name: "Run `bazel test`"
run: |
bazel test --noincompatible_strict_action_env --//:enable_openmp=0 -c fastbuild //...

# Tests specifically for the tfhe-rs codegen
- name: rustup toolchain install
uses: dtolnay/rust-toolchain@439cf607258077187679211f12aa6f19af4a0af7 # [email protected]
with:
toolchain: stable

- name: Test rust codegen targets
run: |
bash .github/workflows/run_rust_tests.sh
23 changes: 0 additions & 23 deletions .github/workflows/run_rust_tests.sh

This file was deleted.

2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ repos:
hooks:
- id: codespell
args: ["-L", "crate, fpt"]

exclude: "^(.*[.]lock)"

# Changes tabs to spaces
- repo: https://github.com/Lucas-C/pre-commit-hooks
Expand Down
30 changes: 30 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,33 @@ git_override(
commit = "c84a140c93352cdabbfb547c531be34515b12228",
remote = "https://github.com/google/re2",
)

bazel_dep(name = "rules_rust", version = "0.60.0")

crate = use_extension("@rules_rust//crate_universe:extensions.bzl", "crate")
crate.spec(
features = ["derive"],
package = "clap",
version = "4.1.8",
)
crate.spec(
package = "rayon",
version = "1.6.1",
)
crate.spec(
features = ["derive"],
package = "serde",
version = "1.0.152",
)
crate.spec(
features = [
"boolean",
"shortint",
"integer",
"x86_64-unix",
],
package = "tfhe",
version = "0.5.3",
)
crate.from_specs()
use_repo(crate, "crates")
3,850 changes: 817 additions & 3,033 deletions MODULE.bazel.lock

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions tests/Examples/common/aes/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# common MLIR files used across multiple test directories

package(
default_applicable_licenses = ["@heir//:license"],
)

exports_files(
glob([
"*.mlir",
]),
visibility = ["@heir//tests/Examples:__subpackages__"],
)
13 changes: 13 additions & 0 deletions tests/Examples/common/aes/add_round_key.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// One block is composed of 16 bytes of 8 bits each

// This function adds (xors) the two blocks
#map = affine_map<(d0) -> (d0)>
func.func @add_round_key(%arg0: tensor<16xi8> {secret.secret}, %arg1: tensor<16xi8> {secret.secret}) -> tensor<16xi8> {
%8 = tensor.empty() : tensor<16xi8>
%9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1: tensor<16xi8>, tensor<16xi8>) outs(%8 : tensor<16xi8>) {
^bb0(%in1: i8, %in2: i8, %out: i8):
%res = arith.xori %in1, %in2 : i8
linalg.yield %res : i8
} -> tensor<16xi8>
return %9 : tensor<16xi8>
}
79 changes: 79 additions & 0 deletions tests/Examples/common/aes/inv_mix_columns.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Inverse mix columns step, computing a linear combination of the columns.

// heir-opt -- tests/Examples/common/aes/inv_mix_columns.mlir --full-loop-unroll --inline --symbol-dce --tosa-to-boolean-tfhe

#map = affine_map<(d0) -> (d0)>

#mapA = affine_map<(d0, d1) -> (d0, d1)> // output of 2
#map1A = affine_map<(d0, d1) -> (d1)> // output dim 1
#map2A = affine_map<(d0, d1) -> (d0)> // output dim 1
module {
func.func private @mul_gf256_4(%x: i8 {secret.secret}, %y: i8) -> (i8) {
%z = arith.constant 0 : i8
%c0 = arith.constant 0 : i8
%c1 = arith.constant 1 : i8
%2:2 = affine.for %4 = 0 to 4 iter_args(%5 = %x, %6 = %z) -> (i8, i8) {
// z = if y & (1 << i): z^x else z
%44 = arith.index_cast %4 : index to i8
%ysh = arith.shrsi %y, %44 : i8
%ysha = arith.andi %ysh, %c1 : i8
%91 = arith.trunci %ysha : i8 to i1
%10 = arith.xori %6, %5 : i8
%argz = arith.select %91, %10, %6 : i8
// check if MSB of x is set
%14 = arith.cmpi slt, %5, %c0 : i8
%11 = arith.constant 1 : i8
%12 = arith.shli %5, %11 : i8
%17 = arith.constant 27 : i8
%122 = arith.xori %12, %17 : i8
%argxx = arith.select %14, %122, %12 : i8
affine.yield %argxx, %argz : i8, i8
}
func.return %2#1 : i8
}

func.func private @inv_mix_single_column(%arg0: tensor<4xi8> {secret.secret}) -> tensor<4xi8> {
%c0 = arith.constant 0 : i8
%valA = arith.constant dense<[[0x0e, 0x0b, 0x0d, 0x09], [0x09, 0x0e, 0x0b, 0x0d], [0x0d, 0x09, 0x0e, 0x0b], [0x0b, 0x0d, 0x09, 0x0e]]> : tensor<4x4xi8>
%0 = tensor.splat %c0: tensor<4xi8>
%1 = linalg.generic {indexing_maps = [#mapA, #map1A, #map2A], iterator_types = ["parallel", "reduction"]} ins(%valA, %arg0 : tensor<4x4xi8>, tensor<4xi8>) outs(%0 : tensor<4xi8>) {
^bb0(%in: i8, %in_0: i8, %out: i8):
%1 = func.call @mul_gf256_4(%in_0, %in) : (i8, i8) -> i8
%2 = arith.xori %out, %1 : i8
linalg.yield %2 : i8
} -> tensor<4xi8>
return %1 : tensor<4xi8>
}

func.func @inv_mix_columns(%arg0: tensor<16xi8> {secret.secret}) -> tensor<16xi8> {
%out = tensor.empty() : tensor<16xi8>
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%mix = affine.for %i = 0 to 4 iter_args(%arg1 = %out) -> (tensor<16xi8>) {
%index = arith.muli %i, %c4 : index
%index1 = arith.addi %index, %c1 : index
%index2 = arith.addi %index1, %c1 : index
%index3 = arith.addi %index2, %c1 : index
// extract_slice bufferizes to reinterpret_cast
%extracted1 = tensor.extract %arg0[%index] : tensor<16xi8>
%extracted2 = tensor.extract %arg0[%index1] : tensor<16xi8>
%extracted3 = tensor.extract %arg0[%index2] : tensor<16xi8>
%extracted4 = tensor.extract %arg0[%index3] : tensor<16xi8>
%extracted = tensor.from_elements %extracted1, %extracted2, %extracted3, %extracted4 : tensor<4xi8>
%mixed = func.call @inv_mix_single_column(%extracted) : (tensor<4xi8>) -> tensor<4xi8>
%emixed = tensor.extract %mixed[%c0] : tensor<4xi8>
%emixed1 = tensor.extract %mixed[%c1] : tensor<4xi8>
%emixed2 = tensor.extract %mixed[%c2] : tensor<4xi8>
%emixed3 = tensor.extract %mixed[%c3] : tensor<4xi8>
%inserted = tensor.insert %emixed into %arg1[%index] : tensor<16xi8>
%inserted1 = tensor.insert %emixed1 into %inserted[%index1] : tensor<16xi8>
%inserted2 = tensor.insert %emixed2 into %inserted1[%index2] : tensor<16xi8>
%inserted3 = tensor.insert %emixed3 into %inserted2[%index3] : tensor<16xi8>
affine.yield %inserted3 : tensor<16xi8>
}
return %mix : tensor<16xi8>
}
}
15 changes: 15 additions & 0 deletions tests/Examples/common/aes/inv_shift_rows.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Inverse Shift rows function, shifting the 4 rows by 0, 1, 2, and 3 steps to the left.

#map = affine_map<(d0) -> (d0)>
module {
func.func @inv_shift_rows(%arg0: tensor<16xi8> {secret.secret}) -> tensor<16xi8> {
%0 = tensor.empty() : tensor<16xi8>
%indices = arith.constant dense<[0, 1, 2, 3, 7, 4, 5, 6, 10, 11, 8, 9, 13, 14, 15, 12]> : tensor<16xindex>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%indices : tensor<16xindex>) outs(%0 : tensor<16xi8>) {
^bb0(%in: index, %out: i8):
%extracted = tensor.extract %arg0[%in] : tensor<16xi8>
linalg.yield %extracted : i8
} -> tensor<16xi8>
return %1 : tensor<16xi8>
}
}
16 changes: 16 additions & 0 deletions tests/Examples/common/aes/inv_sub_bytes.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Inverse Byte substitution function.

#map = affine_map<(d0) -> (d0)>
module {
func.func @inv_sub_bytes(%arg0: tensor<16xi8> {secret.secret}) -> tensor<16xi8> {
%cst = arith.constant dense<"0x52096ad53036a538bf40a39e81f3d7fb7ce339829b2fff87348e4344c4dee9cb547b9432a6c2233dee4c950b42fac34e082ea16628d924b2765ba2496d8bd12572f8f66486689816d4a45ccc5d65b6926c704850fdedb9da5e154657a78d9d8490d8ab008cbcd30af7e45805b8b34506d02c1e8fca3f0f02c1afbd0301138a6b3a9111414f67dcea97f2cfcef0b4e67396ac7422e7ad3585e2f937e81c75df6e47f11a711d29c5896fb7620eaa18be1bfc563e4bc6d279209adbc0fe78cd5af41fdda8338807c731b11210592780ec5f60517fa919b54a0d2de57a9f93c99cefa0e03b4dae2af5b0c8ebbb3c83539961172b047eba77d626e169146355210c7d"> : tensor<256xi8>
%0 = tensor.empty() : tensor<16xi8>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<16xi8>) outs(%0 : tensor<16xi8>) {
^bb0(%in: i8, %out: i8):
%2 = arith.index_cast %in : i8 to index
%extracted = tensor.extract %cst[%2] : tensor<256xi8>
linalg.yield %extracted : i8
} -> tensor<16xi8>
return %1 : tensor<16xi8>
}
}
80 changes: 80 additions & 0 deletions tests/Examples/common/aes/mix_columns.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Mix columns step, computing a linear combination of the columns.

#map = affine_map<(d0) -> (d0)>

#mapA = affine_map<(d0, d1) -> (d0, d1)> // output of 2
#map1A = affine_map<(d0, d1) -> (d1)> // output dim 1
#map2A = affine_map<(d0, d1) -> (d0)> // output dim 1
module {
func.func private @mul_gf256_2(%x: i8 {secret.secret}, %y: i8) -> (i8) {
%z = arith.constant 0 : i8
%c0 = arith.constant 0 : i8
%c1 = arith.constant 1 : i8
%2:2 = affine.for %4 = 0 to 2 iter_args(%5 = %x, %6 = %z) -> (i8, i8) {
// z = if y & (1 << i): z^x else z
%44 = arith.index_cast %4 : index to i8
%ysh = arith.shrsi %y, %44 : i8
%ysha = arith.andi %ysh, %c1 : i8
%91 = arith.trunci %ysha : i8 to i1
%10 = arith.xori %6, %5 : i8
%argz = arith.select %91, %10, %6 : i8
// check if MSB of x is set
%14 = arith.cmpi sle, %5, %c0 : i8
%11 = arith.constant 1 : i8
%12 = arith.shli %5, %11 : i8
%17 = arith.constant 27 : i8
%122 = arith.xori %12, %17 : i8
%argxx = arith.select %14, %122, %12 : i8
affine.yield %argxx, %argz : i8, i8
}
func.return %2#1 : i8
}

func.func private @mix_single_column(%arg0: tensor<4xi8> {secret.secret}) -> tensor<4xi8> {
// A 4x4 matrix with [2, 3, 1, 1], [1, 2, 3, 1], [1, 1, 2, 3], [3, 1, 1, 2]
// use gf256 multiplication with the vector
// use XOR to add
%c0 = arith.constant 0 : i8
%valA = arith.constant dense<[[2, 3, 1, 1], [1, 2, 3, 1], [1, 1, 2, 3], [3, 1, 1, 2]]> : tensor<4x4xi8>
%0 = tensor.splat %c0: tensor<4xi8>
%1 = linalg.generic {indexing_maps = [#mapA, #map1A, #map2A], iterator_types = ["parallel", "reduction"]} ins(%valA, %arg0 : tensor<4x4xi8>, tensor<4xi8>) outs(%0 : tensor<4xi8>) {
^bb0(%in: i8, %in_0: i8, %out: i8):
%1 = func.call @mul_gf256_2(%in_0, %in) : (i8, i8) -> i8
%2 = arith.xori %out, %1 : i8
linalg.yield %2 : i8
} -> tensor<4xi8>
return %1 : tensor<4xi8>
}

func.func @mix_columns(%arg0: tensor<16xi8> {secret.secret}) -> tensor<16xi8> {
%out = tensor.empty() : tensor<16xi8>
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%mix = affine.for %i = 0 to 4 iter_args(%arg1 = %out) -> (tensor<16xi8>) {
%index = arith.muli %i, %c4 : index
%index1 = arith.addi %index, %c1 : index
%index2 = arith.addi %index1, %c1 : index
%index3 = arith.addi %index2, %c1 : index
// extract_slice bufferizes to reinterpret_cast
%extracted1 = tensor.extract %arg0[%index] : tensor<16xi8>
%extracted2 = tensor.extract %arg0[%index1] : tensor<16xi8>
%extracted3 = tensor.extract %arg0[%index2] : tensor<16xi8>
%extracted4 = tensor.extract %arg0[%index3] : tensor<16xi8>
%extracted = tensor.from_elements %extracted1, %extracted2, %extracted3, %extracted4 : tensor<4xi8>
%mixed = func.call @mix_single_column(%extracted) : (tensor<4xi8>) -> tensor<4xi8>
%emixed = tensor.extract %mixed[%c0] : tensor<4xi8>
%emixed1 = tensor.extract %mixed[%c1] : tensor<4xi8>
%emixed2 = tensor.extract %mixed[%c2] : tensor<4xi8>
%emixed3 = tensor.extract %mixed[%c3] : tensor<4xi8>
%inserted = tensor.insert %emixed into %arg1[%index] : tensor<16xi8>
%inserted1 = tensor.insert %emixed1 into %inserted[%index1] : tensor<16xi8>
%inserted2 = tensor.insert %emixed2 into %inserted1[%index2] : tensor<16xi8>
%inserted3 = tensor.insert %emixed3 into %inserted2[%index3] : tensor<16xi8>
affine.yield %inserted3 : tensor<16xi8>
}
return %mix : tensor<16xi8>
}
}
15 changes: 15 additions & 0 deletions tests/Examples/common/aes/shift_rows.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Shift rows function, shifting the 4 rows by 0, 1, 2, and 3 steps to the left.

#map = affine_map<(d0) -> (d0)>
module {
func.func @shift_rows(%arg0: tensor<16xi8> {secret.secret}) -> tensor<16xi8> {
%0 = tensor.empty() : tensor<16xi8>
%indices = arith.constant dense<[0, 1, 2, 3, 5, 6, 7, 4, 10, 11, 8, 9, 15, 12, 13, 14]> : tensor<16xindex>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%indices : tensor<16xindex>) outs(%0 : tensor<16xi8>) {
^bb0(%in: index, %out: i8):
%extracted = tensor.extract %arg0[%in] : tensor<16xi8>
linalg.yield %extracted : i8
} -> tensor<16xi8>
return %1 : tensor<16xi8>
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
// Single sbox table lookup for AES encryption.
//
// RUN: heir-opt --mlir-to-cggi --scheme-to-tfhe-rs %s | heir-translate --emit-tfhe-rust --use-levels > %S/src/fn_under_test.rs
// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_sbox | FileCheck %s

// CHECK: 637C
// CHECK: BB16
// Byte substitution function.

#map = affine_map<(d0) -> (d0)>
module {
func.func @sub_bytes(%arg0: i8 {secret.secret}) -> i8 {
func.func @sub_bytes(%arg0: tensor<16xi8> {secret.secret}) -> tensor<16xi8> {
%cst = arith.constant dense<"0x637C777BF26B6FC53001672BFED7AB76CA82C97DFA5947F0ADD4A2AF9CA472C0B7FD9326363FF7CC34A5E5F171D8311504C723C31896059A071280E2EB27B27509832C1A1B6E5AA0523BD6B329E32F8453D100ED20FCB15B6ACBBE394A4C58CFD0EFAAFB434D338545F9027F503C9FA851A3408F929D38F5BCB6DA2110FFF3D2CD0C13EC5F974417C4A77E3D645D197360814FDC222A908846EEB814DE5E0BDBE0323A0A4906245CC2D3AC629195E479E7C8376D8DD54EA96C56F4EA657AAE08BA78252E1CA6B4C6E8DD741F4BBD8B8A703EB5664803F60E613557B986C11D9EE1F8981169D98E949B1E87E9CE5528DF8CA1890DBFE6426841992D0FB054BB16"> : tensor<256xi8>
%2 = arith.index_cast %arg0 : i8 to index
%extracted = tensor.extract %cst[%2] : tensor<256xi8>
return %extracted : i8
%0 = tensor.empty() : tensor<16xi8>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<16xi8>) outs(%0 : tensor<16xi8>) {
^bb0(%in: i8, %out: i8):
%2 = arith.index_cast %in : i8 to index
%extracted = tensor.extract %cst[%2] : tensor<256xi8>
linalg.yield %extracted : i8
} -> tensor<16xi8>
return %1 : tensor<16xi8>
}
}
3 changes: 1 addition & 2 deletions tests/Examples/openfhe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ These tests exercise OpenFHE codegen for the
[OpenFHE](https://github.com/openfheorg/openfhe-development) backend library,
including compiling the generated C++ source and running the resulting binary.

OpenFHE is added as a project-level dependency (unlike the `tfhe-rs` end-to-end
tests) and built from source.
OpenFHE is added as a project-level dependency and built from source.
Loading