Skip to content

Commit 5dffa8c

Browse files
Donghang Lufacebook-github-bot
authored andcommitted
Implement MPC-AES decryption circuit (#447)
Summary: Pull Request resolved: #447 Differential Revision: D41143077 fbshipit-source-id: e7d7e8f8054e77261baafeb65307605eb9155088
1 parent 156a113 commit 5dffa8c

File tree

3 files changed

+164
-4
lines changed

3 files changed

+164
-4
lines changed

fbpcf/mpc_std_lib/aes_circuit/AesCircuit.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class AesCircuit : public IAesCircuit<BitType> {
4949
void inverseMixColumnsInPlace(WordType& src) const;
5050

5151
void shiftRowInPlace(std::array<WordType, 4>& src) const;
52-
52+
void inverseShiftRowInPlace(std::array<WordType, 4>& src) const;
5353
#ifdef AES_CIRCUIT_TEST_FRIENDS
5454
AES_CIRCUIT_TEST_FRIENDS;
5555
#endif

fbpcf/mpc_std_lib/aes_circuit/AesCircuit_impl.h

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,65 @@ std::vector<BitType> AesCircuit<BitType>::encrypt_impl(
7878
return convertFromWords(plaintextBlocks);
7979
}
8080

81+
// implementation based on https://engineering.purdue.edu/kak/compsec/NewLectures/Lecture8.pdf
8182
template <typename BitType>
8283
std::vector<BitType> AesCircuit<BitType>::decrypt_impl(
83-
const std::vector<BitType>& /* ciphertext */,
84-
const std::vector<BitType>& /* expandedDecKey */) const {
85-
throw std::runtime_error("Not implemented!");
84+
const std::vector<BitType>& ciphertext,
85+
const std::vector<BitType>& expandedDecKey) const {
86+
// prepare input
87+
auto ciphertextBlocks = convertToWords(ciphertext);
88+
auto roundKeys = convertToWords(expandedDecKey);
89+
size_t blockNo = ciphertextBlocks.size();
90+
91+
int round = 10;
92+
// pre-round
93+
for (int block = 0; block < blockNo; ++block) {
94+
for (int word = 0; word < 4; ++word) {
95+
for (int byte = 0; byte < 4; ++byte) {
96+
for (int bit = 0; bit < 8; ++bit) {
97+
ciphertextBlocks[block][word][byte][bit] =
98+
ciphertextBlocks[block][word][byte][bit] ^
99+
roundKeys[round][word][byte][bit];
100+
}
101+
}
102+
}
103+
}
104+
// rounds 1 - 10
105+
for (int round = 9; round >= 0; --round) {
106+
// InverseShiftRows
107+
for (int block = 0; block < blockNo; ++block) {
108+
inverseShiftRowInPlace(ciphertextBlocks[block]);
109+
}
110+
// InverseSbox
111+
for (int block = 0; block < blockNo; ++block) {
112+
for (int word = 0; word < 4; ++word) {
113+
for (int byte = 0; byte < 4; ++byte) {
114+
inverseSBoxInPlace(ciphertextBlocks[block][word][byte]);
115+
}
116+
}
117+
}
118+
// AddRoundKey
119+
for (int block = 0; block < blockNo; ++block) {
120+
for (int word = 0; word < 4; ++word) {
121+
for (int byte = 0; byte < 4; ++byte) {
122+
for (int bit = 0; bit < 8; ++bit) {
123+
ciphertextBlocks[block][word][byte][bit] =
124+
ciphertextBlocks[block][word][byte][bit] ^
125+
roundKeys[round][word][byte][bit];
126+
}
127+
}
128+
}
129+
}
130+
// InverseMixColumns except for 10-th Round
131+
if (round != 0) {
132+
for (int block = 0; block < blockNo; ++block) {
133+
for (int word = 0; word < 4; ++word) {
134+
inverseMixColumnsInPlace(ciphertextBlocks[block][word]);
135+
}
136+
}
137+
}
138+
}
139+
return convertFromWords(ciphertextBlocks);
86140
}
87141

88142
template <typename BitType>
@@ -494,4 +548,23 @@ void AesCircuit<BitType>::shiftRowInPlace(std::array<WordType, 4>& src) const {
494548
std::swap(src[1][row], src[0][row]);
495549
}
496550

551+
template <typename BitType>
552+
void AesCircuit<BitType>::inverseShiftRowInPlace(
553+
std::array<WordType, 4>& src) const {
554+
// 1st row is not shifted, 2nd row shifted right by 1
555+
int row = 1;
556+
std::swap(src[2][row], src[3][row]);
557+
std::swap(src[1][row], src[2][row]);
558+
std::swap(src[0][row], src[1][row]);
559+
// 3rd row shifted right by 2
560+
row++;
561+
std::swap(src[0][row], src[2][row]);
562+
std::swap(src[1][row], src[3][row]);
563+
// 4th row shifted right by 3
564+
row++;
565+
std::swap(src[0][row], src[1][row]);
566+
std::swap(src[1][row], src[2][row]);
567+
std::swap(src[2][row], src[3][row]);
568+
}
569+
497570
} // namespace fbpcf::mpc_std_lib::aes_circuit

fbpcf/mpc_std_lib/aes_circuit/test/AesCircuitTest.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,28 @@ class AesCircuitTests : public AesCircuit<BitType> {
6262
}
6363
}
6464

