|
33 | 33 | #include <cub/block/block_scan.cuh>
|
34 | 34 | #include <cuco/static_set.cuh>
|
35 | 35 | #include <thrust/fill.h>
|
| 36 | +#include <thrust/iterator/counting_iterator.h> |
36 | 37 | #include <thrust/iterator/transform_output_iterator.h>
|
| 38 | +#include <thrust/iterator/zip_iterator.h> |
37 | 39 | #include <thrust/sequence.h>
|
38 | 40 |
|
39 | 41 | #include <cstddef>
|
@@ -79,14 +81,9 @@ class build_keys_fn {
|
79 | 81 |
|
80 | 82 | /**
|
81 | 83 | * @brief Device output transform functor to construct `size_type` with `cuco::pair<hash_value_type,
|
82 |
| - * lhs_index_type>` or `cuco::pair<hash_value_type, rhs_index_type>` |
| 84 | + * rhs_index_type>` |
83 | 85 | */
|
84 | 86 | struct output_fn {
|
85 |
| - __device__ constexpr cudf::size_type operator()( |
86 |
| - cuco::pair<hash_value_type, lhs_index_type> const& x) const |
87 |
| - { |
88 |
| - return static_cast<cudf::size_type>(x.second); |
89 |
| - } |
90 | 87 | __device__ constexpr cudf::size_type operator()(
|
91 | 88 | cuco::pair<hash_value_type, rhs_index_type> const& x) const
|
92 | 89 | {
|
@@ -176,15 +173,33 @@ distinct_hash_join<HasNested>::inner_join(rmm::cuda_stream_view stream,
|
176 | 173 | auto const iter = cudf::detail::make_counting_transform_iterator(
|
177 | 174 | 0, build_keys_fn<decltype(d_probe_hasher), lhs_index_type>{d_probe_hasher});
|
178 | 175 |
|
179 |
| - auto const build_indices_begin = |
180 |
| - thrust::make_transform_output_iterator(build_indices->begin(), output_fn{}); |
181 |
| - auto const probe_indices_begin = |
182 |
| - thrust::make_transform_output_iterator(probe_indices->begin(), output_fn{}); |
183 |
| - |
184 |
| - auto const [probe_indices_end, _] = this->_hash_table.retrieve( |
185 |
| - iter, iter + probe_table_num_rows, probe_indices_begin, build_indices_begin, {stream.value()}); |
| 176 | + auto found_indices = rmm::device_uvector<size_type>(probe_table_num_rows, stream); |
| 177 | + auto const found_begin = |
| 178 | + thrust::make_transform_output_iterator(found_indices.begin(), output_fn{}); |
| 179 | + |
| 180 | + // TODO conditional find for nulls once `cuco::static_set::find_if` is added |
| 181 | + // If `idx` is within the range `[0, probe_table_num_rows)` and `found_indices[idx]` is not equal |
| 182 | + // to `JoinNoneValue`, then `idx` has a match in the hash set. |
| 183 | + this->_hash_table.find_async(iter, iter + probe_table_num_rows, found_begin, stream.value()); |
| 184 | + |
| 185 | + auto const tuple_iter = cudf::detail::make_counting_transform_iterator( |
| 186 | + 0, |
| 187 | + cuda::proclaim_return_type<thrust::tuple<size_type, size_type>>( |
| 188 | + [found_iter = found_indices.begin()] __device__(size_type idx) { |
| 189 | + return thrust::tuple{*(found_iter + idx), idx}; |
| 190 | + })); |
| 191 | + auto const output_begin = |
| 192 | + thrust::make_zip_iterator(build_indices->begin(), probe_indices->begin()); |
| 193 | + auto const output_end = |
| 194 | + thrust::copy_if(rmm::exec_policy_nosync(stream), |
| 195 | + tuple_iter, |
| 196 | + tuple_iter + probe_table_num_rows, |
| 197 | + found_indices.begin(), |
| 198 | + output_begin, |
| 199 | + cuda::proclaim_return_type<bool>( |
| 200 | + [] __device__(size_type idx) { return idx != JoinNoneValue; })); |
| 201 | + auto const actual_size = std::distance(output_begin, output_end); |
186 | 202 |
|
187 |
| - auto const actual_size = std::distance(probe_indices_begin, probe_indices_end); |
188 | 203 | build_indices->resize(actual_size, stream);
|
189 | 204 | probe_indices->resize(actual_size, stream);
|
190 | 205 |
|
|
0 commit comments