Skip to content

Commit

Permalink
fix(gpu): add missing sync clause to CG and TBC Multi-bit PBS
Browse files Browse the repository at this point in the history
  • Loading branch information
pdroalves committed Oct 28, 2024
1 parent 3f44959 commit 9497757
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 25 deletions.
1 change: 0 additions & 1 deletion backends/tfhe-cuda-backend/cuda/src/crypto/gadget.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ public:

mask_mod_b = (1ll << base_log) - 1ll;
current_level = level_count;
synchronize_threads_in_block();
}

// Decomposes all polynomials at once
Expand Down
14 changes: 7 additions & 7 deletions backends/tfhe-cuda-backend/cuda/src/crypto/torus.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,19 @@ __device__ inline T round_to_closest_multiple(T x, uint32_t base_log,
}

template <typename T>
__device__ inline T init_decomposer_state(T input, uint32_t base_log,
uint32_t level_count) {
__device__ __forceinline__ T init_decomposer_state(T input, uint32_t base_log,
uint32_t level_count) {
const T rep_bit_count = level_count * base_log;
const T non_rep_bit_count = sizeof(T) * 8 - rep_bit_count;
T res = input >> (non_rep_bit_count - 1);
T rounding_bit = res & 1;
res += 1;
res = res >> 1;
T rounding_bit = res & (T)1;
res += (T)1;
res = res >> (T)1;
T mod_mask = (T)(-1) >> non_rep_bit_count;
res = res & mod_mask;
T shifted_random = rounding_bit << (rep_bit_count - 1);
T shifted_random = rounding_bit << (rep_bit_count - (T)1);
T need_balance =
(((res - (T)(1)) | shifted_random) & res) >> (rep_bit_count - 1);
(((res - (T)(1)) | shifted_random) & res) >> (rep_bit_count - (T)1);
return res - (need_balance << rep_bit_count);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
}

for (int i = 0; (i + lwe_offset) < lwe_dimension && i < lwe_chunk_size; i++) {
synchronize_threads_in_block();
// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
init_decomposer_state_inplace<Torus, params::opt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ __global__ void __launch_bounds__(params::degree / params::opt)
}

for (int i = 0; (i + lwe_offset) < lwe_dimension && i < lwe_chunk_size; i++) {
synchronize_threads_in_block();
// Perform a rounding to increase the accuracy of the
// bootstrapped ciphertext
init_decomposer_state_inplace<Torus, params::opt,
Expand Down
18 changes: 2 additions & 16 deletions backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -175,22 +175,8 @@ __device__ void init_decomposer_state_inplace(T *rotated_acc, int base_log,
T *rotated_acc_slice = (T *)rotated_acc + (ptrdiff_t)(z * degree);
int tid = threadIdx.x;
for (int i = 0; i < elems_per_thread; i++) {
T x_acc = rotated_acc_slice[tid];

const T rep_bit_count = level_count * base_log;
const T non_rep_bit_count = sizeof(T) * 8 - rep_bit_count;
T res_acc = x_acc >> (non_rep_bit_count - 1);
T rounding_bit = res_acc & 1;
res_acc += 1;
res_acc = res_acc >> 1;
T mod_mask = (T)(-1) >> non_rep_bit_count;
res_acc = res_acc & mod_mask;
T shifted_random = rounding_bit << (rep_bit_count - 1);
T need_balance = (((res_acc - (T)(1)) | shifted_random) & res_acc) >>
(rep_bit_count - 1);
res_acc = res_acc - (need_balance << rep_bit_count);

rotated_acc_slice[tid] = res_acc;
rotated_acc_slice[tid] =
init_decomposer_state(rotated_acc_slice[tid], base_log, level_count);
tid = tid + block_size;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ fn lwe_encrypt_multi_bit_pbs_decrypt_custom_mod<
let decrypted = decrypt_lwe_ciphertext(&output_lwe_secret_key, &out_pbs_ct);

let decoded = round_decode(decrypted.0, delta) % msg_modulus;

assert_eq!(decoded, f(msg));
}
}
Expand Down

0 comments on commit 9497757

Please sign in to comment.