65+
void testInverseShiftRowInPlace(std::vector<bool> plaintext) {
66+
std::array<std::array<std::array<bool, 8>, 4>, 4> block;
67+
for (int k = 0; k < 4; ++k) {
68+
for (int i = 0; i < 4; i++) {
69+
for (int j = 0; j < 8; j++) {
70+
block[k][i][j] = plaintext[32 * k + 8 * i + j];
71+
}
72+
}
73+
}
74+
75+
AesCircuit<bool>::inverseShiftRowInPlace(block);
76+
for (int k = 0; k < 4; ++k) {
77+
for (int i = 0; i < 4; i++) {
78+
for (int j = 0; j < 8; j++) {
79+
EXPECT_EQ(
80+
block[k][i][j],
81+
plaintext[32 * ((((k - i) % 4) + 4) % 4) + 8 * i + j]);
82+
}
83+
}
84+
}
85+
}
86+
6587
void testWordConversion() {
6688
using ByteType = std::array<bool, 8>;
6789
using WordType = std::array<ByteType, 4>;
@@ -159,6 +181,12 @@ TEST(AesCircuitTest, testShiftRowInPlace) {
159181
test.testShiftRowInPlace(plaintext);
160182
}
161183

184+
TEST(AesCircuitTest, testInverseShiftRowInPlace) {
185+
auto plaintext = generateRandomPlaintext();
186+
AesCircuitTests<bool> test;
187+
test.testInverseShiftRowInPlace(plaintext);
188+
}
189+
162190
TEST(AesCircuitTest, testWordConversion) {
163191
AesCircuitTests<bool> test;
164192
test.testWordConversion();
@@ -352,6 +380,65 @@ TEST(AesCircuitTest, testAesCircuitEncrypt) {
352380
testAesCircuitEncrypt(std::make_unique<AesCircuitFactory<bool>>());
353381
}
354382

383+
void testAesCircuitDecrypt(
384+
std::shared_ptr<AesCircuitFactory<bool>> AesCircuitFactory) {
385+
auto AesCircuit = AesCircuitFactory->create();
386+
387+
std::random_device rd;
388+
std::mt19937_64 e(rd());
389+
std::uniform_int_distribution<uint8_t> dist(0, 0xFF);
390+
size_t blockNo = dist(e);
391+
392+
// generate random key
393+
__m128i key = _mm_set_epi32(dist(e), dist(e), dist(e), dist(e));
394+
// generate random plaintext
395+
std::vector<uint8_t> plaintext;
396+
plaintext.reserve(blockNo * 16);
397+
for (int i = 0; i < blockNo * 16; ++i) {
398+
plaintext.push_back(dist(e));
399+
}
400+
std::vector<__m128i> plaintextAES;
401+
loadValueToLocalAes(plaintext, plaintextAES);
402+
403+
// expand key
404+
engine::util::Aes truthAes(key);
405+
auto expandedKey = truthAes.expandEncryptionKey(key);
406+
// extract key and plaintext
407+
std::vector<uint8_t> extractedKeys;
408+
extractedKeys.reserve(176);
409+
for (auto keyb : expandedKey) {
410+
loadValueFromLocalAes(keyb, extractedKeys);
411+
}
412+
413+
// convert key and plaintext into bool vector
414+
std::vector<bool> keyBits;
415+
keyBits.reserve(1408);
416+
int8VecToBinaryVec(extractedKeys, keyBits);
417+
std::vector<bool> plaintextBits;
418+
plaintextBits.reserve(blockNo * 128);
419+
int8VecToBinaryVec(plaintext, plaintextBits);
420+
421+
// encrypt in real aes
422+
truthAes.encryptInPlace(plaintextAES);
423+
424+
// extract ciphertext in real aes
425+
std::vector<uint8_t> ciphertextTruth;
426+
ciphertextTruth.reserve(blockNo * 16);
427+
for (auto b : plaintextAES) {
428+
loadValueFromLocalAes(b, ciphertextTruth);
429+
}
430+
std::vector<bool> cipherextBitsTruth;
431+
cipherextBitsTruth.reserve(blockNo * 128);
432+
int8VecToBinaryVec(ciphertextTruth, cipherextBitsTruth);
433+
// decrypt this ciphertext using our decrypt circuit
434+
auto decryptionBits = AesCircuit->decrypt(cipherextBitsTruth, keyBits);
435+
testVectorEq(decryptionBits, plaintextBits);
436+
}
437+
438+
TEST(AesCircuitTest, testAesCircuitDecrypt) {
439+
testAesCircuitDecrypt(std::make_unique<AesCircuitFactory<bool>>());
440+
}
441+
355442
void testAesCircuitCtr(
356443
std::shared_ptr<AesCircuitCtrFactory<bool>> AesCircuitCtrFactory) {
357444
auto AesCircuitCtr = AesCircuitCtrFactory->create();

0 commit comments

Comments
 (0)