Skip to content

Commit

Permalink
Some additional kernel thread index refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
bdice committed Sep 13, 2023
1 parent 1668c2c commit 248ccab
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 29 deletions.
17 changes: 10 additions & 7 deletions cpp/benchmarks/join/generate_input_tables.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <cudf/detail/utilities/cuda.cuh>
#include <cudf/detail/utilities/device_atomics.cuh>
#include <cudf/utilities/default_stream.hpp>
#include <cudf/utilities/error.hpp>
Expand All @@ -33,7 +34,7 @@

__global__ static void init_curand(curandState* state, int const nstates)
{
int ithread = threadIdx.x + blockIdx.x * blockDim.x;
int ithread = cudf::detail::grid_1d::global_thread_id();

if (ithread < nstates) { curand_init(1234ULL, ithread, 0, state + ithread); }
}
Expand All @@ -45,13 +46,14 @@ __global__ static void init_build_tbl(key_type* const build_tbl,
curandState* state,
int const num_states)
{
auto const start_idx = blockIdx.x * blockDim.x + threadIdx.x;
auto const stride = blockDim.x * gridDim.x;
auto const start_idx = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();
assert(start_idx < num_states);

curandState localState = state[start_idx];

for (size_type idx = start_idx; idx < build_tbl_size; idx += stride) {
for (thread_index_type tidx = start_idx; tidx < build_tbl_size; tidx += stride) {
auto const idx = static_cast<size_type>(tidx);
double const x = curand_uniform_double(&localState);

build_tbl[idx] = static_cast<key_type>(x * (build_tbl_size / multiplicity));
Expand All @@ -70,13 +72,14 @@ __global__ void init_probe_tbl(key_type* const probe_tbl,
curandState* state,
int const num_states)
{
auto const start_idx = blockIdx.x * blockDim.x + threadIdx.x;
auto const stride = blockDim.x * gridDim.x;
auto const start_idx = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();
assert(start_idx < num_states);

curandState localState = state[start_idx];

for (size_type idx = start_idx; idx < probe_tbl_size; idx += stride) {
for (thread_index_type tidx = start_idx; tidx < probe_tbl_size; tidx += stride) {
auto const idx = static_cast<size_type>(tidx);
key_type val;
double x = curand_uniform_double(&localState);

Expand Down
32 changes: 18 additions & 14 deletions cpp/benchmarks/type_dispatcher/type_dispatcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,30 @@ constexpr int block_size = 256;
template <FunctorType functor_type, class T>
__global__ void no_dispatching_kernel(T** A, cudf::size_type n_rows, cudf::size_type n_cols)
{
using F = Functor<T, functor_type>;
cudf::size_type index = blockIdx.x * blockDim.x + threadIdx.x;
while (index < n_rows) {
using F = Functor<T, functor_type>;
auto tidx = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();
while (tidx < n_rows) {
auto const index = static_cast<cudf::size_type>(tid);
for (int c = 0; c < n_cols; c++) {
A[c][index] = F::f(A[c][index]);
}
index += blockDim.x * gridDim.x;
tidx += stride;
}
}

// This is for HOST_DISPATCHING
template <FunctorType functor_type, class T>
__global__ void host_dispatching_kernel(cudf::mutable_column_device_view source_column)
{
using F = Functor<T, functor_type>;
T* A = source_column.data<T>();
cudf::size_type index = blockIdx.x * blockDim.x + threadIdx.x;
while (index < source_column.size()) {
A[index] = F::f(A[index]);
index += blockDim.x * gridDim.x;
using F = Functor<T, functor_type>;
T* A = source_column.data<T>();
auto tidx = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();
while (tidx < source_column.size()) {
auto const index = static_cast<cudf::size_type>(tid);
A[index] = F::f(A[index]);
tidx += stride;
}
}

Expand Down Expand Up @@ -127,14 +131,14 @@ template <FunctorType functor_type>
__global__ void device_dispatching_kernel(cudf::mutable_table_device_view source)
{
cudf::size_type const n_rows = source.num_rows();
cudf::size_type index = threadIdx.x + blockIdx.x * blockDim.x;

while (index < n_rows) {
auto tidx = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();
while (tidx < n_rows) {
for (cudf::size_type i = 0; i < source.num_columns(); i++) {
cudf::type_dispatcher(
source.column(i).type(), RowHandle<functor_type>{}, source.column(i), index);
}
index += blockDim.x * gridDim.x;
tidx += stride;
} // while
}

Expand Down
17 changes: 9 additions & 8 deletions cpp/include/cudf/detail/copy_if_else.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,19 @@ __launch_bounds__(block_size) __global__
mutable_column_device_view out,
size_type* __restrict__ const valid_count)
{
size_type const tid = threadIdx.x + blockIdx.x * block_size;
int const warp_id = tid / warp_size;
auto tidx = cudf::detail::grid_1d::global_thread_id();
auto const stride = cudf::detail::grid_1d::grid_stride();
int const warp_id = tidx / warp_size;
size_type const warps_per_grid = gridDim.x * block_size / warp_size;

// begin/end indices for the column data
size_type begin = 0;
size_type end = out.size();
size_type const begin = 0;
size_type const end = out.size();
// warp indices. since 1 warp == 32 threads == sizeof(bitmask_type) * 8,
// each warp will process one (32 bit) of the validity mask via
// __ballot_sync()
size_type warp_begin = cudf::word_index(begin);
size_type warp_end = cudf::word_index(end - 1);
size_type const warp_begin = cudf::word_index(begin);
size_type const warp_end = cudf::word_index(end - 1);

// lane id within the current warp
constexpr size_type leader_lane{0};
Expand All @@ -65,8 +66,8 @@ __launch_bounds__(block_size) __global__

// current warp.
size_type warp_cur = warp_begin + warp_id;
size_type index = tid;
while (warp_cur <= warp_end) {
auto const index = static_cast<size_type>(tidx);
auto const opt_value =
(index < end) ? (filter(index) ? lhs[index] : rhs[index]) : thrust::nullopt;
if (opt_value) { out.element<T>(index) = static_cast<T>(*opt_value); }
Expand All @@ -84,7 +85,7 @@ __launch_bounds__(block_size) __global__

// next grid
warp_cur += warps_per_grid;
index += block_size * gridDim.x;
tidx += stride;
}

if (has_nulls) {
Expand Down

0 comments on commit 248ccab

Please sign in to comment.