Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 88 additions & 30 deletions extensions/keccak256/circuit/cuda/include/p3_keccakf.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,24 @@ __device__ __constant__ inline uint64_t RC[NUM_ROUNDS] = {
0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800AULL, 0x800000008000000AULL,
0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL
};

// In-place rho/pi permutation cycle (24 elements, flat index y*5+x).
// Follows (x,y) -> (y, 2x+3y mod 5); (0,0) is a fixed point.
// cycle[i] receives its value from cycle[(i+1) % 24] with the listed rotation.
__device__ __constant__ inline int RHO_PI_CYCLE_IDX[24] = {1, 6, 9, 22, 14, 20, 2, 12,
13, 19, 23, 15, 4, 24, 21, 8,
16, 5, 3, 18, 17, 11, 7, 10};
__device__ __constant__ inline uint8_t RHO_PI_CYCLE_ROT[24] = {44, 20, 61, 39, 18, 62, 43, 25,
8, 56, 41, 27, 14, 2, 55, 45,
36, 28, 21, 15, 10, 6, 3, 1};

} // namespace keccak256

namespace p3_keccak_air {
using keccak256::R;
using keccak256::RC;
using keccak256::RHO_PI_CYCLE_IDX;
using keccak256::RHO_PI_CYCLE_ROT;

// Plonky3 KeccakCols structure (from p3_keccak_air)
// Must match exactly for trace compatibility
Expand Down Expand Up @@ -69,6 +82,54 @@ template <typename T> struct KeccakCols {

inline constexpr size_t NUM_KECCAK_COLS = sizeof(KeccakCols<uint8_t>);

// Apply one keccak-f round in-place without trace writes.
// Used by phase 1 to advance state between rounds.
static __device__ __noinline__ void apply_round_in_place(
uint32_t round,
uint64_t current_state[5][5]
) {
// Theta: C[x] = xor(A[x, 0], A[x, 1], A[x, 2], A[x, 3], A[x, 4])
uint64_t state_c[5];
#pragma unroll 5
for (auto x = 0; x < 5; x++) {
state_c[x] = current_state[0][x] ^ current_state[1][x] ^ current_state[2][x] ^
current_state[3][x] ^ current_state[4][x];
}

// Theta: A'[x, y] = A[x, y] ^ D[x], where D[x] = C[x-1] ^ ROT(C[x+1], 1)
for (int x = 0; x < 5; x++) {
uint64_t d = state_c[(x + 4) % 5] ^ ROTL64(state_c[(x + 1) % 5], 1);
#pragma unroll 5
for (int y = 0; y < 5; y++) {
current_state[y][x] ^= d;
}
}

// Rho/Pi: B[x, y] = ROT(A'[x, y], R[x][y]) via 24-element permutation cycle
uint64_t *flat_state = &current_state[0][0];
uint64_t temp = ROTL64(flat_state[RHO_PI_CYCLE_IDX[0]], RHO_PI_CYCLE_ROT[23]);
#pragma unroll 1
for (int i = 0; i < 23; i++) {
flat_state[RHO_PI_CYCLE_IDX[i]] =
ROTL64(flat_state[RHO_PI_CYCLE_IDX[i + 1]], RHO_PI_CYCLE_ROT[i]);
}
flat_state[RHO_PI_CYCLE_IDX[23]] = temp;

// Chi: A''[x, y] = B[x, y] ^ (~B[x+1, y] & B[x+2, y]), in-place with 2 temps per row
for (int y = 0; y < 5; y++) {
uint64_t t0 = current_state[y][0];
uint64_t t1 = current_state[y][1];
current_state[y][0] = t0 ^ ((~t1) & current_state[y][2]);
current_state[y][1] = t1 ^ ((~current_state[y][2]) & current_state[y][3]);
current_state[y][2] ^= (~current_state[y][3]) & current_state[y][4];
current_state[y][3] ^= (~current_state[y][4]) & t0;
current_state[y][4] ^= (~t0) & t1;
}

// Iota: A'''[0, 0] = A''[0, 0] ^ RC[round]
current_state[0][0] ^= RC[round];
}

// tracegen matching plonky3
// `row` must have first NUM_KECCAK_COLS columns matching KeccakCols
static __device__ __noinline__ void generate_trace_row_for_round(
Expand All @@ -88,45 +149,42 @@ static __device__ __noinline__ void generate_trace_row_for_round(
COL_WRITE_BITS(row, KeccakCols, c[x], state_c[x]);
}

// Populate C'[x, z] = xor(C[x, z], C[x - 1, z], ROTL1(C[x + 1, z - 1])).
uint64_t state_c_prime[5];
#pragma unroll 5
for (auto x = 0; x < 5; x++) {
state_c_prime[x] = state_c[x] ^ state_c[(x + 4) % 5] ^ ROTL64(state_c[(x + 1) % 5], 1);
COL_WRITE_BITS(row, KeccakCols, c_prime[x], state_c_prime[x]);
}

// Populate A'. To avoid shifting indices, we rewrite
// A'[x, y, z] = xor(A[x, y, z], C[x - 1, z], C[x + 1, z - 1])
// as
// A'[x, y, z] = xor(A[x, y, z], C[x, z], C'[x, z]).
// Populate C'[x, z] and A'[x, y] using scalar d = C[x-1] ^ ROTL(C[x+1], 1).
// Avoids materializing state_c_prime[5] array (~10 regs saved).
for (int x = 0; x < 5; x++) {
uint64_t d = state_c[(x + 4) % 5] ^ ROTL64(state_c[(x + 1) % 5], 1);
COL_WRITE_BITS(row, KeccakCols, c_prime[x], state_c[x] ^ d);

#pragma unroll 5
for (int y = 0; y < 5; y++) {
current_state[y][x] ^= state_c[x] ^ state_c_prime[x];
current_state[y][x] ^= d;
COL_WRITE_BITS(row, KeccakCols, a_prime[y][x], current_state[y][x]);
}
}

// Rotate the current state to get the B array.
uint64_t state_b[5][5];
for (auto i = 0; i < 5; i++) {
#pragma unroll 5
for (auto j = 0; j < 5; j++) {
auto new_i = (i + 3 * j) % 5;
auto new_j = i;
state_b[j][i] = ROTL64(current_state[new_j][new_i], R[new_i][new_j]);
}
// In-place rho/pi using the 24-element permutation cycle.
// Avoids allocating state_b[5][5] (~50 regs saved).
uint64_t *flat_state = &current_state[0][0];
uint64_t temp = ROTL64(flat_state[RHO_PI_CYCLE_IDX[0]], RHO_PI_CYCLE_ROT[23]);
// Prevent unrolling to avoid code bloat and register pressure from 23 simultaneous rotations.
#pragma unroll 1
for (int i = 0; i < 23; i++) {
flat_state[RHO_PI_CYCLE_IDX[i]] =
ROTL64(flat_state[RHO_PI_CYCLE_IDX[i + 1]], RHO_PI_CYCLE_ROT[i]);
}

// Populate A'' as A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])).
for (int i = 0; i < 5; i++) {
#pragma unroll 5
for (int j = 0; j < 5; j++) {
current_state[i][j] =
state_b[i][j] ^ ((~state_b[i][(j + 1) % 5]) & state_b[i][(j + 2) % 5]);
}
flat_state[RHO_PI_CYCLE_IDX[23]] = temp;

// Populate A'' = B[x,y] ^ (~B[x+1,y] & B[x+2,y]), in-place chi with 2 temps per row.
for (int y = 0; y < 5; y++) {
uint64_t t0 = current_state[y][0];
uint64_t t1 = current_state[y][1];
current_state[y][0] = t0 ^ ((~t1) & current_state[y][2]);
current_state[y][1] = t1 ^ ((~current_state[y][2]) & current_state[y][3]);
current_state[y][2] ^= (~current_state[y][3]) & current_state[y][4];
current_state[y][3] ^= (~current_state[y][4]) & t0;
current_state[y][4] ^= (~t0) & t1;
}

uint16_t *state_limbs = reinterpret_cast<uint16_t *>(&current_state[0][0]);
COL_WRITE_ARRAY(row, KeccakCols, a_prime_prime, state_limbs);

Expand Down
Loading
Loading