diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index 95f81e856d..b965a66128 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -55,6 +55,11 @@ if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CPU) src/metric_ops/metric_ops_host.cpp src/input_combine_ops/input_combine_gpu.cpp) + if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM) + list(APPEND fbgemm_gpu_sources_cpu_static + src/faster_hash_ops/faster_hash.cpp) + endif() + if(NVML_LIB_PATH OR FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM) message(STATUS "Adding merge_pooled_embeddings sources") list(APPEND fbgemm_gpu_sources_cpu_static @@ -122,6 +127,11 @@ if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_CPU) src/sparse_ops/sparse_reorder_batched_ad.cu src/sparse_ops/sparse_segment_sum_csr.cu src/sparse_ops/sparse_zipf.cu) + + if(NOT FBGEMM_BUILD_VARIANT STREQUAL BUILD_VARIANT_ROCM) + list(APPEND fbgemm_gpu_sources_gpu_static + src/faster_hash_ops/faster_hash.cu) + endif() endif() diff --git a/fbgemm_gpu/include/fbgemm_gpu/faster_hash_ops/common_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/faster_hash_ops/common_utils.cuh new file mode 100644 index 0000000000..bc60bf0a19 --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/faster_hash_ops/common_utils.cuh @@ -0,0 +1,137 @@ +/* + * The MIT License (MIT) + * + * Copyright (C) 2016 ExplosionAI GmbH, 2014-2015 Matthew Honnibal, 2016 spaCy + * GmbH + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#define AT_DISPATCH_INTEGER_TYPES(TYPE, NAME, HINT, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT(at::ScalarType::Int, HINT, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Long, HINT, __VA_ARGS__)) + +namespace fbgemm_gpu { + +#if defined(TORBOREC_CUDA) +#define TORBOREC_INLINE __device__ __host__ __inline__ +#else +#define TORBOREC_INLINE inline +#endif + +// NOLINTNEXTLINE: +TORBOREC_INLINE uint64_t +murmur_hash3_2x64(const uint64_t x, const uint64_t y, const uint64_t seed) { + const uint64_t c1 = 0x87c37b91114253d5; + const uint64_t c2 = 0x4cf5ad432745937f; + + uint64_t h1 = seed; + uint64_t h2 = seed; + + // First 64-bit block + uint64_t k1 = x; + k1 *= c1; + k1 = (k1 << 31) | (k1 >> (64 - 31)); + k1 *= c2; + h1 ^= k1; + h1 = (h1 << 27) | (h1 >> (64 - 27)); + h1 += h2; + h1 = h1 * 5 + 0x52dce729; + + // Second 64-bit block + uint64_t k2 = y; + k2 *= c2; + k2 = (k2 << 33) | (k2 >> (64 - 33)); + k2 *= c1; + h2 ^= k2; + h2 = (h2 << 31) | (h2 >> (64 - 31)); + h2 += h1; + h2 = h2 * 5 + 0x38495ab5; + + // Finalization + h1 ^= 16; + h2 ^= 16; + h1 += h2; + h2 += h1; + h1 ^= h1 >> 33; + h1 *= 0xff51afd7ed558ccd; + h1 ^= h1 >> 33; + h1 *= 0xc4ceb9fe1a85ec53; + h1 ^= h1 >> 33; + h2 ^= h2 >> 33; + h2 *= 0xff51afd7ed558ccd; + h2 ^= h2 >> 33; + h2 *= 0xc4ceb9fe1a85ec53; + h2 ^= h2 >> 33; + h1 += h2; + h2 += h1; + + return h1 ^ h2; +} + +// NOLINTNEXTLINE: +template +TORBOREC_INLINE int64_t next_output_index( + int64_t output_index, + int64_t modulo, + int64_t& /* max_probe_local */) { + static_assert(CIRCULAR_PROBE); + return (output_index + 1) % modulo; +} + +// NOLINTNEXTLINE: +template <> +TORBOREC_INLINE int64_t next_output_index( + int64_t output_index, + int64_t modulo, + int64_t& max_probe_local) { + output_index = (output_index + 1) % modulo; + if (output_index == 0) { + // circular, using max_probe_local to control exit. + max_probe_local = 0; + } + return output_index; +} + +TORBOREC_INLINE bool is_eviction_enabled( + bool readonly, + int eviction_threshold, + int eviction_policy) { + return !readonly && (eviction_threshold > 0 || eviction_policy > 0); +} + +#undef TORBOREC_INLINE + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/include/fbgemm_gpu/faster_hash_ops/faster_hash_ops.cuh b/fbgemm_gpu/include/fbgemm_gpu/faster_hash_ops/faster_hash_ops.cuh new file mode 100644 index 0000000000..0633879b4b --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/faster_hash_ops/faster_hash_ops.cuh @@ -0,0 +1,80 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +/// @defgroup faster-hash-ops CUDA Operators +/// The following are CUDA Operators + +namespace fbgemm_gpu { + +using at::Tensor; + +///@ingroup faster-hash-ops +/// +/// CUDA implementation of zero collision hash +/// +/// @param output the output tensor that will be modified in place +/// @param evict_slots the slots that will be evicted +/// @param input the input tensor +/// @param identities the identity tensor +/// @param max_probe the maximum number of probes +/// @param circular_probe whether to use circular probe +/// @param cur_hour the current hour +/// @param readonly whether to use readonly mode +/// @param support_evict whether to support evict +/// @param local_sizes the local sizes tensor +/// @param offsets the offsets tensor +/// @param hash_identity whether to hash the identity +/// @param metadata the metadata tensor +/// @param disable_fallback whether to disable fallback +/// @param input_metadata the input metadata tensor +/// @param eviction_threshold the eviction threshold +/// @param eviction_policy the eviction policy +/// @param opt_in_prob the opt-in probability +/// @param num_reserved_slots the number of reserved slots +/// @param opt_in_rands the opt-in randoms tensor +/// +/// @return None +template +void _zero_collision_hash_cuda( + Tensor& output, + Tensor& evict_slots, + const Tensor& input, + Tensor& identities, + int64_t max_probe, + bool circular_probe, + int64_t cur_hour, + bool readonly, + bool support_evict, + const std::optional& local_sizes, + const std::optional& offsets, + int32_t hash_identity, + const std::optional& metadata, + bool disable_fallback, + const std::optional& input_metadata, + int64_t eviction_threshold, + int64_t eviction_policy, + int64_t opt_in_prob, + int64_t num_reserved_slots, + const std::optional& opt_in_rands); + +///@ingroup faster-hash-ops +/// +/// CUDA implementation of murmurhash3 +/// +/// @param input the input tensor +/// @param y the y value +/// @param seed the seed value + +/// @return the output tensor +Tensor murmur_hash3_cuda(const Tensor& input, int64_t y, int64_t seed); + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/include/fbgemm_gpu/faster_hash_ops/faster_hash_ops.h b/fbgemm_gpu/include/fbgemm_gpu/faster_hash_ops/faster_hash_ops.h new file mode 100644 index 0000000000..c0a0c2f415 --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/faster_hash_ops/faster_hash_ops.h @@ -0,0 +1,187 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace fbgemm_gpu { + +using at::Tensor; + +/// @defgroup faster-hash-ops CPP Operators +/// + +/// @ingroup faster-hash-ops +/// +/// Create buffers for identity table and metadata table for ZCH +/// +/// @param size The target tensor dimensions +/// @param support_evict Whether to support eviction +/// @param device The device to allocate the tensor on +/// @param long_type Whether to use long type for the tensor +/// +/// @return A tuple of two tensors, the first tensor is the +// identity table and the second tensor is the metadata table +std::tuple create_zch_buffer_cpu( + const int64_t size, + bool support_evict, + std::optional device, + bool long_type); + +/// @ingroup faster-hash-ops +/// +/// Murmur hash operator for CPU +/// +/// @param input The input tensor +/// @param y The y value +/// @param seed The seed value +/// +/// @return The output hash value +Tensor murmur_hash3_cpu(const Tensor& input, int64_t y, int64_t seed); + +/// @ingroup faster-hash-ops +/// +/// Zero collision hash operator for CPU +/// +/// @param input The input tensor +/// @param identities The identity table +/// @param max_probe The maximum number of probes +/// @param circular_probe Whether to use circular probe +/// @param exp_hours The number of hours before identity table item's +/// expirition +/// @param readonly Whether to use readonly mode +/// @param local_sizes The local sizes tensor +/// @param offsets The offsets tensor +/// @param metadata The metadata tensor +/// @param output_on_uvm Whether to output on UVM +/// @param disable_fallback Whether to disable fallback +/// @param _modulo_identity_DPRECATED The modulo identity +/// @param input_metadata The input metadata tensor +/// @param eviction_threshold The eviction threshold +/// @param eviction_policy The eviction policy +/// @param opt_in_prob The opt-in probability +/// @param num_reserved_slots The number of reserved slots +/// @param opt_in_rands The opt-in randoms tensor +/// +/// @return A tuple of two tensors, the first tensor is the +/// output tensor and the second tensor is the slots to be evicted +std::tuple zero_collision_hash_cpu( + const Tensor& input, + Tensor& identities, + int64_t max_probe, + bool circular_probe, + int64_t exp_hours, + bool readonly, + const std::optional& local_sizes, + const std::optional& offsets, + const std::optional& metadata, + bool /* output_on_uvm */, + bool disable_fallback, + bool _modulo_identity_DPRECATED, + const std::optional& input_metadata, + int64_t eviction_threshold, + int64_t /* eviction_policy */, + int64_t opt_in_prob, + int64_t num_reserved_slots, + const std::optional& opt_in_rands); + +/// @ingroup faster-hash-ops +/// +/// Zero collision hash operator for data on meta device +/// +/// @param input The input tensor +/// @param identities The identity table +/// @param max_probe The maximum number of probes +/// @param circular_probe Whether to use circular probe +/// @param exp_hours The number of hours before identity table item's expirition +/// @param readonly Whether to use readonly mode +/// @param local_sizes The local sizes tensor +/// @param offsets The offsets tensor +/// @param metadata The metadata tensor +/// @param output_on_uvm Whether to output on UVM +/// @param disable_fallback Whether to disable fallback +/// @param _modulo_identity_DPRECATED The modulo identity +/// @param input_metadata The input metadata tensor +/// @param eviction_threshold The eviction threshold +/// @param eviction_policy The eviction policy +/// @param opt_in_prob The opt-in probability +/// @param num_reserved_slots The number of reserved slots +/// @param opt_in_rands The opt-in randoms tensor +/// +/// @return A tuple of two tensors, the first tensor is the +/// output tensor and the second tensor is the slots to be evicted +std::tuple zero_collision_hash_meta( + const Tensor& input, + Tensor& /* identities */, + int64_t /* max_probe */, + bool /* circular_probe */, + int64_t /* exp_hours */, + bool /* readonly */, + const std::optional& /* local_sizes */, + const std::optional& /* offsets */, + const std::optional& /* metadata */, + bool /* output_on_uvm */, + bool /* disable_fallback */, + bool /* _modulo_identity_DPRECATED */, + const std::optional& /* input_metadata */, + int64_t /* eviction_threshold */, + int64_t /* eviction_policy */, + int64_t /* opt_in_prob */, + int64_t /* num_reserved_slots */, + const std::optional& /* opt_in_rands */); + +/// @ingroup faster-hash-ops +/// +/// Murmur hash operator for Meta device +/// +/// @param input The input tensor +/// @param y The y value +/// @param seed The seed value +Tensor murmur_hash3_meta(const Tensor& input, int64_t y, int64_t seed); + +// /// @ingroup faster-hash-ops +// /// +// /// process one item for zero collision hash +// /// +// /// @param input The input tensor +// /// @param output The output tensor +// /// @param identities The identity table +// /// @param modulo The modulo +// /// @param max_probe The maximum number of probes +// /// @param local_sizes The local sizes tensor +// /// @param offsets The offsets tensor +// /// @param opt_in_prob The opt-in probability +// /// @param num_reserved_slots The number of reserved slots +// /// +// /// @return A template with the following parameters: +// /// DISABLE_FALLBACK: Whether to disable fallback +// /// HASH_IDENTITY: The hash identity +// /// CIRCULAR_PROBE: Whether to use circular probe +// /// HAS_OFFSET: Whether to have offset +// /// - TInput: The type of the input tensor +// /// - TIdentity: The type of the identity table +// template < +// bool DISABLE_FALLBACK, +// int32_t HASH_IDENTITY, +// bool CIRCULAR_PROBE, +// bool HAS_OFFSET, +// typename TInput, +// typename TIdentity> +// void process_item_zch( +// const at::PackedTensorAccessor64& input, +// at::PackedTensorAccessor64 output, +// const at::PackedTensorAccessor64& identities, +// int64_t modulo, +// int64_t max_probe, +// const int64_t* const local_sizes, +// const int64_t* const offsets, +// int64_t opt_in_prob, +// int64_t num_reserved_slots) + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/faster_hash_ops/faster_hash.cpp b/fbgemm_gpu/src/faster_hash_ops/faster_hash.cpp new file mode 100644 index 0000000000..0c30156e6e --- /dev/null +++ b/fbgemm_gpu/src/faster_hash_ops/faster_hash.cpp @@ -0,0 +1,612 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // @manual +#include // @manual +#include // @manual +#include // @manual +#include "c10/core/ScalarType.h" // @manual +#include "c10/core/TensorImpl.h" // @manual +#include "fbgemm_gpu/faster_hash_ops/common_utils.cuh" // @manual +#include "fbgemm_gpu/faster_hash_ops/faster_hash_ops.h" // @manual + +/* Inference ONLY op */ + +#define FASTER_HASH_CPU_INTRO_OP_PARALLEL 0 + +namespace fbgemm_gpu { + +using at::Tensor; + +namespace { +static constexpr int32_t kDefaultTensor = -1; +static constexpr int64_t kMaxIdentityNum = INT32_MAX; + +template < + bool DISABLE_FALLBACK, + int32_t HASH_IDENTITY, + bool CIRCULAR_PROBE, + bool HAS_OFFSET, + typename TInput, + typename TIdentity> +void process_item_zch( + const at::PackedTensorAccessor64& input, + at::PackedTensorAccessor64 output, + const at::PackedTensorAccessor64& identities, + int64_t modulo, + int64_t max_probe, + const int64_t* const local_sizes, + const int64_t* const offsets, + int64_t opt_in_prob, + int64_t num_reserved_slots) { + // Do we need multi-threading here considering prediction are already + // multi-threaded over requests? + + int64_t total_items = input.size(0); + +#ifdef FASTER_HASH_CPU_INTRO_OP_PARALLEL + at::parallel_for( + 0, + total_items, + FASTER_HASH_CPU_INTRO_OP_PARALLEL, + [&](int64_t t_begin, int64_t t_end) { +#else + int64_t t_begin = 0; + int64_t t_end = total_items; +#endif + for (auto process_index = t_begin; process_index < t_end; + ++process_index) { + auto item = input[process_index]; + int64_t offset = 0; + if constexpr (HAS_OFFSET) { + modulo = local_sizes[process_index]; + offset = offsets[process_index]; + } + + auto hash = murmur_hash3_2x64(static_cast(item), 0, 0); + auto opt_in_block_size = + opt_in_prob == -1 ? modulo : modulo - num_reserved_slots; + auto output_index = + static_cast(hash % opt_in_block_size); // Local idx + TIdentity identity; + + if constexpr (HASH_IDENTITY == 1) { + identity = static_cast( + murmur_hash3_2x64( + static_cast(item), + 0x17, // seed + 0) % + kMaxIdentityNum); + } else if constexpr (HASH_IDENTITY == 2) { + identity = static_cast(item % kMaxIdentityNum); + } else { + identity = static_cast(item); + } + + auto max_probe_local = max_probe; + while (max_probe_local-- > 0) { + auto insert_idx = output_index + offset; + auto current_slot_identity = identities[insert_idx][0]; + // Inference treat empty slot (kDefaultTensor) as collision and + // continue next probe + if (current_slot_identity == identity) { + break; + } + + output_index = next_output_index( + output_index, + opt_in_block_size, // only probe within the opt-in block + max_probe_local); + } + + // can't find a slot (all slot full after probing) + if (max_probe_local < 0) { + if constexpr (DISABLE_FALLBACK) { + output_index = -1; + offset = 0; + } else { + output_index = opt_in_prob == -1 + ? static_cast(hash % modulo) + : opt_in_block_size + + static_cast(hash % num_reserved_slots); + } + } + + output[process_index] = output_index + offset; + } +#ifdef FASTER_HASH_CPU_INTRO_OP_PARALLEL + }); +#endif +} + +template +void _zero_collision_hash_cpu_out( + Tensor& output, + const Tensor& input, + const Tensor& identities, + int64_t max_probe, + const bool circular_probe, + const std::optional& local_sizes, + const std::optional& offsets, + int32_t hash_identity, + bool disable_fallback, + int64_t opt_in_prob, + int64_t num_reserved_slots) { + int64_t modulo = identities.size(0); + auto* local_sizes_ptr = + local_sizes.has_value() ? local_sizes->data_ptr() : nullptr; + auto* offsets_ptr = + offsets.has_value() ? offsets->data_ptr() : nullptr; + +#define INVOKE_KERNEL( \ + DISABLE_FALLBACK, HASH_IDENTITY, CIRCULAR_PROBE, HAS_OFFSET) \ + { \ + process_item_zch< \ + DISABLE_FALLBACK, \ + HASH_IDENTITY, \ + CIRCULAR_PROBE, \ + HAS_OFFSET, \ + TInput, \ + TIdentity>( \ + input.packed_accessor64(), \ + output.packed_accessor64(), \ + identities.packed_accessor64(), \ + modulo, \ + max_probe, \ + local_sizes_ptr, \ + offsets_ptr, \ + opt_in_prob, \ + num_reserved_slots); \ + } + +#define INVOKE_HASH_IDENTITY(HASH_IDENTITY, CIRCULAR_PROBE, HAS_OFFSET) \ + { \ + if (disable_fallback) { \ + INVOKE_KERNEL(true, HASH_IDENTITY, CIRCULAR_PROBE, HAS_OFFSET) \ + } else { \ + INVOKE_KERNEL(false, HASH_IDENTITY, CIRCULAR_PROBE, HAS_OFFSET) \ + } \ + } + +#define INVOKE_KERNEL_CIRCULAR_PROBE(CIRCULAR_PROBE, HAS_OFFSET) \ + { \ + if (hash_identity == 1) { \ + INVOKE_HASH_IDENTITY(1, CIRCULAR_PROBE, HAS_OFFSET); \ + } \ + if (hash_identity == 2) { \ + INVOKE_HASH_IDENTITY(2, CIRCULAR_PROBE, HAS_OFFSET); \ + } else { \ + INVOKE_HASH_IDENTITY(0, CIRCULAR_PROBE, HAS_OFFSET); \ + } \ + } + +#define INVOKE_KERNEL_HAS_OFFSET(HAS_OFFSET) \ + { \ + if (circular_probe) { \ + INVOKE_KERNEL_CIRCULAR_PROBE(true, HAS_OFFSET); \ + } else { \ + INVOKE_KERNEL_CIRCULAR_PROBE(false, HAS_OFFSET); \ + } \ + } + + if (local_sizes_ptr != nullptr) { + INVOKE_KERNEL_HAS_OFFSET(true); + } else { + INVOKE_KERNEL_HAS_OFFSET(false); + } + +#undef INVOKE_KERNEL_HAS_OFFSET +#undef INVOKE_KERNEL_CIRCULAR_PROBE +#undef INVOKE_HASH_IDENTITY +#undef INVOKE_KERNEL +} + +} // namespace + +std::tuple zero_collision_hash_meta( + const Tensor& input, + Tensor& /* identities */, + int64_t /* max_probe */, + bool /* circular_probe */, + int64_t /* exp_hours */, + bool /* readonly */, + const std::optional& /* local_sizes */, + const std::optional& /* offsets */, + const std::optional& /* metadata */, + bool /* output_on_uvm */, + bool /* disable_fallback */, + bool /* _modulo_identity_DPRECATED */, + const std::optional& /* input_metadata */, + int64_t /* eviction_threshold */, + int64_t /* eviction_policy */, + int64_t /* opt_in_prob */, + int64_t /* num_reserved_slots */, + const std::optional& /* opt_in_rands */) { + auto out = + at::zeros_symint({input.sym_numel()}, input.options().dtype(at::kLong)); + auto evcit_slots = at::zeros_symint({0}, input.options()); + return {input, evcit_slots}; +} + +Tensor murmur_hash3_meta(const Tensor& input, int64_t y, int64_t seed) { + auto hash = murmur_hash3_2x64( + input.item().to(), + static_cast(y), + static_cast(seed)); + return at::scalar_tensor( + hash, c10::TensorOptions().dtype(at::kLong).device(at::kCPU)); +} + +std::tuple create_zch_buffer_cpu( + const int64_t size, + bool support_evict, + std::optional device, + bool long_type) { + Tensor metadata; + auto identity = at::full( + {size, 1}, + kDefaultTensor, + c10::TensorOptions() + .dtype(long_type ? at::kLong : at::kInt) + .device(device)); + if (support_evict) { + metadata = at::full( + {size, 1}, + kDefaultTensor, + c10::TensorOptions().dtype(at::kInt).device(device)); + } + return {identity, metadata}; +} + +void zero_collision_hash_cpu_out( + Tensor& output, + const Tensor& input, + const Tensor& identities, + int64_t max_probe, + bool circular_probe, + const std::optional& local_sizes, + const std::optional& offsets, + bool disable_fallback, + bool _modulo_identity_DPRECATED, + int64_t opt_in_prob, + int64_t num_reserved_slots) { + TORCH_CHECK(output.is_cpu()); + TORCH_CHECK(output.dtype() == torch::kInt64); + + TORCH_CHECK(input.is_cpu()); + TORCH_CHECK(identities.dim() == 2); + + int hash_identity = _modulo_identity_DPRECATED ? 2 : 1; + if (identities.dtype() == input.dtype()) { + hash_identity = 0; + } + if (input.dtype() == torch::kInt32) { + TORCH_CHECK(identities.dtype() == torch::kInt32); + } + + if (local_sizes.has_value()) { + TORCH_CHECK(local_sizes->is_cpu()); + TORCH_CHECK(input.numel() == local_sizes->numel()); + } + if (offsets.has_value()) { + TORCH_CHECK(offsets->is_cpu()); + TORCH_CHECK(input.numel() == offsets->numel()); + } + if (opt_in_prob != -1) { + TORCH_CHECK(opt_in_prob > 0 && opt_in_prob < 100); + TORCH_CHECK(num_reserved_slots > 0); + } + if (num_reserved_slots != -1) { + TORCH_CHECK(opt_in_prob != -1); + } + + AT_DISPATCH_INTEGER_TYPES( + input.scalar_type(), "zero_collision_hash_input", input_t, [&]() { + AT_DISPATCH_INTEGER_TYPES( + identities.scalar_type(), + "zero_collision_hash_identity", + identity_t, + [&]() { + _zero_collision_hash_cpu_out( + output, + input, + identities, + max_probe, + circular_probe, + local_sizes, + offsets, + hash_identity, + disable_fallback, + opt_in_prob, + num_reserved_slots); + }); + }); +} + +std::tuple zero_collision_hash_cpu( + const Tensor& input, + Tensor& identities, + int64_t max_probe, + bool circular_probe, + int64_t exp_hours, + bool readonly, + const std::optional& local_sizes, + const std::optional& offsets, + const std::optional& metadata, + bool /* output_on_uvm */, + bool disable_fallback, + bool _modulo_identity_DPRECATED, + const std::optional& input_metadata, + int64_t eviction_threshold, + int64_t /* eviction_policy */, + int64_t opt_in_prob, + int64_t num_reserved_slots, + const std::optional& opt_in_rands) { + TORCH_CHECK(exp_hours == -1); + TORCH_CHECK(readonly); + TORCH_CHECK(metadata.has_value() == false); + TORCH_CHECK(input_metadata.has_value() == false); + TORCH_CHECK(eviction_threshold == -1); + TORCH_CHECK(opt_in_rands.has_value() == false); + + int64_t output_size = input.size(0); + c10::TensorOptions options = + c10::TensorOptions().dtype(at::kLong).device(input.device()); + Tensor output = at::empty({output_size}, options); + + // evict_slots will contains the index to be evcited, '-1' will be ignored. + Tensor evict_slots; + + if (output_size == 0) { + return {output, evict_slots}; + } + + zero_collision_hash_cpu_out( + output, + input, + identities, + max_probe, + circular_probe, + local_sizes, + offsets, + disable_fallback, + _modulo_identity_DPRECATED, + opt_in_prob, + num_reserved_slots); + + return {output, evict_slots}; +} + +Tensor murmur_hash3_cpu(const Tensor& input, int64_t y, int64_t seed) { + TORCH_CHECK(input.is_cpu()); + TORCH_CHECK(input.dtype() == torch::kInt64); + TORCH_CHECK(input.dim() == 1); + + return at::scalar_tensor( + murmur_hash3_2x64( + input.item().to(), + static_cast(y), + static_cast(seed)), + c10::TensorOptions().dtype(at::kLong)); +} + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + // Create identities buffer. As we need everything to be -1. + // One could also create themsleves, as long as follow the protocol: + // 1. all value should be -1. + // 2. the tensor should be two dimensions. + // 3. if support evict, need two columns, otherwise, just one column. + // + // Args: + // size: define identities tensor size. + // support_evict: whether we support evict. + // + // Result: + // Tuple[tensor, tensor] for identities and metadata. + // identity: Shape (D, 2) with size(1) = 1 + // metadata: Shape (D, 2) with size(1) = 1 + // + // For other examples, consult the unittests. + m.def( + "create_zch_buffer(" + "int size, " + "bool support_evict=False," + "Device? device=None," + "bool long_type=False" + ") -> (Tensor, Tensor)"); + // Default impl + m.impl("create_zch_buffer", TORCH_FN(create_zch_buffer_cpu)); + + // technically this is not zero collision, but low collision. Trade-off + // between probing speed. (Setting probes to a large value and a larger + // identities tensor size could make it zero collision.) + // + // Here we have a few features: + // 1. probing to find next available slot for hash collision to reduce + // collision. + // 2. non circular probing - as this will be used in local rank, and later + // in publish stage, we will combine all local rank as a global tensor, + // hence non circular probing could make sure probing logic problems. + // 3. eviction - a slot could be evited if it's not been used for a while. + // 4. readonly mode - use for inference, in inference, we don't need atomic + // operation as everything are readonly. + // + // Args: + // input: ids to find slots. Shape (D) + // identities: a tensor which stores identities for ids. Shape (D, 1). + // max_probe: max probing, reach max will fall back to original hash + // position. recommend use 128. + // circular_probe: when hitting end of identities tensor, circular to + // beginning of identities tensor to find slots or not. + // exp_hours (to be deprecated): how many hours without any updates + // considering as slot for eviction. setting as -1 means + // disabling eviction. + // readonly: enable readonly mode or not. Perf will be much faster. + // local_sizes: local size for each chunk. Used to recover the index in + // sharded case. + // offsets: offsets for each chunk. Used to recover the index in sharded + // case. + // disable_fallback: the fallback behavior when an ID does not exist. If + // true, -1 is returned, which indicates it fails to find a + // position for this ID. If false, the position of the first + // probe is returned. + // input_metadata: the metadata for each individual ID. It will become the + // metadata of the slot if the ID is accepted to that slot. + // While it is often used to represent an ID's TTL, the meaning + // can vary. + // eviction_threshold: the threshold selected for eviction. Kernel makes + // an + // eviction decision based on the existing metadata associated + // with slots and the eviction threshold. + // eviction_policy: the kernel based on the eviction policy. + // 0: No eviction or TTL based eviction. + // 1: LRU based eviction timestamped on the hour. + // opt_in_prob: the probability of a new ID being opted in (valid range: 1 + // to 99). If -1, all new IDs are opted in (100%). + // num_reserved_slots: the number of slots reserved (located in the tail) + // for IDs that are not opted in. A non-zero value is required + // when opt-in is enabled. -1 indicates no reserved slots (100% + // opt-in). If the size of embedding table is x, and + // num_reserved_slots is y, then the size of the opt-in block + // will be (x - y). + // opt_in_rands: the random numbers used to determine whether incoming IDs + // should be accepted when opt-in is enabled. Its generated by + // caller of the kernel and its size needs to be identical to + // the input size. Each new ID will be accepted only if its + // rand number is less than opt_in_prob. + // Result: + // identities index tensor: the slots found for the ids. Shape (D) + // evict slots: the index to identities tensor, indicating which slots got + // evicted. note, need to remove '-1' index. + // + // For other examples, consult the unittests. + m.def( + "zero_collision_hash(" + "Tensor input, " + "Tensor identities, " + "int max_probe, " + "bool circular_probe=False, " + "int exp_hours=-1, " + "bool readonly=False, " + "Tensor? local_sizes=None, " + "Tensor? offsets=None, " + "Tensor? metadata=None, " + "bool output_on_uvm=False, " + "bool disable_fallback=False, " + "bool _modulo_identity_DPRECATED=False, " + "Tensor? input_metadata=None, " + "int eviction_threshold=-1, " + "int eviction_policy=0, " + "int opt_in_prob=-1, " + "int num_reserved_slots=-1, " + "Tensor? opt_in_rands=None " + ") -> (Tensor, Tensor)"); + + // define the + m.def( + "murmur_hash3(" + "Tensor input, " + "int y, " + "int seed" + ") -> Tensor"); + // Default impl + m.impl("murmur_hash3", TORCH_FN(murmur_hash3_cpu)); +} + +TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { + m.impl( + "create_zch_buffer", + torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(create_zch_buffer_cpu))); + + m.impl( + "zero_collision_hash", + torch::dispatch( + c10::DispatchKey::CPU, TORCH_FN(zero_collision_hash_cpu))); + + m.impl( + "murmur_hash3", + torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(murmur_hash3_cpu))); +} + +TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { + m.impl( + "zero_collision_hash", + torch::dispatch( + c10::DispatchKey::Meta, TORCH_FN(zero_collision_hash_meta))); + m.impl( + "murmur_hash3", + torch::dispatch(c10::DispatchKey::Meta, TORCH_FN(murmur_hash3_meta))); +} + +} // namespace fbgemm_gpu + +namespace torch::jit { + +using at::Tensor; + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) +REGISTER_NATIVE_OPERATOR_FUNCTOR( + fbgemm::operators::zero_collision_hash_cpu, + fb_zero_collision_hash_cpu, + [](Node* n) -> SROperator { + if (!n->matches(torch::schema("fbgemm::zero_collision_hash(" + "Tensor input, " + "Tensor identities, " + "int max_probe, " + "bool circular_probe=False, " + "int exp_hours=-1, " + "bool readonly=False, " + "Tensor? local_sizes=None, " + "Tensor? offsets=None, " + "Tensor? metadata=None, " + "bool output_on_uvm=False, " + "bool disable_fallback=False, " + "bool _modulo_identity_DPRECATED=False, " + "Tensor? input_metadata=None, " + "int eviction_threshold=-1, " + "int eviction_policy=0, " + "int opt_in_prob=-1, " + "int num_reserved_slots=-1, " + "Tensor? opt_in_rands=None" + ") -> (Tensor, Tensor)"))) { + LogAndDumpSchema(n); + return nullptr; + } + return [](ProcessedNode* p_node) { + const auto& input = p_node->Input(0).toTensor(); + const auto& identities = p_node->Input(1).toTensor(); + const auto max_probe = p_node->Input(2).toInt(); + const auto circular_probe = p_node->Input(3).toBool(); + + const auto& local_sizes = p_node->Input(6).toOptional(); + const auto& offsets = p_node->Input(7).toOptional(); + const auto& disable_fallback = p_node->Input(10).to(); + const auto& _modulo_identity_DPRECATED = p_node->Input(11).to(); + const auto opt_in_prob = p_node->Input(15).toInt(); + const auto num_reserved_slots = p_node->Input(16).toInt(); + + if (p_node->Output(0).isNone()) { + const at::ScalarType output_type = kLong; + p_node->Output(0) = torch::jit::create_empty_from(input, output_type); + } + auto& out_t = p_node->Output(0).toTensor(); + fbgemm_gpu::zero_collision_hash_cpu_out( + out_t, + input, + identities, + max_probe, + circular_probe, + local_sizes, + offsets, + disable_fallback, + _modulo_identity_DPRECATED, + num_reserved_slots, + opt_in_prob); + }; + }); +} // namespace torch::jit diff --git a/fbgemm_gpu/src/faster_hash_ops/faster_hash.cu b/fbgemm_gpu/src/faster_hash_ops/faster_hash.cu new file mode 100644 index 0000000000..d86da305f1 --- /dev/null +++ b/fbgemm_gpu/src/faster_hash_ops/faster_hash.cu @@ -0,0 +1,879 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include // @manual +#include +#include // @manual +#include +#include +#include +#include +#include // @manual +#include // @manual +#include +#include + +#define TORBOREC_CUDA +#include "fbgemm_gpu/faster_hash_ops/common_utils.cuh" // @manual +#include "fbgemm_gpu/faster_hash_ops/faster_hash_ops.cuh" // @manual + +namespace fbgemm_gpu { + +namespace { +using at::Tensor; + +static constexpr int32_t kDefaultTensor = -1; +static constexpr int64_t kMaxIdentityNum = INT32_MAX; +static constexpr int64_t kMaxHours = INT32_MAX; +static constexpr int64_t kSecondsInHour = 60 * 60; + +template +__device__ __inline__ T CAS(T* data, T cmp, T val) { + return atomicCAS(data, cmp, val); +} + +template <> +__device__ __inline__ int64_t +CAS(int64_t* data, int64_t cmp, int64_t val) { + return static_cast(atomicCAS( + reinterpret_cast(data), + static_cast(cmp), + static_cast(val))); +} + +template +__device__ __inline__ void update_metadata( + int32_t* /* metadata */, + int64_t /* output_index */, + int32_t /* metadata_val */) { + static_assert(METADATA_COUNT != 1); + // no op. +} + +template <> +__device__ __inline__ void update_metadata<1>( + int32_t* metadata, + int64_t output_index, + int32_t metadata_val) { + atomicMax(metadata + output_index, metadata_val); +} + +template +__device__ __inline__ void update_metadata_lru( + int32_t* /* metadata */, + int64_t /* output_index */, + int32_t /* val */, + int32_t* /* process_lock */) { + static_assert(METADATA_COUNT != 1); + // no-op +} + +template <> +__device__ __inline__ void update_metadata_lru<1>( + int32_t* metadata, + int64_t output_index, + int32_t val, + int32_t* process_lock) { + // These should be atomic as we release process lock as last step + atomicExch(metadata + output_index, val); + // Release process lock from index + atomicExch(process_lock + output_index, kDefaultTensor); +} + +template +__device__ __inline__ int64_t check_min( + int32_t /* process_index */, + int32_t* /* metadata */, + int64_t min_index, + int64_t /* output_index */, + int64_t /* offset */, + int32_t& /* min_hours */, + int32_t* /* process_lock */, + at::PackedTensorAccessor64 /* identities */, + TIdentity& /* min_slot_identity */, + int32_t /* eviction_threshold */, + std::enable_if_t* = nullptr) { + static_assert(METADATA_COUNT == 0); + // For inference, we keep the same min_index until the ID is found. + return min_index; +} + +template +__device__ __inline__ int64_t check_min( + int32_t process_index, + int32_t* metadata, + int64_t min_index, + int64_t output_index, + int64_t offset, + int32_t& min_hours, + int32_t* process_lock, + at::PackedTensorAccessor64 identities, + TIdentity& min_slot_identity, + int32_t eviction_threshold, + std::enable_if_t* = nullptr) { + static_assert(METADATA_COUNT == 1); + // There could be a case, one id has already occupy the slot, + // and last update hour is not written yet, while the other id checking the + // slot for min index, then it would '-1' in this case, hence we need to + // wait. + auto insert_idx = output_index + offset; + int32_t last_seen = kDefaultTensor; + while (true) { + last_seen = + atomicCAS(metadata + insert_idx, kDefaultTensor, kDefaultTensor); + if (last_seen != kDefaultTensor) { + break; + } + } + + // only check those expired slots + if (eviction_threshold > last_seen && min_hours > last_seen) { + // Try to lock index for thread + auto old_pid = + atomicCAS(process_lock + insert_idx, kDefaultTensor, process_index); + if (old_pid == kDefaultTensor) { + // Index locked for this thread + // Check if value is still same and not updated by other thread + if (last_seen == *(metadata + insert_idx)) { + if (min_index != -1) { + // Release lock on previous min_index + atomicCAS( + process_lock + min_index + offset, process_index, kDefaultTensor); + } + // Update min_index to current + min_index = output_index; + min_hours = last_seen; + min_slot_identity = identities[insert_idx][0]; + } else { + // Value updated by other thread. Release lock on this index + atomicCAS(process_lock + insert_idx, process_index, kDefaultTensor); + } + } + } + return min_index; +} + +template +__device__ __inline__ bool check_evict( + int32_t* /* metadata */, + int64_t /* output_index */, + int32_t /* eviction_threshold */) { + static_assert(METADATA_COUNT != 1); + return false; +} + +template <> +__device__ __inline__ bool check_evict<1>( + int32_t* metadata, + int64_t output_index, + int32_t eviction_threshold) { + // In rare case, one id may have already occupied the slot but its metadata + // has not been written yet, while the other id checking the slot's eviction + // status. Therefore, wait until the metadata is not -1. + int32_t identity_metadata = kDefaultTensor; + while (true) { + identity_metadata = + atomicCAS(metadata + output_index, kDefaultTensor, kDefaultTensor); + if (identity_metadata != kDefaultTensor) { + break; + } + } + + return eviction_threshold > identity_metadata; +} + +template +__device__ __inline__ bool check_and_maybe_update_slot( + TIdentity* identities_slot, + TIdentity identity, + TIdentity& old_value, + std::enable_if_t* = nullptr) { + static_assert(READONLY); + old_value = *identities_slot; + if (old_value == identity) { + return true; + } + return false; +} + +template +__device__ __inline__ bool check_and_maybe_update_slot( + TIdentity* identities_slot, + TIdentity identity, + TIdentity& old_value, + std::enable_if_t* = nullptr) { + static_assert(!READONLY); + old_value = + CAS(identities_slot, static_cast(kDefaultTensor), identity); + if ((old_value == identity) || + (old_value == static_cast(kDefaultTensor))) { + return true; + } + return false; +} + +template +__device__ __inline__ int64_t get_identity_slot( + at::PackedTensorAccessor64 identities, + TIdentity identity, + int64_t output_index, + int64_t offset, + int64_t modulo, + int64_t max_probe) { + while (max_probe-- > 0) { + auto insert_idx = output_index + offset; + auto current_slot_identity = identities[insert_idx][0]; + if (current_slot_identity == kDefaultTensor) { + // Hits end but still don't find, don't disable eviction. + return -1; + } else if (current_slot_identity == identity) { + // there is identity in probing distance, we shouldn't evict. + return output_index; + } + + output_index = + next_output_index(output_index, modulo, max_probe); + } + + // Nothing found, don't disable eviction. + return -1; +} + +template < + int32_t EVICTION_POLICY, + bool DISABLE_FALLBACK, + int32_t HASH_IDENTITY, + int32_t METADATA_COUNT, + bool CIRCULAR_PROBE, + bool READONLY, + typename TInput, + typename TIdentity> +__global__ void process_item_zch( + const at::PackedTensorAccessor64 input, + at::PackedTensorAccessor64 output, + int64_t* evict_slots, + at::PackedTensorAccessor64 identities, + int64_t modulo, + int64_t max_probe, + int32_t cur_hour, + const int64_t* const local_sizes, + const int64_t* const offsets, + int32_t* metadata, + const int32_t* const input_metadata, + int32_t eviction_threshold, + int32_t* /* process_lock */, + int64_t opt_in_prob, + int64_t num_reserved_slots, + const int32_t* const opt_in_rands, + TORCH_DSA_KERNEL_ARGS, + std::enable_if_t* = nullptr) { + static_assert(EVICTION_POLICY == 0); + + // Stride loop: + // https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ + // NOLINTNEXTLINE: Implicitly casting + auto total_items = input.size(0); + for (auto process_index = blockIdx.x * blockDim.x + threadIdx.x; + process_index < total_items; + // NOLINTNEXTLINE: Implicitly casting + process_index += blockDim.x * gridDim.x) { + auto item = input[process_index]; + if (local_sizes != nullptr) { + modulo = local_sizes[process_index]; + } + int64_t offset = 0; + if (offsets != nullptr) { + offset = offsets[process_index]; + } + // for backward compatibility: previous implementation assigns cur_hour + // to metadata + int32_t metadata_val = + input_metadata != nullptr ? input_metadata[process_index] : cur_hour; + + auto hash = murmur_hash3_2x64(static_cast(item), 0, 0); + auto opt_in_block_size = + opt_in_prob == -1 ? modulo : modulo - num_reserved_slots; + auto output_index = + static_cast(hash % opt_in_block_size); // Local idx + TIdentity identity; + + if constexpr (HASH_IDENTITY == 1) { + identity = static_cast( + murmur_hash3_2x64( + static_cast(item), + 0x17, // seed + 0) % + kMaxIdentityNum); + } else if (HASH_IDENTITY == 2) { + identity = static_cast(item % kMaxIdentityNum); + } else { + identity = item; + } + + // probing. + auto max_probe_local = max_probe; + TIdentity old_value = kDefaultTensor; + + // In eviction mode. We might run into case that an ID has already + // had a slot, in position hash(id) + 50 due to probing. + // Now between hash(id) to hash(id) + 50 has an expiration slot, during + // next look up of this id, we will expire that slot and put this id in + // that slot, if this id is very popular id, this id will start from + // ground zero, and it's not ideal. + // Our solution to solve this is to quickly check all probing location, + // see if our id has already existed, if existed, then we don't need + // eviction. Has to note, there might be very rare cases the id slot got + // evicted, it's OK, then we will colide this once with hash(id), and next + // time, it would pick up an expired slot. + // Also, we don't need lock here. As we are readonly here and other + // concurrent write should have no impact on us. + int64_t identity_slot = get_identity_slot( + identities, + identity, + output_index, + offset, + opt_in_block_size, + max_probe); + + bool opt_in = true; + if (identity_slot == -1 && opt_in_rands != nullptr && + opt_in_rands[process_index] >= opt_in_prob) { + // ID with rand value > opt_in_prob will not be accepted and will + // instead be assigned to one of the reserved slots. + opt_in = false; + output_index = + opt_in_block_size + static_cast(hash % num_reserved_slots); + update_metadata( + metadata, output_index + offset, metadata_val); + } + + while (max_probe_local-- > 0 && opt_in) { + auto insert_idx = output_index + offset; + if (check_and_maybe_update_slot( + &identities[insert_idx][0], identity, old_value)) { + update_metadata(metadata, insert_idx, metadata_val); + break; + } + + if (identity_slot == -1 && + check_evict( + metadata, insert_idx, eviction_threshold)) { + auto current_slot_value = + CAS(&identities[insert_idx][0], old_value, identity); + if (current_slot_value == old_value || current_slot_value == identity) { + evict_slots[process_index] = insert_idx; + update_metadata(metadata, insert_idx, metadata_val); + break; + } + } + + output_index = next_output_index( + output_index, + opt_in_block_size, // only probe within the opt-in block + max_probe_local); + } + + // can't find a slot (all slot full after probing), collide + if (max_probe_local < 0) { + if constexpr (DISABLE_FALLBACK) { + output_index = -1; + offset = 0; + } else { + output_index = opt_in_prob == -1 ? static_cast(hash % modulo) + : opt_in_block_size + + static_cast(hash % num_reserved_slots); + } + } + + output[process_index] = output_index + offset; + } +} + +template < + int32_t EVICTION_POLICY, + bool DISABLE_FALLBACK, + int32_t HASH_IDENTITY, + int32_t METADATA_COUNT, + bool CIRCULAR_PROBE, + bool READONLY, + typename TInput, + typename TIdentity> +__global__ void process_item_zch( + const at::PackedTensorAccessor64 input, + at::PackedTensorAccessor64 output, + int64_t* evict_slots, + at::PackedTensorAccessor64 identities, + int64_t modulo, + int64_t max_probe, + int32_t cur_hour, + const int64_t* const local_sizes, + const int64_t* const offsets, + int32_t* metadata, + const int32_t* const input_metadata, + int32_t eviction_threshold, + int32_t* process_lock, + int64_t /* opt_in_prob */, + int64_t /* num_reserved_slots */, + const int32_t* const /* opt_in_rands */, + TORCH_DSA_KERNEL_ARGS, + std::enable_if_t* = nullptr) { + static_assert(EVICTION_POLICY == 1); + + // Stride loop: + // https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ + // NOLINTNEXTLINE: Implicitly casting + auto total_items = input.size(0); + + for (auto process_index = blockIdx.x * blockDim.x + threadIdx.x; + process_index < total_items; + // NOLINTNEXTLINE: Implicitly casting + process_index += blockDim.x * gridDim.x) { + auto item = input[process_index]; + if (local_sizes != nullptr) { + modulo = local_sizes[process_index]; + } + int64_t offset = 0; + if (offsets != nullptr) { + offset = offsets[process_index]; + } + int32_t metadata_val = + input_metadata != nullptr ? input_metadata[process_index] : cur_hour; + auto hash = murmur_hash3_2x64(static_cast(item), 0, 0); + auto output_index = static_cast(hash % modulo); // Local idx + TIdentity identity; + + if constexpr (HASH_IDENTITY == 1) { + identity = static_cast( + murmur_hash3_2x64( + static_cast(item), + 0x17, // seed + 0) % + kMaxIdentityNum); + } else if (HASH_IDENTITY == 2) { + identity = static_cast(item % kMaxIdentityNum); + } else { + identity = item; + } + + // probing. + auto max_probe_local = max_probe; + TIdentity old_value = kDefaultTensor; + + int64_t min_index = -1; // local_index; initially set it as -1 + int32_t min_hours = kMaxHours; + // tracks the existing value of canddiate slot may be evicted during + // probing; + TIdentity min_slot_identity = kDefaultTensor; + while (max_probe_local-- > 0) { + auto insert_idx = output_index + offset; + if (check_and_maybe_update_slot( + &identities[insert_idx][0], identity, old_value)) { + update_metadata_lru( + metadata, insert_idx, metadata_val, process_lock); + break; + } + + min_index = check_min( + process_index, + metadata, + min_index, + output_index, + offset, + min_hours, + process_lock, + identities, + min_slot_identity, + eviction_threshold); + + output_index = next_output_index( + output_index, modulo, max_probe_local); + } + + if (max_probe_local < 0) { + if (min_index == -1) { + // Can't find a min slot due to identities completing for slots in + // probing distance; This case should not be hit frequently. Cases + // like: + // 1. Hashes are concentrated in a probing distance + // 2. Probing distance is too small + // 3. in eval mode, can't find the identity in probing distance + if constexpr (DISABLE_FALLBACK) { + output_index = -1; + offset = 0; + output[process_index] = output_index + offset; + return; + } else { + // collide + output_index = static_cast(hash % modulo); + } + } else { + // find an expire slot to evict + output_index = min_index; + // do evict only in training mode + // directly return output_index in eval mode (readonly = True) + if constexpr (!READONLY) { + auto insert_idx = output_index + offset; + CAS( + &identities[insert_idx][0], min_slot_identity, identity); + update_metadata_lru( + metadata, insert_idx, metadata_val, process_lock); + evict_slots[process_index] = insert_idx; + } + } + } + output[process_index] = output_index + offset; + } +} + +} // namespace + +template +void _zero_collision_hash_cuda( + Tensor& output, + Tensor& evict_slots, + const Tensor& input, + Tensor& identities, + int64_t max_probe, + bool circular_probe, + int64_t cur_hour, + bool readonly, + bool support_evict, + const std::optional& local_sizes, + const std::optional& offsets, + int32_t hash_identity, + const std::optional& metadata, + bool disable_fallback, + const std::optional& input_metadata, + int64_t eviction_threshold, + int64_t eviction_policy, + int64_t opt_in_prob, + int64_t num_reserved_slots, + const std::optional& opt_in_rands) { + constexpr int64_t kThreads = 256L; + auto block_size = kThreads; + // check at::cuda::getCurrentDeviceProperties() is not null + TORCH_CHECK(at::cuda::getCurrentDeviceProperties() != nullptr); + auto grid_size = std::min( + (input.numel() + block_size - 1) / block_size, + 128L * at::cuda::getCurrentDeviceProperties()->multiProcessorCount); + int64_t modulo = identities.size(0); + + // auxiliary data structure to lock each slot + std::optional process_lock; + if (eviction_policy == 1 && metadata.has_value()) { + process_lock = at::full( + {modulo, 1}, + kDefaultTensor, + c10::TensorOptions().dtype(at::kInt).device(metadata->device())); + } +#define INVOKE_KERNEL( \ + EVICTION_POLICY, \ + DISABLE_FALLBACK, \ + HASH_IDENTITY, \ + METADATA_COUNT, \ + CIRCULAR_PROBE, \ + READONLY) \ + { \ + TORCH_DSA_KERNEL_LAUNCH( \ + (process_item_zch< \ + EVICTION_POLICY, \ + DISABLE_FALLBACK, \ + HASH_IDENTITY, \ + METADATA_COUNT, \ + CIRCULAR_PROBE, \ + READONLY, \ + TInput, \ + TIdentity>), \ + grid_size, \ + block_size, \ + 0, \ + at::cuda::getCurrentCUDAStream(), \ + input.packed_accessor64(), \ + output.packed_accessor64(), \ + support_evict ? evict_slots.data_ptr() : nullptr, \ + identities.packed_accessor64(), \ + modulo, \ + max_probe, \ + static_cast(cur_hour), \ + local_sizes.has_value() ? local_sizes->data_ptr() : nullptr, \ + offsets.has_value() ? offsets->data_ptr() : nullptr, \ + metadata.has_value() ? metadata->data_ptr() : nullptr, \ + input_metadata.has_value() ? input_metadata->data_ptr() \ + : nullptr, \ + static_cast(eviction_threshold), \ + process_lock.has_value() ? process_lock->data_ptr() \ + : nullptr, \ + opt_in_prob, \ + num_reserved_slots, \ + opt_in_rands.has_value() ? opt_in_rands->data_ptr() \ + : nullptr); \ + } + +#define INVOKE_KERNEL_EVICT_POLICY( \ + DISABLE_FALLBACK, HASH_IDENTITY, METADATA_COUNT, CIRCULAR_PROBE, READONLY) \ + { \ + if (eviction_policy == 0) { \ + INVOKE_KERNEL( \ + 0, \ + DISABLE_FALLBACK, \ + HASH_IDENTITY, \ + METADATA_COUNT, \ + CIRCULAR_PROBE, \ + READONLY); \ + } else { \ + INVOKE_KERNEL( \ + 1, \ + DISABLE_FALLBACK, \ + HASH_IDENTITY, \ + METADATA_COUNT, \ + CIRCULAR_PROBE, \ + READONLY); \ + } \ + } + +#define INVOKE_HASH_IDENTITY( \ + HASH_IDENTITY, METADATA_COUNT, CIRCULAR_PROBE, READONLY) \ + { \ + if (disable_fallback) { \ + INVOKE_KERNEL_EVICT_POLICY( \ + true, HASH_IDENTITY, METADATA_COUNT, CIRCULAR_PROBE, READONLY) \ + } else { \ + INVOKE_KERNEL_EVICT_POLICY( \ + false, HASH_IDENTITY, METADATA_COUNT, CIRCULAR_PROBE, READONLY) \ + } \ + } + +#define INVOKE_KERNEL_METADATA_COUNT(METADATA_COUNT, CIRCULAR_PROBE, READONLY) \ + { \ + if (hash_identity == 1) { \ + INVOKE_HASH_IDENTITY(1, METADATA_COUNT, CIRCULAR_PROBE, READONLY); \ + } else if (hash_identity == 2) { \ + INVOKE_HASH_IDENTITY(2, METADATA_COUNT, CIRCULAR_PROBE, READONLY); \ + } else { \ + INVOKE_HASH_IDENTITY(0, METADATA_COUNT, CIRCULAR_PROBE, READONLY); \ + } \ + } + +#define INVOKE_KERNEL_CIRCULAR_PROBE(CIRCULAR_PROBE, READONLY) \ + { \ + if (support_evict) { \ + INVOKE_KERNEL_METADATA_COUNT(1, CIRCULAR_PROBE, READONLY); \ + } else { \ + INVOKE_KERNEL_METADATA_COUNT(0, CIRCULAR_PROBE, READONLY); \ + } \ + } + +#define INVOKE_KERNEL_READ_ONLY(READONLY) \ + { \ + if (circular_probe) { \ + INVOKE_KERNEL_CIRCULAR_PROBE(true, READONLY); \ + } else { \ + INVOKE_KERNEL_CIRCULAR_PROBE(false, READONLY); \ + } \ + } + + if (readonly) { + INVOKE_KERNEL_READ_ONLY(true); + } else { + INVOKE_KERNEL_READ_ONLY(false); + } + +#undef INVOKE_KERNEL_READ_ONLY +#undef INVOKE_KERNEL_CIRCULAR_PROBE +#undef INVOKE_KERNEL_METADATA_COUNT +#undef INVOKE_HASH_IDENTITY +#undef INVOKE_KERNEL +} + +Tensor murmur_hash3_cuda(const Tensor& input, int64_t y, int64_t seed) { + auto hash = murmur_hash3_2x64( + input.item().to(), + static_cast(y), + static_cast(seed)); + return at::scalar_tensor(hash, input.options()); +} + +std::tuple zero_collision_hash_cuda( + const Tensor& input, + Tensor& identities, + int64_t max_probe, + bool circular_probe, + int64_t exp_hours, // to be deprecated + bool readonly, + const std::optional& local_sizes, + const std::optional& offsets, + const std::optional& metadata, + bool output_on_uvm, + bool disable_fallback, + bool _modulo_identity_DPRECATED, + const std::optional& input_metadata, + int64_t eviction_threshold, + int64_t eviction_policy, + int64_t opt_in_prob, + int64_t num_reserved_slots, + const std::optional& opt_in_rands) { + TORCH_CHECK(input.is_cuda()); + TORCH_CHECK(identities.dim() == 2); + + int32_t hash_identity = _modulo_identity_DPRECATED ? 1 : 2; + if (identities.dtype() == input.dtype()) { + hash_identity = 0; + } + if (input.dtype() == torch::kInt32) { + TORCH_CHECK(identities.dtype() == torch::kInt32); + } + + if (input_metadata.has_value()) { + TORCH_CHECK(exp_hours == -1); + TORCH_CHECK(input_metadata->size(0) == input.size(0)); + TORCH_CHECK(eviction_threshold != -1); + TORCH_CHECK(eviction_policy == 0 || eviction_policy == 1); + } + if (eviction_threshold != -1) { + TORCH_CHECK(eviction_policy == 0 || eviction_policy == 1); + TORCH_CHECK(input_metadata.has_value()); + } + + std::time_t now_c = time(nullptr); + auto hours = static_cast(now_c) / kSecondsInHour; + auto cur_hour = hours % kMaxHours; + + if (exp_hours > 0) { + TORCH_CHECK(!input_metadata.has_value()); + TORCH_CHECK(eviction_threshold == -1); + + // for backward compatibility: previous implementation uses cur_hour - + // exp_hours as threshold + // note the eviction criteria is the same: eviction_threshold > + // identity_metadata (last-seen hour) + eviction_threshold = cur_hour - exp_hours; + } + + bool support_evict = + is_eviction_enabled(readonly, eviction_threshold, eviction_policy); + + TORCH_CHECK( + !support_evict || metadata.has_value(), + "support_evict=", + support_evict, + "metadata is null"); + TORCH_CHECK( + support_evict || !metadata.has_value(), + "support_evict=", + support_evict, + "metadata is not null"); + + if (metadata.has_value()) { + TORCH_CHECK(metadata->dim() == 2); + TORCH_CHECK(metadata->is_cuda()); + TORCH_CHECK(metadata->size(0) == identities.size(0)); + } + // offsets and local_sizes are null in training; not null during + // inference/eval + if (local_sizes.has_value()) { + TORCH_CHECK(local_sizes->is_cuda()); + TORCH_CHECK(input.numel() == local_sizes->numel()); + } + if (offsets.has_value()) { + TORCH_CHECK(offsets->is_cuda()); + TORCH_CHECK(input.numel() == offsets->numel()); + } + if (opt_in_prob != -1) { + TORCH_CHECK(opt_in_prob > 0 && opt_in_prob < 100); + TORCH_CHECK(num_reserved_slots > 0); + } + if (num_reserved_slots != -1) { + TORCH_CHECK(opt_in_prob != -1); + } + if (opt_in_rands.has_value()) { + TORCH_CHECK(opt_in_prob != -1); + TORCH_CHECK(opt_in_rands->size(0) == input.size(0)); + TORCH_CHECK(opt_in_rands->dtype() == torch::kInt32); + } + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(input.get_device()); + + int64_t output_size = input.size(0); + c10::TensorOptions options; + + if (output_on_uvm) { + options = + c10::TensorOptions().dtype(at::kLong).device(at::kCPU).pinned_memory( + true); + } else { + options = c10::TensorOptions().dtype(at::kLong).device(input.device()); + } + + Tensor output = at::empty({output_size}, options); + + // evict_slots will contains the index to be evcited, '-1' will be ignored. + Tensor evict_slots; + if (support_evict) { + evict_slots = at::full( + {output_size}, + static_cast(kDefaultTensor), + c10::TensorOptions().dtype(at::kLong).device(input.device())); + } + + if (output_size == 0) { + return {output, evict_slots}; + } + + AT_DISPATCH_INTEGER_TYPES( + input.scalar_type(), "zero_collision_hash_input", input_t, [&]() { + AT_DISPATCH_INTEGER_TYPES( + identities.scalar_type(), + "zero_collision_hash_identity", + identity_t, + [&]() { + _zero_collision_hash_cuda( + output, + evict_slots, + input, + identities, + max_probe, + circular_probe, + cur_hour, + readonly, + support_evict, + local_sizes, + offsets, + hash_identity, + metadata, + disable_fallback, + input_metadata, + eviction_threshold, + eviction_policy, + opt_in_prob, + num_reserved_slots, + opt_in_rands); + }); + }); + + if (support_evict) { + evict_slots = std::get<0>(torch::_unique( + evict_slots.masked_select(evict_slots != kDefaultTensor))); + } + if (output_on_uvm) { + C10_CUDA_CHECK(cudaDeviceSynchronize()); + } + return {output, evict_slots}; +} + +// Register operators +TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { + m.impl( + "zero_collision_hash", + torch::dispatch( + c10::DispatchKey::CUDA, TORCH_FN(zero_collision_hash_cuda))); + m.impl( + "murmur_hash3", + torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(murmur_hash3_cuda))); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/test/faster_hash_test.py b/fbgemm_gpu/test/faster_hash_test.py new file mode 100644 index 0000000000..512c1d6691 --- /dev/null +++ b/fbgemm_gpu/test/faster_hash_test.py @@ -0,0 +1,1715 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import unittest +from enum import IntEnum + +# pyre-ignore[21] +import fbgemm_gpu # noqa: F401 + +import torch + +# check if we are in open source env to decide how to import necessary modules +try: + # pyre-ignore[21] + from fbgemm_gpu import open_source # noqa: F401 + + # pyre-ignore[21] + from test_utils import ( # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils + gpu_unavailable, + skipIfRocm, + ) +except Exception: + from fbgemm_gpu.test.test_utils import gpu_unavailable, skipIfRocm + + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:faster_hash_ops") + + +class HashZchKernelEvictionPolicy(IntEnum): + THRESHOLD_EVICTION = 0 + LRU_EVICTION = 1 + + +class FasterHashTest(unittest.TestCase): + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_simple_zch_no_evict(self) -> None: + # no evict + identities, _ = torch.ops.fbgemm.create_zch_buffer( + 200, device=torch.device("cuda") + ) + numbers = torch.arange(0, 100, dtype=torch.int64, device="cuda") + local_sizes = torch.ones_like(numbers) * 100 + + output1, evict_slots1 = torch.ops.fbgemm.zero_collision_hash( + input=numbers, + identities=identities, + max_probe=100, + circular_probe=True, + local_sizes=local_sizes, + offsets=torch.zeros_like(numbers), + ) + output2, evict_slots2 = torch.ops.fbgemm.zero_collision_hash( + input=numbers + 100, + identities=identities, + max_probe=100, + circular_probe=True, + local_sizes=local_sizes, + offsets=torch.ones_like(numbers) * 100, + ) + + self.assertEqual( + torch.unique(output1).tolist(), + numbers.tolist(), + f"{torch.unique(output1).tolist()=} != {numbers.tolist()=}", + ) + self.assertEqual(torch.unique(output2).tolist(), (numbers + 100).tolist()) + self.assertTrue(torch.all(identities != -1)) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + input=numbers + 100, + identities=identities, + max_probe=100, + circular_probe=True, + exp_hours=-1, + readonly=True, + local_sizes=local_sizes, + offsets=torch.ones_like(numbers) * 100, + ) + self.assertTrue(torch.equal(output2, output_readonly)) + + # CPU + output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash( + input=numbers.cpu() + 100, + identities=identities.cpu(), + max_probe=100, + circular_probe=True, + exp_hours=-1, + readonly=True, + local_sizes=local_sizes.cpu(), + offsets=torch.ones_like(numbers).cpu() * 100, + ) + self.assertTrue( + torch.equal(output2.cpu(), output_readonly_cpu), + f"{output2.cpu()=} != {output_readonly_cpu=}", + ) + + # other numbers. + identities, _ = torch.ops.fbgemm.create_zch_buffer( + 100, device=torch.device("cuda") + ) + numbers_100_200 = torch.arange(100, 200, dtype=torch.int64, device="cuda") + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_100_200, + identities, + 100, + circular_probe=True, + ) + self.assertEqual(torch.unique(output).tolist(), numbers.tolist()) + self.assertTrue(torch.all(identities != -1)) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_100_200, + identities, + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output, output_readonly)) + + # CPU + output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash( + input=numbers_100_200.cpu(), + identities=identities.cpu(), + max_probe=100, + circular_probe=True, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output_readonly.cpu(), output_readonly_cpu)) + + # no evict + no circular probe + identities, _ = torch.ops.fbgemm.create_zch_buffer( + 100, device=torch.device("cuda") + ) + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers, + identities, + 100, + circular_probe=False, + ) + self.assertFalse(torch.all(identities != -1)) + unique_indices = torch.unique(output) + all_indices = torch.arange(identities.size(0), device="cuda") + not_select_indices = torch.isin(all_indices, unique_indices, invert=True) + self.assertTrue(torch.all(identities[unique_indices] != -1)) + self.assertTrue(torch.all(identities[not_select_indices] == -1)) + + unique_elements, counts = torch.unique( + identities[identities[:, 0] != -1][:, 0], return_counts=True + ) + self.assertTrue(torch.all(counts == 1)) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers, + identities, + 100, + circular_probe=False, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output, output_readonly)) + + # CPU + output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash( + input=numbers.cpu(), + identities=identities.cpu(), + max_probe=100, + circular_probe=True, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output_readonly.cpu(), output_readonly_cpu)) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_simple_zch_no_evict_rand(self) -> None: + # no evict - rand number. + identities, _ = torch.ops.fbgemm.create_zch_buffer( + 100, device=torch.device("cuda") + ) + random_numbers = torch.randint(0, 100, (100,), device="cuda") + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + random_numbers, + identities, + 100, + circular_probe=True, + ) + + for i in range(100): + to_test = output[random_numbers == i] + if len(to_test) > 0: + self.assertTrue(torch.all(to_test == to_test[0])) + + unique_indices = torch.unique(output) + all_indices = torch.arange(identities.size(0), device="cuda") + not_select_indices = torch.isin(all_indices, unique_indices, invert=True) + self.assertTrue(torch.all(identities[unique_indices] != -1)) + self.assertTrue(torch.all(identities[not_select_indices] == -1)) + unique_elements, counts = torch.unique( + identities[identities[:, 0] != -1][:, 0], return_counts=True + ) + self.assertTrue(torch.all(counts == 1)) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + random_numbers, + identities, + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output, output_readonly)) + + # CPU + output_readonly_cpu, evict_slots = torch.ops.fbgemm.zero_collision_hash( + random_numbers.cpu(), + identities.cpu(), + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output.cpu(), output_readonly_cpu)) + + # no evict + no circular probe + identities, _ = torch.ops.fbgemm.create_zch_buffer( + 100, device=torch.device("cuda") + ) + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + random_numbers, + identities, + 100, + circular_probe=False, + ) + unique_indices_no_circular = torch.unique(output) + all_indices_no_circular = torch.arange(identities.size(0), device="cuda") + not_select_indices_no_circular = torch.isin( + all_indices_no_circular, unique_indices_no_circular, invert=True + ) + self.assertTrue(torch.all(identities[unique_indices_no_circular] != -1)) + self.assertTrue(torch.all(identities[not_select_indices_no_circular] == -1)) + self.assertTrue(unique_indices_no_circular.size(0) <= unique_indices.size(0)) + unique_elements, counts = torch.unique( + identities[identities[:, 0] != -1][:, 0], return_counts=True + ) + self.assertTrue(torch.all(counts == 1)) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + random_numbers, + identities, + 100, + circular_probe=False, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output, output_readonly)) + + # CPU + output_readonly_cpu, evict_slots = torch.ops.fbgemm.zero_collision_hash( + random_numbers.cpu(), + identities.cpu(), + 100, + circular_probe=False, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output.cpu(), output_readonly_cpu)) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_simple_zch_evict(self) -> None: + # evict + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + 100, support_evict=True, device=torch.device("cuda") + ) + numbers = torch.arange(0, 100, dtype=torch.int64, device="cuda") + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers, + identities, + 100, + circular_probe=True, + exp_hours=7 * 24, + metadata=metadata, + ) + self.assertEqual(torch.unique(output).tolist(), numbers.tolist()) + self.assertTrue(evict_slots.numel() == 0) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers, + identities, + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output, output_readonly)) + + # evict with all expired hours. + metadata[:, 0] -= 7 * 24 + 1 + numbers_100_200 = torch.arange(100, 200, dtype=torch.int64, device="cuda") + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_100_200, + identities, + 100, + circular_probe=True, + exp_hours=7 * 24, + metadata=metadata, + ) + self.assertEqual(torch.unique(output).tolist(), numbers.tolist()) + self.assertTrue(torch.all(evict_slots != -1)) + self.assertEqual(torch.unique(evict_slots).tolist(), numbers.tolist()) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_100_200, + identities, + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output, output_readonly)) + + # evict + no circular probe + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + 100, support_evict=True, device=torch.device("cuda") + ) + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers, + identities, + 100, + circular_probe=False, + exp_hours=7 * 24, + metadata=metadata, + ) + self.assertFalse(torch.all(identities != -1)) + unique_indices = torch.unique(output) + all_indices = torch.arange(identities.size(0), device="cuda") + not_select_indices = torch.isin(all_indices, unique_indices, invert=True) + self.assertTrue(torch.all(identities[unique_indices] != -1)) + self.assertTrue(torch.all(identities[not_select_indices] == -1)) + unique_elements, counts = torch.unique( + identities[identities[:, 0] != -1][:, 0], return_counts=True + ) + self.assertTrue(torch.all(counts == 1)) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers, + identities, + 100, + circular_probe=False, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output, output_readonly)) + + # evict with all expired hours + no circular probe + evict_slot_candidate_mask = metadata[:, 0] != -1 + evict_slot_candidates = torch.nonzero(evict_slot_candidate_mask) + self.assertTrue(evict_slot_candidates.size(0) != 0) + metadata[evict_slot_candidate_mask, 0] -= 7 * 24 + 1 + old_time_value = metadata[evict_slot_candidates[0], 0] + numbers_100_200 = torch.arange(100, 200, dtype=torch.int64, device="cuda") + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_100_200, + identities, + 100, + circular_probe=False, + exp_hours=7 * 24, + metadata=metadata, + ) + self.assertTrue(torch.all(torch.isin(evict_slots, evict_slot_candidates))) + self.assertTrue(torch.all(metadata[evict_slots][:, 0] != old_time_value)) + unique_elements, counts = torch.unique( + identities[identities[:, 0] != -1][:, 0], return_counts=True + ) + self.assertTrue(torch.all(counts == 1)) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_100_200, + identities, + 100, + circular_probe=False, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output, output_readonly)) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_simple_zch_evict_with_rand_unique_numbers(self) -> None: + # evict - rand number. + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + 100, support_evict=True, device=torch.device("cuda") + ) + random_numbers = torch.unique(torch.randint(0, 100, (100,), device="cuda")) + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + random_numbers, + identities, + 100, + circular_probe=True, + exp_hours=7 * 24, + metadata=metadata, + ) + + for i in range(100): + to_test = output[random_numbers == i] + if len(to_test) > 0: + self.assertTrue(torch.all(to_test == to_test[0])) + + unique_indices = torch.unique(output) + all_indices = torch.arange(identities.size(0), device="cuda") + not_select_indices = torch.isin(all_indices, unique_indices, invert=True) + self.assertTrue(torch.all(identities[unique_indices] != -1)) + self.assertTrue(torch.all(identities[not_select_indices] == -1)) + unique_elements, counts = torch.unique( + identities[identities[:, 0] != -1][:, 0], return_counts=True + ) + self.assertTrue(torch.all(counts == 1)) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + random_numbers, + identities[:, 0].unsqueeze(1), + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output, output_readonly)) + + # evict - rand number + no circular probe + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + 100, support_evict=True, device=torch.device("cuda") + ) + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + random_numbers, + identities, + 100, + circular_probe=False, + exp_hours=7 * 24, + metadata=metadata, + ) + + for i in range(100): + to_test = output[random_numbers == i] + if len(to_test) > 0: + self.assertTrue(torch.all(to_test == to_test[0])) + + unique_indices_no_circular = torch.unique(output) + all_indices_no_circular = torch.arange(identities.size(0), device="cuda") + not_select_indices_no_circular = torch.isin( + all_indices_no_circular, unique_indices_no_circular, invert=True + ) + self.assertTrue(torch.all(identities[unique_indices_no_circular] != -1)) + self.assertTrue(torch.all(identities[not_select_indices_no_circular] == -1)) + self.assertTrue(unique_indices_no_circular.size(0) <= unique_indices.size(0)) + unique_elements, counts = torch.unique( + identities[identities[:, 0] != -1][:, 0], return_counts=True + ) + self.assertTrue(torch.all(counts == 1)) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + random_numbers, + identities, + 100, + circular_probe=False, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output, output_readonly)) + + # evict with all expired hours + no circular probe + evict_slot_candidate_mask = metadata[:, 0] != -1 + evict_slot_candidates = torch.nonzero(evict_slot_candidate_mask) + self.assertTrue(evict_slot_candidates.size(0) != 0) + metadata[evict_slot_candidate_mask, 0] -= 7 * 24 + 1 + old_time_value = metadata[evict_slot_candidates[0], 0] + random_numbers_100_200 = torch.unique( + torch.randint(100, 200, (100,), device="cuda") + ) + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + random_numbers_100_200, + identities, + 100, + circular_probe=False, + exp_hours=7 * 24, + metadata=metadata, + ) + self.assertTrue(torch.all(torch.isin(evict_slots, evict_slot_candidates))) + self.assertTrue(torch.all(metadata[evict_slots][:, 0] != old_time_value)) + + unique_elements, counts = torch.unique( + identities[identities[:, 0] != -1][:, 0], return_counts=True + ) + self.assertTrue(torch.all(counts == 1), counts) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + random_numbers_100_200, + identities[:, 0].unsqueeze(1), + 100, + circular_probe=False, + exp_hours=-1, + readonly=True, + ) + self.assertTrue( + torch.equal(output, output_readonly), f"{output=}, {output_readonly=}" + ) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_eviction_during_lookup(self) -> None: + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + 100, support_evict=True, device=torch.device("cuda") + ) + numbers_0_99 = torch.arange(0, 99, dtype=torch.int64, device="cuda") + output_0_99, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_0_99, + identities, + 100, + circular_probe=True, + exp_hours=7 * 24, + metadata=metadata, + ) + empty_slots = identities[:, 0] == -1 + self.assertTrue(torch.sum(empty_slots) == 1, torch.sum(empty_slots)) + + # insert number 101, should be able to fill all slots. + numbers = torch.tensor([101], dtype=torch.int64, device="cuda") + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers, + identities, + 100, + circular_probe=True, + exp_hours=7 * 24, + metadata=metadata, + ) + self.assertTrue(torch.all(identities[:, 0] != -1)) + + # make none 101 slots expired. + metadata[~empty_slots, 0] -= 7 * 24 + 1 + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers, + identities, + 100, + circular_probe=True, + exp_hours=7 * 24, + metadata=metadata, + ) + unique_elements, counts = torch.unique(identities[:, 0], return_counts=True) + self.assertTrue(torch.all(counts == 1)) + self.assertTrue(evict_slots.numel() == 0) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_0_99, + identities, + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + ) + self.assertTrue(torch.equal(output_0_99, output_readonly)) + + # evict some slot. + numbers = torch.tensor([102], dtype=torch.int64, device="cuda") + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers, + identities, + 100, + circular_probe=True, + exp_hours=7 * 24, + metadata=metadata, + ) + self.assertTrue(evict_slots.numel() == 1) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_zch_output_on_uvm(self) -> None: + # no evict + identities, _ = torch.ops.fbgemm.create_zch_buffer( + 200, device=torch.device("cuda") + ) + numbers = torch.arange(0, 100, dtype=torch.int64, device="cuda") + + output, _ = torch.ops.fbgemm.zero_collision_hash( + input=numbers, + identities=identities, + max_probe=100, + circular_probe=True, + output_on_uvm=True, + ) + + self.assertTrue(output.device.type == "cpu") + + add_on = torch.arange(100, 200, dtype=torch.int64) + self.assertTrue( + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got `int`. + torch.equal(((output + add_on) + (output - add_on)), 2 * output) + ) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_zch_int64_nohash_identity(self) -> None: + # no evict + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + 100, device=torch.device("cuda"), support_evict=True, long_type=True + ) + numbers = torch.arange(2**33, 2**33 + 100, dtype=torch.int64, device="cuda") + + output, _ = torch.ops.fbgemm.zero_collision_hash( + input=numbers, + identities=identities, + max_probe=100, + circular_probe=True, + readonly=False, + exp_hours=7 * 24, + metadata=metadata, + ) + + self.assertTrue( + torch.equal( + torch.sort(identities[identities != -1].view(-1))[0], + numbers, + ), + f"{identities=} vs {numbers=}", + ) + + numbers_100_200 = torch.arange( + 2**33 + 100, 2**33 + 200, dtype=torch.int64, device="cuda" + ) + metadata[:, 0] -= 7 * 24 + 1 + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + input=numbers_100_200, + identities=identities, + max_probe=100, + circular_probe=True, + readonly=False, + exp_hours=7 * 24, + metadata=metadata, + ) + + expect_indices = list(range(100)) + self.assertEqual(torch.unique(output).tolist(), expect_indices) + self.assertTrue(torch.all(evict_slots != -1)) + self.assertEqual(torch.unique(evict_slots).tolist(), expect_indices) + self.assertTrue( + torch.equal( + torch.sort(identities[identities != -1].view(-1))[0], + numbers_100_200, + ), + f"{identities=} vs {numbers_100_200=}", + ) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_zch_int32_nohash_identity(self) -> None: + # no evict + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + 100, device=torch.device("cuda"), support_evict=True, long_type=False + ) + numbers = torch.arange(2**33, 2**33 + 100, dtype=torch.int32, device="cuda") + + output, _ = torch.ops.fbgemm.zero_collision_hash( + input=numbers, + identities=identities, + max_probe=100, + circular_probe=True, + readonly=False, + exp_hours=7 * 24, + metadata=metadata, + ) + + self.assertTrue( + torch.equal( + torch.sort(identities[identities != -1].view(-1))[0], + numbers, + ), + f"{identities=} vs {numbers=}", + ) + + numbers_100_200 = torch.arange( + 2**33 + 100, 2**33 + 200, dtype=torch.int32, device="cuda" + ) + metadata[:, 0] -= 7 * 24 + 1 + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + input=numbers_100_200, + identities=identities, + max_probe=100, + circular_probe=True, + readonly=False, + exp_hours=7 * 24, + metadata=metadata, + ) + + expect_indices = list(range(100)) + self.assertEqual(torch.unique(output).tolist(), expect_indices) + self.assertTrue(torch.all(evict_slots != -1)) + self.assertEqual(torch.unique(evict_slots).tolist(), expect_indices) + self.assertTrue( + torch.equal( + torch.sort(identities[identities != -1].view(-1))[0], + numbers_100_200, + ), + f"{identities=} vs {numbers_100_200=}", + ) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_fallback(self) -> None: + # init and add some ids + identities, _ = torch.ops.fbgemm.create_zch_buffer( + 100, device=torch.device("cuda"), long_type=True + ) + ids = torch.arange(0, 100, device="cuda") + output, _ = torch.ops.fbgemm.zero_collision_hash( + input=ids, + identities=identities, + max_probe=100, + circular_probe=True, + readonly=False, + ) + + # non-readonly and fallback enabled + ids = torch.arange(90, 120, device="cuda") + remapped_ids, _ = torch.ops.fbgemm.zero_collision_hash( + input=ids, + identities=identities, + max_probe=100, + circular_probe=True, + readonly=False, + disable_fallback=False, + ) + # all ids (including unexisting ones) are mapped to a position + self.assertTrue(torch.all(remapped_ids != -1)) + + # readonly and fallback enabled + ids = torch.arange(90, 120, device="cuda") + remapped_ids, _ = torch.ops.fbgemm.zero_collision_hash( + input=ids, + identities=identities, + max_probe=100, + circular_probe=True, + readonly=True, + disable_fallback=False, + ) + # all ids (including unexisting ones) are mapped to a position + self.assertTrue(torch.all(remapped_ids != -1)) + + # non-readonly and fallback disabled + ids = torch.arange(90, 120, device="cuda") + remapped_ids, _ = torch.ops.fbgemm.zero_collision_hash( + input=ids, + identities=identities, + max_probe=100, + circular_probe=True, + readonly=False, + disable_fallback=True, + ) + # existing ids are mapped to a position and unexisting ones are mapped to -1 + self.assertTrue( + torch.equal( + torch.index_select( + identities, 0, remapped_ids[remapped_ids != -1] + ).squeeze(), + torch.arange(90, 100, device="cuda"), + ) + ) + self.assertTrue(torch.all(remapped_ids[-20:] == -1)) + + # readonly and fallback disabled + ids = torch.arange(90, 120, device="cuda") + remapped_ids, _ = torch.ops.fbgemm.zero_collision_hash( + input=ids, + identities=identities, + max_probe=100, + circular_probe=True, + readonly=True, + disable_fallback=True, + ) + # existing ids are mapped to a position and unexisting ones are mapped to -1 + self.assertTrue( + torch.equal( + torch.index_select( + identities, 0, remapped_ids[remapped_ids != -1] + ).squeeze(), + torch.arange(90, 100, device="cuda"), + ) + ) + self.assertTrue(torch.all(remapped_ids[-20:] == -1)) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_simple_zch_individual_score_evict(self) -> None: + # evict + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + 100, support_evict=True, long_type=True, device=torch.device("cuda") + ) + numbers_0_100 = torch.arange(0, 100, dtype=torch.int64, device="cuda") + input_metadata_500_600 = torch.arange( + 500, 600, dtype=torch.int32, device="cuda" + ) + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_0_100, + identities, + 100, + circular_probe=True, + metadata=metadata, + input_metadata=input_metadata_500_600, + eviction_threshold=100, + ) + self.assertEqual(torch.unique(output).tolist(), numbers_0_100.tolist()) + self.assertEqual( + torch.unique(metadata).tolist(), input_metadata_500_600.tolist() + ) + self.assertTrue(evict_slots.numel() == 0) + + # readonly lookup. + output_readonly, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_0_100, + identities, + 100, + circular_probe=True, + readonly=True, + ) + self.assertTrue(torch.equal(output, output_readonly)) + + numbers_100_200 = torch.arange(100, 200, dtype=torch.int64, device="cuda") + input_metadata_600_700 = torch.arange( + 600, 700, dtype=torch.int32, device="cuda" + ) + + # evict by setting eviction_threshold to 550 (half of the slots of which the + # eviction scores are less 550 will be evicted) + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_100_200, + identities, + 100, + circular_probe=True, + metadata=metadata, + input_metadata=input_metadata_600_700, + eviction_threshold=550, + ) + + self.assertEqual(evict_slots.numel(), 50) + self.assertTrue(torch.all(metadata >= 550)) + + # readonly lookup. + output_readonly, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_100_200, + identities, + 100, + circular_probe=True, + readonly=True, + ) + self.assertTrue(torch.equal(output, output_readonly)) + + # attempt to update with lower input_metadata values + metadata0 = metadata.clone() + input_metadata_0_100 = torch.arange(0, 100, dtype=torch.int32, device="cuda") + output_lower_metadata, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_100_200, + identities, + 100, + circular_probe=True, + metadata=metadata, + input_metadata=input_metadata_0_100, + eviction_threshold=550, + ) + + self.assertTrue(torch.equal(output_lower_metadata, output)) + # metadata should not be overwritten + self.assertTrue(torch.equal(metadata, metadata0)) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_zch_lru_evict(self) -> None: + # No evict + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + 100, support_evict=True, device=torch.device("cuda") + ) + numbers_0_100 = torch.arange(0, 100, dtype=torch.int64, device="cuda") + + cur_hour = 500 + ttl = 72 + + input_metadata = torch.full_like( + numbers_0_100, + ttl + cur_hour, + dtype=torch.int32, + device="cuda", + ) + + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_0_100, + identities, + 100, + circular_probe=True, + metadata=metadata, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + input_metadata=input_metadata, + eviction_threshold=cur_hour, + ) + self.assertEqual( + torch.unique(output).tolist(), numbers_0_100.tolist(), f"{output=}" + ) + self.assertTrue(torch.all(metadata != -1), metadata) + self.assertTrue(evict_slots.numel() == 0) + self.assertEqual( + torch.unique(identities).tolist(), numbers_0_100.tolist(), f"{identities=}" + ) + + # readonly lookup. + output_readonly, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_0_100, + identities, + 100, + circular_probe=True, + readonly=True, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + ) + self.assertTrue(output.tolist(), output_readonly.tolist()) + + output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_0_100.cpu(), + identities.cpu(), + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + ) + self.assertTrue( + torch.equal(output_readonly_cpu, output_readonly.cpu()), + f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}", + ) + + numbers_100_120 = torch.arange(100, 120, dtype=torch.int64, device="cuda") + new_cur_hour = 600 + new_input_metadata = torch.full_like( + numbers_100_120, + ttl + new_cur_hour, + dtype=torch.int32, + device="cuda", + ) + + # modify metadata to set different update hours to trigger LRU eviction + metadata = torch.randint( + 500, (100, 1), dtype=torch.int32, device=metadata.device + ) + + # arrange metadata in update order + eviction_order = ( + torch.sort(metadata, 0) + .indices.index_select(1, torch.tensor([0], device=metadata.device)) + .squeeze() + ) + + # all rows were occupied, do evict for all input numbers + # evict by LRU + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + input=numbers_100_120, + identities=identities, + max_probe=100, + circular_probe=True, + metadata=metadata, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + input_metadata=new_input_metadata, + eviction_threshold=new_cur_hour, + ) + self.assertEqual(evict_slots.numel(), 20) + self.assertTrue( + set(evict_slots.tolist()).issubset(set(eviction_order[:40].tolist())), + f"{evict_slots=}, {eviction_order=}", + ) + + self.assertTrue( + torch.equal( + torch.sort(identities[identities >= 100])[0], + torch.sort(numbers_100_120)[0], + ), + f"{identities=} vs {numbers_100_120=}", + ) + + self.assertTrue( + torch.equal(evict_slots, torch.sort(output)[0]), + f"{evict_slots=} vs {output=}", + ) + self.assertTrue( + torch.equal( + torch.nonzero(metadata >= 500), torch.nonzero(identities >= 100) + ), + f"{torch.nonzero(metadata >= 500)=} vs {torch.nonzero(identities >= 100)=}", + ) + + # readonly lookup again + output_readonly, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_100_120, + identities, + 100, + circular_probe=True, + readonly=True, + ) + self.assertTrue(output.tolist(), output_readonly.tolist()) + + output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_100_120.cpu(), + identities.cpu(), + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + ) + self.assertTrue( + torch.equal(output_readonly_cpu, output_readonly.cpu()), + f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}", + ) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_zch_lru_evict_with_unexpired_slots(self) -> None: + # No evict + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + 100, support_evict=True, device=torch.device("cuda") + ) + numbers_0_100 = torch.arange(0, 100, dtype=torch.int64, device="cuda") + + cur_hour = 1000 + ttl = 72 + + input_metadata = torch.full_like( + numbers_0_100, + ttl + cur_hour, + dtype=torch.int32, + device="cuda", + ) + + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_0_100, + identities, + 100, + circular_probe=True, + metadata=metadata, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + eviction_threshold=cur_hour, + input_metadata=input_metadata, + ) + self.assertEqual( + torch.unique(output).tolist(), numbers_0_100.tolist(), f"{output=}" + ) + self.assertTrue(torch.all(metadata != -1), metadata) + self.assertTrue(evict_slots.numel() == 0) + self.assertEqual( + torch.unique(identities).tolist(), numbers_0_100.tolist(), f"{identities=}" + ) + + # readonly lookup. + output_readonly, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_0_100, + identities, + 100, + circular_probe=True, + readonly=True, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + ) + self.assertTrue(output.tolist(), output_readonly.tolist()) + + output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_0_100.cpu(), + identities.cpu(), + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + ) + self.assertTrue( + torch.equal(output_readonly_cpu, output_readonly.cpu()), + f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}", + ) + + numbers_100_150 = torch.arange(100, 150, dtype=torch.int64, device="cuda") + + # 20 slots expired, 80 unexpired + metadata_to_update = torch.randint( + 500, 1050, (20, 1), dtype=torch.int32, device=metadata.device + ) + metadata[0:20] = metadata_to_update + + metadata_index_0_20 = torch.arange( + 0, 20, dtype=torch.int64, device=metadata.device + ) + + new_cur_hour = 1050 + new_input_metadata = torch.full_like( + numbers_100_150, + ttl + new_cur_hour, + dtype=torch.int32, + device="cuda", + ) + + # all rows were occupied, do evict by LRU + TTL rule + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + input=numbers_100_150, + identities=identities, + max_probe=100, + circular_probe=True, + metadata=metadata, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + eviction_threshold=new_cur_hour, + input_metadata=new_input_metadata, + ) + self.assertEqual(evict_slots.numel(), 20) + self.assertTrue( + torch.equal( + torch.sort(evict_slots)[0], + torch.sort(metadata_index_0_20)[0], + ), + f"{evict_slots=}, {metadata_index_0_20=}", + ) + + self.assertTrue(torch.all(metadata[0:20][0] == 1050 + ttl)) + self.assertTrue(torch.all(metadata[20:][0] == 1000 + ttl)) + + self.assertEqual(identities[identities >= 100].numel(), 20) + self.assertTrue(torch.all(identities[20:][0] < 100)) + + # readonly lookup - gpu + output_readonly, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_100_150, + identities, + 100, + circular_probe=True, + readonly=True, + ) + self.assertTrue(output.tolist(), output_readonly.tolist()) + + # readonly lookup - cpu + output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_100_150.cpu(), + identities.cpu(), + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + ) + self.assertTrue( + torch.equal(output_readonly_cpu, output_readonly.cpu()), + f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}", + ) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_rand_numbers_zch_lru_evict(self) -> None: + # No evict + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + 100, support_evict=True, device=torch.device("cuda"), long_type=True + ) + numbers_0_100 = torch.arange(0, 100, dtype=torch.int64, device="cuda") + + cur_hour = 1000 + ttl = 24 + + input_metadata = torch.full_like( + numbers_0_100, + ttl + cur_hour, # TTL 24h + dtype=torch.int32, + device="cuda", + ) + + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_0_100, + identities, + 100, + circular_probe=True, + metadata=metadata, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + eviction_threshold=cur_hour, + input_metadata=input_metadata, + ) + + self.assertEqual(torch.unique(output).tolist(), numbers_0_100.tolist()) + self.assertTrue(evict_slots.numel() == 0) + + # a tensor with 60 numbers with duplicates + random_numbers_100_150 = torch.randint( + 100, 150, (60,), dtype=torch.int64, device="cuda" + ) + + new_cur_hour = 1025 + new_input_metadata = torch.full_like( + random_numbers_100_150, + ttl + new_cur_hour, + dtype=torch.int32, + device="cuda", + ) + + # all rows were occupied, do evict for all input numbers + # evict by LRU + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + input=random_numbers_100_150, + identities=identities, + max_probe=100, + circular_probe=True, + metadata=metadata, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + eviction_threshold=new_cur_hour, + input_metadata=new_input_metadata, + ) + + self.assertLessEqual(evict_slots.numel(), 60) + self.assertTrue( + torch.equal( + torch.unique(identities[identities >= 100]), + torch.unique(random_numbers_100_150), + ), + f"{torch.unique(identities[identities >= 100])=} vs {torch.unique(random_numbers_100_150)=}", + ) + + self.assertTrue( + torch.equal( + torch.nonzero(metadata >= 1025), torch.nonzero(identities >= 100) + ), + f"{torch.nonzero(metadata >= 1025)=} vs {torch.nonzero(identities >= 100)=}", + ) + + # readonly lookup again + output_readonly, _ = torch.ops.fbgemm.zero_collision_hash( + random_numbers_100_150, + identities, + 100, + circular_probe=True, + readonly=True, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + ) + self.assertTrue(output.tolist(), output_readonly.tolist()) + + output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash( + random_numbers_100_150.cpu(), + identities.cpu(), + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + ) + self.assertTrue( + torch.equal(output_readonly_cpu, output_readonly.cpu()), + f"{output_readonly_cpu=} v.s {output_readonly.cpu()=}", + ) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_zch_lru_evict_with_offsets(self) -> None: + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + 200, + device=torch.device("cuda"), + long_type=True, + support_evict=True, + ) + + numbers_0_100 = torch.arange(0, 100, dtype=torch.int64, device="cuda") + local_sizes = torch.ones_like(numbers_0_100) * 100 + + cur_hour = 1000 + ttl = 24 + input_metadata = torch.full_like( + numbers_0_100, + ttl + cur_hour, # TTL 24h + dtype=torch.int32, + device="cuda", + ) + + output1, evict_slots1 = torch.ops.fbgemm.zero_collision_hash( + input=numbers_0_100, + identities=identities, + max_probe=100, + circular_probe=True, + metadata=metadata, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + input_metadata=input_metadata, + eviction_threshold=cur_hour, + local_sizes=local_sizes, + offsets=torch.zeros_like(numbers_0_100), + ) + + output2, evict_slots2 = torch.ops.fbgemm.zero_collision_hash( + input=numbers_0_100 + 100, + identities=identities, + max_probe=100, + circular_probe=True, + metadata=metadata, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + input_metadata=input_metadata, + eviction_threshold=cur_hour, + local_sizes=local_sizes, + offsets=torch.ones_like(numbers_0_100) * 100, + ) + + self.assertEqual( + torch.unique(output1).tolist(), + numbers_0_100.tolist(), + f"{torch.unique(output1).tolist()=} != {numbers_0_100.tolist()=}", + ) + + self.assertEqual(torch.unique(output2).tolist(), (numbers_0_100 + 100).tolist()) + # verify all the rows in each batch are occupied + self.assertTrue(torch.all(identities[0:99][0] != -1)) + self.assertTrue(torch.all(identities[100:199][0] != -1)) + + # no eviction + self.assertTrue(evict_slots1.numel() == 0) + self.assertTrue(evict_slots2.numel() == 0) + + # readonly lookup. + output_readonly, evict_slots = torch.ops.fbgemm.zero_collision_hash( + input=numbers_0_100 + 100, + identities=identities, + max_probe=100, + circular_probe=True, + exp_hours=-1, + readonly=True, + local_sizes=local_sizes, + offsets=torch.ones_like(numbers_0_100) * 100, + ) + self.assertTrue(torch.equal(output2, output_readonly)) + + # a tensor with 60 numbers with duplicates + random_numbers_200_250 = torch.randint( + 200, 250, (60,), dtype=torch.int64, device="cuda" + ) + # second input batch + random_numbers_300_350 = random_numbers_200_250 + 100 + + # modify metadata to set different timestamps in the range of [500, 1024) + metadata = torch.randint( + 500, 1024, (200, 1), dtype=torch.int32, device=metadata.device + ) + new_cur_hour = 1025 + new_input_metadata = torch.full_like( + random_numbers_200_250, + ttl + new_cur_hour, # TTL 24h + dtype=torch.int32, + device="cuda", + ) + + local_sizes2 = torch.ones_like(random_numbers_200_250) * 100 + # all rows were occupied, do evict for all input numbers + # evict by LRU + output3, evict_slots3 = torch.ops.fbgemm.zero_collision_hash( + input=random_numbers_200_250, + identities=identities, + max_probe=100, + circular_probe=True, + metadata=metadata, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + input_metadata=new_input_metadata, + eviction_threshold=new_cur_hour, + local_sizes=local_sizes2, + offsets=torch.zeros_like(random_numbers_200_250), + ) + + output4, evict_slots4 = torch.ops.fbgemm.zero_collision_hash( + input=random_numbers_300_350, + identities=identities, + max_probe=100, + circular_probe=True, + metadata=metadata, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + input_metadata=new_input_metadata, + eviction_threshold=new_cur_hour, + local_sizes=local_sizes2, + offsets=torch.ones_like(random_numbers_300_350) * 100, + ) + + # num of evicted slots may be < 60 when all slots were being probed/locked by other ids, this id will fall back to original slot (collide) and no eviction + self.assertLessEqual(evict_slots3.numel(), 60) + self.assertLessEqual(evict_slots4.numel(), 60) + + # verify index stored in evict_slot/output should within each batch's boundary + self.assertTrue(torch.all(evict_slots3 < 100) and torch.all(evict_slots3 >= 0)) + self.assertTrue( + torch.all(evict_slots4 < 200) and torch.all(evict_slots4 >= 100) + ) + self.assertTrue(torch.all(output3 < 100) and torch.all(output3 >= 0)) + self.assertTrue(torch.all(output4 < 200) and torch.all(output4 >= 100)) + + self.assertTrue( + set(evict_slots3.tolist()).issubset(set(output3.tolist())), + f"{evict_slots3=}, {torch.sort(output3)[0]=}", + ) + self.assertTrue( + set(evict_slots4.tolist()).issubset(set(output4.tolist())), + f"{evict_slots4=}, {torch.sort(output4)[0]=}", + ) + + # verify values stored in identities + first_half = identities.view(-1)[0:99] + second_half = identities.view(-1)[100:199] + self.assertTrue( + set(first_half[first_half >= 200].tolist()).issubset( + set(random_numbers_200_250.tolist()) + ), + f"{set(first_half[first_half >= 200].tolist())=}, {set(random_numbers_200_250.tolist())=}", + ) + self.assertTrue( + set(second_half[second_half >= 200].tolist()).issubset( + set(random_numbers_300_350.tolist()) + ), + f"{set(second_half[second_half >= 300].tolist())=}, {set(random_numbers_300_350.tolist())=}", + ) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_opt_in_with_prob(self) -> None: + zch_size = 100 + num_reserved_slots = 10 + num_opt_in_slots = zch_size - num_reserved_slots + opt_in_prob = 20 + + # without eviction + identities, _ = torch.ops.fbgemm.create_zch_buffer( + zch_size, support_evict=False, long_type=True, device=torch.device("cuda") + ) + numbers = torch.arange(0, 100, dtype=torch.int64, device="cuda") + opt_in_rands = torch.arange(0, 100, dtype=torch.int32, device="cuda") + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers, + identities, + 100, + circular_probe=True, + opt_in_prob=opt_in_prob, + num_reserved_slots=num_reserved_slots, + opt_in_rands=opt_in_rands, + ) + + self.assertTrue(torch.sum((output >= 0) & (output < num_opt_in_slots)) == 20) + self.assertTrue( + torch.sum((output >= num_opt_in_slots) & (output < zch_size)) == 80 + ) + identities_opt_in_slots = identities[:num_opt_in_slots] + identities_opt_in_slots_occupied = identities_opt_in_slots[ + identities_opt_in_slots != -1 + ] + self.assertTrue( + torch.equal( + torch.unique(identities_opt_in_slots_occupied), + torch.arange(0, 20, dtype=torch.int64, device="cuda"), + ) + ) + identities_reserved_slots = identities[num_opt_in_slots:] + self.assertTrue(torch.all(identities_reserved_slots == -1)) + + # with eviction + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + zch_size, support_evict=True, long_type=True, device=torch.device("cuda") + ) + numbers = torch.arange(0, 100, dtype=torch.int64, device="cuda") + opt_in_rands = torch.arange(0, 100, dtype=torch.int32, device="cuda") + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers, + identities, + 100, + circular_probe=True, + exp_hours=7 * 24, + metadata=metadata, + opt_in_prob=opt_in_prob, + num_reserved_slots=num_reserved_slots, + opt_in_rands=opt_in_rands, + ) + + self.assertTrue(torch.sum((output >= 0) & (output < num_opt_in_slots)) == 20) + self.assertTrue( + torch.sum((output >= num_opt_in_slots) & (output < zch_size)) == 80 + ) + identities_opt_in_slots = identities[:num_opt_in_slots] + identities_opt_in_slots_occupied = identities_opt_in_slots[ + identities_opt_in_slots != -1 + ] + self.assertTrue( + torch.equal( + torch.unique(identities_opt_in_slots_occupied), + torch.arange(0, 20, dtype=torch.int64, device="cuda"), + ) + ) + identities_reserved_slots = identities[num_opt_in_slots:] + self.assertTrue(torch.all(identities_reserved_slots == -1)) + + # readonly lookup + numbers_0_20 = torch.arange(0, 20, dtype=torch.int64, device="cuda") + output_readonly, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_0_20, + identities, + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + opt_in_prob=opt_in_prob, + num_reserved_slots=num_reserved_slots, + ) + self.assertTrue( + torch.all((output_readonly >= 0) & (output_readonly < num_opt_in_slots)) + ) + output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_0_20.cpu(), + identities.cpu(), + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + opt_in_prob=opt_in_prob, + num_reserved_slots=num_reserved_slots, + ) + self.assertTrue(torch.equal(output_readonly_cpu, output_readonly.cpu())) + + numbers_20_100 = torch.arange(20, 100, dtype=torch.int64, device="cuda") + output_readonly, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_20_100, + identities, + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + opt_in_prob=opt_in_prob, + num_reserved_slots=num_reserved_slots, + ) + self.assertTrue( + torch.all( + (output_readonly >= num_opt_in_slots) & (output_readonly < zch_size) + ) + ) + output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_20_100.cpu(), + identities.cpu(), + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + opt_in_prob=opt_in_prob, + num_reserved_slots=num_reserved_slots, + ) + self.assertTrue(torch.equal(output_readonly_cpu, output_readonly.cpu())) + + # fill in all slots in the opt-in block and start eviction + opt_in_rands = torch.full_like(numbers, 0, dtype=torch.int32, device="cuda") + torch.ops.fbgemm.zero_collision_hash( + numbers, + identities, + 100, + circular_probe=True, + exp_hours=7 * 24, + metadata=metadata, + opt_in_prob=opt_in_prob, + num_reserved_slots=num_reserved_slots, + opt_in_rands=opt_in_rands, + ) + identities_opt_in_slots = identities[:num_opt_in_slots] + self.assertTrue(torch.all(identities_opt_in_slots != -1)) + + metadata[:, 0] -= 7 * 24 + 1 + + # number 101/102 are expected to be probed in opt-in/preserved blocks, respectively + number_101_102 = torch.tensor([101, 102], dtype=torch.int64, device="cuda") + opt_in_rands_101_102 = torch.tensor([10, 80], dtype=torch.int32, device="cuda") + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + number_101_102, + identities, + 100, + circular_probe=True, + exp_hours=7 * 24, + metadata=metadata, + opt_in_prob=opt_in_prob, + num_reserved_slots=num_reserved_slots, + opt_in_rands=opt_in_rands_101_102, + ) + self.assertTrue(output[0] < num_opt_in_slots) + self.assertTrue(output[1] >= num_opt_in_slots) + self.assertTrue(evict_slots.numel() == 1) + self.assertTrue( + evict_slots[0] < num_opt_in_slots + ) # no eviction in reserved block + + output_readonly, _ = torch.ops.fbgemm.zero_collision_hash( + number_101_102, + identities, + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + opt_in_prob=opt_in_prob, + num_reserved_slots=num_reserved_slots, + ) + self.assertTrue(torch.equal(output_readonly, output)) + output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash( + number_101_102.cpu(), + identities.cpu(), + 100, + circular_probe=True, + exp_hours=-1, + readonly=True, + opt_in_prob=opt_in_prob, + num_reserved_slots=num_reserved_slots, + ) + self.assertTrue(torch.equal(output_readonly_cpu, output_readonly.cpu())) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_zch_lru_evict_train_eval(self) -> None: + identities, metadata = torch.ops.fbgemm.create_zch_buffer( + 100, support_evict=True, long_type=True, device=torch.device("cuda") + ) + numbers_0_100 = torch.arange(0, 100, dtype=torch.int64, device="cuda") + cur_hour = 1000 + ttl = 24 + input_metadata = torch.full_like( + numbers_0_100, + ttl + cur_hour, # TTL 24h + dtype=torch.int32, + device="cuda", + ) + output, evict_slots = torch.ops.fbgemm.zero_collision_hash( + numbers_0_100, + identities, + 100, + circular_probe=True, + metadata=metadata, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + input_metadata=input_metadata, + eviction_threshold=cur_hour, + ) + + self.assertTrue( + torch.equal( + torch.sort(identities[identities != -1].view(-1))[0], + numbers_0_100, + ), + f"{identities=}", + ) + self.assertTrue(evict_slots.numel() == 0) + + identities_copy = identities.detach().clone() + numbers_80_120 = torch.arange(80, 120, dtype=torch.int64, device="cuda") + # gpu - readonly lookup: eval + output_readonly, evictions = torch.ops.fbgemm.zero_collision_hash( + numbers_80_120, + identities, + 100, + circular_probe=True, + readonly=True, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + ) + + # check identities are not changed during readonly lookup + self.assertTrue( + torch.equal(identities_copy, identities), + f"{identities_copy=} v.s {identities=}", + ) + self.assertTrue(evictions is None) + + # [80, 100) will found at identities table, [100, 120) can't be found + for idx in range(0, 20): + self.assertEqual( + identities[output_readonly[idx]], + numbers_80_120[idx], + f"{idx=}, {identities=}, {output_readonly=}, {numbers_80_120[idx]=}", + ) + + for idx in range(20, 40): + self.assertNotEqual( + identities[output_readonly[idx]], + numbers_80_120[idx], + f"{idx=}, {identities=}, {output_readonly=}, {numbers_80_120[idx]=}", + ) + + output_readonly_cpu, _ = torch.ops.fbgemm.zero_collision_hash( + numbers_80_120.cpu(), + identities.cpu(), + 100, + circular_probe=True, + readonly=True, + eviction_policy=HashZchKernelEvictionPolicy.LRU_EVICTION.value, + ) + self.assertTrue( + torch.equal(output_readonly_cpu, output_readonly.cpu()), + f"{output_readonly_cpu=} v.s {output_readonly=}", + ) + + @skipIfRocm("The CUDA kernel is not supported on ROCm") + @unittest.skipIf(*gpu_unavailable) + def test_murmur_hash(self) -> None: + # # test on cpu + input_item = torch.tensor([10000], dtype=torch.int64) + output_item_first_round = torch.ops.fbgemm.murmur_hash3(input_item, 0, 0) + output_item_second_round = torch.ops.fbgemm.murmur_hash3(input_item, 0, 0) + self.assertTrue(torch.equal(output_item_first_round, output_item_second_round)) + # test on gpu + input_item = torch.tensor([10000], dtype=torch.int64, device="cuda") + output_item_first_round = torch.ops.fbgemm.murmur_hash3(input_item, 0, 0) + output_item_second_round = torch.ops.fbgemm.murmur_hash3(input_item, 0, 0) + self.assertTrue(torch.equal(output_item_first_round, output_item_second_round))