Skip to content

Commit c0dc36a

Browse files
committed
WIP
1 parent e79787b commit c0dc36a

File tree

4 files changed

+75
-26
lines changed

4 files changed

+75
-26
lines changed

Diff for: include/cuco/detail/equal_wrapper.cuh

+8-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace detail {
2626
* @brief Enum of equality comparison results
2727
*/
2828
// ENUM VALUE MATTERS, DO NOT CHANGE
29-
enum class equal_result : int32_t { UNEQUAL = 0, EQUAL = 1, EMPTY = 2, AVAILABLE = 3 };
29+
enum class equal_result : int32_t { UNEQUAL = 0, EQUAL = 1, EMPTY = 2, ERASED = 3 };
3030

3131
enum class is_insert : bool { YES, NO };
3232

@@ -97,10 +97,13 @@ struct equal_wrapper {
9797
__device__ constexpr equal_result operator()(LHS const& lhs, RHS const& rhs) const noexcept
9898
{
9999
if constexpr (IsInsert == is_insert::YES) {
100-
return (cuco::detail::bitwise_compare(rhs, empty_sentinel_) or
101-
cuco::detail::bitwise_compare(rhs, erased_sentinel_))
102-
? equal_result::AVAILABLE
103-
: this->equal_to(lhs, rhs);
100+
if (cuco::detail::bitwise_compare(rhs, empty_sentinel_)) {
101+
return equal_result::EMPTY;
102+
} else if (cuco::detail::bitwise_compare(rhs, erased_sentinel_)) {
103+
return equal_result::ERASED;
104+
} else {
105+
return this->equal_to(lhs, rhs);
106+
}
104107
} else {
105108
return cuco::detail::bitwise_compare(rhs, empty_sentinel_) ? equal_result::EMPTY
106109
: this->equal_to(lhs, rhs);

Diff for: include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh

+57-17
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,12 @@ class open_addressing_ref_impl {
383383
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
384384
auto const init_idx = *probing_iter;
385385

386+
[[maybe_unused]] auto probing_iter_copy = probing_iter;
387+
[[maybe_unused]] bool erased = false;
388+
[[maybe_unused]] bool empty_after_erased = false;
389+
386390
while (true) {
391+
[[maybe_unused]] continue_after_erased:
387392
auto const bucket_slots = storage_ref_[*probing_iter];
388393

389394
for (auto& slot_content : bucket_slots) {
@@ -393,23 +398,54 @@ class open_addressing_ref_impl {
393398
if constexpr (not allows_duplicates) {
394399
// If the key is already in the container, return false
395400
if (eq_res == detail::equal_result::EQUAL) { return false; }
401+
if (eq_res == detail::equal_result::ERASED and not erased and not empty_after_erased) {
402+
erased = true;
403+
probing_iter_copy = probing_iter;
404+
}
405+
if (eq_res == detail::equal_result::EMPTY and erased and not empty_after_erased) {
406+
empty_after_erased = true;
407+
probing_iter = probing_iter_copy;
408+
goto continue_after_erased;
409+
}
396410
}
397-
if (eq_res == detail::equal_result::AVAILABLE) {
398-
auto const intra_bucket_index = thrust::distance(bucket_slots.begin(), &slot_content);
399-
switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_bucket_index,
400-
slot_content,
401-
val)) {
402-
case insert_result::DUPLICATE: {
403-
if constexpr (allows_duplicates) {
404-
[[fallthrough]];
405-
} else {
406-
return false;
411+
412+
if (not erased or empty_after_erased) {
413+
if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) {
414+
auto const intra_bucket_index = thrust::distance(bucket_slots.begin(), &slot_content);
415+
switch (
416+
attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_bucket_index,
417+
slot_content,
418+
val)) {
419+
case insert_result::DUPLICATE: {
420+
if constexpr (allows_duplicates) {
421+
[[fallthrough]];
422+
} else {
423+
return false;
424+
}
407425
}
426+
case insert_result::CONTINUE: continue;
427+
case insert_result::SUCCESS: return true;
408428
}
409-
case insert_result::CONTINUE: continue;
410-
case insert_result::SUCCESS: return true;
411429
}
412430
}
431+
432+
// if (eq_res == detail::equal_result::AVAILABLE) {
433+
// auto const intra_bucket_index = thrust::distance(bucket_slots.begin(), &slot_content);
434+
// switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() +
435+
// intra_bucket_index,
436+
// slot_content,
437+
// val)) {
438+
// case insert_result::DUPLICATE: {
439+
// if constexpr (allows_duplicates) {
440+
// [[fallthrough]];
441+
// } else {
442+
// return false;
443+
// }
444+
// }
445+
// case insert_result::CONTINUE: continue;
446+
// case insert_result::SUCCESS: return true;
447+
// }
448+
// }
413449
}
414450
++probing_iter;
415451
if (*probing_iter == init_idx) { return false; }
@@ -442,8 +478,10 @@ class open_addressing_ref_impl {
442478
for (auto i = 0; i < bucket_size; ++i) {
443479
switch (
444480
this->predicate_.operator()<is_insert::YES>(key, this->extract_key(bucket_slots[i]))) {
445-
case detail::equal_result::AVAILABLE:
446-
return bucket_probing_results{detail::equal_result::AVAILABLE, i};
481+
case detail::equal_result::EMPTY:
482+
return bucket_probing_results{detail::equal_result::EMPTY, i};
483+
case detail::equal_result::ERASED:
484+
return bucket_probing_results{detail::equal_result::ERASED, i};
447485
case detail::equal_result::EQUAL: {
448486
if constexpr (allows_duplicates) {
449487
continue;
@@ -463,7 +501,8 @@ class open_addressing_ref_impl {
463501
if (group.any(state == detail::equal_result::EQUAL)) { return false; }
464502
}
465503

466-
auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
504+
auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or
505+
(state == detail::equal_result::ERASED));
467506
if (group_contains_available) {
468507
auto const src_lane = __ffs(group_contains_available) - 1;
469508
auto const status =
@@ -538,7 +577,7 @@ class open_addressing_ref_impl {
538577
}
539578
return {iterator{&bucket_ptr[i]}, false};
540579
}
541-
if (eq_res == detail::equal_result::AVAILABLE) {
580+
if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) {
542581
switch (this->attempt_insert_stable(bucket_ptr + i, bucket_slots[i], val)) {
543582
case insert_result::SUCCESS: {
544583
if constexpr (has_payload) {
@@ -626,7 +665,8 @@ class open_addressing_ref_impl {
626665
return {iterator{reinterpret_cast<value_type*>(res)}, false};
627666
}
628667

629-
auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
668+
auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or
669+
(state == detail::equal_result::ERASED));
630670
if (group_contains_available) {
631671
auto const src_lane = __ffs(group_contains_available) - 1;
632672
auto const res = group.shfl(reinterpret_cast<intptr_t>(slot_ptr), src_lane);

Diff for: include/cuco/detail/static_map/static_map_ref.inl

+6-4
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ class operator_impl<
510510
payload_ref.store(val.second, cuda::memory_order_relaxed);
511511
return;
512512
}
513-
if (eq_res == detail::equal_result::AVAILABLE) {
513+
if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) {
514514
if (attempt_insert_or_assign(slot_ptr, val)) { return; }
515515
}
516516
}
@@ -571,7 +571,8 @@ class operator_impl<
571571
return;
572572
}
573573

574-
auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
574+
auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or
575+
(state == detail::equal_result::ERASED));
575576
if (group_contains_available) {
576577
auto const src_lane = __ffs(group_contains_available) - 1;
577578
auto const status =
@@ -883,7 +884,7 @@ class operator_impl<
883884
op(cuda::atomic_ref<T, Scope>{slot_ptr->second}, val.second);
884885
return false;
885886
}
886-
if (eq_res == detail::equal_result::AVAILABLE) {
887+
if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) {
887888
switch (ref_.attempt_insert_or_apply<UseDirectApply>(slot_ptr, slot_content, val, op)) {
888889
case insert_result::SUCCESS: return true;
889890
case insert_result::DUPLICATE: {
@@ -970,7 +971,8 @@ class operator_impl<
970971
return false;
971972
}
972973

973-
auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
974+
auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or
975+
(state == detail::equal_result::ERASED));
974976
if (group_contains_available) {
975977
auto const src_lane = __ffs(group_contains_available) - 1;
976978
auto const status = [&, target_idx = intra_bucket_index]() {

Diff for: tests/static_map/erase_test.cu

+4
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ void test_erase(Map& map, size_type num_keys)
7575
REQUIRE(cuco::test::all_of(
7676
d_keys_exist.begin() + num_keys / 2, d_keys_exist.end(), thrust::identity{}));
7777

78+
// tests #606
79+
map.insert(pairs_begin + num_keys / 2, pairs_begin + num_keys);
80+
// TODO insert_and_find, insert_or_assign, insert_or_apply
81+
7882
map.erase(keys_begin + num_keys / 2, keys_begin + num_keys);
7983
REQUIRE(map.size() == 0);
8084
}

0 commit comments

Comments
 (0)