Skip to content

Commit

Permalink
chore(gpu): use same balanced decomposition code as in the CPU code
Browse files Browse the repository at this point in the history
  • Loading branch information
IceTDrinker committed Oct 28, 2024
1 parent f497bf0 commit 3f44959
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 24 deletions.
5 changes: 0 additions & 5 deletions backends/tfhe-cuda-backend/cuda/src/crypto/gadget.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ public:

mask_mod_b = (1ll << base_log) - 1ll;
current_level = level_count;
int tid = threadIdx.x;
for (int i = 0; i < num_poly * params::opt; i++) {
state[tid] >>= (sizeof(T) * 8 - base_log * level_count);
tid += params::degree / params::opt;
}
synchronize_threads_in_block();
}

Expand Down
8 changes: 3 additions & 5 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,

// This loop distribution seems to benefit the global mem reads
for (int i = start_i; i < end_i; i++) {
Torus a_i = round_to_closest_multiple(block_lwe_array_in[i], base_log,
level_count);
Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
Torus state =
init_decomposer_state(block_lwe_array_in[i], base_log, level_count);

for (int j = level_count - 1; j >= 0; j--) {
// Levels are stored in reverse order
Expand Down Expand Up @@ -202,9 +201,8 @@ __device__ void packing_keyswitch_lwe_ciphertext_into_glwe_ciphertext(
// Iterate through all lwe elements
for (int i = 0; i < lwe_dimension_in; i++) {
// Round and prepare decomposition
Torus a_i = round_to_closest_multiple(lwe_in[i], base_log, level_count);
Torus state = init_decomposer_state(lwe_in[i], base_log, level_count);

Torus state = a_i >> (sizeof(Torus) * 8 - base_log * level_count);
Torus mod_b_mask = (1ll << base_log) - 1ll;

// block of key for current lwe coefficient (cur_input_lwe[i])
Expand Down
17 changes: 17 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@ __device__ inline T round_to_closest_multiple(T x, uint32_t base_log,
return res << shift;
}

template <typename T>
__device__ inline T init_decomposer_state(T input, uint32_t base_log,
uint32_t level_count) {
const T rep_bit_count = level_count * base_log;
const T non_rep_bit_count = sizeof(T) * 8 - rep_bit_count;
T res = input >> (non_rep_bit_count - 1);
T rounding_bit = res & 1;
res += 1;
res = res >> 1;
T mod_mask = (T)(-1) >> non_rep_bit_count;
res = res & mod_mask;
T shifted_random = rounding_bit << (rep_bit_count - 1);
T need_balance =
(((res - (T)(1)) | shifted_random) & res) >> (rep_bit_count - 1);
return res - (need_balance << rep_bit_count);
}

template <typename T>
__device__ __forceinline__ void modulus_switch(T input, T &output,
uint32_t log_modulus) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ __global__ void device_programmable_bootstrap_amortized(

// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator_rotated, base_log, level_count, glwe_dimension + 1);

// Initialize the polynomial multiplication via FFT arrays
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ __global__ void device_programmable_bootstrap_cg(

// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator_rotated, base_log, level_count);

synchronize_threads_in_block();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ __global__ void __launch_bounds__(params::degree / params::opt)
for (int i = 0; (i + lwe_offset) < lwe_dimension && i < lwe_chunk_size; i++) {
// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator, base_log, level_count);

// Decompose the accumulator. Each block gets one level of the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ __global__ void __launch_bounds__(params::degree / params::opt)

// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator, base_log, level_count);

synchronize_threads_in_block();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ __global__ void __launch_bounds__(params::degree / params::opt)

// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator, base_log, level_count);

// Decompose the accumulator. Each block gets one level of the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ __global__ void device_programmable_bootstrap_tbc(

// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator_rotated, base_log, level_count);

synchronize_threads_in_block();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ __global__ void __launch_bounds__(params::degree / params::opt)
for (int i = 0; (i + lwe_offset) < lwe_dimension && i < lwe_chunk_size; i++) {
// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
round_to_closest_multiple_inplace<Torus, params::opt,
params::degree / params::opt>(
init_decomposer_state_inplace<Torus, params::opt,
params::degree / params::opt>(
accumulator, base_log, level_count);

// Decompose the accumulator. Each block gets one level of the
Expand Down
36 changes: 36 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,42 @@ __device__ void round_to_closest_multiple_inplace(T *rotated_acc, int base_log,
}
}

/*
* Receives num_poly concatenated polynomials of type T. For each performs a
* rounding to increase accuracy of the PBS. Calculates inplace.
*
* By default, it works on a single polynomial.
*/
template <typename T, int elems_per_thread, int block_size>
__device__ void init_decomposer_state_inplace(T *rotated_acc, int base_log,
int level_count,
uint32_t num_poly = 1) {
constexpr int degree = block_size * elems_per_thread;
for (int z = 0; z < num_poly; z++) {
T *rotated_acc_slice = (T *)rotated_acc + (ptrdiff_t)(z * degree);
int tid = threadIdx.x;
for (int i = 0; i < elems_per_thread; i++) {
T x_acc = rotated_acc_slice[tid];

const T rep_bit_count = level_count * base_log;
const T non_rep_bit_count = sizeof(T) * 8 - rep_bit_count;
T res_acc = x_acc >> (non_rep_bit_count - 1);
T rounding_bit = res_acc & 1;
res_acc += 1;
res_acc = res_acc >> 1;
T mod_mask = (T)(-1) >> non_rep_bit_count;
res_acc = res_acc & mod_mask;
T shifted_random = rounding_bit << (rep_bit_count - 1);
T need_balance = (((res_acc - (T)(1)) | shifted_random) & res_acc) >>
(rep_bit_count - 1);
res_acc = res_acc - (need_balance << rep_bit_count);

rotated_acc_slice[tid] = res_acc;
tid = tid + block_size;
}
}
}

template <typename Torus, class params>
__device__ void add_to_torus(double2 *m_values, Torus *result,
bool init_torus = false) {
Expand Down

0 comments on commit 3f44959

Please sign in to comment.