@@ -62,6 +62,28 @@ class AesCircuitTests : public AesCircuit<BitType> {
62
62
}
63
63
}
64
64
65
+ void testInverseShiftRowInPlace (std::vector<bool > plaintext) {
66
+ std::array<std::array<std::array<bool , 8 >, 4 >, 4 > block;
67
+ for (int k = 0 ; k < 4 ; ++k) {
68
+ for (int i = 0 ; i < 4 ; i++) {
69
+ for (int j = 0 ; j < 8 ; j++) {
70
+ block[k][i][j] = plaintext[32 * k + 8 * i + j];
71
+ }
72
+ }
73
+ }
74
+
75
+ AesCircuit<bool >::inverseShiftRowInPlace (block);
76
+ for (int k = 0 ; k < 4 ; ++k) {
77
+ for (int i = 0 ; i < 4 ; i++) {
78
+ for (int j = 0 ; j < 8 ; j++) {
79
+ EXPECT_EQ (
80
+ block[k][i][j],
81
+ plaintext[32 * ((((k - i) % 4 ) + 4 ) % 4 ) + 8 * i + j]);
82
+ }
83
+ }
84
+ }
85
+ }
86
+
65
87
void testWordConversion () {
66
88
using ByteType = std::array<bool , 8 >;
67
89
using WordType = std::array<ByteType, 4 >;
@@ -159,6 +181,12 @@ TEST(AesCircuitTest, testShiftRowInPlace) {
159
181
test.testShiftRowInPlace (plaintext);
160
182
}
161
183
184
+ TEST (AesCircuitTest, testInverseShiftRowInPlace) {
185
+ auto plaintext = generateRandomPlaintext ();
186
+ AesCircuitTests<bool > test;
187
+ test.testInverseShiftRowInPlace (plaintext);
188
+ }
189
+
162
190
TEST (AesCircuitTest, testWordConversion) {
163
191
AesCircuitTests<bool > test;
164
192
test.testWordConversion ();
@@ -352,6 +380,65 @@ TEST(AesCircuitTest, testAesCircuitEncrypt) {
352
380
testAesCircuitEncrypt (std::make_unique<AesCircuitFactory<bool >>());
353
381
}
354
382
383
+ void testAesCircuitDecrypt (
384
+ std::shared_ptr<AesCircuitFactory<bool >> AesCircuitFactory) {
385
+ auto AesCircuit = AesCircuitFactory->create ();
386
+
387
+ std::random_device rd;
388
+ std::mt19937_64 e (rd ());
389
+ std::uniform_int_distribution<uint8_t > dist (0 , 0xFF );
390
+ size_t blockNo = dist (e);
391
+
392
+ // generate random key
393
+ __m128i key = _mm_set_epi32 (dist (e), dist (e), dist (e), dist (e));
394
+ // generate random plaintext
395
+ std::vector<uint8_t > plaintext;
396
+ plaintext.reserve (blockNo * 16 );
397
+ for (int i = 0 ; i < blockNo * 16 ; ++i) {
398
+ plaintext.push_back (dist (e));
399
+ }
400
+ std::vector<__m128i> plaintextAES;
401
+ loadValueToLocalAes (plaintext, plaintextAES);
402
+
403
+ // expand key
404
+ engine::util::Aes truthAes (key);
405
+ auto expandedKey = truthAes.expandEncryptionKey (key);
406
+ // extract key and plaintext
407
+ std::vector<uint8_t > extractedKeys;
408
+ extractedKeys.reserve (176 );
409
+ for (auto keyb : expandedKey) {
410
+ loadValueFromLocalAes (keyb, extractedKeys);
411
+ }
412
+
413
+ // convert key and plaintext into bool vector
414
+ std::vector<bool > keyBits;
415
+ keyBits.reserve (1408 );
416
+ int8VecToBinaryVec (extractedKeys, keyBits);
417
+ std::vector<bool > plaintextBits;
418
+ plaintextBits.reserve (blockNo * 128 );
419
+ int8VecToBinaryVec (plaintext, plaintextBits);
420
+
421
+ // encrypt in real aes
422
+ truthAes.encryptInPlace (plaintextAES);
423
+
424
+ // extract ciphertext in real aes
425
+ std::vector<uint8_t > ciphertextTruth;
426
+ ciphertextTruth.reserve (blockNo * 16 );
427
+ for (auto b : plaintextAES) {
428
+ loadValueFromLocalAes (b, ciphertextTruth);
429
+ }
430
+ std::vector<bool > cipherextBitsTruth;
431
+ cipherextBitsTruth.reserve (blockNo * 128 );
432
+ int8VecToBinaryVec (ciphertextTruth, cipherextBitsTruth);
433
+ // decrypt this ciphertext using our decrypt circuit
434
+ auto decryptionBits = AesCircuit->decrypt (cipherextBitsTruth, keyBits);
435
+ testVectorEq (decryptionBits, plaintextBits);
436
+ }
437
+
438
+ TEST (AesCircuitTest, testAesCircuitDecrypt) {
439
+ testAesCircuitDecrypt (std::make_unique<AesCircuitFactory<bool >>());
440
+ }
441
+
355
442
void testAesCircuitCtr (
356
443
std::shared_ptr<AesCircuitCtrFactory<bool >> AesCircuitCtrFactory) {
357
444
auto AesCircuitCtr = AesCircuitCtrFactory->create ();
0 commit comments