diff --git a/extensions/keccak256/circuit/cuda/include/p3_keccakf.cuh b/extensions/keccak256/circuit/cuda/include/p3_keccakf.cuh index 77bf925d0f..699d69a38f 100644 --- a/extensions/keccak256/circuit/cuda/include/p3_keccakf.cuh +++ b/extensions/keccak256/circuit/cuda/include/p3_keccakf.cuh @@ -25,11 +25,24 @@ __device__ __constant__ inline uint64_t RC[NUM_ROUNDS] = { 0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800AULL, 0x800000008000000AULL, 0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL }; + +// In-place rho/pi permutation cycle (24 elements, flat index y*5+x). +// Follows (x,y) -> (y, 2x+3y mod 5); (0,0) is a fixed point. +// cycle[i] receives its value from cycle[(i+1) % 24] with the listed rotation. +__device__ __constant__ inline int RHO_PI_CYCLE_IDX[24] = {1, 6, 9, 22, 14, 20, 2, 12, + 13, 19, 23, 15, 4, 24, 21, 8, + 16, 5, 3, 18, 17, 11, 7, 10}; +__device__ __constant__ inline uint8_t RHO_PI_CYCLE_ROT[24] = {44, 20, 61, 39, 18, 62, 43, 25, + 8, 56, 41, 27, 14, 2, 55, 45, + 36, 28, 21, 15, 10, 6, 3, 1}; + } // namespace keccak256 namespace p3_keccak_air { using keccak256::R; using keccak256::RC; +using keccak256::RHO_PI_CYCLE_IDX; +using keccak256::RHO_PI_CYCLE_ROT; // Plonky3 KeccakCols structure (from p3_keccak_air) // Must match exactly for trace compatibility @@ -69,6 +82,54 @@ template struct KeccakCols { inline constexpr size_t NUM_KECCAK_COLS = sizeof(KeccakCols); +// Apply one keccak-f round in-place without trace writes. +// Used by phase 1 to advance state between rounds. +static __device__ __noinline__ void apply_round_in_place( + uint32_t round, + uint64_t current_state[5][5] +) { + // Theta: C[x] = xor(A[x, 0], A[x, 1], A[x, 2], A[x, 3], A[x, 4]) + uint64_t state_c[5]; +#pragma unroll 5 + for (auto x = 0; x < 5; x++) { + state_c[x] = current_state[0][x] ^ current_state[1][x] ^ current_state[2][x] ^ + current_state[3][x] ^ current_state[4][x]; + } + + // Theta: A'[x, y] = A[x, y] ^ D[x], where D[x] = C[x-1] ^ ROT(C[x+1], 1) + for (int x = 0; x < 5; x++) { + uint64_t d = state_c[(x + 4) % 5] ^ ROTL64(state_c[(x + 1) % 5], 1); +#pragma unroll 5 + for (int y = 0; y < 5; y++) { + current_state[y][x] ^= d; + } + } + + // Rho/Pi: B[x, y] = ROT(A'[x, y], R[x][y]) via 24-element permutation cycle + uint64_t *flat_state = ¤t_state[0][0]; + uint64_t temp = ROTL64(flat_state[RHO_PI_CYCLE_IDX[0]], RHO_PI_CYCLE_ROT[23]); +#pragma unroll 1 + for (int i = 0; i < 23; i++) { + flat_state[RHO_PI_CYCLE_IDX[i]] = + ROTL64(flat_state[RHO_PI_CYCLE_IDX[i + 1]], RHO_PI_CYCLE_ROT[i]); + } + flat_state[RHO_PI_CYCLE_IDX[23]] = temp; + + // Chi: A''[x, y] = B[x, y] ^ (~B[x+1, y] & B[x+2, y]), in-place with 2 temps per row + for (int y = 0; y < 5; y++) { + uint64_t t0 = current_state[y][0]; + uint64_t t1 = current_state[y][1]; + current_state[y][0] = t0 ^ ((~t1) & current_state[y][2]); + current_state[y][1] = t1 ^ ((~current_state[y][2]) & current_state[y][3]); + current_state[y][2] ^= (~current_state[y][3]) & current_state[y][4]; + current_state[y][3] ^= (~current_state[y][4]) & t0; + current_state[y][4] ^= (~t0) & t1; + } + + // Iota: A'''[0, 0] = A''[0, 0] ^ RC[round] + current_state[0][0] ^= RC[round]; +} + // tracegen matching plonky3 // `row` must have first NUM_KECCAK_COLS columns matching KeccakCols static __device__ __noinline__ void generate_trace_row_for_round( @@ -88,45 +149,42 @@ static __device__ __noinline__ void generate_trace_row_for_round( COL_WRITE_BITS(row, KeccakCols, c[x], state_c[x]); } - // Populate C'[x, z] = xor(C[x, z], C[x - 1, z], ROTL1(C[x + 1, z - 1])). - uint64_t state_c_prime[5]; -#pragma unroll 5 - for (auto x = 0; x < 5; x++) { - state_c_prime[x] = state_c[x] ^ state_c[(x + 4) % 5] ^ ROTL64(state_c[(x + 1) % 5], 1); - COL_WRITE_BITS(row, KeccakCols, c_prime[x], state_c_prime[x]); - } - - // Populate A'. To avoid shifting indices, we rewrite - // A'[x, y, z] = xor(A[x, y, z], C[x - 1, z], C[x + 1, z - 1]) - // as - // A'[x, y, z] = xor(A[x, y, z], C[x, z], C'[x, z]). + // Populate C'[x, z] and A'[x, y] using scalar d = C[x-1] ^ ROTL(C[x+1], 1). + // Avoids materializing state_c_prime[5] array (~10 regs saved). for (int x = 0; x < 5; x++) { + uint64_t d = state_c[(x + 4) % 5] ^ ROTL64(state_c[(x + 1) % 5], 1); + COL_WRITE_BITS(row, KeccakCols, c_prime[x], state_c[x] ^ d); + #pragma unroll 5 for (int y = 0; y < 5; y++) { - current_state[y][x] ^= state_c[x] ^ state_c_prime[x]; + current_state[y][x] ^= d; COL_WRITE_BITS(row, KeccakCols, a_prime[y][x], current_state[y][x]); } } - // Rotate the current state to get the B array. - uint64_t state_b[5][5]; - for (auto i = 0; i < 5; i++) { -#pragma unroll 5 - for (auto j = 0; j < 5; j++) { - auto new_i = (i + 3 * j) % 5; - auto new_j = i; - state_b[j][i] = ROTL64(current_state[new_j][new_i], R[new_i][new_j]); - } + // In-place rho/pi using the 24-element permutation cycle. + // Avoids allocating state_b[5][5] (~50 regs saved). + uint64_t *flat_state = ¤t_state[0][0]; + uint64_t temp = ROTL64(flat_state[RHO_PI_CYCLE_IDX[0]], RHO_PI_CYCLE_ROT[23]); + // Prevent unrolling to avoid code bloat and register pressure from 23 simultaneous rotations. +#pragma unroll 1 + for (int i = 0; i < 23; i++) { + flat_state[RHO_PI_CYCLE_IDX[i]] = + ROTL64(flat_state[RHO_PI_CYCLE_IDX[i + 1]], RHO_PI_CYCLE_ROT[i]); } - - // Populate A'' as A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])). - for (int i = 0; i < 5; i++) { -#pragma unroll 5 - for (int j = 0; j < 5; j++) { - current_state[i][j] = - state_b[i][j] ^ ((~state_b[i][(j + 1) % 5]) & state_b[i][(j + 2) % 5]); - } + flat_state[RHO_PI_CYCLE_IDX[23]] = temp; + + // Populate A'' = B[x,y] ^ (~B[x+1,y] & B[x+2,y]), in-place chi with 2 temps per row. + for (int y = 0; y < 5; y++) { + uint64_t t0 = current_state[y][0]; + uint64_t t1 = current_state[y][1]; + current_state[y][0] = t0 ^ ((~t1) & current_state[y][2]); + current_state[y][1] = t1 ^ ((~current_state[y][2]) & current_state[y][3]); + current_state[y][2] ^= (~current_state[y][3]) & current_state[y][4]; + current_state[y][3] ^= (~current_state[y][4]) & t0; + current_state[y][4] ^= (~t0) & t1; } + uint16_t *state_limbs = reinterpret_cast(¤t_state[0][0]); COL_WRITE_ARRAY(row, KeccakCols, a_prime_prime, state_limbs); diff --git a/extensions/keccak256/circuit/cuda/src/keccakf_perm.cu b/extensions/keccak256/circuit/cuda/src/keccakf_perm.cu index 4c7dc58ff9..591346ca68 100644 --- a/extensions/keccak256/circuit/cuda/src/keccakf_perm.cu +++ b/extensions/keccak256/circuit/cuda/src/keccakf_perm.cu @@ -19,30 +19,38 @@ using p3_keccak_air::U64_LIMBS; #define KECCAKF_PERM_WRITE(FIELD, VALUE) COL_WRITE_VALUE(row, KeccakfPermCols, FIELD, VALUE) #define KECCAKF_PERM_WRITE_ARRAY(FIELD, VALUES) COL_WRITE_ARRAY(row, KeccakfPermCols, FIELD, VALUES) -// Main kernel for KeccakfPermChip trace generation -// Each thread processes one permutation (24 rows) -__global__ void keccakf_perm_tracegen( - Fp *d_trace, - size_t height, +static constexpr uint32_t KECCAK_STATE_WORDS = 25; + +// Two-phase keccak-f trace generation for coalesced column-major stores. +// Trace layout is trace[col * height + row]; threads writing adjacent rows coalesce. +// +// Phase 1 (1 thread per permutation): +// - runs 24 keccak-f rounds (theta/rho/pi/chi/iota) +// - stores the 25-lane u64 round-input state before each round +// into scratch: d_round_states[perm][round][lane] (~4.8 KB/perm) +// +// Phase 2 (1 thread per row = 1 round of 1 permutation): +// - loads round-input state from scratch +// - recomputes that round to materialize intermediates (c, c', a', a'', ...) +// - writes all 2634 trace columns + +// Phase 1: compute keccak-f, store round-input states to scratch +// each thread processes one permutation (24 rounds) +__global__ void keccakf_perm_phase1( + uint64_t *__restrict__ d_round_states, // [blocks_to_fill][24][25] uint32_t num_records, - uint32_t blocks_to_fill, // = ceil(height / 24) + uint32_t blocks_to_fill, DeviceBufferConstView d_records ) { - uint32_t block_idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (block_idx >= blocks_to_fill) { + uint32_t perm_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (perm_idx >= blocks_to_fill) { return; } - // Initialize state - will be transformed by generate_trace_row_for_round __align__(16) uint64_t current_state[5][5] = {0}; - __align__(16) uint64_t initial_state[5][5] = {0}; - - uint32_t timestamp = 0; - if (block_idx < num_records) { - auto const &rec = d_records[block_idx]; - timestamp = rec.timestamp; + if (perm_idx < num_records) { + auto const &rec = d_records[perm_idx]; // Convert preimage bytes to u64 state with coordinate transposition // generate_trace_row_for_round expects current_state[y][x] = A[x][y] (keccak notation) @@ -57,116 +65,70 @@ __global__ void keccakf_perm_tracegen( val |= static_cast(rec.preimage_buffer_bytes[keccak_offset * 8 + j]) << (j * 8); } - // Store as current_state[y][x] = A[x][y] for generate_trace_row_for_round current_state[y][x] = val; - initial_state[y][x] = val; } } } - // Convert initial state to u16 limbs for preimage columns - uint16_t *initial_state_limbs = reinterpret_cast(initial_state); - - // Calculate how many rows to fill for this block - size_t rows_for_this_block = NUM_ROUNDS; - if (block_idx == blocks_to_fill - 1) { - // Last block might have fewer rows - size_t remaining = height - block_idx * NUM_ROUNDS; - if (remaining < NUM_ROUNDS) { - rows_for_this_block = remaining; + // Store round-input state before each round, then advance + uint64_t *flat = ¤t_state[0][0]; + for (uint32_t round_idx = 0; round_idx < NUM_ROUNDS; round_idx++) { + size_t off = (static_cast(perm_idx) * NUM_ROUNDS + round_idx) * KECCAK_STATE_WORDS; +#pragma unroll + for (uint32_t i = 0; i < KECCAK_STATE_WORDS; i++) { + d_round_states[off + i] = flat[i]; } + p3_keccak_air::apply_round_in_place(round_idx, current_state); } +} - // Generate 24 round rows - for (uint32_t round_idx = 0; round_idx < rows_for_this_block; round_idx++) { - size_t row_idx = block_idx * NUM_ROUNDS + round_idx; - RowSlice row(d_trace + row_idx, height); - - // Fill zero first for safety - row.fill_zero(0, sizeof(KeccakfPermCols)); - - if (block_idx < num_records) { - // Valid record: fill preimage and compute keccak-f trace - - // Fill preimage columns (same for all rounds within a permutation) - COL_WRITE_ARRAY(row, KeccakfPermCols, inner.preimage, initial_state_limbs); - - // Fill 'a' input state - on first round, same as preimage - // On subsequent rounds, copy from previous row's output - if (round_idx == 0) { - COL_WRITE_ARRAY(row, KeccakfPermCols, inner.a, initial_state_limbs); - } else { - // Copy previous round's output to this round's input - // a[y][x] gets a_prime_prime_prime[0][0] for (x,y)=(0,0), else a_prime_prime[y][x] - RowSlice prev_row(d_trace + row_idx - 1, height); - for (int y = 0; y < 5; y++) { - for (int x = 0; x < 5; x++) { - for (int limb = 0; limb < U64_LIMBS; limb++) { - Fp val; - if (x == 0 && y == 0) { - val = prev_row[COL_INDEX( - KeccakfPermCols, inner.a_prime_prime_prime_0_0_limbs[limb] - )]; - } else { - val = prev_row[COL_INDEX( - KeccakfPermCols, inner.a_prime_prime[y][x][limb] - )]; - } - KECCAKF_PERM_WRITE(inner.a[y][x][limb], val); - } - } - } - } +// Phase 2: write column-major trace from cached round states +// Each thread writes one row; adjacent threads write adjacent rows (coalesced) +__global__ void keccakf_perm_phase2( + Fp *__restrict__ d_trace, + size_t height, + uint32_t num_records, + DeviceBufferConstView d_records, + uint64_t const *__restrict__ d_round_states // [blocks_to_fill][24][25] +) { + size_t row_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (row_idx >= height) { + return; + } - // Generate trace row for this round (updates current_state in-place) - p3_keccak_air::generate_trace_row_for_round(row, round_idx, current_state); + uint32_t perm_idx = static_cast(row_idx / NUM_ROUNDS); + uint32_t round_idx = static_cast(row_idx % NUM_ROUNDS); - // Set export flag and timestamp on last round - if (round_idx == NUM_ROUNDS - 1) { - KECCAKF_PERM_WRITE(inner._export, 1); - KECCAKF_PERM_WRITE(timestamp, timestamp); - } else { - KECCAKF_PERM_WRITE(inner._export, 0); - KECCAKF_PERM_WRITE(timestamp, 0); - } - } else { - // Dummy block: generate valid keccak-f trace with zero state, export=0 - // The KeccakAir constraints require all intermediate columns (C, C', A', A'', etc.) - // to be properly computed, so we can't just zero them out. - - // Fill preimage with zeros (already zeroed) - // Fill 'a' input state - if (round_idx == 0) { - // a = preimage = zeros (already zeroed) - } else { - // Copy previous round's output to this round's input - RowSlice prev_row(d_trace + row_idx - 1, height); - for (int y = 0; y < 5; y++) { - for (int x = 0; x < 5; x++) { - for (int limb = 0; limb < U64_LIMBS; limb++) { - Fp val; - if (x == 0 && y == 0) { - val = prev_row[COL_INDEX( - KeccakfPermCols, inner.a_prime_prime_prime_0_0_limbs[limb] - )]; - } else { - val = prev_row[COL_INDEX( - KeccakfPermCols, inner.a_prime_prime[y][x][limb] - )]; - } - KECCAKF_PERM_WRITE(inner.a[y][x][limb], val); - } - } - } - } + // Load round-input state from scratch + __align__(16) uint64_t current_state[5][5]; + size_t off = (static_cast(perm_idx) * NUM_ROUNDS + round_idx) * KECCAK_STATE_WORDS; + uint64_t *flat = ¤t_state[0][0]; +#pragma unroll + for (uint32_t i = 0; i < KECCAK_STATE_WORDS; i++) { + flat[i] = d_round_states[off + i]; + } - // Generate trace row for this round (using current_state which is zero for dummy blocks) - p3_keccak_air::generate_trace_row_for_round(row, round_idx, current_state); + RowSlice row(d_trace + row_idx, height); - // Dummy rows: export must be 0, timestamp = 0 - KECCAKF_PERM_WRITE(inner._export, 0); - KECCAKF_PERM_WRITE(timestamp, 0); - } + // Fill preimage columns (invariant across rounds, read from round 0 of this permutation) + size_t preimage_off = static_cast(perm_idx) * NUM_ROUNDS * KECCAK_STATE_WORDS; + auto const *preimage_limbs = reinterpret_cast(&d_round_states[preimage_off]); + KECCAKF_PERM_WRITE_ARRAY(inner.preimage, preimage_limbs); + + // Fill 'a' input state from current round state + uint16_t *state_limbs = reinterpret_cast(¤t_state[0][0]); + COL_WRITE_ARRAY(row, KeccakfPermCols, inner.a, state_limbs); + + // Generate trace row for this round (writes c, c_prime, a_prime, a_prime_prime, etc.) + p3_keccak_air::generate_trace_row_for_round(row, round_idx, current_state); + + // Set export flag and timestamp on last round of valid records + if (perm_idx < num_records && round_idx == NUM_ROUNDS - 1) { + KECCAKF_PERM_WRITE(inner._export, 1); + KECCAKF_PERM_WRITE(timestamp, d_records[perm_idx].timestamp); + } else { + KECCAKF_PERM_WRITE(inner._export, 0); + KECCAKF_PERM_WRITE(timestamp, 0); } } @@ -178,16 +140,32 @@ extern "C" int _keccakf_perm_tracegen( size_t height, size_t width, DeviceBufferConstView d_records, - size_t num_records + size_t num_records, + uint64_t *d_round_states, + size_t round_state_words ) { assert((height & (height - 1)) == 0); assert(width == sizeof(KeccakfPermCols)); uint32_t blocks_to_fill = div_ceil(height, uint32_t(NUM_ROUNDS)); + assert( + round_state_words >= static_cast(blocks_to_fill) * NUM_ROUNDS * KECCAK_STATE_WORDS + ); + + // Phase 1: compute keccak-f, store round states to scratch + auto [p1_grid, p1_block] = kernel_launch_params(blocks_to_fill, 128); + keccakf_perm_phase1<<>>( + d_round_states, static_cast(num_records), blocks_to_fill, d_records + ); + int result = CHECK_KERNEL(); + if (result != 0) { + return result; + } - auto [grid, block] = kernel_launch_params(blocks_to_fill, 256); - keccakf_perm_tracegen<<>>( - d_trace, height, static_cast(num_records), blocks_to_fill, d_records + // Phase 2: write trace with coalesced stores (one thread per row) + auto [p2_grid, p2_block] = kernel_launch_params(height, 256); + keccakf_perm_phase2<<>>( + d_trace, height, static_cast(num_records), d_records, d_round_states ); return CHECK_KERNEL(); } diff --git a/extensions/keccak256/circuit/src/cuda/cuda_abi.rs b/extensions/keccak256/circuit/src/cuda/cuda_abi.rs index cb6869e014..3755148f4d 100644 --- a/extensions/keccak256/circuit/src/cuda/cuda_abi.rs +++ b/extensions/keccak256/circuit/src/cuda/cuda_abi.rs @@ -110,6 +110,8 @@ pub mod keccakf_perm { width: usize, d_records: DeviceBufferView, num_records: usize, + d_round_states: *mut u64, + round_state_words: usize, ) -> i32; } @@ -120,6 +122,7 @@ pub mod keccakf_perm { height: usize, d_records: &DeviceBuffer, num_records: usize, + d_round_states: &DeviceBuffer, ) -> Result<(), CudaError> { assert!(height.is_power_of_two() || height == 0); CudaError::from_result(_keccakf_perm_tracegen( @@ -128,6 +131,8 @@ pub mod keccakf_perm { d_trace.len() / height, d_records.view(), num_records, + d_round_states.as_mut_ptr(), + d_round_states.len(), )) } } diff --git a/extensions/keccak256/circuit/src/cuda/mod.rs b/extensions/keccak256/circuit/src/cuda/mod.rs index d7dab02914..e0566cebd7 100644 --- a/extensions/keccak256/circuit/src/cuda/mod.rs +++ b/extensions/keccak256/circuit/src/cuda/mod.rs @@ -163,6 +163,11 @@ impl Chip for KeccakfPermChipGpu { let trace_height = next_power_of_two_or_zero(num_records * NUM_ROUNDS); let d_trace = DeviceMatrix::::with_capacity(trace_height, trace_width); + // Scratch buffer for two-phase tracegen: 25 u64 lanes per round per permutation. + // 24 rounds * 25 lanes * 8 bytes = 4800 bytes/perm, vs 24 * 2634 * 4 = 252864 bytes/perm + // for the trace matrix (~1.9% overhead). + let blocks_to_fill = trace_height.div_ceil(NUM_ROUNDS); + let d_round_states = DeviceBuffer::::with_capacity(blocks_to_fill * NUM_ROUNDS * 25); unsafe { cuda_abi::keccakf_perm::tracegen( @@ -170,6 +175,7 @@ impl Chip for KeccakfPermChipGpu { trace_height, &d_records, num_records, + &d_round_states, ) .unwrap(); }