diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 800f0b46259..d33d9b6c31d 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -310,6 +310,12 @@ set(CUGRAPH_SOURCES src/sampling/detail/gather_one_hop_edgelist_mg_v32_e64.cu src/sampling/detail/remove_visited_vertices_from_frontier_sg_v32_e32.cu src/sampling/detail/remove_visited_vertices_from_frontier_sg_v64_e64.cu + src/sampling/detail/check_edge_bias_values_sg_v64_e64.cu + src/sampling/detail/check_edge_bias_values_sg_v32_e32.cu + src/sampling/detail/check_edge_bias_values_sg_v32_e64.cu + src/sampling/detail/check_edge_bias_values_mg_v64_e64.cu + src/sampling/detail/check_edge_bias_values_mg_v32_e32.cu + src/sampling/detail/check_edge_bias_values_mg_v32_e64.cu src/sampling/detail/sample_edges_sg_v64_e64.cu src/sampling/detail/sample_edges_sg_v32_e32.cu src/sampling/detail/sample_edges_sg_v32_e64.cu @@ -319,12 +325,12 @@ set(CUGRAPH_SOURCES src/sampling/detail/shuffle_and_organize_output_mg_v64_e64.cu src/sampling/detail/shuffle_and_organize_output_mg_v32_e32.cu src/sampling/detail/shuffle_and_organize_output_mg_v32_e64.cu - src/sampling/uniform_neighbor_sampling_mg_v32_e64.cpp - src/sampling/uniform_neighbor_sampling_mg_v32_e32.cpp - src/sampling/uniform_neighbor_sampling_mg_v64_e64.cpp - src/sampling/uniform_neighbor_sampling_sg_v32_e64.cpp - src/sampling/uniform_neighbor_sampling_sg_v32_e32.cpp - src/sampling/uniform_neighbor_sampling_sg_v64_e64.cpp + src/sampling/neighbor_sampling_mg_v32_e64.cpp + src/sampling/neighbor_sampling_mg_v32_e32.cpp + src/sampling/neighbor_sampling_mg_v64_e64.cpp + src/sampling/neighbor_sampling_sg_v32_e64.cpp + src/sampling/neighbor_sampling_sg_v32_e32.cpp + src/sampling/neighbor_sampling_sg_v64_e64.cpp src/sampling/renumber_sampled_edgelist_sg_v64_e64.cu src/sampling/renumber_sampled_edgelist_sg_v32_e32.cu src/sampling/sampling_post_processing_sg_v64_e64.cu diff --git a/cpp/include/cugraph/algorithms.hpp b/cpp/include/cugraph/algorithms.hpp index bce484ece20..4cf18f01310 100644 --- a/cpp/include/cugraph/algorithms.hpp +++ b/cpp/include/cugraph/algorithms.hpp @@ -1872,115 +1872,6 @@ k_core(raft::handle_t const& handle, std::optional> core_numbers, bool do_expensive_check = false); -/** - * @brief Controls how we treat prior sources in sampling - * - * @param DEFAULT Add vertices encounted while sampling to the new frontier - * @param CARRY_OVER In addition to newly encountered vertices, include vertices - * used as sources in any previous frontier in the new frontier - * @param EXCLUDE Filter the new frontier to exclude any vertex that was - * used as a source in a previous frontier - */ -enum class prior_sources_behavior_t { DEFAULT = 0, CARRY_OVER, EXCLUDE }; - -/** - * @brief Uniform Neighborhood Sampling. - * - * This function traverses from a set of starting vertices, traversing outgoing edges and - * randomly selects from these outgoing neighbors to extract a subgraph. - * - * Output from this function is a tuple of vectors (src, dst, weight, edge_id, edge_type, hop, - * label, offsets), identifying the randomly selected edges. src is the source vertex, dst is the - * destination vertex, weight (optional) is the edge weight, edge_id (optional) identifies the edge - * id, edge_type (optional) identifies the edge type, hop identifies which hop the edge was - * encountered in. The label output (optional) identifes the vertex label. The offsets array - * (optional) will be described below and is dependent upon the input parameters. - * - * - * If @p starting_vertex_labels is not specified then no organization is applied to the output, the - * label and offsets values in the return set will be std::nullopt. - * - * If @p starting_vertex_labels is specified and @p label_to_output_comm_rank is not specified then - * the label output has values. This will also result in the output being sorted by vertex label. - * The offsets array in the return will be a CSR-style offsets array to identify the beginning of - * each label range in the data. `labels.size() == (offsets.size() - 1)`. - * - * If @p starting_vertex_labels is specified and @p label_to_output_comm_rank is specified then the - * label output has values. This will also result in the output being sorted by vertex label. The - * offsets array in the return will be a CSR-style offsets array to identify the beginning of each - * label range in the data. `labels.size() == (offsets.size() - 1)`. Additionally, the data will - * be shuffled so that all data with a particular label will be on the specified rank. - * - * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. - * @tparam edge_t Type of edge identifiers. Needs to be an integral type. - * @tparam weight_t Type of edge weights. Needs to be a floating point type. - * @tparam edge_type_t Type of edge type. Needs to be an integral type. - * @tparam label_t Type of label. Needs to be an integral type. - * @tparam store_transposed Flag indicating whether sources (if false) or destinations (if - * true) are major indices - * @tparam multi_gpu Flag indicating whether template instantiation should target single-GPU (false) - * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and - * handles to various CUDA libraries) to run graph algorithms. - * @param graph_view Graph View object to generate NBR Sampling on. - * @param edge_weight_view Optional view object holding edge weights for @p graph_view. - * @param edge_id_view Optional view object holding edge ids for @p graph_view. - * @param edge_type_view Optional view object holding edge types for @p graph_view. - * @param starting_vertices Device span of starting vertex IDs for the sampling. - * In a multi-gpu context the starting vertices should be local to this GPU. - * @param starting_vertex_labels Optional device span of labels associted with each starting vertex - * for the sampling. - * @param label_to_output_comm_rank Optional tuple of device spans mapping label to a particular - * output rank. Element 0 of the tuple identifes the label, Element 1 of the tuple identifies the - * output rank. The label span must be sorted in ascending order. - * @param fan_out Host span defining branching out (fan-out) degree per source vertex for each - * level - * @param rng_state A pre-initialized raft::RngState object for generating random numbers - * @param return_hops boolean flag specifying if the hop information should be returned - * @param prior_sources_behavior Enum type defining how to handle prior sources, (defaults to - * DEFAULT) - * @param dedupe_sources boolean flag, if true then if a vertex v appears as a destination in hop X - * multiple times with the same label, it will only be passed once (for each label) as a source - * for the next hop. Default is false. - * @param with_replacement boolean flag specifying if random sampling is done with replacement - * (true); or, without replacement (false); default = true; - * @param do_expensive_check A flag to run expensive checks for input arguments (if set to `true`). - * @return tuple device vectors (vertex_t source_vertex, vertex_t destination_vertex, - * optional weight_t weight, optional edge_t edge id, optional edge_type_t edge type, - * optional int32_t hop, optional label_t label, optional size_t offsets) - */ -template -std::tuple, - rmm::device_uvector, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>> -uniform_neighbor_sample( - raft::handle_t const& handle, - graph_view_t const& graph_view, - std::optional> edge_weight_view, - std::optional> edge_id_view, - std::optional> edge_type_view, - raft::device_span starting_vertices, - std::optional> starting_vertex_labels, - std::optional, raft::device_span>> - label_to_output_comm_rank, - raft::host_span fan_out, - raft::random::RngState& rng_state, - bool return_hops, - bool with_replacement = true, - prior_sources_behavior_t prior_sources_behavior = prior_sources_behavior_t::DEFAULT, - bool dedupe_sources = false, - bool do_expensive_check = false); - /* * @brief Compute triangle counts. * diff --git a/cpp/include/cugraph/graph.hpp b/cpp/include/cugraph/graph.hpp index 0ccd0cbc6df..0607b39153d 100644 --- a/cpp/include/cugraph/graph.hpp +++ b/cpp/include/cugraph/graph.hpp @@ -319,11 +319,20 @@ struct invalid_idx< template struct invalid_vertex_id : invalid_idx {}; +template +inline constexpr vertex_t invalid_vertex_id_v = invalid_vertex_id::value; + template struct invalid_edge_id : invalid_idx {}; -template -struct invalid_component_id : invalid_idx {}; +template +inline constexpr edge_t invalid_edge_id_v = invalid_edge_id::value; + +template +struct invalid_component_id : invalid_idx {}; + +template +inline constexpr component_t invalid_component_id_v = invalid_component_id::value; template __host__ __device__ std::enable_if_t::value, bool> is_valid_vertex( diff --git a/cpp/include/cugraph/sampling_functions.hpp b/cpp/include/cugraph/sampling_functions.hpp index 971a0197d6f..fec1a07604e 100644 --- a/cpp/include/cugraph/sampling_functions.hpp +++ b/cpp/include/cugraph/sampling_functions.hpp @@ -15,10 +15,12 @@ */ #pragma once +#include #include #include #include +#include #include @@ -27,6 +29,217 @@ namespace cugraph { +/** + * @brief Controls how we treat prior sources in sampling + * + * @param DEFAULT Add vertices encountered while sampling to the new frontier + * @param CARRY_OVER In addition to newly encountered vertices, include vertices + * used as sources in any previous frontier in the new frontier + * @param EXCLUDE Filter the new frontier to exclude any vertex that was + * used as a source in a previous frontier + */ +enum class prior_sources_behavior_t { DEFAULT = 0, CARRY_OVER, EXCLUDE }; + +/** + * @brief Uniform Neighborhood Sampling. + * + * This function traverses from a set of starting vertices, traversing outgoing edges and + * randomly selects from these outgoing neighbors to extract a subgraph. + * + * Output from this function is a tuple of vectors (src, dst, weight, edge_id, edge_type, hop, + * label, offsets), identifying the randomly selected edges. src is the source vertex, dst is the + * destination vertex, weight (optional) is the edge weight, edge_id (optional) identifies the edge + * id, edge_type (optional) identifies the edge type, hop identifies which hop the edge was + * encountered in. The label output (optional) identifes the vertex label. The offsets array + * (optional) will be described below and is dependent upon the input parameters. + * + * If @p starting_vertex_labels is not specified then no organization is applied to the output, the + * label and offsets values in the return set will be std::nullopt. + * + * If @p starting_vertex_labels is specified and @p label_to_output_comm_rank is not specified then + * the label output has values. This will also result in the output being sorted by vertex label. + * The offsets array in the return will be a CSR-style offsets array to identify the beginning of + * each label range in the data. `labels.size() == (offsets.size() - 1)`. + * + * If @p starting_vertex_labels is specified and @p label_to_output_comm_rank is specified then the + * label output has values. This will also result in the output being sorted by vertex label. The + * offsets array in the return will be a CSR-style offsets array to identify the beginning of each + * label range in the data. `labels.size() == (offsets.size() - 1)`. Additionally, the data will + * be shuffled so that all data with a particular label will be on the specified rank. + * + * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. + * @tparam edge_t Type of edge identifiers. Needs to be an integral type. + * @tparam weight_t Type of edge weights. Needs to be a floating point type. + * @tparam edge_type_t Type of edge type. Needs to be an integral type. + * @tparam label_t Type of label. Needs to be an integral type. + * @tparam store_transposed Flag indicating whether sources (if false) or destinations (if + * true) are major indices + * @tparam multi_gpu Flag indicating whether template instantiation should target single-GPU (false) + * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and + * handles to various CUDA libraries) to run graph algorithms. + * @param graph_view Graph View object to generate NBR Sampling on. + * @param edge_weight_view Optional view object holding edge weights for @p graph_view. + * @param edge_id_view Optional view object holding edge ids for @p graph_view. + * @param edge_type_view Optional view object holding edge types for @p graph_view. + * @param starting_vertices Device span of starting vertex IDs for the sampling. + * In a multi-gpu context the starting vertices should be local to this GPU. + * @param starting_vertex_labels Optional device span of labels associted with each starting vertex + * for the sampling. + * @param label_to_output_comm_rank Optional tuple of device spans mapping label to a particular + * output rank. Element 0 of the tuple identifes the label, Element 1 of the tuple identifies the + * output rank. The label span must be sorted in ascending order. + * @param fan_out Host span defining branching out (fan-out) degree per source vertex for each + * level + * @param rng_state A pre-initialized raft::RngState object for generating random numbers + * @param return_hops boolean flag specifying if the hop information should be returned + * @param prior_sources_behavior Enum type defining how to handle prior sources, (defaults to + * DEFAULT) + * @param dedupe_sources boolean flag, if true then if a vertex v appears as a destination in hop X + * multiple times with the same label, it will only be passed once (for each label) as a source + * for the next hop. Default is false. + * @param with_replacement boolean flag specifying if random sampling is done with replacement + * (true); or, without replacement (false); default = true; + * @param do_expensive_check A flag to run expensive checks for input arguments (if set to `true`). + * @return tuple device vectors (vertex_t source_vertex, vertex_t destination_vertex, + * optional weight_t weight, optional edge_t edge id, optional edge_type_t edge type, + * optional int32_t hop, optional label_t label, optional size_t offsets) + */ +template +std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +uniform_neighbor_sample( + raft::handle_t const& handle, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional, raft::device_span>> + label_to_output_comm_rank, + raft::host_span fan_out, + raft::random::RngState& rng_state, + bool return_hops, + bool with_replacement = true, + prior_sources_behavior_t prior_sources_behavior = prior_sources_behavior_t::DEFAULT, + bool dedupe_sources = false, + bool do_expensive_check = false); + +/** + * @brief Biased Neighborhood Sampling. + * + * This function traverses from a set of starting vertices, traversing outgoing edges and + * randomly selects (with edge biases) from these outgoing neighbors to extract a subgraph. + * + * Output from this function is a tuple of vectors (src, dst, weight, edge_id, edge_type, hop, + * label, offsets), identifying the randomly selected edges. src is the source vertex, dst is the + * destination vertex, weight (optional) is the edge weight, edge_id (optional) identifies the edge + * id, edge_type (optional) identifies the edge type, hop identifies which hop the edge was + * encountered in. The label output (optional) identifes the vertex label. The offsets array + * (optional) will be described below and is dependent upon the input parameters. + * + * If @p starting_vertex_labels is not specified then no organization is applied to the output, the + * label and offsets values in the return set will be std::nullopt. + * + * If @p starting_vertex_labels is specified and @p label_to_output_comm_rank is not specified then + * the label output has values. This will also result in the output being sorted by vertex label. + * The offsets array in the return will be a CSR-style offsets array to identify the beginning of + * each label range in the data. `labels.size() == (offsets.size() - 1)`. + * + * If @p starting_vertex_labels is specified and @p label_to_output_comm_rank is specified then the + * label output has values. This will also result in the output being sorted by vertex label. The + * offsets array in the return will be a CSR-style offsets array to identify the beginning of each + * label range in the data. `labels.size() == (offsets.size() - 1)`. Additionally, the data will + * be shuffled so that all data with a particular label will be on the specified rank. + * + * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. + * @tparam edge_t Type of edge identifiers. Needs to be an integral type. + * @tparam weight_t Type of edge weights. Needs to be a floating point type. + * @tparam edge_type_t Type of edge type. Needs to be an integral type. + * @tparam label_t Type of label. Needs to be an integral type. + * @tparam store_transposed Flag indicating whether sources (if false) or destinations (if + * true) are major indices + * @tparam multi_gpu Flag indicating whether template instantiation should target single-GPU (false) + * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and + * handles to various CUDA libraries) to run graph algorithms. + * @param graph_view Graph View object to generate NBR Sampling on. + * @param edge_weight_view Optional view object holding edge weights for @p graph_view. + * @param edge_id_view Optional view object holding edge ids for @p graph_view. + * @param edge_type_view Optional view object holding edge types for @p graph_view. + * @param edge_bias_view View object holding edge biases (to be used in biased sampling) for @p + * graph_view. Bias values should be non-negative and the sum of edge bias values from any vertex + * should not exceed std::numeric_limits::max(). 0 bias value indicates that the + * corresponding edge can never be selected. + * @param starting_vertices Device span of starting vertex IDs for the sampling. + * In a multi-gpu context the starting vertices should be local to this GPU. + * @param starting_vertex_labels Optional device span of labels associted with each starting vertex + * for the sampling. + * @param label_to_output_comm_rank Optional tuple of device spans mapping label to a particular + * output rank. Element 0 of the tuple identifes the label, Element 1 of the tuple identifies the + * output rank. The label span must be sorted in ascending order. + * @param fan_out Host span defining branching out (fan-out) degree per source vertex for each + * level + * @param rng_state A pre-initialized raft::RngState object for generating random numbers + * @param return_hops boolean flag specifying if the hop information should be returned + * @param prior_sources_behavior Enum type defining how to handle prior sources, (defaults to + * DEFAULT) + * @param dedupe_sources boolean flag, if true then if a vertex v appears as a destination in hop X + * multiple times with the same label, it will only be passed once (for each label) as a source + * for the next hop. Default is false. + * @param with_replacement boolean flag specifying if random sampling is done with replacement + * (true); or, without replacement (false); default = true; + * @param do_expensive_check A flag to run expensive checks for input arguments (if set to `true`). + * @return tuple device vectors (vertex_t source_vertex, vertex_t destination_vertex, + * optional weight_t weight, optional edge_t edge id, optional edge_type_t edge type, + * optional int32_t hop, optional label_t label, optional size_t offsets) + */ +template +std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +biased_neighbor_sample( + raft::handle_t const& handle, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional, raft::device_span>> + label_to_output_comm_rank, + raft::host_span fan_out, + raft::random::RngState& rng_state, + bool return_hops, + bool with_replacement = true, + prior_sources_behavior_t prior_sources_behavior = prior_sources_behavior_t::DEFAULT, + bool dedupe_sources = false, + bool do_expensive_check = false); + /* * @brief renumber sampled edge list and compress to the (D)CSR|(D)CSC format. * diff --git a/cpp/include/cugraph/utilities/dataframe_buffer.hpp b/cpp/include/cugraph/utilities/dataframe_buffer.hpp index 450e816bd96..ab4c4eff6b5 100644 --- a/cpp/include/cugraph/utilities/dataframe_buffer.hpp +++ b/cpp/include/cugraph/utilities/dataframe_buffer.hpp @@ -68,24 +68,6 @@ auto get_dataframe_buffer_cend_tuple_impl(std::index_sequence, TupleType& } // namespace detail -template -struct dataframe_element { - using type = void; -}; - -template -struct dataframe_element...>> { - using type = thrust::tuple; -}; - -template -struct dataframe_element> { - using type = T; -}; - -template -using dataframe_element_t = typename dataframe_element::type; - template ::value>* = nullptr> auto allocate_dataframe_buffer(size_t buffer_size, rmm::cuda_stream_view stream_view) { @@ -207,14 +189,6 @@ auto get_dataframe_buffer_end(BufferType& buffer) std::make_index_sequence::value>(), buffer); } -template -struct dataframe_buffer_type { - using type = decltype(allocate_dataframe_buffer(size_t{0}, rmm::cuda_stream_view{})); -}; - -template -using dataframe_buffer_type_t = typename dataframe_buffer_type::type; - template , rmm::device_uvector>::value>* = nullptr> @@ -232,4 +206,30 @@ auto get_dataframe_buffer_cend(BufferType& buffer) std::make_index_sequence::value>(), buffer); } +template +struct dataframe_buffer_value_type { + using type = void; +}; + +template +struct dataframe_buffer_value_type> { + using type = T; +}; + +template +struct dataframe_buffer_value_type...>> { + using type = thrust::tuple; +}; + +template +using dataframe_buffer_value_type_t = typename dataframe_buffer_value_type::type; + +template +struct dataframe_buffer_type { + using type = decltype(allocate_dataframe_buffer(size_t{0}, rmm::cuda_stream_view{})); +}; + +template +using dataframe_buffer_type_t = typename dataframe_buffer_type::type; + } // namespace cugraph diff --git a/cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh b/cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh index ba6f7dea040..64b6aab9baf 100644 --- a/cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh +++ b/cpp/src/prims/detail/sample_and_compute_local_nbr_indices.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -111,7 +112,6 @@ struct convert_pair_to_quadruplet_t { raft::device_span partitioned_local_value_displacements{}; // one partition per gpu in the same minor_comm raft::device_span tx_counts{}; - size_t stride{}; int minor_comm_size{}; value_t invalid_value{}; @@ -162,15 +162,17 @@ struct find_nth_valid_nbr_idx_t { edge_partition_device_view_t edge_partition{}; EdgePartitionEdgeMaskWrapper edge_partition_e_mask; VertexIterator major_first{}; + raft::device_span major_idx_to_unique_major_idx{}; thrust::tuple, raft::device_span> - major_valid_local_nbr_count_inclusive_sums{}; + unique_major_valid_local_nbr_count_inclusive_sums{}; __device__ edge_t operator()(thrust::tuple pair) const { - edge_t local_nbr_idx = thrust::get<0>(pair); - size_t major_idx = thrust::get<1>(pair); - auto major = *(major_first + major_idx); - auto major_offset = edge_partition.major_offset_from_major_nocheck(major); + edge_t local_nbr_idx = thrust::get<0>(pair); + size_t major_idx = thrust::get<1>(pair); + size_t unique_major_idx = major_idx_to_unique_major_idx[major_idx]; + auto major = *(major_first + major_idx); + auto major_offset = edge_partition.major_offset_from_major_nocheck(major); vertex_t const* indices{nullptr}; edge_t edge_offset{0}; [[maybe_unused]] edge_t local_degree{0}; @@ -194,9 +196,12 @@ struct find_nth_valid_nbr_idx_t { local_nbr_idx = find_nth_set_bits( (*edge_partition_e_mask).value_first(), edge_offset, local_degree, local_nbr_idx + 1); } else { - auto inclusive_sum_first = thrust::get<1>(major_valid_local_nbr_count_inclusive_sums).begin(); - auto start_offset = thrust::get<0>(major_valid_local_nbr_count_inclusive_sums)[major_idx]; - auto end_offset = thrust::get<0>(major_valid_local_nbr_count_inclusive_sums)[major_idx + 1]; + auto inclusive_sum_first = + thrust::get<1>(unique_major_valid_local_nbr_count_inclusive_sums).begin(); + auto start_offset = + thrust::get<0>(unique_major_valid_local_nbr_count_inclusive_sums)[unique_major_idx]; + auto end_offset = + thrust::get<0>(unique_major_valid_local_nbr_count_inclusive_sums)[unique_major_idx + 1]; auto word_idx = static_cast(thrust::distance(inclusive_sum_first + start_offset, thrust::upper_bound(thrust::seq, @@ -326,6 +331,83 @@ __global__ static void compute_valid_local_nbr_count_inclusive_sums_high_local_d } } +template +std::tuple::value_type>, + rmm::device_uvector, + std::vector, + std::vector> +compute_unique_keys(raft::handle_t const& handle, + KeyIterator aggregate_local_frontier_key_first, + std::vector const& local_frontier_displacements, + std::vector const& local_frontier_sizes) +{ + using key_t = typename thrust::iterator_traits::value_type; + + auto aggregate_local_frontier_unique_keys = + allocate_dataframe_buffer(0, handle.get_stream()); + auto aggregate_local_frontier_key_idx_to_unique_key_idx = rmm::device_uvector( + local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()); + auto local_frontier_unique_key_displacements = + std::vector(local_frontier_displacements.size()); + auto local_frontier_unique_key_sizes = std::vector(local_frontier_sizes.size()); + + auto tmp_keys = allocate_dataframe_buffer( + local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()); + for (size_t i = 0; i < local_frontier_displacements.size(); ++i) { + thrust::copy(handle.get_thrust_policy(), + aggregate_local_frontier_key_first + local_frontier_displacements[i], + aggregate_local_frontier_key_first + local_frontier_displacements[i] + + local_frontier_sizes[i], + get_dataframe_buffer_begin(tmp_keys) + local_frontier_displacements[i]); + thrust::sort(handle.get_thrust_policy(), + get_dataframe_buffer_begin(tmp_keys) + local_frontier_displacements[i], + get_dataframe_buffer_begin(tmp_keys) + local_frontier_displacements[i] + + local_frontier_sizes[i]); + local_frontier_unique_key_sizes[i] = thrust::distance( + get_dataframe_buffer_begin(tmp_keys) + local_frontier_displacements[i], + thrust::unique(handle.get_thrust_policy(), + get_dataframe_buffer_begin(tmp_keys) + local_frontier_displacements[i], + get_dataframe_buffer_begin(tmp_keys) + local_frontier_displacements[i] + + local_frontier_sizes[i])); + } + std::exclusive_scan(local_frontier_unique_key_sizes.begin(), + local_frontier_unique_key_sizes.end(), + local_frontier_unique_key_displacements.begin(), + size_t{0}); + resize_dataframe_buffer( + aggregate_local_frontier_unique_keys, + local_frontier_unique_key_displacements.back() + local_frontier_unique_key_sizes.back(), + handle.get_stream()); + for (size_t i = 0; i < local_frontier_displacements.size(); ++i) { + thrust::copy(handle.get_thrust_policy(), + get_dataframe_buffer_begin(tmp_keys) + local_frontier_displacements[i], + get_dataframe_buffer_begin(tmp_keys) + local_frontier_displacements[i] + + local_frontier_unique_key_sizes[i], + get_dataframe_buffer_begin(aggregate_local_frontier_unique_keys) + + local_frontier_unique_key_displacements[i]); + thrust::transform( + handle.get_thrust_policy(), + aggregate_local_frontier_key_first + local_frontier_displacements[i], + aggregate_local_frontier_key_first + local_frontier_displacements[i] + + local_frontier_sizes[i], + aggregate_local_frontier_key_idx_to_unique_key_idx.begin() + local_frontier_displacements[i], + cuda::proclaim_return_type( + [unique_key_first = get_dataframe_buffer_begin(aggregate_local_frontier_unique_keys) + + local_frontier_unique_key_displacements[i], + num_unique_keys = local_frontier_unique_key_sizes[i]] __device__(key_t key) { + return static_cast(thrust::distance( + unique_key_first, + thrust::lower_bound( + thrust::seq, unique_key_first, unique_key_first + num_unique_keys, key))); + })); + } + + return std::make_tuple(std::move(aggregate_local_frontier_unique_keys), + std::move(aggregate_local_frontier_key_idx_to_unique_key_idx), + std::move(local_frontier_unique_key_displacements), + std::move(local_frontier_unique_key_sizes)); +} + template std::tuple, rmm::device_uvector> compute_frontier_value_sums_and_partitioned_local_value_sum_displacements( @@ -518,6 +600,8 @@ rmm::device_uvector compute_uniform_sampling_index_without_replacement( size_t K) { #ifndef NO_CUGRAPH_OPS + assert(cugraph::invalid_edge_id_v == cugraph::ops::graph::INVALID_ID); + edge_t mid_partition_degree_range_last = static_cast(K * 10); // tuning parameter assert(mid_partition_degree_range_last > K); size_t high_partition_oversampling_K = K * 2; // tuning parameter @@ -543,7 +627,7 @@ rmm::device_uvector compute_uniform_sampling_index_without_replacement( frontier_degrees = raft::device_span(frontier_degrees.data(), frontier_degrees.size()), nbr_indices = raft::device_span(nbr_indices.data(), nbr_indices.size()), - invalid_idx = cugraph::ops::graph::INVALID_ID] __device__(size_t i) { + invalid_idx = cugraph::invalid_edge_id_v] __device__(size_t i) { auto frontier_idx = frontier_indices[i / K]; auto degree = frontier_degrees[frontier_idx]; auto sample_idx = static_cast(i % K); @@ -905,7 +989,7 @@ template void compute_biased_sampling_index_without_replacement( raft::handle_t const& handle, std::optional> - input_frontier_indices, // input_biases & input_degree_offsets + input_frontier_indices, // input_degree_offsets & input_biases // are already packed if std::nullopt raft::device_span input_degree_offsets, raft::device_span input_biases, // bias 0 edges can't be selected @@ -951,175 +1035,204 @@ void compute_biased_sampling_index_without_replacement( 1, handle.get_stream()); handle.sync_stream(); - rmm::device_uvector keys(num_pairs, handle.get_stream()); - cugraph::detail::uniform_random_fill( - handle.get_stream(), keys.data(), keys.size(), bias_t{0.0}, bias_t{1.0}, rng_state); - - if (input_frontier_indices) { - auto bias_first = thrust::make_transform_iterator( - thrust::make_counting_iterator(size_t{0}), - cuda::proclaim_return_type( - [input_biases, - input_degree_offsets, - frontier_indices = *input_frontier_indices, - packed_input_degree_offsets = raft::device_span( - (*packed_input_degree_offsets).data(), - (*packed_input_degree_offsets).size())] __device__(size_t i) { - auto it = thrust::upper_bound(thrust::seq, - packed_input_degree_offsets.begin() + 1, - packed_input_degree_offsets.end(), - i); - auto idx = thrust::distance(packed_input_degree_offsets.begin() + 1, it); - auto frontier_idx = frontier_indices[idx]; - return input_biases[input_degree_offsets[frontier_idx] + - (i - packed_input_degree_offsets[idx])]; + auto approx_edges_to_process_per_iteration = + static_cast(handle.get_device_properties().multiProcessorCount) * + (1 << 18) /* tuning parameter */; + auto [chunk_offsets, element_offsets] = cugraph::detail::compute_offset_aligned_element_chunks( + handle, + raft::device_span( + packed_input_degree_offsets ? (*packed_input_degree_offsets).data() + : input_degree_offsets.data(), + packed_input_degree_offsets ? (*packed_input_degree_offsets).size() + : input_degree_offsets.size()), + num_pairs, + approx_edges_to_process_per_iteration); + auto num_chunks = chunk_offsets.size() - 1; + for (size_t i = 0; i < num_chunks; ++i) { + auto num_chunk_pairs = element_offsets[i + 1] - element_offsets[i]; + rmm::device_uvector keys(num_chunk_pairs, handle.get_stream()); + + cugraph::detail::uniform_random_fill( + handle.get_stream(), keys.data(), keys.size(), bias_t{0.0}, bias_t{1.0}, rng_state); + + if (packed_input_degree_offsets) { + auto bias_first = thrust::make_transform_iterator( + thrust::make_counting_iterator(element_offsets[i]), + cuda::proclaim_return_type( + [input_biases, + input_degree_offsets, + frontier_indices = *input_frontier_indices, + packed_input_degree_offsets = raft::device_span( + (*packed_input_degree_offsets).data(), + (*packed_input_degree_offsets).size())] __device__(size_t i) { + auto it = thrust::upper_bound(thrust::seq, + packed_input_degree_offsets.begin() + 1, + packed_input_degree_offsets.end(), + i); + auto idx = thrust::distance(packed_input_degree_offsets.begin() + 1, it); + auto frontier_idx = frontier_indices[idx]; + return input_biases[input_degree_offsets[frontier_idx] + + (i - packed_input_degree_offsets[idx])]; + })); + thrust::transform( + handle.get_thrust_policy(), + keys.begin(), + keys.end(), + bias_first, + keys.begin(), + cuda::proclaim_return_type([] __device__(bias_t r, bias_t b) { + return b > 0.0 + ? cuda::std::min(-log(r) / b, std::numeric_limits::max()) + : std::numeric_limits< + bias_t>::infinity() /* inf used as invalid value (can't be selected) */; })); - thrust::transform( - handle.get_thrust_policy(), - keys.begin(), - keys.end(), - bias_first, - keys.begin(), - cuda::proclaim_return_type([] __device__(bias_t r, bias_t b) { - return b > 0.0 - ? cuda::std::min(-log(r) / b, std::numeric_limits::max()) - : std::numeric_limits< - bias_t>::infinity() /* inf used as invalid value (can't be selected) */; - })); - } else { - thrust::transform(handle.get_thrust_policy(), - keys.begin(), - keys.end(), - input_biases.begin(), - keys.begin(), - cuda::proclaim_return_type([] __device__(bias_t r, bias_t b) { - return b > 0.0 - ? cuda::std::min(-log(r) / b, std::numeric_limits::max()) - : std::numeric_limits::infinity() - /* inf used as invalid value (can't be selected) */; - })); - } - - rmm::device_uvector nbr_indices(keys.size(), handle.get_stream()); - thrust::tabulate( - handle.get_thrust_policy(), - nbr_indices.begin(), - nbr_indices.end(), - [offsets = packed_input_degree_offsets - ? raft::device_span((*packed_input_degree_offsets).data(), - (*packed_input_degree_offsets).size()) - : input_degree_offsets] __device__(size_t i) { - auto it = thrust::upper_bound(thrust::seq, offsets.begin() + 1, offsets.end(), i); - auto idx = thrust::distance(offsets.begin() + 1, it); - return static_cast(i - offsets[idx]); - }); - - // pick top K for each frontier index + } else { + thrust::transform(handle.get_thrust_policy(), + keys.begin(), + keys.end(), + input_biases.begin() + element_offsets[i], + keys.begin(), + cuda::proclaim_return_type([] __device__(bias_t r, bias_t b) { + return b > 0.0 ? cuda::std::min(-log(r) / b, + std::numeric_limits::max()) + : std::numeric_limits::infinity() + /* inf used as invalid value (can't be selected) */; + })); + } - rmm::device_uvector d_tmp_storage(0, handle.get_stream()); - size_t tmp_storage_bytes{0}; + rmm::device_uvector nbr_indices(keys.size(), handle.get_stream()); + thrust::tabulate( + handle.get_thrust_policy(), + nbr_indices.begin(), + nbr_indices.end(), + [offsets = packed_input_degree_offsets + ? raft::device_span((*packed_input_degree_offsets).data(), + (*packed_input_degree_offsets).size()) + : input_degree_offsets, + element_offset = element_offsets[i]] __device__(size_t i) { + auto it = thrust::upper_bound( + thrust::seq, offsets.begin() + 1, offsets.end(), element_offset + i); + auto idx = thrust::distance(offsets.begin() + 1, it); + return static_cast((element_offset + i) - offsets[idx]); + }); - rmm::device_uvector segment_sorted_keys(keys.size(), handle.get_stream()); - rmm::device_uvector segment_sorted_nbr_indices(nbr_indices.size(), handle.get_stream()); - - cub::DeviceSegmentedSort::SortPairs( - static_cast(nullptr), - tmp_storage_bytes, - keys.data(), - segment_sorted_keys.data(), - nbr_indices.data(), - segment_sorted_nbr_indices.data(), - keys.size(), - input_frontier_indices ? (*input_frontier_indices).size() : (input_degree_offsets.size() - 1), - packed_input_degree_offsets ? (*packed_input_degree_offsets).begin() - : input_degree_offsets.begin(), - (packed_input_degree_offsets ? (*packed_input_degree_offsets).begin() - : input_degree_offsets.begin()) + - 1, - handle.get_stream()); - if (tmp_storage_bytes > d_tmp_storage.size()) { - d_tmp_storage = rmm::device_uvector(tmp_storage_bytes, handle.get_stream()); - } - cub::DeviceSegmentedSort::SortPairs( - d_tmp_storage.data(), - tmp_storage_bytes, - keys.data(), - segment_sorted_keys.data(), - nbr_indices.data(), - segment_sorted_nbr_indices.data(), - keys.size(), - input_frontier_indices ? (*input_frontier_indices).size() : input_degree_offsets.size() - 1, - packed_input_degree_offsets ? (*packed_input_degree_offsets).begin() - : input_degree_offsets.begin(), - (packed_input_degree_offsets ? (*packed_input_degree_offsets).begin() - : input_degree_offsets.begin()) + - 1, - handle.get_stream()); + // pick top K for each frontier index + + rmm::device_uvector d_tmp_storage(0, handle.get_stream()); + size_t tmp_storage_bytes{0}; + + rmm::device_uvector segment_sorted_keys(keys.size(), handle.get_stream()); + rmm::device_uvector segment_sorted_nbr_indices(nbr_indices.size(), + handle.get_stream()); + + auto offset_first = thrust::make_transform_iterator( + (packed_input_degree_offsets ? (*packed_input_degree_offsets).begin() + : input_degree_offsets.begin()) + + chunk_offsets[i], + detail::shift_left_t{element_offsets[i]}); + cub::DeviceSegmentedSort::SortPairs(static_cast(nullptr), + tmp_storage_bytes, + keys.data(), + segment_sorted_keys.data(), + nbr_indices.data(), + segment_sorted_nbr_indices.data(), + keys.size(), + chunk_offsets[i + 1] - chunk_offsets[i], + offset_first, + offset_first + 1, + handle.get_stream()); + if (tmp_storage_bytes > d_tmp_storage.size()) { + d_tmp_storage = rmm::device_uvector(tmp_storage_bytes, handle.get_stream()); + } + cub::DeviceSegmentedSort::SortPairs(d_tmp_storage.data(), + tmp_storage_bytes, + keys.data(), + segment_sorted_keys.data(), + nbr_indices.data(), + segment_sorted_nbr_indices.data(), + keys.size(), + chunk_offsets[i + 1] - chunk_offsets[i], + offset_first, + offset_first + 1, + handle.get_stream()); - if (output_frontier_indices) { - thrust::for_each( - handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator((*output_frontier_indices).size() * K), - [input_degree_offsets = - packed_input_degree_offsets - ? raft::device_span((*packed_input_degree_offsets).data(), - (*packed_input_degree_offsets).size()) - : input_degree_offsets, - output_frontier_indices = *output_frontier_indices, - output_keys, - output_nbr_indices, - segment_sorted_keys = - raft::device_span(segment_sorted_keys.data(), segment_sorted_keys.size()), - segment_sorted_nbr_indices = raft::device_span( - segment_sorted_nbr_indices.data(), segment_sorted_nbr_indices.size()), - K, - invalid_idx = cugraph::ops::graph::INVALID_ID] __device__(size_t i) { - auto output_frontier_idx = output_frontier_indices[i / K]; - auto output_idx = output_frontier_idx * K + (i % K); - auto degree = input_degree_offsets[i / K + 1] - input_degree_offsets[i / K]; - auto input_idx = input_degree_offsets[i / K] + (i % K); - if ((i % K < degree) && - (segment_sorted_keys[input_idx] < std::numeric_limits::infinity())) { - if (output_keys) { (*output_keys)[output_idx] = segment_sorted_keys[input_idx]; } - output_nbr_indices[output_idx] = segment_sorted_nbr_indices[input_idx]; - } else { - if (output_keys) { - (*output_keys)[output_idx] = std::numeric_limits::infinity(); + if (output_frontier_indices) { + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator((chunk_offsets[i + 1] - chunk_offsets[i]) * K), + [input_degree_offsets = + packed_input_degree_offsets + ? raft::device_span((*packed_input_degree_offsets).data(), + (*packed_input_degree_offsets).size()) + : input_degree_offsets, + idx_offset = chunk_offsets[i] * K, + output_frontier_indices = *output_frontier_indices, + output_keys, + output_nbr_indices, + segment_sorted_keys = raft::device_span(segment_sorted_keys.data(), + segment_sorted_keys.size()), + segment_sorted_nbr_indices = raft::device_span( + segment_sorted_nbr_indices.data(), segment_sorted_nbr_indices.size()), + K, + invalid_idx = cugraph::invalid_edge_id_v] __device__(size_t i) { + auto idx = idx_offset + i; + auto key_idx = idx / K; + auto output_frontier_idx = output_frontier_indices[key_idx]; + auto output_idx = output_frontier_idx * K + (idx % K); + auto degree = input_degree_offsets[key_idx + 1] - input_degree_offsets[key_idx]; + auto segment_sorted_input_idx = + (input_degree_offsets[key_idx] - input_degree_offsets[idx_offset / K]) + (idx % K); + if (((idx % K) < degree) && (segment_sorted_keys[segment_sorted_input_idx] < + std::numeric_limits::infinity())) { + if (output_keys) { + (*output_keys)[output_idx] = segment_sorted_keys[segment_sorted_input_idx]; + } + output_nbr_indices[output_idx] = segment_sorted_nbr_indices[segment_sorted_input_idx]; + } else { + if (output_keys) { + (*output_keys)[output_idx] = std::numeric_limits::infinity(); + } + output_nbr_indices[output_idx] = invalid_idx; } - output_nbr_indices[output_idx] = invalid_idx; - } - }); - } else { - thrust::for_each( - handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(output_nbr_indices.size()), - [input_degree_offsets = - packed_input_degree_offsets - ? raft::device_span((*packed_input_degree_offsets).data(), - (*packed_input_degree_offsets).size()) - : input_degree_offsets, - output_keys, - output_nbr_indices, - segment_sorted_keys = - raft::device_span(segment_sorted_keys.data(), segment_sorted_keys.size()), - segment_sorted_nbr_indices = raft::device_span( - segment_sorted_nbr_indices.data(), segment_sorted_nbr_indices.size()), - K, - invalid_idx = cugraph::ops::graph::INVALID_ID] __device__(size_t i) { - auto degree = input_degree_offsets[i / K + 1] - input_degree_offsets[i / K]; - auto input_idx = input_degree_offsets[i / K] + (i % K); - if ((i % K < degree) && - (segment_sorted_keys[input_idx] < std::numeric_limits::infinity())) { - if (output_keys) { (*output_keys)[i] = segment_sorted_keys[input_idx]; } - output_nbr_indices[i] = segment_sorted_nbr_indices[input_idx]; - } else { - if (output_keys) { (*output_keys)[i] = std::numeric_limits::infinity(); } - output_nbr_indices[i] = invalid_idx; - } - }); + }); + } else { + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator((chunk_offsets[i + 1] - chunk_offsets[i]) * K), + [input_degree_offsets = + packed_input_degree_offsets + ? raft::device_span((*packed_input_degree_offsets).data(), + (*packed_input_degree_offsets).size()) + : input_degree_offsets, + idx_offset = chunk_offsets[i] * K, + output_keys, + output_nbr_indices, + segment_sorted_keys = raft::device_span(segment_sorted_keys.data(), + segment_sorted_keys.size()), + segment_sorted_nbr_indices = raft::device_span( + segment_sorted_nbr_indices.data(), segment_sorted_nbr_indices.size()), + K, + invalid_idx = cugraph::invalid_edge_id_v] __device__(size_t i) { + auto idx = idx_offset + i; + auto key_idx = idx / K; + auto degree = input_degree_offsets[key_idx + 1] - input_degree_offsets[key_idx]; + auto segment_sorted_input_idx = + (input_degree_offsets[key_idx] - input_degree_offsets[idx_offset / K]) + (idx % K); + if (((idx % K) < degree) && (segment_sorted_keys[segment_sorted_input_idx] < + std::numeric_limits::infinity())) { + if (output_keys) { + (*output_keys)[idx] = segment_sorted_keys[segment_sorted_input_idx]; + } + output_nbr_indices[idx] = segment_sorted_nbr_indices[segment_sorted_input_idx]; + } else { + if (output_keys) { (*output_keys)[idx] = std::numeric_limits::infinity(); } + output_nbr_indices[idx] = invalid_idx; + } + }); + } } } @@ -1256,16 +1369,12 @@ shuffle_and_compute_local_nbr_values(raft::handle_t const& handle, rmm::device_uvector&& sample_nbr_values, std::optional> frontier_partitioned_value_local_sum_displacements, - std::vector const& local_frontier_displacements, - std::vector const& local_frontier_sizes, size_t K, value_t invalid_value) { - int minor_comm_rank{0}; int minor_comm_size{1}; if constexpr (multi_gpu) { auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); - minor_comm_rank = minor_comm.get_rank(); minor_comm_size = minor_comm.get_size(); } @@ -1302,7 +1411,6 @@ shuffle_and_compute_local_nbr_values(raft::handle_t const& handle, (*frontier_partitioned_value_local_sum_displacements).data(), (*frontier_partitioned_value_local_sum_displacements).size()), raft::device_span(d_tx_counts.data(), d_tx_counts.size()), - local_frontier_sizes[minor_comm_rank], minor_comm_size, invalid_value}); rmm::device_uvector tx_displacements(minor_comm_size, handle.get_stream()); @@ -1352,8 +1460,7 @@ shuffle_and_compute_local_nbr_values(raft::handle_t const& handle, std::inclusive_scan( rx_counts.begin(), rx_counts.end(), local_frontier_sample_offsets.begin() + 1); } else { - local_frontier_sample_offsets = - std::vector{size_t{0}, local_frontier_sizes[minor_comm_rank] * K}; + local_frontier_sample_offsets = std::vector{size_t{0}, sample_local_nbr_values.size()}; } return std::make_tuple(std::move(sample_local_nbr_values), @@ -1361,7 +1468,7 @@ shuffle_and_compute_local_nbr_values(raft::handle_t const& handle, std::move(local_frontier_sample_offsets)); } -// skip conversion if local neighbor index is cugraph::ops::graph::INVALID_ID +// skip conversion if local neighbor index is cugraph::invalid_edge_id_v template rmm::device_uvector convert_to_unmasked_local_nbr_idx( raft::handle_t const& handle, @@ -1381,13 +1488,22 @@ rmm::device_uvector convert_to_unmasked_local auto edge_mask_view = graph_view.edge_mask_view(); + auto [aggregate_local_frontier_unique_majors, + aggregate_local_frontier_major_idx_to_unique_major_idx, + local_frontier_unique_major_displacements, + local_frontier_unique_major_sizes] = + compute_unique_keys(handle, + aggregate_local_frontier_major_first, + local_frontier_displacements, + local_frontier_sizes); + // to avoid searching the entire neighbor list K times for high degree vertices with edge masking - auto local_frontier_valid_local_nbr_count_inclusive_sums = + auto local_frontier_unique_major_valid_local_nbr_count_inclusive_sums = compute_valid_local_nbr_count_inclusive_sums(handle, graph_view, - aggregate_local_frontier_major_first, - local_frontier_displacements, - local_frontier_sizes); + aggregate_local_frontier_unique_majors.begin(), + local_frontier_unique_major_displacements, + local_frontier_unique_major_sizes); auto sample_major_idx_first = thrust::make_transform_iterator( thrust::make_counting_iterator(size_t{0}), @@ -1424,14 +1540,20 @@ rmm::device_uvector convert_to_unmasked_local edge_partition, edge_partition_e_mask, edge_partition_frontier_major_first, + raft::device_span( + aggregate_local_frontier_major_idx_to_unique_major_idx.data() + + local_frontier_displacements[i], + local_frontier_sizes[i]), thrust::make_tuple( raft::device_span( - std::get<0>(local_frontier_valid_local_nbr_count_inclusive_sums[i]).data(), - std::get<0>(local_frontier_valid_local_nbr_count_inclusive_sums[i]).size()), + std::get<0>(local_frontier_unique_major_valid_local_nbr_count_inclusive_sums[i]).data(), + std::get<0>(local_frontier_unique_major_valid_local_nbr_count_inclusive_sums[i]) + .size()), raft::device_span( - std::get<1>(local_frontier_valid_local_nbr_count_inclusive_sums[i]).data(), - std::get<1>(local_frontier_valid_local_nbr_count_inclusive_sums[i]).size()))}, - is_not_equal_t{cugraph::ops::graph::INVALID_ID}); + std::get<1>(local_frontier_unique_major_valid_local_nbr_count_inclusive_sums[i]).data(), + std::get<1>(local_frontier_unique_major_valid_local_nbr_count_inclusive_sums[i]) + .size()))}, + is_not_equal_t{cugraph::invalid_edge_id_v}); } return std::move(local_nbr_indices); @@ -1451,6 +1573,9 @@ uniform_sample_and_compute_local_nbr_indices( size_t K, bool with_replacement) { +#ifndef NO_CUGRAPH_OPS + assert(cugraph::invalid_edge_id_v == cugraph::ops::graph::INVALID_ID); + using vertex_t = typename GraphViewType::vertex_type; using edge_t = typename GraphViewType::edge_type; using key_t = typename thrust::iterator_traits::value_type; @@ -1527,10 +1652,8 @@ uniform_sample_and_compute_local_nbr_indices( (*frontier_partitioned_local_degree_displacements).data(), (*frontier_partitioned_local_degree_displacements).size()) : std::nullopt, - local_frontier_displacements, - local_frontier_sizes, K, - cugraph::ops::graph::INVALID_ID); + cugraph::invalid_edge_id_v); // 4. convert neighbor indices in the neighbor list considering edge mask to neighbor indices in // the neighbor list ignoring edge mask @@ -1552,6 +1675,11 @@ uniform_sample_and_compute_local_nbr_indices( return std::make_tuple( std::move(local_nbr_indices), std::move(key_indices), std::move(local_frontier_sample_offsets)); +#else + CUGRAPH_FAIL("unimplemented."); + return std::make_tuple( + rmm::device_uvector(0, handle.get_stream()), std::nullopt, std::vector()); +#endif } template ( [offsets = raft::device_span( - aggregate_local_frontier_local_degree_offsets.data(), - aggregate_local_frontier_local_degree_offsets.size())] __device__(size_t i) { + aggregate_local_frontier_unique_key_local_degree_offsets.data(), + aggregate_local_frontier_unique_key_local_degree_offsets.size())] __device__(size_t i) { return static_cast(thrust::distance( offsets.begin() + 1, thrust::upper_bound(thrust::seq, offsets.begin() + 1, offsets.end(), i))); })); - thrust::inclusive_scan_by_key(handle.get_thrust_policy(), - key_first, - key_first + aggregate_local_frontier_biases.size(), - get_dataframe_buffer_begin(aggregate_local_frontier_biases), - get_dataframe_buffer_begin(aggregate_local_frontier_biases)); + thrust::inclusive_scan_by_key( + handle.get_thrust_policy(), + unique_key_first, + unique_key_first + aggregate_local_frontier_unique_key_biases.size(), + get_dataframe_buffer_begin(aggregate_local_frontier_unique_key_biases), + get_dataframe_buffer_begin(aggregate_local_frontier_unique_key_biases)); - auto aggregate_local_frontier_bias_segmented_local_inclusive_sums = - std::move(aggregate_local_frontier_biases); + auto aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums = + std::move(aggregate_local_frontier_unique_key_biases); auto aggregate_local_frontier_bias_local_sums = rmm::device_uvector( local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()); - thrust::tabulate( - handle.get_thrust_policy(), - get_dataframe_buffer_begin(aggregate_local_frontier_bias_local_sums), - get_dataframe_buffer_end(aggregate_local_frontier_bias_local_sums), - [offsets = - raft::device_span(aggregate_local_frontier_local_degree_offsets.data(), - aggregate_local_frontier_local_degree_offsets.size()), - aggregate_local_frontier_bias_segmented_local_inclusive_sums = - raft::device_span( - aggregate_local_frontier_bias_segmented_local_inclusive_sums.data(), - aggregate_local_frontier_bias_segmented_local_inclusive_sums - .size())] __device__(size_t i) { - auto degree = offsets[i + 1] - offsets[i]; - if (degree > 0) { - return aggregate_local_frontier_bias_segmented_local_inclusive_sums[offsets[i] + degree - - 1]; - } else { - return bias_t{0.0}; - } - }); + for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + thrust::tabulate( + handle.get_thrust_policy(), + get_dataframe_buffer_begin(aggregate_local_frontier_bias_local_sums) + + local_frontier_displacements[i], + get_dataframe_buffer_begin(aggregate_local_frontier_bias_local_sums) + + local_frontier_displacements[i] + local_frontier_sizes[i], + [key_idx_to_unique_key_idx = + raft::device_span(aggregate_local_frontier_key_idx_to_unique_key_idx.data() + + local_frontier_displacements[i], + local_frontier_sizes[i]), + unique_key_local_degree_offsets = raft::device_span( + aggregate_local_frontier_unique_key_local_degree_offsets.data() + + local_frontier_unique_key_displacements[i], + local_frontier_unique_key_sizes[i] + 1), + aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums = + raft::device_span( + aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums.data(), + aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums + .size())] __device__(size_t i) { + auto unique_key_idx = key_idx_to_unique_key_idx[i]; + auto degree = unique_key_local_degree_offsets[unique_key_idx + 1] - + unique_key_local_degree_offsets[unique_key_idx]; + if (degree > 0) { + return aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums + [unique_key_local_degree_offsets[unique_key_idx] + degree - 1]; + } else { + return bias_t{0.0}; + } + }); + } rmm::device_uvector frontier_bias_sums(0, handle.get_stream()); std::optional> frontier_partitioned_bias_local_sum_displacements{ @@ -1684,7 +1834,7 @@ biased_sample_and_compute_local_nbr_indices( frontier_bias_sums = std::move(aggregate_local_frontier_bias_local_sums); } - rmm::device_uvector sample_random_numbers(frontier_bias_sums.size() * K, + rmm::device_uvector sample_random_numbers(local_frontier_sizes[minor_comm_rank] * K, handle.get_stream()); cugraph::detail::uniform_random_fill(handle.get_stream(), sample_random_numbers.data(), @@ -1716,8 +1866,6 @@ biased_sample_and_compute_local_nbr_indices( (*frontier_partitioned_bias_local_sum_displacements).data(), (*frontier_partitioned_bias_local_sum_displacements).size()) : std::nullopt, - local_frontier_displacements, - local_frontier_sizes, K, std::numeric_limits::infinity()); @@ -1736,22 +1884,30 @@ biased_sample_and_compute_local_nbr_indices( (*key_indices).data() + local_frontier_sample_offsets[i], local_frontier_sample_offsets[i + 1] - local_frontier_sample_offsets[i]) : thrust::nullopt, - aggregate_local_frontier_bias_segmented_local_inclusive_sums = raft::device_span( - aggregate_local_frontier_bias_segmented_local_inclusive_sums.data(), - aggregate_local_frontier_bias_segmented_local_inclusive_sums.size()), - local_degree_offsets = raft::device_span( - aggregate_local_frontier_local_degree_offsets.data() + local_frontier_displacements[i], - local_frontier_sizes[i] + 1), + key_idx_to_unique_key_idx = + raft::device_span(aggregate_local_frontier_key_idx_to_unique_key_idx.data() + + local_frontier_displacements[i], + local_frontier_sizes[i]), + aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums = + raft::device_span( + aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums.data(), + aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums.size()), + unique_key_local_degree_offsets = raft::device_span( + aggregate_local_frontier_unique_key_local_degree_offsets.data() + + local_frontier_unique_key_displacements[i], + local_frontier_unique_key_sizes[i] + 1), invalid_random_number = std::numeric_limits::infinity(), - invalid_idx = cugraph::ops::graph::INVALID_ID] __device__(size_t i) { + invalid_idx = cugraph::invalid_edge_id_v] __device__(size_t i) { auto key_idx = key_indices ? (*key_indices)[i] : (i / K); + auto unique_key_idx = key_idx_to_unique_key_idx[key_idx]; auto local_random_number = sample_local_random_numbers[i]; if (local_random_number != invalid_random_number) { - auto local_degree = static_cast(local_degree_offsets[key_idx + 1] - - local_degree_offsets[key_idx]); + auto local_degree = + static_cast(unique_key_local_degree_offsets[unique_key_idx + 1] - + unique_key_local_degree_offsets[unique_key_idx]); auto inclusive_sum_first = - aggregate_local_frontier_bias_segmented_local_inclusive_sums.begin() + - local_degree_offsets[key_idx]; + aggregate_local_frontier_unique_key_bias_segmented_local_inclusive_sums.begin() + + unique_key_local_degree_offsets[unique_key_idx]; auto inclusive_sum_last = inclusive_sum_first + local_degree; auto local_nbr_idx = static_cast(thrust::distance( inclusive_sum_first, @@ -1770,10 +1926,25 @@ biased_sample_and_compute_local_nbr_indices( { rmm::device_uvector aggregate_local_frontier_local_degrees( local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()); - thrust::adjacent_difference(handle.get_thrust_policy(), - aggregate_local_frontier_local_degree_offsets.begin() + 1, - aggregate_local_frontier_local_degree_offsets.end(), - aggregate_local_frontier_local_degrees.begin()); + for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + thrust::tabulate( + handle.get_thrust_policy(), + aggregate_local_frontier_local_degrees.begin() + local_frontier_displacements[i], + aggregate_local_frontier_local_degrees.begin() + local_frontier_displacements[i] + + local_frontier_sizes[i], + [key_idx_to_unique_key_idx = + raft::device_span(aggregate_local_frontier_key_idx_to_unique_key_idx.data() + + local_frontier_displacements[i], + local_frontier_sizes[i]), + unique_key_local_degree_offsets = raft::device_span( + aggregate_local_frontier_unique_key_local_degree_offsets.data() + + local_frontier_unique_key_displacements[i], + local_frontier_unique_key_sizes[i] + 1)] __device__(size_t i) { + auto unique_key_idx = key_idx_to_unique_key_idx[i]; + return unique_key_local_degree_offsets[unique_key_idx + 1] - + unique_key_local_degree_offsets[unique_key_idx]; + }); + } if (minor_comm_size > 1) { std::tie(frontier_degrees, frontier_partitioned_local_degree_displacements) = compute_frontier_value_sums_and_partitioned_local_value_sum_displacements( @@ -1836,22 +2007,29 @@ biased_sample_and_compute_local_nbr_indices( aggregate_low_local_frontier_indices.begin() + low_local_frontier_displacements[i], aggregate_low_local_frontier_indices.begin() + (low_local_frontier_displacements[i] + low_local_frontier_sizes[i]), - [aggregate_local_frontier_biases = raft::device_span( - aggregate_local_frontier_biases.data(), aggregate_local_frontier_biases.size()), - aggregate_local_frontier_local_degree_offsets = raft::device_span( - aggregate_local_frontier_local_degree_offsets.data(), - aggregate_local_frontier_local_degree_offsets.size()), + [key_idx_to_unique_key_idx = + raft::device_span(aggregate_local_frontier_key_idx_to_unique_key_idx.data() + + local_frontier_displacements[i], + local_frontier_sizes[i]), + aggregate_local_frontier_unique_key_biases = + raft::device_span(aggregate_local_frontier_unique_key_biases.data(), + aggregate_local_frontier_unique_key_biases.size()), + unique_key_local_degree_offsets = raft::device_span( + aggregate_local_frontier_unique_key_local_degree_offsets.data() + + local_frontier_unique_key_displacements[i], + local_frontier_unique_key_sizes[i] + 1), zero_bias_frontier_indices = raft::device_span( zero_bias_frontier_indices.data(), zero_bias_frontier_indices.size()), zero_bias_local_nbr_indices = raft::device_span( zero_bias_local_nbr_indices.data(), zero_bias_local_nbr_indices.size()), input_offset = local_frontier_displacements[i], counter = counter.data()] __device__(size_t i) { - auto start_offset = aggregate_local_frontier_local_degree_offsets[input_offset + i]; - auto end_offset = aggregate_local_frontier_local_degree_offsets[input_offset + i + 1]; + auto unique_key_idx = key_idx_to_unique_key_idx[i]; + auto start_offset = unique_key_local_degree_offsets[unique_key_idx]; + auto end_offset = unique_key_local_degree_offsets[unique_key_idx + 1]; cuda::atomic_ref atomic_counter(*counter); for (auto j = start_offset; j < end_offset; ++j) { - if (aggregate_local_frontier_biases[j] == 0.0) { + if (aggregate_local_frontier_unique_key_biases[j] == 0.0) { auto idx = atomic_counter.fetch_add(size_t{1}, cuda::std::memory_order_relaxed); zero_bias_frontier_indices[idx] = i; zero_bias_local_nbr_indices[idx] = j - start_offset; @@ -1933,7 +2111,7 @@ biased_sample_and_compute_local_nbr_indices( raft::device_span(frontier_degrees.data(), frontier_degrees.size()), nbr_indices = raft::device_span(nbr_indices.data(), nbr_indices.size()), K, - invalid_idx = cugraph::ops::graph::INVALID_ID] __device__(size_t i) { + invalid_idx = cugraph::invalid_edge_id_v] __device__(size_t i) { auto first = thrust::lower_bound(thrust::seq, sorted_zero_bias_frontier_indices.begin(), sorted_zero_bias_frontier_indices.end(), @@ -2005,11 +2183,17 @@ biased_sample_and_compute_local_nbr_indices( aggregate_mid_local_frontier_local_degrees.begin() + mid_local_frontier_displacements[i], cuda::proclaim_return_type( - [offsets = raft::device_span( - aggregate_local_frontier_local_degree_offsets.data() + + [key_idx_to_unique_key_idx = raft::device_span( + aggregate_local_frontier_key_idx_to_unique_key_idx.data() + local_frontier_displacements[i], - local_frontier_sizes[i] + 1)] __device__(size_t i) { - return static_cast(offsets[i + 1] - offsets[i]); + local_frontier_sizes[i]), + unique_key_local_degree_offsets = raft::device_span( + aggregate_local_frontier_unique_key_local_degree_offsets.data() + + local_frontier_unique_key_displacements[i], + local_frontier_unique_key_sizes[i] + 1)] __device__(size_t i) { + auto unique_key_idx = key_idx_to_unique_key_idx[i]; + return static_cast(unique_key_local_degree_offsets[unique_key_idx + 1] - + unique_key_local_degree_offsets[unique_key_idx]); })); } @@ -2037,11 +2221,17 @@ biased_sample_and_compute_local_nbr_indices( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), thrust::make_counting_iterator(mid_local_frontier_sizes[i]), - [aggregate_local_frontier_biases = raft::device_span( - aggregate_local_frontier_biases.data(), aggregate_local_frontier_biases.size()), - aggregate_local_frontier_local_degree_offsets = - raft::device_span(aggregate_local_frontier_local_degree_offsets.data(), - aggregate_local_frontier_local_degree_offsets.size()), + [key_idx_to_unique_key_idx = raft::device_span( + aggregate_local_frontier_key_idx_to_unique_key_idx.data() + + local_frontier_displacements[i], + local_frontier_sizes[i]), + aggregate_local_frontier_unique_key_biases = + raft::device_span(aggregate_local_frontier_unique_key_biases.data(), + aggregate_local_frontier_unique_key_biases.size()), + unique_key_local_degree_offsets = raft::device_span( + aggregate_local_frontier_unique_key_local_degree_offsets.data() + + local_frontier_unique_key_displacements[i], + local_frontier_unique_key_sizes[i] + 1), mid_local_frontier_indices = raft::device_span( aggregate_mid_local_frontier_indices.data() + mid_local_frontier_displacements[i], mid_local_frontier_sizes[i]), @@ -2051,16 +2241,14 @@ biased_sample_and_compute_local_nbr_indices( aggregate_mid_local_frontier_local_degree_offsets = raft::device_span( aggregate_mid_local_frontier_local_degree_offsets.data(), aggregate_mid_local_frontier_local_degree_offsets.size()), - input_offset = local_frontier_displacements[i], output_offset = mid_local_frontier_displacements[i]] __device__(size_t i) { + auto unique_key_idx = key_idx_to_unique_key_idx[mid_local_frontier_indices[i]]; thrust::copy( thrust::seq, - aggregate_local_frontier_biases.begin() + - aggregate_local_frontier_local_degree_offsets[input_offset + - mid_local_frontier_indices[i]], - aggregate_local_frontier_biases.begin() + - aggregate_local_frontier_local_degree_offsets - [input_offset + (mid_local_frontier_indices[i] + 1)], + aggregate_local_frontier_unique_key_biases.begin() + + unique_key_local_degree_offsets[unique_key_idx], + aggregate_local_frontier_unique_key_biases.begin() + + unique_key_local_degree_offsets[unique_key_idx + 1], aggregate_mid_local_frontier_biases.begin() + aggregate_mid_local_frontier_local_degree_offsets[output_offset + i]); }); @@ -2202,16 +2390,26 @@ biased_sample_and_compute_local_nbr_indices( rmm::device_uvector aggregate_high_local_frontier_keys( aggregate_high_local_frontier_local_nbr_indices.size(), handle.get_stream()); for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + rmm::device_uvector unique_key_indices_for_key_indices( + high_local_frontier_sizes[i], handle.get_stream()); + thrust::gather( + handle.get_thrust_policy(), + aggregate_high_local_frontier_indices.data() + high_local_frontier_displacements[i], + aggregate_high_local_frontier_indices.data() + high_local_frontier_displacements[i] + + high_local_frontier_sizes[i], + aggregate_local_frontier_key_idx_to_unique_key_idx.data() + + local_frontier_displacements[i], + unique_key_indices_for_key_indices.begin()); compute_biased_sampling_index_without_replacement( handle, std::make_optional>( - aggregate_high_local_frontier_indices.data() + high_local_frontier_displacements[i], - high_local_frontier_sizes[i]), - raft::device_span(aggregate_local_frontier_local_degree_offsets.data() + - local_frontier_displacements[i], - local_frontier_sizes[i] + 1), - raft::device_span(aggregate_local_frontier_biases.data(), - aggregate_local_frontier_biases.size()), + unique_key_indices_for_key_indices.data(), unique_key_indices_for_key_indices.size()), + raft::device_span( + aggregate_local_frontier_unique_key_local_degree_offsets.data() + + local_frontier_unique_key_displacements[i], + local_frontier_unique_key_sizes[i] + 1), + raft::device_span(aggregate_local_frontier_unique_key_biases.data(), + aggregate_local_frontier_unique_key_biases.size()), std::nullopt, raft::device_span(aggregate_high_local_frontier_local_nbr_indices.data() + high_local_frontier_displacements[i] * K, @@ -2356,19 +2554,27 @@ biased_sample_and_compute_local_nbr_indices( handle.get_thrust_policy(), frontier_indices.begin(), frontier_indices.begin() + frontier_partition_offsets[1], - [aggregate_local_frontier_biases = raft::device_span( - aggregate_local_frontier_biases.data(), aggregate_local_frontier_biases.size()), - aggregate_local_frontier_local_degree_offsets = - raft::device_span(aggregate_local_frontier_local_degree_offsets.data(), - aggregate_local_frontier_local_degree_offsets.size()), + [key_idx_to_unique_key_idx = + raft::device_span(aggregate_local_frontier_key_idx_to_unique_key_idx.data(), + aggregate_local_frontier_key_idx_to_unique_key_idx.size()), + aggregate_local_frontier_unique_key_biases = + raft::device_span(aggregate_local_frontier_unique_key_biases.data(), + aggregate_local_frontier_unique_key_biases.size()), + aggregate_local_frontier_unique_key_local_degree_offsets = raft::device_span( + aggregate_local_frontier_unique_key_local_degree_offsets.data(), + aggregate_local_frontier_unique_key_local_degree_offsets.size()), nbr_indices = raft::device_span(nbr_indices.data(), nbr_indices.size()), K, - invalid_idx = cugraph::ops::graph::INVALID_ID] __device__(size_t i) { - auto start_offset = aggregate_local_frontier_local_degree_offsets[i]; - auto degree = aggregate_local_frontier_local_degree_offsets[i + 1] - start_offset; - edge_t num_valid = 0; + invalid_idx = cugraph::invalid_edge_id_v] __device__(size_t i) { + auto unique_key_idx = key_idx_to_unique_key_idx[i]; + auto start_offset = + aggregate_local_frontier_unique_key_local_degree_offsets[unique_key_idx]; + auto degree = + aggregate_local_frontier_unique_key_local_degree_offsets[unique_key_idx + 1] - + start_offset; + edge_t num_valid = 0; for (size_t j = 0; j < degree; ++j) { - auto bias = aggregate_local_frontier_biases[start_offset + j]; + auto bias = aggregate_local_frontier_unique_key_biases[start_offset + j]; if (bias > 0.0) { *(nbr_indices.begin() + i * K + num_valid) = j; ++num_valid; @@ -2382,14 +2588,23 @@ biased_sample_and_compute_local_nbr_indices( auto mid_and_high_frontier_size = frontier_partition_offsets[3] - frontier_partition_offsets[1]; + rmm::device_uvector unique_key_indices_for_key_indices(mid_and_high_frontier_size, + handle.get_stream()); + thrust::gather( + handle.get_thrust_policy(), + frontier_indices.data() + frontier_partition_offsets[1], + frontier_indices.data() + frontier_partition_offsets[1] + mid_and_high_frontier_size, + aggregate_local_frontier_key_idx_to_unique_key_idx.begin(), + unique_key_indices_for_key_indices.begin()); compute_biased_sampling_index_without_replacement( handle, std::make_optional>( - frontier_indices.data() + frontier_partition_offsets[1], mid_and_high_frontier_size), - raft::device_span(aggregate_local_frontier_local_degree_offsets.data(), - aggregate_local_frontier_local_degree_offsets.size()), - raft::device_span(aggregate_local_frontier_biases.data(), - aggregate_local_frontier_biases.size()), + unique_key_indices_for_key_indices.data(), unique_key_indices_for_key_indices.size()), + raft::device_span( + aggregate_local_frontier_unique_key_local_degree_offsets.data(), + aggregate_local_frontier_unique_key_local_degree_offsets.size()), + raft::device_span(aggregate_local_frontier_unique_key_biases.data(), + aggregate_local_frontier_unique_key_biases.size()), std::make_optional>( frontier_indices.data() + frontier_partition_offsets[1], mid_and_high_frontier_size), raft::device_span(nbr_indices.data(), nbr_indices.size()), @@ -2408,10 +2623,8 @@ biased_sample_and_compute_local_nbr_indices( (*frontier_partitioned_local_degree_displacements).data(), (*frontier_partitioned_local_degree_displacements).size()) : std::nullopt, - local_frontier_displacements, - local_frontier_sizes, K, - cugraph::ops::graph::INVALID_ID); + cugraph::invalid_edge_id_v); } // 3. convert neighbor indices in the neighbor list considering edge mask to neighbor indices in diff --git a/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh b/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh index a004741f719..7253fde8d4e 100644 --- a/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh +++ b/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -206,10 +207,13 @@ struct return_value_compute_offset_t { template std::tuple>, @@ -217,10 +221,13 @@ std::tuple>, per_v_random_select_transform_e(raft::handle_t const& handle, GraphViewType const& graph_view, VertexFrontierBucketType const& frontier, + EdgeBiasSrcValueInputWrapper edge_bias_src_value_input, + EdgeBiasDstValueInputWrapper edge_bias_dst_value_input, + EdgeBiasValueInputWrapper edge_bias_value_input, + EdgeBiasOp e_bias_op, EdgeSrcValueInputWrapper edge_src_value_input, EdgeDstValueInputWrapper edge_dst_value_input, EdgeValueInputWrapper edge_value_input, - EdgeBiasOp e_bias_op, EdgeOp e_op, raft::random::RngState& rng_state, size_t K, @@ -229,11 +236,10 @@ per_v_random_select_transform_e(raft::handle_t const& handle, bool do_expensive_check) { #ifndef NO_CUGRAPH_OPS - using vertex_t = typename GraphViewType::vertex_type; - using edge_t = typename GraphViewType::edge_type; - using key_t = typename VertexFrontierBucketType::key_type; - using key_buffer_t = - decltype(allocate_dataframe_buffer(size_t{0}, rmm::cuda_stream_view{})); + using vertex_t = typename GraphViewType::vertex_type; + using edge_t = typename GraphViewType::edge_type; + using key_t = typename VertexFrontierBucketType::key_type; + using key_buffer_t = dataframe_buffer_type_t; using edge_partition_src_input_device_view_t = std::conditional_t< std::is_same_v, @@ -315,13 +321,12 @@ per_v_random_select_transform_e(raft::handle_t const& handle, // 1. aggregate frontier - auto aggregate_local_frontier = - (minor_comm_size > 1) - ? std::make_optional( - local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()) - : std::nullopt; + std::optional aggregate_local_frontier{std::nullopt}; if (minor_comm_size > 1) { auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + + aggregate_local_frontier = allocate_dataframe_buffer( + local_frontier_displacements.back() + local_frontier_sizes.back(), handle.get_stream()); device_allgatherv(minor_comm, frontier.begin(), get_dataframe_buffer_begin(*aggregate_local_frontier), @@ -338,9 +343,9 @@ per_v_random_select_transform_e(raft::handle_t const& handle, std::vector local_frontier_sample_offsets{}; if constexpr (std::is_same_v>) { std::tie(sample_local_nbr_indices, sample_key_indices, local_frontier_sample_offsets) = uniform_sample_and_compute_local_nbr_indices( @@ -360,9 +365,9 @@ per_v_random_select_transform_e(raft::handle_t const& handle, graph_view, (minor_comm_size > 1) ? get_dataframe_buffer_begin(*aggregate_local_frontier) : frontier.begin(), - edge_src_value_input, - edge_dst_value_input, - edge_value_input, + edge_bias_src_value_input, + edge_bias_dst_value_input, + edge_bias_value_input, e_bias_op, local_frontier_displacements, local_frontier_sizes, @@ -433,7 +438,7 @@ per_v_random_select_transform_e(raft::handle_t const& handle, edge_partition_dst_value_input, edge_partition_e_value_input, e_op, - cugraph::ops::graph::INVALID_ID, + cugraph::invalid_edge_id_v, to_thrust_optional(invalid_value), K}); } else { @@ -457,7 +462,7 @@ per_v_random_select_transform_e(raft::handle_t const& handle, edge_partition_dst_value_input, edge_partition_e_value_input, e_op, - cugraph::ops::graph::INVALID_ID, + cugraph::invalid_edge_id_v, to_thrust_optional(invalid_value), K}); } @@ -557,7 +562,7 @@ per_v_random_select_transform_e(raft::handle_t const& handle, count_valids_t{raft::device_span(sample_local_nbr_indices.data(), sample_local_nbr_indices.size()), K, - cugraph::ops::graph::INVALID_ID}); + cugraph::invalid_edge_id_v}); (*sample_offsets).set_element_to_zero_async(size_t{0}, handle.get_stream()); auto typecasted_sample_count_first = thrust::make_transform_iterator(sample_counts.begin(), typecast_t{}); @@ -574,7 +579,7 @@ per_v_random_select_transform_e(raft::handle_t const& handle, thrust::remove_if(handle.get_thrust_policy(), pair_first, pair_first + sample_local_nbr_indices.size(), - check_invalid_t{cugraph::ops::graph::INVALID_ID}); + check_invalid_t{cugraph::invalid_edge_id_v}); sample_local_nbr_indices.resize(0, handle.get_stream()); sample_local_nbr_indices.shrink_to_fit(handle.get_stream()); @@ -588,7 +593,7 @@ per_v_random_select_transform_e(raft::handle_t const& handle, #else CUGRAPH_FAIL("unimplemented."); return std::make_tuple(std::nullopt, - allocate_dataframe_buffer(size_t{0}, rmm::cuda_stream_view{})); + allocate_dataframe_buffer(size_t{0}, rmm::cuda_stream_view{})); #endif } @@ -653,10 +658,13 @@ per_v_random_select_transform_e(raft::handle_t const& handle, */ template std::tuple>, @@ -664,10 +672,13 @@ std::tuple>, per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, GraphViewType const& graph_view, VertexFrontierBucketType const& frontier, + EdgeBiasSrcValueInputWrapper edge_bias_src_value_input, + EdgeBiasDstValueInputWrapper edge_bias_dst_value_input, + EdgeBiasValueInputWrapper edge_bias_value_input, + EdgeBiasOp e_bias_op, EdgeSrcValueInputWrapper edge_src_value_input, EdgeDstValueInputWrapper edge_dst_value_input, EdgeValueInputWrapper edge_value_input, - EdgeBiasOp e_bias_op, EdgeOp e_op, raft::random::RngState& rng_state, size_t K, @@ -678,10 +689,13 @@ per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, return detail::per_v_random_select_transform_e(handle, graph_view, frontier, + edge_bias_src_value_input, + edge_bias_dst_value_input, + edge_bias_value_input, + e_bias_op, edge_src_value_input, edge_dst_value_input, edge_value_input, - e_bias_op, e_op, rng_state, K, @@ -768,14 +782,17 @@ per_v_random_select_transform_outgoing_e(raft::handle_t const& handle, handle, graph_view, frontier, + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + edge_dummy_property_t{}.view(), + detail::constant_e_bias_op_t{}, edge_src_value_input, edge_dst_value_input, edge_value_input, - detail::constant_e_bias_op_t{}, e_op, rng_state, K, diff --git a/cpp/src/prims/property_op_utils.cuh b/cpp/src/prims/property_op_utils.cuh index 8d74e6be292..04ad22cbf71 100644 --- a/cpp/src/prims/property_op_utils.cuh +++ b/cpp/src/prims/property_op_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ #include #include -#include #include #include diff --git a/cpp/src/sampling/detail/check_edge_bias_values.cuh b/cpp/src/sampling/detail/check_edge_bias_values.cuh new file mode 100644 index 00000000000..3b28df1037d --- /dev/null +++ b/cpp/src/sampling/detail/check_edge_bias_values.cuh @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "prims/count_if_e.cuh" + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +#include + +namespace cugraph { +namespace detail { + +template +std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view) +{ + auto num_negative_edge_weights = + count_if_e(handle, + graph_view, + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + edge_bias_view, + [] __device__(vertex_t, vertex_t, auto, auto, bias_t b) { return b < 0.0; }); + + size_t num_overflows{0}; + { + auto bias_sums = compute_out_weight_sums(handle, graph_view, edge_bias_view); + num_overflows = thrust::count_if( + handle.get_thrust_policy(), bias_sums.begin(), bias_sums.end(), [] __device__(auto sum) { + return sum > std::numeric_limits::max(); + }); + } + + if constexpr (multi_gpu) { + num_negative_edge_weights = host_scalar_allreduce( + handle.get_comms(), num_negative_edge_weights, raft::comms::op_t::SUM, handle.get_stream()); + num_overflows = host_scalar_allreduce( + handle.get_comms(), num_overflows, raft::comms::op_t::SUM, handle.get_stream()); + } + + return std::make_tuple(num_negative_edge_weights, num_overflows); +} + +} // namespace detail +} // namespace cugraph diff --git a/cpp/src/sampling/detail/check_edge_bias_values_mg_v32_e32.cu b/cpp/src/sampling/detail/check_edge_bias_values_mg_v32_e32.cu new file mode 100644 index 00000000000..41019b4ec4b --- /dev/null +++ b/cpp/src/sampling/detail/check_edge_bias_values_mg_v32_e32.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sampling/detail/check_edge_bias_values.cuh" + +namespace cugraph { +namespace detail { + +template std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view); + +template std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view); + +} // namespace detail +} // namespace cugraph diff --git a/cpp/src/sampling/detail/check_edge_bias_values_mg_v32_e64.cu b/cpp/src/sampling/detail/check_edge_bias_values_mg_v32_e64.cu new file mode 100644 index 00000000000..b8b3564fee7 --- /dev/null +++ b/cpp/src/sampling/detail/check_edge_bias_values_mg_v32_e64.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sampling/detail/check_edge_bias_values.cuh" + +namespace cugraph { +namespace detail { + +template std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view); + +template std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view); + +} // namespace detail +} // namespace cugraph diff --git a/cpp/src/sampling/detail/check_edge_bias_values_mg_v64_e64.cu b/cpp/src/sampling/detail/check_edge_bias_values_mg_v64_e64.cu new file mode 100644 index 00000000000..81c99ab1d8a --- /dev/null +++ b/cpp/src/sampling/detail/check_edge_bias_values_mg_v64_e64.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sampling/detail/check_edge_bias_values.cuh" + +namespace cugraph { +namespace detail { + +template std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view); + +template std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view); + +} // namespace detail +} // namespace cugraph diff --git a/cpp/src/sampling/detail/check_edge_bias_values_sg_v32_e32.cu b/cpp/src/sampling/detail/check_edge_bias_values_sg_v32_e32.cu new file mode 100644 index 00000000000..31b5ceeab77 --- /dev/null +++ b/cpp/src/sampling/detail/check_edge_bias_values_sg_v32_e32.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sampling/detail/check_edge_bias_values.cuh" + +namespace cugraph { +namespace detail { + +template std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view); + +template std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view); + +} // namespace detail +} // namespace cugraph diff --git a/cpp/src/sampling/detail/check_edge_bias_values_sg_v32_e64.cu b/cpp/src/sampling/detail/check_edge_bias_values_sg_v32_e64.cu new file mode 100644 index 00000000000..c8c28a5ad04 --- /dev/null +++ b/cpp/src/sampling/detail/check_edge_bias_values_sg_v32_e64.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sampling/detail/check_edge_bias_values.cuh" + +namespace cugraph { +namespace detail { + +template std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view); + +template std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view); + +} // namespace detail +} // namespace cugraph diff --git a/cpp/src/sampling/detail/check_edge_bias_values_sg_v64_e64.cu b/cpp/src/sampling/detail/check_edge_bias_values_sg_v64_e64.cu new file mode 100644 index 00000000000..42daa9f24a0 --- /dev/null +++ b/cpp/src/sampling/detail/check_edge_bias_values_sg_v64_e64.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sampling/detail/check_edge_bias_values.cuh" + +namespace cugraph { +namespace detail { + +template std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view); + +template std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view); + +} // namespace detail +} // namespace cugraph diff --git a/cpp/src/sampling/detail/prepare_next_frontier_impl.cuh b/cpp/src/sampling/detail/prepare_next_frontier_impl.cuh index 8ad5114eda6..5c04d628f09 100644 --- a/cpp/src/sampling/detail/prepare_next_frontier_impl.cuh +++ b/cpp/src/sampling/detail/prepare_next_frontier_impl.cuh @@ -18,10 +18,10 @@ #include "sampling/detail/sampling_utils.hpp" -#include #include #include #include +#include #include #include diff --git a/cpp/src/sampling/detail/sample_edges.cuh b/cpp/src/sampling/detail/sample_edges.cuh index 9b49f6a5b49..0c670c6507e 100644 --- a/cpp/src/sampling/detail/sample_edges.cuh +++ b/cpp/src/sampling/detail/sample_edges.cuh @@ -17,7 +17,6 @@ #pragma once #include "prims/per_v_random_select_transform_outgoing_e.cuh" -#include "prims/update_edge_src_dst_property.cuh" // ?? #include "prims/vertex_frontier.cuh" #include "structure/detail/structure_utils.cuh" @@ -66,6 +65,15 @@ struct sample_edges_op_t { } }; +template +struct sample_edge_biases_op_t { + auto __host__ __device__ + operator()(vertex_t, vertex_t, thrust::nullopt_t, thrust::nullopt_t, bias_t bias) const + { + return bias; + } +}; + struct segmented_fill_t { raft::device_span fill_values{}; raft::device_span segment_offsets{}; @@ -84,6 +92,7 @@ template std::tuple, @@ -97,6 +106,7 @@ sample_edges(raft::handle_t const& handle, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, @@ -120,135 +130,305 @@ sample_edges(raft::handle_t const& handle, if (edge_weight_view) { if (edge_id_view) { if (edge_type_view) { - std::forward_as_tuple(sample_offsets, - std::tie(majors, minors, weights, edge_ids, edge_types)) = - cugraph::per_v_random_select_transform_outgoing_e( - handle, - graph_view, - vertex_frontier.bucket(0), - edge_src_dummy_property_t{}.view(), - edge_dst_dummy_property_t{}.view(), - view_concat(*edge_weight_view, *edge_id_view, *edge_type_view), - sample_edges_op_t{}, - rng_state, - fanout, - with_replacement, - std::optional>{ - std::nullopt}, - true); + if (edge_bias_view) { + std::forward_as_tuple(sample_offsets, + std::tie(majors, minors, weights, edge_ids, edge_types)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_bias_view, + sample_edge_biases_op_t{}, + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + view_concat(*edge_weight_view, *edge_id_view, *edge_type_view), + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{ + std::nullopt}, + false); + } else { + std::forward_as_tuple(sample_offsets, + std::tie(majors, minors, weights, edge_ids, edge_types)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + view_concat(*edge_weight_view, *edge_id_view, *edge_type_view), + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{ + std::nullopt}, + false); + } } else { - std::forward_as_tuple(sample_offsets, std::tie(majors, minors, weights, edge_ids)) = - cugraph::per_v_random_select_transform_outgoing_e( - handle, - graph_view, - vertex_frontier.bucket(0), - edge_src_dummy_property_t{}.view(), - edge_dst_dummy_property_t{}.view(), - view_concat(*edge_weight_view, *edge_id_view), - sample_edges_op_t{}, - rng_state, - fanout, - with_replacement, - std::optional>{std::nullopt}, - true); + if (edge_bias_view) { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors, weights, edge_ids)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_bias_view, + sample_edge_biases_op_t{}, + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + view_concat(*edge_weight_view, *edge_id_view), + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } else { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors, weights, edge_ids)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + view_concat(*edge_weight_view, *edge_id_view), + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } } } else { if (edge_type_view) { - std::forward_as_tuple(sample_offsets, std::tie(majors, minors, weights, edge_types)) = - cugraph::per_v_random_select_transform_outgoing_e( - handle, - graph_view, - vertex_frontier.bucket(0), - edge_src_dummy_property_t{}.view(), - edge_dst_dummy_property_t{}.view(), - view_concat(*edge_weight_view, *edge_type_view), - sample_edges_op_t{}, - rng_state, - fanout, - with_replacement, - std::optional>{std::nullopt}, - true); + if (edge_bias_view) { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors, weights, edge_types)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_bias_view, + sample_edge_biases_op_t{}, + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + view_concat(*edge_weight_view, *edge_type_view), + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } else { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors, weights, edge_types)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + view_concat(*edge_weight_view, *edge_type_view), + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } } else { - std::forward_as_tuple(sample_offsets, std::tie(majors, minors, weights)) = - cugraph::per_v_random_select_transform_outgoing_e( - handle, - graph_view, - vertex_frontier.bucket(0), - edge_src_dummy_property_t{}.view(), - edge_dst_dummy_property_t{}.view(), - *edge_weight_view, - sample_edges_op_t{}, - rng_state, - fanout, - with_replacement, - std::optional>{std::nullopt}, - true); + if (edge_bias_view) { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors, weights)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_bias_view, + sample_edge_biases_op_t{}, + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_weight_view, + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } else { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors, weights)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_weight_view, + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } } } } else { if (edge_id_view) { if (edge_type_view) { - std::forward_as_tuple(sample_offsets, std::tie(majors, minors, edge_ids, edge_types)) = - cugraph::per_v_random_select_transform_outgoing_e( - handle, - graph_view, - vertex_frontier.bucket(0), - edge_src_dummy_property_t{}.view(), - edge_dst_dummy_property_t{}.view(), - view_concat(*edge_id_view, *edge_type_view), - sample_edges_op_t{}, - rng_state, - fanout, - with_replacement, - std::optional>{std::nullopt}, - true); + if (edge_bias_view) { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors, edge_ids, edge_types)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_bias_view, + sample_edge_biases_op_t{}, + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + view_concat(*edge_id_view, *edge_type_view), + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } else { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors, edge_ids, edge_types)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + view_concat(*edge_id_view, *edge_type_view), + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } } else { - std::forward_as_tuple(sample_offsets, std::tie(majors, minors, edge_ids)) = - cugraph::per_v_random_select_transform_outgoing_e( - handle, - graph_view, - vertex_frontier.bucket(0), - edge_src_dummy_property_t{}.view(), - edge_dst_dummy_property_t{}.view(), - *edge_id_view, - sample_edges_op_t{}, - rng_state, - fanout, - with_replacement, - std::optional>{std::nullopt}, - true); + if (edge_bias_view) { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors, edge_ids)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_bias_view, + sample_edge_biases_op_t{}, + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_id_view, + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } else { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors, edge_ids)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_id_view, + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } } } else { if (edge_type_view) { - std::forward_as_tuple(sample_offsets, std::tie(majors, minors, edge_types)) = - cugraph::per_v_random_select_transform_outgoing_e( - handle, - graph_view, - vertex_frontier.bucket(0), - edge_src_dummy_property_t{}.view(), - edge_dst_dummy_property_t{}.view(), - *edge_type_view, - sample_edges_op_t{}, - rng_state, - fanout, - with_replacement, - std::optional>{std::nullopt}, - true); + if (edge_bias_view) { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors, edge_types)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_bias_view, + sample_edge_biases_op_t{}, + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_type_view, + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } else { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors, edge_types)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_type_view, + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } } else { - std::forward_as_tuple(sample_offsets, std::tie(majors, minors)) = - cugraph::per_v_random_select_transform_outgoing_e( - handle, - graph_view, - vertex_frontier.bucket(0), - edge_src_dummy_property_t{}.view(), - edge_dst_dummy_property_t{}.view(), - edge_dummy_property_t{}.view(), - sample_edges_op_t{}, - rng_state, - fanout, - with_replacement, - std::optional>{std::nullopt}, - true); + if (edge_bias_view) { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + *edge_bias_view, + sample_edge_biases_op_t{}, + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + edge_dummy_property_t{}.view(), + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } else { + std::forward_as_tuple(sample_offsets, std::tie(majors, minors)) = + cugraph::per_v_random_select_transform_outgoing_e( + handle, + graph_view, + vertex_frontier.bucket(0), + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + edge_dummy_property_t{}.view(), + sample_edges_op_t{}, + rng_state, + fanout, + with_replacement, + std::optional>{std::nullopt}, + false); + } } } } diff --git a/cpp/src/sampling/detail/sample_edges_mg_v32_e32.cu b/cpp/src/sampling/detail/sample_edges_mg_v32_e32.cu index e456d6c1af3..c3e6670a490 100644 --- a/cpp/src/sampling/detail/sample_edges_mg_v32_e32.cu +++ b/cpp/src/sampling/detail/sample_edges_mg_v32_e32.cu @@ -30,6 +30,7 @@ sample_edges(raft::handle_t const& handle, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, @@ -47,6 +48,7 @@ sample_edges(raft::handle_t const& handle, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, diff --git a/cpp/src/sampling/detail/sample_edges_mg_v32_e64.cu b/cpp/src/sampling/detail/sample_edges_mg_v32_e64.cu index 8dad10fb7b4..4628840499f 100644 --- a/cpp/src/sampling/detail/sample_edges_mg_v32_e64.cu +++ b/cpp/src/sampling/detail/sample_edges_mg_v32_e64.cu @@ -30,6 +30,7 @@ sample_edges(raft::handle_t const& handle, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, @@ -47,6 +48,7 @@ sample_edges(raft::handle_t const& handle, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, diff --git a/cpp/src/sampling/detail/sample_edges_mg_v64_e64.cu b/cpp/src/sampling/detail/sample_edges_mg_v64_e64.cu index b3edfdda3dc..45028503faf 100644 --- a/cpp/src/sampling/detail/sample_edges_mg_v64_e64.cu +++ b/cpp/src/sampling/detail/sample_edges_mg_v64_e64.cu @@ -30,6 +30,7 @@ sample_edges(raft::handle_t const& handle, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, @@ -47,6 +48,7 @@ sample_edges(raft::handle_t const& handle, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, diff --git a/cpp/src/sampling/detail/sample_edges_sg_v32_e32.cu b/cpp/src/sampling/detail/sample_edges_sg_v32_e32.cu index f17e3c8a497..2b41b47bc21 100644 --- a/cpp/src/sampling/detail/sample_edges_sg_v32_e32.cu +++ b/cpp/src/sampling/detail/sample_edges_sg_v32_e32.cu @@ -30,6 +30,7 @@ sample_edges(raft::handle_t const& handle, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, @@ -47,6 +48,7 @@ sample_edges(raft::handle_t const& handle, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, diff --git a/cpp/src/sampling/detail/sample_edges_sg_v32_e64.cu b/cpp/src/sampling/detail/sample_edges_sg_v32_e64.cu index 5571905b2b6..b1d664782f2 100644 --- a/cpp/src/sampling/detail/sample_edges_sg_v32_e64.cu +++ b/cpp/src/sampling/detail/sample_edges_sg_v32_e64.cu @@ -30,6 +30,7 @@ sample_edges(raft::handle_t const& handle, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, @@ -47,6 +48,7 @@ sample_edges(raft::handle_t const& handle, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, diff --git a/cpp/src/sampling/detail/sample_edges_sg_v64_e64.cu b/cpp/src/sampling/detail/sample_edges_sg_v64_e64.cu index 6e1867eaa2a..bd434298cc0 100644 --- a/cpp/src/sampling/detail/sample_edges_sg_v64_e64.cu +++ b/cpp/src/sampling/detail/sample_edges_sg_v64_e64.cu @@ -30,6 +30,7 @@ sample_edges(raft::handle_t const& handle, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, @@ -47,6 +48,7 @@ sample_edges(raft::handle_t const& handle, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, diff --git a/cpp/src/sampling/detail/sampling_utils.hpp b/cpp/src/sampling/detail/sampling_utils.hpp index e56da2053d5..102f9ec58f7 100644 --- a/cpp/src/sampling/detail/sampling_utils.hpp +++ b/cpp/src/sampling/detail/sampling_utils.hpp @@ -16,7 +16,7 @@ #pragma once -#include +#include #include @@ -31,6 +31,27 @@ namespace detail { // in implementation, naming and documentation. We should review these and // consider updating things to support an arbitrary value for store_transposed +/** + * @brief Check edge bias values. + * + * Count the number of negative edge bias values & the number of vertices with the sum of their + * outgoing edge bias values exceeding std::numeric_limits::max(). + * + * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. + * @tparam edge_t Type of edge identifiers. Needs to be an integral type. + * @tparam bias_t Type of edge bias values. Needs to be a floating point type. + * @tparam multi_gpu Flag indicating whether template instantiation should target single-GPU (false) + * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and + * handles to various CUDA libraries) to run graph algorithms. + * @param graph_view Graph View object to generate neighbor sampling on. + * @param edge_weight_view View object holding edge bias values for @p graph_view. + */ +template +std::tuple check_edge_bias_values( + raft::handle_t const& handle, + graph_view_t const& graph_view, + edge_property_view_t edge_bias_view); + /** * @brief Gather edge list for specified vertices * @@ -72,7 +93,7 @@ gather_one_hop_edgelist( graph_view_t const& graph_view, std::optional> edge_weight_view, std::optional> edge_id_view, - std::optional> edge_edge_type_view, + std::optional> edge_type_view, raft::device_span active_majors, std::optional> active_major_labels, bool do_expensive_check = false); @@ -107,6 +128,7 @@ template std::tuple, @@ -119,7 +141,8 @@ sample_edges(raft::handle_t const& handle, graph_view_t const& graph_view, std::optional> edge_weight_view, std::optional> edge_id_view, - std::optional> edge_edge_type_view, + std::optional> edge_type_view, + std::optional> edge_bias_view, raft::random::RngState& rng_state, raft::device_span active_majors, std::optional> active_major_labels, diff --git a/cpp/src/sampling/detail/shuffle_and_organize_output_impl.cuh b/cpp/src/sampling/detail/shuffle_and_organize_output_impl.cuh index e4942d0860c..ec14e99baec 100644 --- a/cpp/src/sampling/detail/shuffle_and_organize_output_impl.cuh +++ b/cpp/src/sampling/detail/shuffle_and_organize_output_impl.cuh @@ -64,143 +64,124 @@ void sort_sampled_tuples(raft::handle_t const& handle, std::optional>& edge_ids, std::optional>& edge_types, std::optional>& hops, - std::optional>& labels) + rmm::device_uvector& labels) { + rmm::device_uvector indices(majors.size(), handle.get_stream()); + thrust::sequence(handle.get_thrust_policy(), indices.begin(), indices.end(), size_t{0}); + rmm::device_uvector tmp_labels(indices.size(), handle.get_stream()); + auto tmp_hops = + hops ? std::make_optional>(indices.size(), handle.get_stream()) + : std::nullopt; + if (hops) { + thrust::sort( + handle.get_thrust_policy(), + indices.begin(), + indices.end(), + [labels = raft::device_span(labels.data(), labels.size()), + hops = raft::device_span(hops->data(), hops->size())] __device__(size_t l, + size_t r) { + return thrust::make_tuple(labels[l], hops[l]) < thrust::make_tuple(labels[r], hops[r]); + }); + thrust::gather(handle.get_thrust_policy(), + indices.begin(), + indices.end(), + thrust::make_zip_iterator(labels.begin(), hops->begin()), + thrust::make_zip_iterator(tmp_labels.begin(), tmp_hops->begin())); + hops = std::move(tmp_hops); + } else { + thrust::sort( + handle.get_thrust_policy(), + indices.begin(), + indices.end(), + [labels = raft::device_span(labels.data(), labels.size())] __device__( + size_t l, size_t r) { return labels[l] < labels[r]; }); + thrust::gather(handle.get_thrust_policy(), + indices.begin(), + indices.end(), + labels.begin(), + tmp_labels.begin()); + } + labels = std::move(tmp_labels); + + rmm::device_uvector tmp_majors(indices.size(), handle.get_stream()); + rmm::device_uvector tmp_minors(indices.size(), handle.get_stream()); + thrust::gather(handle.get_thrust_policy(), + indices.begin(), + indices.end(), + thrust::make_zip_iterator(majors.begin(), minors.begin()), + thrust::make_zip_iterator(tmp_majors.begin(), tmp_minors.begin())); + majors = std::move(tmp_majors); + minors = std::move(tmp_minors); + + auto tmp_weights = + weights ? std::make_optional>(indices.size(), handle.get_stream()) + : std::nullopt; + auto tmp_edge_ids = + edge_ids ? std::make_optional>(indices.size(), handle.get_stream()) + : std::nullopt; + auto tmp_edge_types = edge_types ? std::make_optional>( + indices.size(), handle.get_stream()) + : std::nullopt; if (weights) { if (edge_ids) { if (edge_types) { - if (hops) { - thrust::sort_by_key(handle.get_thrust_policy(), - thrust::make_zip_iterator(labels->begin(), hops->begin()), - thrust::make_zip_iterator(labels->end(), hops->end()), - thrust::make_zip_iterator(majors.begin(), - minors.begin(), - weights->begin(), - edge_ids->begin(), - edge_types->begin())); - } else { - thrust::sort_by_key(handle.get_thrust_policy(), - labels->begin(), - labels->end(), - thrust::make_zip_iterator(majors.begin(), - minors.begin(), - weights->begin(), - edge_ids->begin(), - edge_types->begin())); - } + thrust::gather( + handle.get_thrust_policy(), + indices.begin(), + indices.end(), + thrust::make_zip_iterator(weights->begin(), edge_ids->begin(), edge_types->begin()), + thrust::make_zip_iterator( + tmp_weights->begin(), tmp_edge_ids->begin(), tmp_edge_types->begin())); } else { - if (hops) { - thrust::sort_by_key( - handle.get_thrust_policy(), - thrust::make_zip_iterator(labels->begin(), hops->begin()), - thrust::make_zip_iterator(labels->end(), hops->end()), - thrust::make_zip_iterator( - majors.begin(), minors.begin(), weights->begin(), edge_ids->begin())); - } else { - thrust::sort_by_key( - handle.get_thrust_policy(), - labels->begin(), - labels->end(), - thrust::make_zip_iterator( - majors.begin(), minors.begin(), weights->begin(), edge_ids->begin())); - } + thrust::gather(handle.get_thrust_policy(), + indices.begin(), + indices.end(), + thrust::make_zip_iterator(weights->begin(), edge_ids->begin()), + thrust::make_zip_iterator(tmp_weights->begin(), tmp_edge_ids->begin())); } } else { if (edge_types) { - if (hops) { - thrust::sort_by_key( - handle.get_thrust_policy(), - thrust::make_zip_iterator(labels->begin(), hops->begin()), - thrust::make_zip_iterator(labels->end(), hops->end()), - thrust::make_zip_iterator( - majors.begin(), minors.begin(), weights->begin(), edge_types->begin())); - } else { - thrust::sort_by_key( - handle.get_thrust_policy(), - labels->begin(), - labels->end(), - thrust::make_zip_iterator( - majors.begin(), minors.begin(), weights->begin(), edge_types->begin())); - } + thrust::gather(handle.get_thrust_policy(), + indices.begin(), + indices.end(), + thrust::make_zip_iterator(weights->begin(), edge_types->begin()), + thrust::make_zip_iterator(tmp_weights->begin(), tmp_edge_types->begin())); } else { - if (hops) { - thrust::sort_by_key( - handle.get_thrust_policy(), - thrust::make_zip_iterator(labels->begin(), hops->begin()), - thrust::make_zip_iterator(labels->end(), hops->end()), - thrust::make_zip_iterator(majors.begin(), minors.begin(), weights->begin())); - } else { - thrust::sort_by_key( - handle.get_thrust_policy(), - labels->begin(), - labels->end(), - thrust::make_zip_iterator(majors.begin(), minors.begin(), weights->begin())); - } + thrust::gather(handle.get_thrust_policy(), + indices.begin(), + indices.end(), + weights->begin(), + tmp_weights->begin()); } } } else { if (edge_ids) { if (edge_types) { - if (hops) { - thrust::sort_by_key( - handle.get_thrust_policy(), - thrust::make_zip_iterator(labels->begin(), hops->begin()), - thrust::make_zip_iterator(labels->end(), hops->end()), - thrust::make_zip_iterator( - majors.begin(), minors.begin(), edge_ids->begin(), edge_types->begin())); - } else { - thrust::sort_by_key( - handle.get_thrust_policy(), - labels->begin(), - labels->end(), - thrust::make_zip_iterator( - majors.begin(), minors.begin(), edge_ids->begin(), edge_types->begin())); - } + thrust::gather(handle.get_thrust_policy(), + indices.begin(), + indices.end(), + thrust::make_zip_iterator(edge_ids->begin(), edge_types->begin()), + thrust::make_zip_iterator(tmp_edge_ids->begin(), tmp_edge_types->begin())); } else { - if (hops) { - thrust::sort_by_key( - handle.get_thrust_policy(), - thrust::make_zip_iterator(labels->begin(), hops->begin()), - thrust::make_zip_iterator(labels->end(), hops->end()), - thrust::make_zip_iterator(majors.begin(), minors.begin(), edge_ids->begin())); - } else { - thrust::sort_by_key( - handle.get_thrust_policy(), - labels->begin(), - labels->end(), - thrust::make_zip_iterator(majors.begin(), minors.begin(), edge_ids->begin())); - } + thrust::gather(handle.get_thrust_policy(), + indices.begin(), + indices.end(), + edge_ids->begin(), + tmp_edge_ids->begin()); } } else { if (edge_types) { - if (hops) { - thrust::sort_by_key( - handle.get_thrust_policy(), - thrust::make_zip_iterator(labels->begin(), hops->begin()), - thrust::make_zip_iterator(labels->end(), hops->end()), - thrust::make_zip_iterator(majors.begin(), minors.begin(), edge_types->begin())); - } else { - thrust::sort_by_key( - handle.get_thrust_policy(), - labels->begin(), - labels->end(), - thrust::make_zip_iterator(majors.begin(), minors.begin(), edge_types->begin())); - } - } else { - if (hops) { - thrust::sort_by_key(handle.get_thrust_policy(), - thrust::make_zip_iterator(labels->begin(), hops->begin()), - thrust::make_zip_iterator(labels->end(), hops->end()), - thrust::make_zip_iterator(majors.begin(), minors.begin())); - } else { - thrust::sort_by_key(handle.get_thrust_policy(), - labels->begin(), - labels->end(), - thrust::make_zip_iterator(majors.begin(), minors.begin())); - } + thrust::gather(handle.get_thrust_policy(), + indices.begin(), + indices.end(), + edge_types->begin(), + tmp_edge_types->begin()); } } } + weights = std::move(tmp_weights); + edge_ids = std::move(tmp_edge_ids); + edge_types = std::move(tmp_edge_types); } template > offsets{std::nullopt}; if (labels) { - sort_sampled_tuples(handle, majors, minors, weights, edge_ids, edge_types, hops, labels); + sort_sampled_tuples(handle, majors, minors, weights, edge_ids, edge_types, hops, *labels); if (label_to_output_comm_rank) { CUGRAPH_EXPECTS(labels, "labels must be specified in order to shuffle sampling results"); @@ -744,7 +725,7 @@ shuffle_and_organize_output( } } - sort_sampled_tuples(handle, majors, minors, weights, edge_ids, edge_types, hops, labels); + sort_sampled_tuples(handle, majors, minors, weights, edge_ids, edge_types, hops, *labels); } size_t num_unique_labels = diff --git a/cpp/src/sampling/uniform_neighbor_sampling_impl.hpp b/cpp/src/sampling/neighbor_sampling_impl.hpp similarity index 80% rename from cpp/src/sampling/uniform_neighbor_sampling_impl.hpp rename to cpp/src/sampling/neighbor_sampling_impl.hpp index 21033783508..1785f934426 100644 --- a/cpp/src/sampling/uniform_neighbor_sampling_impl.hpp +++ b/cpp/src/sampling/neighbor_sampling_impl.hpp @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include @@ -34,6 +36,7 @@ template @@ -45,12 +48,13 @@ std::tuple, std::optional>, std::optional>, std::optional>> -uniform_neighbor_sample_impl( +neighbor_sample_impl( raft::handle_t const& handle, graph_view_t const& graph_view, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_type_view, + std::optional> edge_bias_view, raft::device_span this_frontier_vertices, std::optional> this_frontier_vertex_labels, std::optional, raft::device_span>> @@ -63,10 +67,14 @@ uniform_neighbor_sample_impl( raft::random::RngState& rng_state, bool do_expensive_check) { -#ifdef NO_CUGRAPH_OPS +#ifdef NO_CUGRAPH_OPS // FIXME: this is relevant only when edge_bias_view.has_value() is false, + // this ifdef statement will be removed once we migrate relevant cugraph-ops + // functions to cugraph CUGRAPH_FAIL( - "uniform_neighbor_sample_impl not supported in this configuration, built with NO_CUGRAPH_OPS"); + "neighbor_sample_impl not supported in this configuration, built with NO_CUGRAPH_OPS"); #else + static_assert(std::is_floating_point_v); + CUGRAPH_EXPECTS(fan_out.size() > 0, "Invalid input argument: number of levels must be non-zero."); CUGRAPH_EXPECTS( fan_out.size() <= static_cast(std::numeric_limits::max()), @@ -83,6 +91,18 @@ uniform_neighbor_sample_impl( "cannot specify output GPU mapping without also specifying this_frontier_vertex_labels"); if (do_expensive_check) { + if (edge_bias_view) { + auto [num_negative_edge_weights, num_overflows] = + check_edge_bias_values(handle, graph_view, *edge_bias_view); + + CUGRAPH_EXPECTS( + num_negative_edge_weights == 0, + "Invalid input argument: input edge bias values should have non-negative values."); + CUGRAPH_EXPECTS(num_overflows == 0, + "Invalid input argument: sum of neighboring edge bias values should not " + "exceed std::numeric_limits::max() for any vertex."); + } + if (label_to_output_comm_rank) { CUGRAPH_EXPECTS(cugraph::detail::is_sorted(handle, std::get<0>(*label_to_output_comm_rank)), "Labels in label_to_output_comm_rank must be sorted"); @@ -145,6 +165,7 @@ uniform_neighbor_sample_impl( edge_weight_view, edge_id_view, edge_type_view, + edge_bias_view, rng_state, this_frontier_vertices, this_frontier_vertex_labels, @@ -353,23 +374,78 @@ uniform_neighbor_sample( bool dedupe_sources, bool do_expensive_check) { - CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented."); - - return detail::uniform_neighbor_sample_impl(handle, - graph_view, - edge_weight_view, - edge_id_view, - edge_type_view, - starting_vertices, - starting_vertex_labels, - label_to_output_comm_rank, - fan_out, - return_hops, - with_replacement, - prior_sources_behavior, - dedupe_sources, - rng_state, - do_expensive_check); + using bias_t = weight_t; // dummy + return detail::neighbor_sample_impl( + handle, + graph_view, + edge_weight_view, + edge_id_view, + edge_type_view, + std::nullopt, + starting_vertices, + starting_vertex_labels, + label_to_output_comm_rank, + fan_out, + return_hops, + with_replacement, + prior_sources_behavior, + dedupe_sources, + rng_state, + do_expensive_check); +} + +template +std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +biased_neighbor_sample( + raft::handle_t const& handle, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional, raft::device_span>> + label_to_output_comm_rank, + raft::host_span fan_out, + raft::random::RngState& rng_state, + bool return_hops, + bool with_replacement, + prior_sources_behavior_t prior_sources_behavior, + bool dedupe_sources, + bool do_expensive_check) +{ + return detail::neighbor_sample_impl( + handle, + graph_view, + edge_weight_view, + edge_id_view, + edge_type_view, + edge_bias_view, + starting_vertices, + starting_vertex_labels, + label_to_output_comm_rank, + fan_out, + return_hops, + with_replacement, + prior_sources_behavior, + dedupe_sources, + rng_state, + do_expensive_check); } } // namespace cugraph diff --git a/cpp/src/sampling/uniform_neighbor_sampling_mg_v32_e32.cpp b/cpp/src/sampling/neighbor_sampling_mg_v32_e32.cpp similarity index 53% rename from cpp/src/sampling/uniform_neighbor_sampling_mg_v32_e32.cpp rename to cpp/src/sampling/neighbor_sampling_mg_v32_e32.cpp index 3e816d8c9f9..f61c1c10c53 100644 --- a/cpp/src/sampling/uniform_neighbor_sampling_mg_v32_e32.cpp +++ b/cpp/src/sampling/neighbor_sampling_mg_v32_e32.cpp @@ -14,9 +14,10 @@ * limitations under the License. */ -#include "uniform_neighbor_sampling_impl.hpp" +#include "neighbor_sampling_impl.hpp" #include +#include namespace cugraph { @@ -72,4 +73,58 @@ uniform_neighbor_sample( bool dedupe_sources, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +biased_neighbor_sample( + raft::handle_t const& handle, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional, raft::device_span>> + label_to_output_comm_rank, + raft::host_span fan_out, + raft::random::RngState& rng_state, + bool return_hops, + bool with_replacement, + prior_sources_behavior_t prior_sources_behavior, + bool dedupe_sources, + bool do_expensive_check); + +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +biased_neighbor_sample( + raft::handle_t const& handle, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional, raft::device_span>> + label_to_output_comm_rank, + raft::host_span fan_out, + raft::random::RngState& rng_state, + bool return_hops, + bool with_replacement, + prior_sources_behavior_t prior_sources_behavior, + bool dedupe_sources, + bool do_expensive_check); + } // namespace cugraph diff --git a/cpp/src/sampling/uniform_neighbor_sampling_mg.cpp b/cpp/src/sampling/neighbor_sampling_mg_v32_e64.cpp similarity index 66% rename from cpp/src/sampling/uniform_neighbor_sampling_mg.cpp rename to cpp/src/sampling/neighbor_sampling_mg_v32_e64.cpp index 6e0c15c70dd..c37b353ae1c 100644 --- a/cpp/src/sampling/uniform_neighbor_sampling_mg.cpp +++ b/cpp/src/sampling/neighbor_sampling_mg_v32_e64.cpp @@ -14,26 +14,27 @@ * limitations under the License. */ -#include "uniform_neighbor_sampling_impl.hpp" +#include "neighbor_sampling_impl.hpp" #include +#include namespace cugraph { template std::tuple, rmm::device_uvector, std::optional>, - std::optional>, + std::optional>, std::optional>, std::optional>, std::optional>, std::optional>> uniform_neighbor_sample( raft::handle_t const& handle, - graph_view_t const& graph_view, - std::optional> edge_weight_view, - std::optional> edge_id_view, - std::optional> edge_type_view, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, raft::device_span starting_vertices, std::optional> starting_vertex_labels, std::optional, raft::device_span>> @@ -48,7 +49,7 @@ uniform_neighbor_sample( template std::tuple, rmm::device_uvector, - std::optional>, + std::optional>, std::optional>, std::optional>, std::optional>, @@ -57,7 +58,7 @@ template std::tuple, uniform_neighbor_sample( raft::handle_t const& handle, graph_view_t const& graph_view, - std::optional> edge_weight_view, + std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_type_view, raft::device_span starting_vertices, @@ -72,46 +73,21 @@ uniform_neighbor_sample( bool dedupe_sources, bool do_expensive_check); -template std::tuple, - rmm::device_uvector, +template std::tuple, + rmm::device_uvector, std::optional>, std::optional>, std::optional>, std::optional>, std::optional>, std::optional>> -uniform_neighbor_sample( +biased_neighbor_sample( raft::handle_t const& handle, - graph_view_t const& graph_view, + graph_view_t const& graph_view, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_type_view, - raft::device_span starting_vertices, - std::optional> starting_vertex_labels, - std::optional, raft::device_span>> - label_to_output_comm_rank, - raft::host_span fan_out, - raft::random::RngState& rng_state, - bool return_hops, - bool with_replacement, - prior_sources_behavior_t prior_sources_behavior, - bool dedupe_sources, - bool do_expensive_check); - -template std::tuple, - rmm::device_uvector, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>> -uniform_neighbor_sample( - raft::handle_t const& handle, - graph_view_t const& graph_view, - std::optional> edge_weight_view, - std::optional> edge_id_view, - std::optional> edge_type_view, + edge_property_view_t edge_bias_view, raft::device_span starting_vertices, std::optional> starting_vertex_labels, std::optional, raft::device_span>> @@ -132,12 +108,13 @@ template std::tuple, std::optional>, std::optional>, std::optional>> -uniform_neighbor_sample( +biased_neighbor_sample( raft::handle_t const& handle, graph_view_t const& graph_view, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_type_view, + edge_property_view_t edge_bias_view, raft::device_span starting_vertices, std::optional> starting_vertex_labels, std::optional, raft::device_span>> @@ -150,30 +127,4 @@ uniform_neighbor_sample( bool dedupe_sources, bool do_expensive_check); -template std::tuple, - rmm::device_uvector, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>> -uniform_neighbor_sample( - raft::handle_t const& handle, - graph_view_t const& graph_view, - std::optional> edge_weight_view, - std::optional> edge_id_view, - std::optional> edge_type_view, - raft::device_span starting_vertices, - std::optional> starting_vertex_labels, - std::optional, raft::device_span>> - label_to_output_comm_rank, - raft::host_span fan_out, - raft::random::RngState& rng_state, - bool return_hops, - bool with_replacement, - prior_sources_behavior_t prior_sources_behavior, - bool dedupe_sources, - bool do_expensive_check); - } // namespace cugraph diff --git a/cpp/src/sampling/uniform_neighbor_sampling_mg_v64_e64.cpp b/cpp/src/sampling/neighbor_sampling_mg_v64_e64.cpp similarity index 53% rename from cpp/src/sampling/uniform_neighbor_sampling_mg_v64_e64.cpp rename to cpp/src/sampling/neighbor_sampling_mg_v64_e64.cpp index 8989f5f4284..ea3f6b466da 100644 --- a/cpp/src/sampling/uniform_neighbor_sampling_mg_v64_e64.cpp +++ b/cpp/src/sampling/neighbor_sampling_mg_v64_e64.cpp @@ -14,9 +14,10 @@ * limitations under the License. */ -#include "uniform_neighbor_sampling_impl.hpp" +#include "neighbor_sampling_impl.hpp" #include +#include namespace cugraph { @@ -72,4 +73,58 @@ uniform_neighbor_sample( bool dedupe_sources, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +biased_neighbor_sample( + raft::handle_t const& handle, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional, raft::device_span>> + label_to_output_comm_rank, + raft::host_span fan_out, + raft::random::RngState& rng_state, + bool return_hops, + bool with_replacement, + prior_sources_behavior_t prior_sources_behavior, + bool dedupe_sources, + bool do_expensive_check); + +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +biased_neighbor_sample( + raft::handle_t const& handle, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional, raft::device_span>> + label_to_output_comm_rank, + raft::host_span fan_out, + raft::random::RngState& rng_state, + bool return_hops, + bool with_replacement, + prior_sources_behavior_t prior_sources_behavior, + bool dedupe_sources, + bool do_expensive_check); + } // namespace cugraph diff --git a/cpp/src/sampling/uniform_neighbor_sampling_sg_v32_e32.cpp b/cpp/src/sampling/neighbor_sampling_sg_v32_e32.cpp similarity index 53% rename from cpp/src/sampling/uniform_neighbor_sampling_sg_v32_e32.cpp rename to cpp/src/sampling/neighbor_sampling_sg_v32_e32.cpp index 2c32653eba8..0f0affbb323 100644 --- a/cpp/src/sampling/uniform_neighbor_sampling_sg_v32_e32.cpp +++ b/cpp/src/sampling/neighbor_sampling_sg_v32_e32.cpp @@ -14,9 +14,10 @@ * limitations under the License. */ -#include "uniform_neighbor_sampling_impl.hpp" +#include "neighbor_sampling_impl.hpp" #include +#include namespace cugraph { @@ -72,4 +73,58 @@ uniform_neighbor_sample( bool dedupe_sources, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +biased_neighbor_sample( + raft::handle_t const& handle, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional, raft::device_span>> + label_to_output_comm_rank, + raft::host_span fan_out, + raft::random::RngState& rng_state, + bool return_hops, + bool with_replacement, + prior_sources_behavior_t prior_sources_behavior, + bool dedupe_sources, + bool do_expensive_check); + +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +biased_neighbor_sample( + raft::handle_t const& handle, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional, raft::device_span>> + label_to_output_comm_rank, + raft::host_span fan_out, + raft::random::RngState& rng_state, + bool return_hops, + bool with_replacement, + prior_sources_behavior_t prior_sources_behavior, + bool dedupe_sources, + bool do_expensive_check); + } // namespace cugraph diff --git a/cpp/src/sampling/uniform_neighbor_sampling_sg.cpp b/cpp/src/sampling/neighbor_sampling_sg_v32_e64.cpp similarity index 66% rename from cpp/src/sampling/uniform_neighbor_sampling_sg.cpp rename to cpp/src/sampling/neighbor_sampling_sg_v32_e64.cpp index d069cecfb5c..7ab0a8782ec 100644 --- a/cpp/src/sampling/uniform_neighbor_sampling_sg.cpp +++ b/cpp/src/sampling/neighbor_sampling_sg_v32_e64.cpp @@ -14,26 +14,27 @@ * limitations under the License. */ -#include "uniform_neighbor_sampling_impl.hpp" +#include "neighbor_sampling_impl.hpp" #include +#include namespace cugraph { template std::tuple, rmm::device_uvector, std::optional>, - std::optional>, + std::optional>, std::optional>, std::optional>, std::optional>, std::optional>> uniform_neighbor_sample( raft::handle_t const& handle, - graph_view_t const& graph_view, - std::optional> edge_weight_view, - std::optional> edge_id_view, - std::optional> edge_type_view, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, raft::device_span starting_vertices, std::optional> starting_vertex_labels, std::optional, raft::device_span>> @@ -48,7 +49,7 @@ uniform_neighbor_sample( template std::tuple, rmm::device_uvector, - std::optional>, + std::optional>, std::optional>, std::optional>, std::optional>, @@ -57,7 +58,7 @@ template std::tuple, uniform_neighbor_sample( raft::handle_t const& handle, graph_view_t const& graph_view, - std::optional> edge_weight_view, + std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_type_view, raft::device_span starting_vertices, @@ -72,46 +73,21 @@ uniform_neighbor_sample( bool dedupe_sources, bool do_expensive_check); -template std::tuple, - rmm::device_uvector, +template std::tuple, + rmm::device_uvector, std::optional>, std::optional>, std::optional>, std::optional>, std::optional>, std::optional>> -uniform_neighbor_sample( +biased_neighbor_sample( raft::handle_t const& handle, - graph_view_t const& graph_view, + graph_view_t const& graph_view, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_type_view, - raft::device_span starting_vertices, - std::optional> starting_vertex_labels, - std::optional, raft::device_span>> - label_to_output_comm_rank, - raft::host_span fan_out, - raft::random::RngState& rng_state, - bool return_hops, - bool with_replacement, - prior_sources_behavior_t prior_sources_behavior, - bool dedupe_sources, - bool do_expensive_check); - -template std::tuple, - rmm::device_uvector, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>> -uniform_neighbor_sample( - raft::handle_t const& handle, - graph_view_t const& graph_view, - std::optional> edge_weight_view, - std::optional> edge_id_view, - std::optional> edge_type_view, + edge_property_view_t edge_bias_view, raft::device_span starting_vertices, std::optional> starting_vertex_labels, std::optional, raft::device_span>> @@ -132,12 +108,13 @@ template std::tuple, std::optional>, std::optional>, std::optional>> -uniform_neighbor_sample( +biased_neighbor_sample( raft::handle_t const& handle, graph_view_t const& graph_view, std::optional> edge_weight_view, std::optional> edge_id_view, std::optional> edge_type_view, + edge_property_view_t edge_bias_view, raft::device_span starting_vertices, std::optional> starting_vertex_labels, std::optional, raft::device_span>> @@ -150,30 +127,4 @@ uniform_neighbor_sample( bool dedupe_sources, bool do_expensive_check); -template std::tuple, - rmm::device_uvector, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>> -uniform_neighbor_sample( - raft::handle_t const& handle, - graph_view_t const& graph_view, - std::optional> edge_weight_view, - std::optional> edge_id_view, - std::optional> edge_type_view, - raft::device_span starting_vertices, - std::optional> starting_vertex_labels, - std::optional, raft::device_span>> - label_to_output_comm_rank, - raft::host_span fan_out, - raft::random::RngState& rng_state, - bool return_hops, - bool with_replacement, - prior_sources_behavior_t prior_sources_behavior, - bool dedupe_sources, - bool do_expensive_check); - } // namespace cugraph diff --git a/cpp/src/sampling/uniform_neighbor_sampling_sg_v64_e64.cpp b/cpp/src/sampling/neighbor_sampling_sg_v64_e64.cpp similarity index 53% rename from cpp/src/sampling/uniform_neighbor_sampling_sg_v64_e64.cpp rename to cpp/src/sampling/neighbor_sampling_sg_v64_e64.cpp index eb35faada28..70dd9a59842 100644 --- a/cpp/src/sampling/uniform_neighbor_sampling_sg_v64_e64.cpp +++ b/cpp/src/sampling/neighbor_sampling_sg_v64_e64.cpp @@ -14,9 +14,10 @@ * limitations under the License. */ -#include "uniform_neighbor_sampling_impl.hpp" +#include "neighbor_sampling_impl.hpp" #include +#include namespace cugraph { @@ -72,4 +73,58 @@ uniform_neighbor_sample( bool dedupe_sources, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +biased_neighbor_sample( + raft::handle_t const& handle, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional, raft::device_span>> + label_to_output_comm_rank, + raft::host_span fan_out, + raft::random::RngState& rng_state, + bool return_hops, + bool with_replacement, + prior_sources_behavior_t prior_sources_behavior, + bool dedupe_sources, + bool do_expensive_check); + +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +biased_neighbor_sample( + raft::handle_t const& handle, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional, raft::device_span>> + label_to_output_comm_rank, + raft::host_span fan_out, + raft::random::RngState& rng_state, + bool return_hops, + bool with_replacement, + prior_sources_behavior_t prior_sources_behavior, + bool dedupe_sources, + bool do_expensive_check); + } // namespace cugraph diff --git a/cpp/src/sampling/uniform_neighbor_sampling_mg_v32_e64.cpp b/cpp/src/sampling/uniform_neighbor_sampling_mg_v32_e64.cpp deleted file mode 100644 index 3b1e6ba7d3e..00000000000 --- a/cpp/src/sampling/uniform_neighbor_sampling_mg_v32_e64.cpp +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "uniform_neighbor_sampling_impl.hpp" - -#include - -namespace cugraph { - -template std::tuple, - rmm::device_uvector, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>> -uniform_neighbor_sample( - raft::handle_t const& handle, - graph_view_t const& graph_view, - std::optional> edge_weight_view, - std::optional> edge_id_view, - std::optional> edge_type_view, - raft::device_span starting_vertices, - std::optional> starting_vertex_labels, - std::optional, raft::device_span>> - label_to_output_comm_rank, - raft::host_span fan_out, - raft::random::RngState& rng_state, - bool return_hops, - bool with_replacement, - prior_sources_behavior_t prior_sources_behavior, - bool dedupe_sources, - bool do_expensive_check); - -template std::tuple, - rmm::device_uvector, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>> -uniform_neighbor_sample( - raft::handle_t const& handle, - graph_view_t const& graph_view, - std::optional> edge_weight_view, - std::optional> edge_id_view, - std::optional> edge_type_view, - raft::device_span starting_vertices, - std::optional> starting_vertex_labels, - std::optional, raft::device_span>> - label_to_output_comm_rank, - raft::host_span fan_out, - raft::random::RngState& rng_state, - bool return_hops, - bool with_replacement, - prior_sources_behavior_t prior_sources_behavior, - bool dedupe_sources, - bool do_expensive_check); - -} // namespace cugraph diff --git a/cpp/src/sampling/uniform_neighbor_sampling_sg_v32_e64.cpp b/cpp/src/sampling/uniform_neighbor_sampling_sg_v32_e64.cpp deleted file mode 100644 index 23bb3137183..00000000000 --- a/cpp/src/sampling/uniform_neighbor_sampling_sg_v32_e64.cpp +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "uniform_neighbor_sampling_impl.hpp" - -#include - -namespace cugraph { - -template std::tuple, - rmm::device_uvector, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>> -uniform_neighbor_sample( - raft::handle_t const& handle, - graph_view_t const& graph_view, - std::optional> edge_weight_view, - std::optional> edge_id_view, - std::optional> edge_type_view, - raft::device_span starting_vertices, - std::optional> starting_vertex_labels, - std::optional, raft::device_span>> - label_to_output_comm_rank, - raft::host_span fan_out, - raft::random::RngState& rng_state, - bool return_hops, - bool with_replacement, - prior_sources_behavior_t prior_sources_behavior, - bool dedupe_sources, - bool do_expensive_check); - -template std::tuple, - rmm::device_uvector, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>, - std::optional>> -uniform_neighbor_sample( - raft::handle_t const& handle, - graph_view_t const& graph_view, - std::optional> edge_weight_view, - std::optional> edge_id_view, - std::optional> edge_type_view, - raft::device_span starting_vertices, - std::optional> starting_vertex_labels, - std::optional, raft::device_span>> - label_to_output_comm_rank, - raft::host_span fan_out, - raft::random::RngState& rng_state, - bool return_hops, - bool with_replacement, - prior_sources_behavior_t prior_sources_behavior, - bool dedupe_sources, - bool do_expensive_check); - -} // namespace cugraph diff --git a/cpp/src/structure/induced_subgraph_impl.cuh b/cpp/src/structure/induced_subgraph_impl.cuh index 2e774497b78..b1ce8e6f51e 100644 --- a/cpp/src/structure/induced_subgraph_impl.cuh +++ b/cpp/src/structure/induced_subgraph_impl.cuh @@ -133,16 +133,9 @@ extract_induced_subgraphs( #endif // 1. check input arguments - CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented."); - if (do_expensive_check) { size_t should_be_zero{std::numeric_limits::max()}; - size_t num_aggregate_subgraph_vertices{}; raft::update_host(&should_be_zero, subgraph_offsets.data(), 1, handle.get_stream()); - raft::update_host(&num_aggregate_subgraph_vertices, - subgraph_offsets.data() + subgraph_offsets.size() - 1, - 1, - handle.get_stream()); handle.sync_stream(); CUGRAPH_EXPECTS(should_be_zero == 0, "Invalid input argument: subgraph_offsets[0] should be 0."); @@ -225,12 +218,6 @@ extract_induced_subgraphs( // vertex has a property of 1 vertex_frontier_t vertex_frontier(handle, 1); - std::vector h_subgraph_offsets(subgraph_offsets.size()); - raft::update_host(h_subgraph_offsets.data(), - subgraph_offsets.data(), - subgraph_offsets.size(), - handle.get_stream()); - graph_ids_v = detail::expand_sparse_offsets(subgraph_offsets, size_t{0}, handle.get_stream()); vertex_frontier.bucket(0).insert( diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index fd356ff8b89..c94ffed3b9f 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -45,6 +45,7 @@ add_library(cugraphtestutil STATIC cores/k_core_validate.cu structure/induced_subgraph_validate.cu sampling/random_walks_check_sg.cu + sampling/detail/nbr_sampling_validate.cu ../../thirdparty/mmio/mmio.c) target_compile_options(cugraphtestutil @@ -479,8 +480,12 @@ ConfigureTest(WEIGHTED_SIMILARITY_TEST link_prediction/weighted_similarity_test. ConfigureTest(RANDOM_WALKS_TEST sampling/sg_random_walks_test.cpp) ################################################################################################### -# - NBR SAMPLING tests ---------------------------------------------------------------------------- -ConfigureTest(UNIFORM_NEIGHBOR_SAMPLING_TEST sampling/sg_uniform_neighbor_sampling.cu) +# - UNIFORM NBR SAMPLING tests -------------------------------------------------------------------- +ConfigureTest(UNIFORM_NEIGHBOR_SAMPLING_TEST sampling/uniform_neighbor_sampling.cpp) + +################################################################################################### +# - BIASED NBR SAMPLING tests --------------------------------------------------------------------- +ConfigureTest(BIASED_NEIGHBOR_SAMPLING_TEST sampling/biased_neighbor_sampling.cpp) ################################################################################################### # - SAMPLING_POST_PROCESSING tests ---------------------------------------------------------------- @@ -729,8 +734,12 @@ if(BUILD_CUGRAPH_MG_TESTS) prims/mg_per_v_pair_transform_dst_nbr_weighted_intersection.cu) ############################################################################################### - # - MG NBR SAMPLING tests --------------------------------------------------------------------- - ConfigureTestMG(MG_UNIFORM_NEIGHBOR_SAMPLING_TEST sampling/mg_uniform_neighbor_sampling.cu) + # - MG UNIFORM NBR SAMPLING tests ------------------------------------------------------------- + ConfigureTestMG(MG_UNIFORM_NEIGHBOR_SAMPLING_TEST sampling/mg_uniform_neighbor_sampling.cpp) + + ############################################################################################### + # - MG BIASED NBR SAMPLING tests -------------------------------------------------------------- + ConfigureTestMG(MG_BIASED_NEIGHBOR_SAMPLING_TEST sampling/mg_biased_neighbor_sampling.cpp) ############################################################################################### # - MG RANDOM_WALKS tests --------------------------------------------------------------------- diff --git a/cpp/tests/centrality/katz_centrality_test.cpp b/cpp/tests/centrality/katz_centrality_test.cpp index 7c8a22221c0..190007c38e5 100644 --- a/cpp/tests/centrality/katz_centrality_test.cpp +++ b/cpp/tests/centrality/katz_centrality_test.cpp @@ -214,7 +214,8 @@ class Tests_KatzCentrality rmm::device_uvector d_unrenumbered_katz_centralities(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_katz_centralities) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_katz_centralities); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_katz_centralities); h_cugraph_katz_centralities = cugraph::test::to_host(handle, d_unrenumbered_katz_centralities); } else { diff --git a/cpp/tests/community/k_truss_test.cpp b/cpp/tests/community/k_truss_test.cpp index c8010422e42..424d52f2067 100644 --- a/cpp/tests/community/k_truss_test.cpp +++ b/cpp/tests/community/k_truss_test.cpp @@ -224,10 +224,11 @@ class Tests_KTruss : public ::testing::TestWithParam( + handle, d_cugraph_srcs, d_cugraph_dsts, *d_cugraph_wgts); } else { std::tie(d_sorted_cugraph_srcs, d_sorted_cugraph_dsts) = - cugraph::test::sort(handle, d_cugraph_srcs, d_cugraph_dsts); + cugraph::test::sort(handle, d_cugraph_srcs, d_cugraph_dsts); } auto h_cugraph_srcs = cugraph::test::to_host(handle, d_sorted_cugraph_srcs); diff --git a/cpp/tests/components/weakly_connected_components_test.cpp b/cpp/tests/components/weakly_connected_components_test.cpp index 2dd82316b00..7b909c6f594 100644 --- a/cpp/tests/components/weakly_connected_components_test.cpp +++ b/cpp/tests/components/weakly_connected_components_test.cpp @@ -170,7 +170,8 @@ class Tests_WeaklyConnectedComponent if (renumber) { rmm::device_uvector d_unrenumbered_components(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_components) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_components); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_components); h_cugraph_components = cugraph::test::to_host(handle, d_unrenumbered_components); } else { h_cugraph_components = cugraph::test::to_host(handle, d_components); diff --git a/cpp/tests/cores/core_number_test.cpp b/cpp/tests/cores/core_number_test.cpp index fb6f26278af..ca0174202c2 100644 --- a/cpp/tests/cores/core_number_test.cpp +++ b/cpp/tests/cores/core_number_test.cpp @@ -300,7 +300,8 @@ class Tests_CoreNumber if (renumber) { rmm::device_uvector d_unrenumbered_core_numbers(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_core_numbers) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_core_numbers); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_core_numbers); h_cugraph_core_numbers = cugraph::test::to_host(handle, d_unrenumbered_core_numbers); } else { h_cugraph_core_numbers = cugraph::test::to_host(handle, d_core_numbers); diff --git a/cpp/tests/link_analysis/hits_test.cpp b/cpp/tests/link_analysis/hits_test.cpp index f1b2a0ef0df..31ed5537a6b 100644 --- a/cpp/tests/link_analysis/hits_test.cpp +++ b/cpp/tests/link_analysis/hits_test.cpp @@ -255,7 +255,8 @@ class Tests_Hits : public ::testing::TestWithParam d_unrenumbered_initial_random_hubs(0, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_initial_random_hubs) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, *d_initial_random_hubs); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, *d_initial_random_hubs); h_initial_random_hubs = cugraph::test::to_host(handle, d_unrenumbered_initial_random_hubs); } else { @@ -277,7 +278,7 @@ class Tests_Hits : public ::testing::TestWithParam d_unrenumbered_hubs(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_hubs) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_hubs); + cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_hubs); h_cugraph_hits = cugraph::test::to_host(handle, d_unrenumbered_hubs); } else { h_cugraph_hits = cugraph::test::to_host(handle, d_hubs); diff --git a/cpp/tests/link_analysis/pagerank_test.cpp b/cpp/tests/link_analysis/pagerank_test.cpp index 9219832ac63..196476d6756 100644 --- a/cpp/tests/link_analysis/pagerank_test.cpp +++ b/cpp/tests/link_analysis/pagerank_test.cpp @@ -282,9 +282,9 @@ class Tests_PageRank vertex_t{0}, graph_view.number_of_vertices()); std::tie(d_unrenumbered_personalization_vertices, d_unrenumbered_personalization_values) = - cugraph::test::sort_by_key(handle, - d_unrenumbered_personalization_vertices, - d_unrenumbered_personalization_values); + cugraph::test::sort_by_key(handle, + d_unrenumbered_personalization_vertices, + d_unrenumbered_personalization_values); h_unrenumbered_personalization_vertices = cugraph::test::to_host(handle, d_unrenumbered_personalization_vertices); @@ -327,7 +327,8 @@ class Tests_PageRank if (renumber) { rmm::device_uvector d_unrenumbered_pageranks(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_pageranks) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_pageranks); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_pageranks); h_cugraph_pageranks = cugraph::test::to_host(handle, d_unrenumbered_pageranks); } else { h_cugraph_pageranks = cugraph::test::to_host(handle, d_pageranks); diff --git a/cpp/tests/prims/mg_per_v_random_select_transform_outgoing_e.cu b/cpp/tests/prims/mg_per_v_random_select_transform_outgoing_e.cu index b99dbf16107..49c14631839 100644 --- a/cpp/tests/prims/mg_per_v_random_select_transform_outgoing_e.cu +++ b/cpp/tests/prims/mg_per_v_random_select_transform_outgoing_e.cu @@ -49,11 +49,12 @@ #include -template +template struct e_bias_op_t { - __device__ weight_t operator()(vertex_t, vertex_t, property_t, property_t, weight_t w) const + __device__ bias_t + operator()(vertex_t, vertex_t, thrust::nullopt_t, thrust::nullopt_t, bias_t bias) const { - return w; + return bias; } }; @@ -216,10 +217,13 @@ class Tests_MGPerVRandomSelectTransformOutgoingE *handle_, mg_graph_view, mg_vertex_frontier.bucket(bucket_idx_cur), + cugraph::edge_src_dummy_property_t{}.view(), + cugraph::edge_dst_dummy_property_t{}.view(), + *mg_edge_weight_view, + e_bias_op_t{}, mg_src_prop.view(), mg_dst_prop.view(), *mg_edge_weight_view, - e_bias_op_t{}, e_op_t{}, rng_state, prims_usecase.K, diff --git a/cpp/tests/sampling/biased_neighbor_sampling.cpp b/cpp/tests/sampling/biased_neighbor_sampling.cpp new file mode 100644 index 00000000000..8ee3ab27833 --- /dev/null +++ b/cpp/tests/sampling/biased_neighbor_sampling.cpp @@ -0,0 +1,295 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "detail/nbr_sampling_validate.hpp" +#include "utilities/base_fixture.hpp" +#include "utilities/property_generator_utilities.hpp" + +#include +#include + +#include +#include +#include + +#include + +struct Biased_Neighbor_Sampling_Usecase { + std::vector fanout{{-1}}; + int32_t batch_size{10}; + bool flag_replacement{true}; + + bool edge_masking{false}; + bool check_correctness{true}; +}; + +template +class Tests_Biased_Neighbor_Sampling + : public ::testing::TestWithParam> { + public: + Tests_Biased_Neighbor_Sampling() {} + + static void SetUpTestCase() {} + static void TearDownTestCase() {} + + virtual void SetUp() {} + virtual void TearDown() {} + + template + void run_current_test( + std::tuple const& param) + { + auto [biased_neighbor_sampling_usecase, input_usecase] = param; + raft::handle_t handle{}; + HighResTimer hr_timer{}; + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("Construct graph"); + } + + auto [graph, edge_weights, renumber_map_labels] = + cugraph::test::construct_graph( + handle, input_usecase, true, true); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + auto graph_view = graph.view(); + auto edge_weight_view = + edge_weights ? std::make_optional((*edge_weights).view()) : std::nullopt; + + std::optional> edge_mask{std::nullopt}; + if (biased_neighbor_sampling_usecase.edge_masking) { + edge_mask = + cugraph::test::generate::edge_property(handle, graph_view, 2); + graph_view.attach_edge_mask((*edge_mask).view()); + } + + constexpr float select_probability{0.05}; + + // FIXME: Update the tests to initialize RngState and use it instead + // of seed... + constexpr uint64_t seed{0}; + + raft::random::RngState rng_state(seed); + + auto random_sources = cugraph::select_random_vertices( + handle, + graph_view, + std::optional>{std::nullopt}, + rng_state, + std::max(static_cast(graph_view.number_of_vertices() * select_probability), + std::min(static_cast(graph_view.number_of_vertices()), size_t{1})), + false, + false); + + // + // Now we'll assign the vertices to batches + // + rmm::device_uvector random_numbers(random_sources.size(), handle.get_stream()); + + cugraph::detail::uniform_random_fill(handle.get_stream(), + random_numbers.data(), + random_numbers.size(), + float{0}, + float{1}, + rng_state); + + std::tie(random_numbers, random_sources) = cugraph::test::sort_by_key( + handle, std::move(random_numbers), std::move(random_sources)); + + random_numbers.resize(0, handle.get_stream()); + random_numbers.shrink_to_fit(handle.get_stream()); + + auto batch_number = std::make_optional>(0, handle.get_stream()); + + batch_number = cugraph::test::sequence( + handle, random_sources.size(), biased_neighbor_sampling_usecase.batch_size, int32_t{0}); + + rmm::device_uvector random_sources_copy(random_sources.size(), handle.get_stream()); + + raft::copy(random_sources_copy.data(), + random_sources.data(), + random_sources.size(), + handle.get_stream()); + + std::optional, raft::device_span>> + label_to_output_comm_rank_mapping{std::nullopt}; + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("Biased neighbor sampling"); + } + + auto&& [src_out, dst_out, wgt_out, edge_id, edge_type, hop, labels, offsets] = + cugraph::biased_neighbor_sample( + handle, + graph_view, + edge_weight_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + *edge_weight_view, + raft::device_span{random_sources_copy.data(), random_sources.size()}, + batch_number ? std::make_optional(raft::device_span{batch_number->data(), + batch_number->size()}) + : std::nullopt, + label_to_output_comm_rank_mapping, + raft::host_span(biased_neighbor_sampling_usecase.fanout.data(), + biased_neighbor_sampling_usecase.fanout.size()), + rng_state, + true, + biased_neighbor_sampling_usecase.flag_replacement); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + if (biased_neighbor_sampling_usecase.check_correctness) { + // First validate that the extracted edges are actually a subset of the + // edges in the input graph + rmm::device_uvector vertices(2 * src_out.size(), handle.get_stream()); + raft::copy(vertices.data(), src_out.data(), src_out.size(), handle.get_stream()); + raft::copy( + vertices.data() + src_out.size(), dst_out.data(), dst_out.size(), handle.get_stream()); + vertices = cugraph::test::sort(handle, std::move(vertices)); + vertices = cugraph::test::unique(handle, std::move(vertices)); + + rmm::device_uvector d_subgraph_offsets(2, handle.get_stream()); + std::vector h_subgraph_offsets({0, vertices.size()}); + + raft::update_device(d_subgraph_offsets.data(), + h_subgraph_offsets.data(), + h_subgraph_offsets.size(), + handle.get_stream()); + + rmm::device_uvector src_compare(0, handle.get_stream()); + rmm::device_uvector dst_compare(0, handle.get_stream()); + std::optional> wgt_compare{std::nullopt}; + + std::tie(src_compare, dst_compare, wgt_compare, std::ignore) = extract_induced_subgraphs( + handle, + graph_view, + edge_weight_view, + raft::device_span(d_subgraph_offsets.data(), 2), + raft::device_span(vertices.data(), vertices.size()), + true); + + ASSERT_TRUE(cugraph::test::validate_extracted_graph_is_subgraph( + handle, src_compare, dst_compare, wgt_compare, src_out, dst_out, wgt_out)); + + if (random_sources.size() < 100) { + // This validation is too expensive for large number of vertices + ASSERT_TRUE( + cugraph::test::validate_sampling_depth(handle, + std::move(src_out), + std::move(dst_out), + std::move(wgt_out), + std::move(random_sources), + biased_neighbor_sampling_usecase.fanout.size())); + } + } + } +}; + +using Tests_Biased_Neighbor_Sampling_File = + Tests_Biased_Neighbor_Sampling; + +using Tests_Biased_Neighbor_Sampling_Rmat = + Tests_Biased_Neighbor_Sampling; + +TEST_P(Tests_Biased_Neighbor_Sampling_File, CheckInt32Int32Float) +{ + auto param = GetParam(); + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_Biased_Neighbor_Sampling_File, CheckInt32Int64Float) +{ + auto param = GetParam(); + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_Biased_Neighbor_Sampling_File, CheckInt64Int64Float) +{ + auto param = GetParam(); + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_Biased_Neighbor_Sampling_Rmat, CheckInt32Int32Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_Biased_Neighbor_Sampling_Rmat, CheckInt32Int64Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_Biased_Neighbor_Sampling_Rmat, CheckInt64Int64Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_Biased_Neighbor_Sampling_File, + ::testing::Combine( + ::testing::Values(Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, false, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, false, true}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, true, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, true, true}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), + cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), + cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), + cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx")))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_test, + Tests_Biased_Neighbor_Sampling_Rmat, + ::testing::Combine( + ::testing::Values(Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, false, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, false, true}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, true, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, true, true}), + ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false, 0)))); + +INSTANTIATE_TEST_SUITE_P( + rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with + --gtest_filter to select only the rmat_benchmark_test with a specific + vertex & edge type combination) by command line arguments and do not + include more than one Rmat_Usecase that differ only in scale or edge + factor (to avoid running same benchmarks more than once) */ + Tests_Biased_Neighbor_Sampling_Rmat, + ::testing::Combine( + ::testing::Values(Biased_Neighbor_Sampling_Usecase{{4, 10}, 1024, false, false, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 1024, false, true, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 1024, true, false, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 1024, true, true, false}), + ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false, 0)))); + +CUGRAPH_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/sampling/detail/nbr_sampling_utils.cuh b/cpp/tests/sampling/detail/nbr_sampling_validate.cu similarity index 64% rename from cpp/tests/sampling/detail/nbr_sampling_utils.cuh rename to cpp/tests/sampling/detail/nbr_sampling_validate.cu index be93990eb9c..61731e2e15c 100644 --- a/cpp/tests/sampling/detail/nbr_sampling_utils.cuh +++ b/cpp/tests/sampling/detail/nbr_sampling_validate.cu @@ -14,12 +14,6 @@ * limitations under the License. */ -// Andrei Schaffer, aschaffer@nvidia.com -// -#pragma once - -#include "utilities/base_fixture.hpp" -#include "utilities/device_comm_wrapper.hpp" #include "utilities/test_graphs.hpp" #include "utilities/thrust_wrapper.hpp" @@ -28,7 +22,6 @@ #include #include #include -#include #include @@ -48,8 +41,6 @@ #include #include -#include - #include #include #include @@ -110,7 +101,7 @@ struct ArithmeticZipEqual { }; template -void validate_extracted_graph_is_subgraph( +bool validate_extracted_graph_is_subgraph( raft::handle_t const& handle, rmm::device_uvector const& src, rmm::device_uvector const& dst, @@ -119,123 +110,104 @@ void validate_extracted_graph_is_subgraph( rmm::device_uvector const& subgraph_dst, std::optional> const& subgraph_wgt) { - ASSERT_EQ(wgt.has_value(), subgraph_wgt.has_value()); + if (wgt.has_value() != subgraph_wgt.has_value()) { return false; } rmm::device_uvector src_v(src.size(), handle.get_stream()); rmm::device_uvector dst_v(dst.size(), handle.get_stream()); - rmm::device_uvector subgraph_src_v(subgraph_src.size(), handle.get_stream()); - rmm::device_uvector subgraph_dst_v(subgraph_dst.size(), handle.get_stream()); - raft::copy(src_v.data(), src.data(), src.size(), handle.get_stream()); raft::copy(dst_v.data(), dst.data(), dst.size(), handle.get_stream()); - raft::copy(subgraph_src_v.data(), subgraph_src.data(), subgraph_src.size(), handle.get_stream()); - raft::copy(subgraph_dst_v.data(), subgraph_dst.data(), subgraph_dst.size(), handle.get_stream()); - - size_t dist{0}; + size_t num_invalids{0}; if (wgt) { rmm::device_uvector wgt_v(wgt->size(), handle.get_stream()); - rmm::device_uvector subgraph_wgt_v(subgraph_wgt->size(), handle.get_stream()); - raft::copy(wgt_v.data(), wgt->data(), wgt->size(), handle.get_stream()); - raft::copy( - subgraph_wgt_v.data(), subgraph_wgt->data(), subgraph_wgt->size(), handle.get_stream()); auto graph_iter = thrust::make_zip_iterator(thrust::make_tuple(src_v.begin(), dst_v.begin(), wgt_v.begin())); - auto subgraph_iter = thrust::make_zip_iterator( - thrust::make_tuple(subgraph_src_v.begin(), subgraph_dst_v.begin(), subgraph_wgt_v.begin())); - thrust::sort( handle.get_thrust_policy(), graph_iter, graph_iter + src_v.size(), ArithmeticZipLess{}); - thrust::sort(handle.get_thrust_policy(), - subgraph_iter, - subgraph_iter + subgraph_src_v.size(), - ArithmeticZipLess{}); - auto graph_iter_end = thrust::unique( handle.get_thrust_policy(), graph_iter, graph_iter + src_v.size(), ArithmeticZipEqual{}); - auto subgraph_iter_end = thrust::unique(handle.get_thrust_policy(), - subgraph_iter, - subgraph_iter + subgraph_src_v.size(), - ArithmeticZipEqual{}); - auto new_size = thrust::distance(graph_iter, graph_iter_end); src_v.resize(new_size, handle.get_stream()); dst_v.resize(new_size, handle.get_stream()); wgt_v.resize(new_size, handle.get_stream()); - new_size = thrust::distance(subgraph_iter, subgraph_iter_end); - subgraph_src_v.resize(new_size, handle.get_stream()); - subgraph_dst_v.resize(new_size, handle.get_stream()); - subgraph_wgt_v.resize(new_size, handle.get_stream()); - - rmm::device_uvector tmp_src(new_size, handle.get_stream()); - rmm::device_uvector tmp_dst(new_size, handle.get_stream()); - rmm::device_uvector tmp_wgt(new_size, handle.get_stream()); - - auto tmp_subgraph_iter = thrust::make_zip_iterator( - thrust::make_tuple(tmp_src.begin(), tmp_dst.begin(), tmp_wgt.begin())); - - auto tmp_subgraph_iter_end = thrust::set_difference(handle.get_thrust_policy(), - subgraph_iter, - subgraph_iter + subgraph_src_v.size(), - graph_iter, - graph_iter + src_v.size(), - tmp_subgraph_iter, - ArithmeticZipLess{}); - - dist = thrust::distance(tmp_subgraph_iter, tmp_subgraph_iter_end); + auto subgraph_iter = thrust::make_zip_iterator( + thrust::make_tuple(subgraph_src.begin(), subgraph_dst.begin(), subgraph_wgt->begin())); + num_invalids = + thrust::count_if(handle.get_thrust_policy(), + subgraph_iter, + subgraph_iter + subgraph_src.size(), + [graph_iter, new_size] __device__(auto tup) { + return (thrust::binary_search( + thrust::seq, graph_iter, graph_iter + new_size, tup) == false); + }); } else { auto graph_iter = thrust::make_zip_iterator(thrust::make_tuple(src_v.begin(), dst_v.begin())); - auto subgraph_iter = - thrust::make_zip_iterator(thrust::make_tuple(subgraph_src_v.begin(), subgraph_dst_v.begin())); - thrust::sort( handle.get_thrust_policy(), graph_iter, graph_iter + src_v.size(), ArithmeticZipLess{}); - thrust::sort(handle.get_thrust_policy(), - subgraph_iter, - subgraph_iter + subgraph_src_v.size(), - ArithmeticZipLess{}); - auto graph_iter_end = thrust::unique( handle.get_thrust_policy(), graph_iter, graph_iter + src_v.size(), ArithmeticZipEqual{}); - auto subgraph_iter_end = thrust::unique(handle.get_thrust_policy(), - subgraph_iter, - subgraph_iter + subgraph_src_v.size(), - ArithmeticZipEqual{}); - auto new_size = thrust::distance(graph_iter, graph_iter_end); src_v.resize(new_size, handle.get_stream()); dst_v.resize(new_size, handle.get_stream()); - new_size = thrust::distance(subgraph_iter, subgraph_iter_end); - subgraph_src_v.resize(new_size, handle.get_stream()); - subgraph_dst_v.resize(new_size, handle.get_stream()); - - rmm::device_uvector tmp_src(new_size, handle.get_stream()); - rmm::device_uvector tmp_dst(new_size, handle.get_stream()); - - auto tmp_subgraph_iter = thrust::make_zip_iterator(tmp_src.begin(), tmp_dst.begin()); - - auto tmp_subgraph_iter_end = thrust::set_difference(handle.get_thrust_policy(), - subgraph_iter, - subgraph_iter + subgraph_src_v.size(), - graph_iter, - graph_iter + src_v.size(), - tmp_subgraph_iter, - ArithmeticZipLess{}); - - dist = thrust::distance(tmp_subgraph_iter, tmp_subgraph_iter_end); + auto subgraph_iter = + thrust::make_zip_iterator(thrust::make_tuple(subgraph_src.begin(), subgraph_dst.begin())); + num_invalids = + thrust::count_if(handle.get_thrust_policy(), + subgraph_iter, + subgraph_iter + subgraph_src.size(), + [graph_iter, new_size] __device__(auto tup) { + return (thrust::binary_search( + thrust::seq, graph_iter, graph_iter + new_size, tup) == false); + }); } - ASSERT_EQ(0, dist); + return (num_invalids == 0); } +template bool validate_extracted_graph_is_subgraph( + raft::handle_t const& handle, + rmm::device_uvector const& src, + rmm::device_uvector const& dst, + std::optional> const& wgt, + rmm::device_uvector const& subgraph_src, + rmm::device_uvector const& subgraph_dst, + std::optional> const& subgraph_wgt); + +template bool validate_extracted_graph_is_subgraph( + raft::handle_t const& handle, + rmm::device_uvector const& src, + rmm::device_uvector const& dst, + std::optional> const& wgt, + rmm::device_uvector const& subgraph_src, + rmm::device_uvector const& subgraph_dst, + std::optional> const& subgraph_wgt); + +template bool validate_extracted_graph_is_subgraph( + raft::handle_t const& handle, + rmm::device_uvector const& src, + rmm::device_uvector const& dst, + std::optional> const& wgt, + rmm::device_uvector const& subgraph_src, + rmm::device_uvector const& subgraph_dst, + std::optional> const& subgraph_wgt); + +template bool validate_extracted_graph_is_subgraph( + raft::handle_t const& handle, + rmm::device_uvector const& src, + rmm::device_uvector const& dst, + std::optional> const& wgt, + rmm::device_uvector const& subgraph_src, + rmm::device_uvector const& subgraph_dst, + std::optional> const& subgraph_wgt); + template -void validate_sampling_depth(raft::handle_t const& handle, +bool validate_sampling_depth(raft::handle_t const& handle, rmm::device_uvector&& d_src, rmm::device_uvector&& d_dst, std::optional>&& d_wgt, @@ -304,12 +276,39 @@ void validate_sampling_depth(raft::handle_t const& handle, } } - ASSERT_EQ(0, - thrust::count_if(handle.get_thrust_policy(), - d_distances.begin(), - d_distances.end(), - [max_depth] __device__(auto d) { return d > max_depth; })); + return (thrust::count_if(handle.get_thrust_policy(), + d_distances.begin(), + d_distances.end(), + [max_depth] __device__(auto d) { return d > max_depth; }) == 0); } +template bool validate_sampling_depth(raft::handle_t const& handle, + rmm::device_uvector&& d_src, + rmm::device_uvector&& d_dst, + std::optional>&& d_wgt, + rmm::device_uvector&& d_source_vertices, + int max_depth); + +template bool validate_sampling_depth(raft::handle_t const& handle, + rmm::device_uvector&& d_src, + rmm::device_uvector&& d_dst, + std::optional>&& d_wgt, + rmm::device_uvector&& d_source_vertices, + int max_depth); + +template bool validate_sampling_depth(raft::handle_t const& handle, + rmm::device_uvector&& d_src, + rmm::device_uvector&& d_dst, + std::optional>&& d_wgt, + rmm::device_uvector&& d_source_vertices, + int max_depth); + +template bool validate_sampling_depth(raft::handle_t const& handle, + rmm::device_uvector&& d_src, + rmm::device_uvector&& d_dst, + std::optional>&& d_wgt, + rmm::device_uvector&& d_source_vertices, + int max_depth); + } // namespace test } // namespace cugraph diff --git a/cpp/tests/sampling/detail/nbr_sampling_validate.hpp b/cpp/tests/sampling/detail/nbr_sampling_validate.hpp new file mode 100644 index 00000000000..8e46d3a5782 --- /dev/null +++ b/cpp/tests/sampling/detail/nbr_sampling_validate.hpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +#include + +// utilities for testing / verification of Nbr Sampling functionality: +// +namespace cugraph { +namespace test { + +template +bool validate_extracted_graph_is_subgraph( + raft::handle_t const& handle, + rmm::device_uvector const& src, + rmm::device_uvector const& dst, + std::optional> const& wgt, + rmm::device_uvector const& subgraph_src, + rmm::device_uvector const& subgraph_dst, + std::optional> const& subgraph_wgt); + +template +bool validate_sampling_depth(raft::handle_t const& handle, + rmm::device_uvector&& d_src, + rmm::device_uvector&& d_dst, + std::optional>&& d_wgt, + rmm::device_uvector&& d_source_vertices, + int max_depth); + +} // namespace test +} // namespace cugraph diff --git a/cpp/tests/sampling/mg_biased_neighbor_sampling.cpp b/cpp/tests/sampling/mg_biased_neighbor_sampling.cpp new file mode 100644 index 00000000000..000b0a31e4b --- /dev/null +++ b/cpp/tests/sampling/mg_biased_neighbor_sampling.cpp @@ -0,0 +1,351 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "detail/nbr_sampling_validate.hpp" +#include "utilities/base_fixture.hpp" +#include "utilities/device_comm_wrapper.hpp" +#include "utilities/mg_utilities.hpp" +#include "utilities/property_generator_utilities.hpp" + +#include +#include + +#include + +struct Biased_Neighbor_Sampling_Usecase { + std::vector fanout{{-1}}; + int32_t batch_size{10}; + bool with_replacement{true}; + + bool edge_masking{false}; + bool check_correctness{true}; +}; + +template +class Tests_MGBiased_Neighbor_Sampling + : public ::testing::TestWithParam> { + public: + Tests_MGBiased_Neighbor_Sampling() {} + + static void SetUpTestCase() { handle_ = cugraph::test::initialize_mg_handle(); } + + static void TearDownTestCase() { handle_.reset(); } + + virtual void SetUp() {} + virtual void TearDown() {} + + template + void run_current_test(std::tuple const& param) + { + auto [biased_neighbor_sampling_usecase, input_usecase] = param; + + HighResTimer hr_timer{}; + + // 1. create MG graph + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.start("MG construct graph"); + } + + auto [mg_graph, mg_edge_weights, mg_renumber_map_labels] = + cugraph::test::construct_graph( + *handle_, + input_usecase, + true /* test_weighted */, + true /* renumber */, + false /* drop_self_loops */, + false /* drop_multi_edges */); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + auto mg_graph_view = mg_graph.view(); + auto mg_edge_weight_view = + mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt; + + std::optional> edge_mask{std::nullopt}; + if (biased_neighbor_sampling_usecase.edge_masking) { + edge_mask = cugraph::test::generate::edge_property( + *handle_, mg_graph_view, 2); + mg_graph_view.attach_edge_mask((*edge_mask).view()); + } + + // + // Test is designed like GNN sampling. We'll select 5% of vertices to be included in sampling + // batches + // + + constexpr float select_probability{0.05}; + + raft::random::RngState rng_state(handle_->get_comms().get_rank()); + + auto random_sources = cugraph::select_random_vertices( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + rng_state, + std::max(static_cast(mg_graph_view.number_of_vertices() * select_probability), + std::min(static_cast(mg_graph_view.number_of_vertices()), size_t{1})), + false, + false); + + // + // Now we'll assign the vertices to batches + // + + rmm::device_uvector random_numbers(random_sources.size(), handle_->get_stream()); + + cugraph::detail::uniform_random_fill(handle_->get_stream(), + random_numbers.data(), + random_numbers.size(), + float{0}, + float{1}, + rng_state); + + std::tie(random_numbers, random_sources) = cugraph::test::sort_by_key( + *handle_, std::move(random_numbers), std::move(random_sources)); + + random_numbers.resize(0, handle_->get_stream()); + random_numbers.shrink_to_fit(handle_->get_stream()); + + auto seed_sizes = cugraph::host_scalar_allgather( + handle_->get_comms(), random_sources.size(), handle_->get_stream()); + size_t num_seeds = std::reduce(seed_sizes.begin(), seed_sizes.end()); + size_t num_batches = (num_seeds + biased_neighbor_sampling_usecase.batch_size - 1) / + biased_neighbor_sampling_usecase.batch_size; + + std::vector seed_offsets(seed_sizes.size()); + std::exclusive_scan(seed_sizes.begin(), seed_sizes.end(), seed_offsets.begin(), size_t{0}); + + auto batch_number = cugraph::test::modulo_sequence( + *handle_, random_sources.size(), num_batches, seed_offsets[handle_->get_comms().get_rank()]); + + rmm::device_uvector unique_batches(num_batches, handle_->get_stream()); + cugraph::detail::sequence_fill( + handle_->get_stream(), unique_batches.data(), unique_batches.size(), int32_t{0}); + + auto comm_ranks = cugraph::test::modulo_sequence( + *handle_, num_batches, handle_->get_comms().get_size(), int32_t{0}); + + rmm::device_uvector random_sources_copy(random_sources.size(), handle_->get_stream()); + + raft::copy(random_sources_copy.data(), + random_sources.data(), + random_sources.size(), + handle_->get_stream()); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.start("MG biased_neighbor_sample"); + } + + auto&& [src_out, dst_out, wgt_out, edge_id, edge_type, hop, labels, offsets] = + cugraph::biased_neighbor_sample( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + *mg_edge_weight_view, + raft::device_span{random_sources_copy.data(), random_sources.size()}, + std::make_optional( + raft::device_span{batch_number.data(), batch_number.size()}), + std::make_optional(std::make_tuple( + raft::device_span{unique_batches.data(), unique_batches.size()}, + raft::device_span{comm_ranks.data(), comm_ranks.size()})), + raft::host_span(biased_neighbor_sampling_usecase.fanout.data(), + biased_neighbor_sampling_usecase.fanout.size()), + rng_state, + true, + biased_neighbor_sampling_usecase.with_replacement); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + if (biased_neighbor_sampling_usecase.check_correctness) { + // Consolidate results on GPU 0 + auto mg_start_src = cugraph::test::device_gatherv( + *handle_, raft::device_span{random_sources.data(), random_sources.size()}); + auto mg_aggregate_src = cugraph::test::device_gatherv( + *handle_, raft::device_span{src_out.data(), src_out.size()}); + auto mg_aggregate_dst = cugraph::test::device_gatherv( + *handle_, raft::device_span{dst_out.data(), dst_out.size()}); + auto mg_aggregate_wgt = + wgt_out ? std::make_optional(cugraph::test::device_gatherv( + *handle_, raft::device_span{wgt_out->data(), wgt_out->size()})) + : std::nullopt; + + // First validate that the extracted edges are actually a subset of the + // edges in the input graph + rmm::device_uvector vertices(2 * mg_aggregate_src.size(), handle_->get_stream()); + raft::copy( + vertices.data(), mg_aggregate_src.data(), mg_aggregate_src.size(), handle_->get_stream()); + raft::copy(vertices.data() + mg_aggregate_src.size(), + mg_aggregate_dst.data(), + mg_aggregate_dst.size(), + handle_->get_stream()); + vertices = cugraph::test::sort(*handle_, std::move(vertices)); + vertices = cugraph::test::unique(*handle_, std::move(vertices)); + + vertices = cugraph::detail::shuffle_int_vertices_to_local_gpu_by_vertex_partitioning( + *handle_, std::move(vertices), mg_graph_view.vertex_partition_range_lasts()); + + vertices = cugraph::test::sort(*handle_, std::move(vertices)); + vertices = cugraph::test::unique(*handle_, std::move(vertices)); + + rmm::device_uvector d_subgraph_offsets(2, handle_->get_stream()); + std::vector h_subgraph_offsets({0, vertices.size()}); + + raft::update_device(d_subgraph_offsets.data(), + h_subgraph_offsets.data(), + h_subgraph_offsets.size(), + handle_->get_stream()); + + rmm::device_uvector src_compare(0, handle_->get_stream()); + rmm::device_uvector dst_compare(0, handle_->get_stream()); + std::optional> wgt_compare{std::nullopt}; + std::tie(src_compare, dst_compare, wgt_compare, std::ignore) = extract_induced_subgraphs( + *handle_, + mg_graph_view, + mg_edge_weight_view, + raft::device_span(d_subgraph_offsets.data(), 2), + raft::device_span(vertices.data(), vertices.size()), + true); + + auto mg_aggregate_src_compare = cugraph::test::device_gatherv( + *handle_, raft::device_span{src_compare.data(), src_compare.size()}); + auto mg_aggregate_dst_compare = cugraph::test::device_gatherv( + *handle_, raft::device_span{dst_compare.data(), dst_compare.size()}); + auto mg_aggregate_wgt_compare = + wgt_compare + ? std::make_optional(cugraph::test::device_gatherv( + *handle_, + raft::device_span{wgt_compare->data(), wgt_compare->size()})) + : std::nullopt; + + if (handle_->get_comms().get_rank() == 0) { + ASSERT_TRUE(cugraph::test::validate_extracted_graph_is_subgraph(*handle_, + mg_aggregate_src_compare, + mg_aggregate_dst_compare, + mg_aggregate_wgt_compare, + mg_aggregate_src, + mg_aggregate_dst, + mg_aggregate_wgt)); + + if (random_sources.size() < 100) { + // This validation is too expensive for large number of vertices + if (mg_aggregate_src.size() > 0) { + ASSERT_TRUE(cugraph::test::validate_sampling_depth( + *handle_, + std::move(mg_aggregate_src), + std::move(mg_aggregate_dst), + std::move(mg_aggregate_wgt), + std::move(mg_start_src), + biased_neighbor_sampling_usecase.fanout.size())); + } + } + } + } + } + + private: + static std::unique_ptr handle_; +}; + +template +std::unique_ptr Tests_MGBiased_Neighbor_Sampling::handle_ = + nullptr; + +using Tests_MGBiased_Neighbor_Sampling_File = + Tests_MGBiased_Neighbor_Sampling; + +using Tests_MGBiased_Neighbor_Sampling_Rmat = + Tests_MGBiased_Neighbor_Sampling; + +TEST_P(Tests_MGBiased_Neighbor_Sampling_File, CheckInt32Int32Float) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGBiased_Neighbor_Sampling_Rmat, CheckInt32Int32Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGBiased_Neighbor_Sampling_Rmat, CheckInt32Int64Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGBiased_Neighbor_Sampling_Rmat, CheckInt64Int64Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_MGBiased_Neighbor_Sampling_File, + ::testing::Combine( + ::testing::Values(Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, false, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, false, true}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, true, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, true, true}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), + cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), + cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), + cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx")))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_test, + Tests_MGBiased_Neighbor_Sampling_Rmat, + ::testing::Combine(::testing::Values(Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, false, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, false, true}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, true, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 128, true, true}), + ::testing::Values( + // cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false)))); + cugraph::test::Rmat_Usecase(5, 16, 0.57, 0.19, 0.19, 0, false, false)))); + +INSTANTIATE_TEST_SUITE_P( + rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with + --gtest_filter to select only the rmat_benchmark_test with a specific + vertex & edge type combination) by command line arguments and do not + include more than one Rmat_Usecase that differ only in scale or edge + factor (to avoid running same benchmarks more than once) */ + Tests_MGBiased_Neighbor_Sampling_Rmat, + ::testing::Combine( + ::testing::Values(Biased_Neighbor_Sampling_Usecase{{4, 10}, 1024, false, false, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 1024, false, true, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 1024, true, false, false}, + Biased_Neighbor_Sampling_Usecase{{4, 10}, 1024, true, true, false}), + ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false)))); + +CUGRAPH_MG_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/sampling/mg_uniform_neighbor_sampling.cu b/cpp/tests/sampling/mg_uniform_neighbor_sampling.cpp similarity index 85% rename from cpp/tests/sampling/mg_uniform_neighbor_sampling.cu rename to cpp/tests/sampling/mg_uniform_neighbor_sampling.cpp index 22c1a4a7edf..ee9651d5e5d 100644 --- a/cpp/tests/sampling/mg_uniform_neighbor_sampling.cu +++ b/cpp/tests/sampling/mg_uniform_neighbor_sampling.cpp @@ -14,14 +14,15 @@ * limitations under the License. */ -#include "detail/nbr_sampling_utils.cuh" +#include "detail/nbr_sampling_validate.hpp" +#include "utilities/base_fixture.hpp" +#include "utilities/device_comm_wrapper.hpp" #include "utilities/mg_utilities.hpp" +#include "utilities/property_generator_utilities.hpp" +#include "utilities/test_graphs.hpp" -#include - -#include -#include -#include +#include +#include #include @@ -29,6 +30,8 @@ struct Uniform_Neighbor_Sampling_Usecase { std::vector fanout{{-1}}; int32_t batch_size{10}; bool with_replacement{true}; + + bool edge_masking{false}; bool check_correctness{true}; }; @@ -81,6 +84,13 @@ class Tests_MGUniform_Neighbor_Sampling auto mg_edge_weight_view = mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt; + std::optional> edge_mask{std::nullopt}; + if (uniform_neighbor_sampling_usecase.edge_masking) { + edge_mask = cugraph::test::generate::edge_property( + *handle_, mg_graph_view, 2); + mg_graph_view.attach_edge_mask((*edge_mask).view()); + } + // // Test is designed like GNN sampling. We'll select 5% of vertices to be included in sampling // batches @@ -113,16 +123,12 @@ class Tests_MGUniform_Neighbor_Sampling float{1}, rng_state); - thrust::sort_by_key(handle_->get_thrust_policy(), - random_numbers.begin(), - random_numbers.end(), - random_sources.begin()); + std::tie(random_numbers, random_sources) = cugraph::test::sort_by_key( + *handle_, std::move(random_numbers), std::move(random_sources)); random_numbers.resize(0, handle_->get_stream()); random_numbers.shrink_to_fit(handle_->get_stream()); - rmm::device_uvector batch_number(random_sources.size(), handle_->get_stream()); - auto seed_sizes = cugraph::host_scalar_allgather( handle_->get_comms(), random_sources.size(), handle_->get_stream()); size_t num_seeds = std::reduce(seed_sizes.begin(), seed_sizes.end()); @@ -132,24 +138,15 @@ class Tests_MGUniform_Neighbor_Sampling std::vector seed_offsets(seed_sizes.size()); std::exclusive_scan(seed_sizes.begin(), seed_sizes.end(), seed_offsets.begin(), size_t{0}); - thrust::tabulate( - handle_->get_thrust_policy(), - batch_number.begin(), - batch_number.end(), - [seed_offset = seed_offsets[handle_->get_comms().get_rank()], - num_batches] __device__(int32_t index) { return (seed_offset + index) % num_batches; }); + auto batch_number = cugraph::test::modulo_sequence( + *handle_, random_sources.size(), num_batches, seed_offsets[handle_->get_comms().get_rank()]); rmm::device_uvector unique_batches(num_batches, handle_->get_stream()); - rmm::device_uvector comm_ranks(num_batches, handle_->get_stream()); - cugraph::detail::sequence_fill( handle_->get_stream(), unique_batches.data(), unique_batches.size(), int32_t{0}); - thrust::tabulate(handle_->get_thrust_policy(), - comm_ranks.begin(), - comm_ranks.end(), - [num_gpus = handle_->get_comms().get_size()] __device__(auto index) { - return index % num_gpus; - }); + + auto comm_ranks = cugraph::test::modulo_sequence( + *handle_, num_batches, handle_->get_comms().get_size(), int32_t{0}); rmm::device_uvector random_sources_copy(random_sources.size(), handle_->get_stream()); @@ -233,15 +230,14 @@ class Tests_MGUniform_Neighbor_Sampling mg_aggregate_dst.data(), mg_aggregate_dst.size(), handle_->get_stream()); - thrust::sort(handle_->get_thrust_policy(), vertices.begin(), vertices.end()); - auto vertices_end = - thrust::unique(handle_->get_thrust_policy(), vertices.begin(), vertices.end()); - vertices.resize(thrust::distance(vertices.begin(), vertices_end), handle_->get_stream()); + vertices = cugraph::test::sort(*handle_, std::move(vertices)); + vertices = cugraph::test::unique(*handle_, std::move(vertices)); vertices = cugraph::detail::shuffle_int_vertices_to_local_gpu_by_vertex_partitioning( *handle_, std::move(vertices), mg_graph_view.vertex_partition_range_lasts()); - thrust::sort(handle_->get_thrust_policy(), vertices.begin(), vertices.end()); + vertices = cugraph::test::sort(*handle_, std::move(vertices)); + vertices = cugraph::test::unique(*handle_, std::move(vertices)); rmm::device_uvector d_subgraph_offsets(2, handle_->get_stream()); std::vector h_subgraph_offsets({0, vertices.size()}); @@ -340,8 +336,10 @@ INSTANTIATE_TEST_SUITE_P( file_test, Tests_MGUniform_Neighbor_Sampling_File, ::testing::Combine( - ::testing::Values(Uniform_Neighbor_Sampling_Usecase{{10, 25}, 128, false, true}, - Uniform_Neighbor_Sampling_Usecase{{10, 25}, 128, true, true}), + ::testing::Values(Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, false, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, false, true}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, true, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), @@ -351,8 +349,10 @@ INSTANTIATE_TEST_SUITE_P( rmat_small_test, Tests_MGUniform_Neighbor_Sampling_Rmat, ::testing::Combine( - ::testing::Values(Uniform_Neighbor_Sampling_Usecase{{10, 25}, 128, false, true}, - Uniform_Neighbor_Sampling_Usecase{{10, 25}, 128, true, true}), + ::testing::Values(Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, false, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, false, true}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, true, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, true, true}), ::testing::Values( // cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false)))); cugraph::test::Rmat_Usecase(5, 16, 0.57, 0.19, 0.19, 0, false, false)))); @@ -365,8 +365,10 @@ INSTANTIATE_TEST_SUITE_P( factor (to avoid running same benchmarks more than once) */ Tests_MGUniform_Neighbor_Sampling_Rmat, ::testing::Combine( - ::testing::Values(Uniform_Neighbor_Sampling_Usecase{{10, 25}, 128, false, false}, - Uniform_Neighbor_Sampling_Usecase{{10, 25}, 128, true, false}), + ::testing::Values(Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, false, false, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, false, true, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, true, false, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, true, true, false}), ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false)))); CUGRAPH_MG_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/sampling/sampling_post_processing_test.cu b/cpp/tests/sampling/sampling_post_processing_test.cu index 3bca382a2eb..ecec1d0ed89 100644 --- a/cpp/tests/sampling/sampling_post_processing_test.cu +++ b/cpp/tests/sampling/sampling_post_processing_test.cu @@ -1604,47 +1604,47 @@ INSTANTIATE_TEST_SUITE_P( SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, true, true, false, false, false}, SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, true, true, false, true, false}, SamplingPostProcessing_Usecase{1, 64, {5, 10, 15}, true, true, true, true, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, false, false, false, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, false, false, false, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, false, false, true, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, false, false, true, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, false, true, false, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, false, true, false, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, false, true, true, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, false, true, true, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, true, false, false, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, true, false, false, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, true, false, true, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, true, false, true, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, true, true, false, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, true, true, false, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, true, true, true, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {10}, true, true, true, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, false, false, false, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, false, false, false, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, false, false, true, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, false, false, true, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, false, true, false, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, false, true, false, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, false, true, true, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, false, true, true, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, true, false, false, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, true, false, false, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, true, false, true, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, true, false, true, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, true, true, false, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, true, true, false, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, true, true, true, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {10}, true, true, true, false, true, false}, SamplingPostProcessing_Usecase{ - 256, 64, {5, 10, 15}, false, false, false, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, false, false, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, false, false, true, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, false, true, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, false, true, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, false, true, true, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, false, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, false, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, false, true, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, true, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, true, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, false, true, true, true, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, false, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, false, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, false, true, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, true, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, true, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, false, true, true, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, false, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, false, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, false, true, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, true, false, false, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, true, false, true, false}, - SamplingPostProcessing_Usecase{256, 64, {5, 10, 15}, true, true, true, true, false, false}), + 128, 64, {5, 10, 15}, false, false, false, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, false, false, false, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, false, false, false, true, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, false, false, true, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, false, false, true, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, false, false, true, true, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, false, true, false, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, false, true, false, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, false, true, false, true, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, false, true, true, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, false, true, true, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, false, true, true, true, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, true, false, false, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, true, false, false, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, true, false, false, true, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, true, false, true, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, true, false, true, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, true, false, true, true, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, true, true, false, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, true, true, false, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, true, true, false, true, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, true, true, true, false, false, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, true, true, true, false, true, false}, + SamplingPostProcessing_Usecase{128, 64, {5, 10, 15}, true, true, true, true, false, false}), ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false)))); CUGRAPH_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/sampling/sg_uniform_neighbor_sampling.cu b/cpp/tests/sampling/uniform_neighbor_sampling.cpp similarity index 69% rename from cpp/tests/sampling/sg_uniform_neighbor_sampling.cu rename to cpp/tests/sampling/uniform_neighbor_sampling.cpp index 1b038e2b6c4..6ab11cc6c29 100644 --- a/cpp/tests/sampling/sg_uniform_neighbor_sampling.cu +++ b/cpp/tests/sampling/uniform_neighbor_sampling.cpp @@ -14,21 +14,22 @@ * limitations under the License. */ -#include "detail/nbr_sampling_utils.cuh" +#include "detail/nbr_sampling_validate.hpp" +#include "utilities/base_fixture.hpp" +#include "utilities/property_generator_utilities.hpp" -#include - -#include -#include -#include +#include +#include #include struct Uniform_Neighbor_Sampling_Usecase { std::vector fanout{{-1}}; int32_t batch_size{10}; - bool check_correctness{true}; bool flag_replacement{true}; + + bool edge_masking{false}; + bool check_correctness{true}; }; template @@ -45,9 +46,11 @@ class Tests_Uniform_Neighbor_Sampling virtual void TearDown() {} template - void run_current_test(Uniform_Neighbor_Sampling_Usecase const& uniform_neighbor_sampling_usecase, - input_usecase_t const& input_usecase) + void run_current_test( + std::tuple const& param) { + auto [uniform_neighbor_sampling_usecase, input_usecase] = param; + raft::handle_t handle{}; HighResTimer hr_timer{}; @@ -70,11 +73,14 @@ class Tests_Uniform_Neighbor_Sampling auto edge_weight_view = edge_weights ? std::make_optional((*edge_weights).view()) : std::nullopt; - // - // Test is designed like GNN sampling. We'll select 90% of vertices - // to be included in sampling batches - // - constexpr float select_probability{0.9}; + std::optional> edge_mask{std::nullopt}; + if (uniform_neighbor_sampling_usecase.edge_masking) { + edge_mask = + cugraph::test::generate::edge_property(handle, graph_view, 2); + graph_view.attach_edge_mask((*edge_mask).view()); + } + + constexpr float select_probability{0.05}; // FIXME: Update the tests to initialize RngState and use it instead // of seed... @@ -104,22 +110,16 @@ class Tests_Uniform_Neighbor_Sampling float{1}, rng_state); - thrust::sort_by_key(handle.get_thrust_policy(), - random_numbers.begin(), - random_numbers.end(), - random_sources.begin()); + std::tie(random_numbers, random_sources) = cugraph::test::sort_by_key( + handle, std::move(random_numbers), std::move(random_sources)); random_numbers.resize(0, handle.get_stream()); random_numbers.shrink_to_fit(handle.get_stream()); - auto batch_number = - std::make_optional>(random_sources.size(), handle.get_stream()); + auto batch_number = std::make_optional>(0, handle.get_stream()); - thrust::tabulate(handle.get_thrust_policy(), - batch_number->begin(), - batch_number->end(), - [batch_size = uniform_neighbor_sampling_usecase.batch_size] __device__( - int32_t index) { return index / batch_size; }); + batch_number = cugraph::test::sequence( + handle, random_sources.size(), uniform_neighbor_sampling_usecase.batch_size, int32_t{0}); rmm::device_uvector random_sources_copy(random_sources.size(), handle.get_stream()); @@ -151,6 +151,11 @@ class Tests_Uniform_Neighbor_Sampling uniform_neighbor_sampling_usecase.flag_replacement), std::exception); #else + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("Uniform neighbor sampling"); + } + auto&& [src_out, dst_out, wgt_out, edge_id, edge_type, hop, labels, offsets] = cugraph::uniform_neighbor_sample( handle, @@ -169,6 +174,12 @@ class Tests_Uniform_Neighbor_Sampling true, uniform_neighbor_sampling_usecase.flag_replacement); + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + if (uniform_neighbor_sampling_usecase.check_correctness) { // First validate that the extracted edges are actually a subset of the // edges in the input graph @@ -176,10 +187,8 @@ class Tests_Uniform_Neighbor_Sampling raft::copy(vertices.data(), src_out.data(), src_out.size(), handle.get_stream()); raft::copy( vertices.data() + src_out.size(), dst_out.data(), dst_out.size(), handle.get_stream()); - thrust::sort(handle.get_thrust_policy(), vertices.begin(), vertices.end()); - auto vertices_end = - thrust::unique(handle.get_thrust_policy(), vertices.begin(), vertices.end()); - vertices.resize(thrust::distance(vertices.begin(), vertices_end), handle.get_stream()); + vertices = cugraph::test::sort(handle, std::move(vertices)); + vertices = cugraph::test::unique(handle, std::move(vertices)); rmm::device_uvector d_subgraph_offsets(2, handle.get_stream()); std::vector h_subgraph_offsets({0, vertices.size()}); @@ -201,17 +210,18 @@ class Tests_Uniform_Neighbor_Sampling raft::device_span(vertices.data(), vertices.size()), true); - cugraph::test::validate_extracted_graph_is_subgraph( - handle, src_compare, dst_compare, wgt_compare, src_out, dst_out, wgt_out); + ASSERT_TRUE(cugraph::test::validate_extracted_graph_is_subgraph( + handle, src_compare, dst_compare, wgt_compare, src_out, dst_out, wgt_out)); if (random_sources.size() < 100) { // This validation is too expensive for large number of vertices - cugraph::test::validate_sampling_depth(handle, - std::move(src_out), - std::move(dst_out), - std::move(wgt_out), - std::move(random_sources), - uniform_neighbor_sampling_usecase.fanout.size()); + ASSERT_TRUE( + cugraph::test::validate_sampling_depth(handle, + std::move(src_out), + std::move(dst_out), + std::move(wgt_out), + std::move(random_sources), + uniform_neighbor_sampling_usecase.fanout.size())); } } #endif @@ -226,46 +236,48 @@ using Tests_Uniform_Neighbor_Sampling_Rmat = TEST_P(Tests_Uniform_Neighbor_Sampling_File, CheckInt32Int32Float) { - auto param = GetParam(); - run_current_test(std::get<0>(param), std::get<1>(param)); + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); } TEST_P(Tests_Uniform_Neighbor_Sampling_File, CheckInt32Int64Float) { - auto param = GetParam(); - run_current_test(std::get<0>(param), std::get<1>(param)); + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); } TEST_P(Tests_Uniform_Neighbor_Sampling_File, CheckInt64Int64Float) { - auto param = GetParam(); - run_current_test(std::get<0>(param), std::get<1>(param)); + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); } TEST_P(Tests_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int32Float) { - auto param = GetParam(); - run_current_test(std::get<0>(param), std::get<1>(param)); + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); } TEST_P(Tests_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int64Float) { - auto param = GetParam(); - run_current_test(std::get<0>(param), std::get<1>(param)); + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); } TEST_P(Tests_Uniform_Neighbor_Sampling_Rmat, CheckInt64Int64Float) { - auto param = GetParam(); - run_current_test(std::get<0>(param), std::get<1>(param)); + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); } INSTANTIATE_TEST_SUITE_P( file_test, Tests_Uniform_Neighbor_Sampling_File, ::testing::Combine( - ::testing::Values(Uniform_Neighbor_Sampling_Usecase{{2}, 100, true, true}, - Uniform_Neighbor_Sampling_Usecase{{2}, 100, true, false}), + ::testing::Values(Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, false, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, false, true}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, true, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), @@ -275,7 +287,10 @@ INSTANTIATE_TEST_SUITE_P( rmat_small_test, Tests_Uniform_Neighbor_Sampling_Rmat, ::testing::Combine( - ::testing::Values(Uniform_Neighbor_Sampling_Usecase{{2}, 10, false, true}), + ::testing::Values(Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, false, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, false, true}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, true, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, true, true}), ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false, 0)))); INSTANTIATE_TEST_SUITE_P( @@ -286,7 +301,10 @@ INSTANTIATE_TEST_SUITE_P( factor (to avoid running same benchmarks more than once) */ Tests_Uniform_Neighbor_Sampling_Rmat, ::testing::Combine( - ::testing::Values(Uniform_Neighbor_Sampling_Usecase{{2}, 500, false, true}), + ::testing::Values(Uniform_Neighbor_Sampling_Usecase{{4, 10}, 1024, false, false, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 1024, false, true, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 1024, true, false, false}, + Uniform_Neighbor_Sampling_Usecase{{4, 10}, 1024, true, true, false}), ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false, 0)))); CUGRAPH_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/structure/induced_subgraph_test.cpp b/cpp/tests/structure/induced_subgraph_test.cpp index d8f55b65bda..73168819600 100644 --- a/cpp/tests/structure/induced_subgraph_test.cpp +++ b/cpp/tests/structure/induced_subgraph_test.cpp @@ -16,6 +16,7 @@ #include "structure/induced_subgraph_validate.hpp" #include "utilities/base_fixture.hpp" #include "utilities/conversion_utilities.hpp" +#include "utilities/property_generator_utilities.hpp" #include "utilities/test_graphs.hpp" #include @@ -89,6 +90,8 @@ extract_induced_subgraph_reference(std::vector const& offsets, struct InducedSubgraph_Usecase { std::vector subgraph_sizes{}; bool test_weighted{false}; + + bool edge_masking{false}; bool check_correctness{false}; }; @@ -134,6 +137,13 @@ class Tests_InducedSubgraph auto edge_weight_view = edge_weights ? std::make_optional((*edge_weights).view()) : std::nullopt; + std::optional> edge_mask{std::nullopt}; + if (induced_subgraph_usecase.edge_masking) { + edge_mask = + cugraph::test::generate::edge_property(handle, graph_view, 2); + graph_view.attach_edge_mask((*edge_mask).view()); + } + // Construct random subgraph vertex lists raft::random::RngState rng_state(0); @@ -267,12 +277,19 @@ INSTANTIATE_TEST_SUITE_P( karate_test, Tests_InducedSubgraph_File, ::testing::Combine( - ::testing::Values(InducedSubgraph_Usecase{std::vector{0}, false, true}, + ::testing::Values(InducedSubgraph_Usecase{std::vector{0}, false, false}, + InducedSubgraph_Usecase{std::vector{0}, false, true}, + InducedSubgraph_Usecase{std::vector{1}, false, false}, InducedSubgraph_Usecase{std::vector{1}, false, true}, + InducedSubgraph_Usecase{std::vector{10}, false, false}, InducedSubgraph_Usecase{std::vector{10}, false, true}, + InducedSubgraph_Usecase{std::vector{34}, false, false}, InducedSubgraph_Usecase{std::vector{34}, false, true}, + InducedSubgraph_Usecase{std::vector{10, 0, 5}, false, false}, InducedSubgraph_Usecase{std::vector{10, 0, 5}, false, true}, + InducedSubgraph_Usecase{std::vector{9, 3, 10}, false, false}, InducedSubgraph_Usecase{std::vector{9, 3, 10}, false, true}, + InducedSubgraph_Usecase{std::vector{5, 12, 13}, true, false}, InducedSubgraph_Usecase{std::vector{5, 12, 13}, true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); @@ -280,7 +297,9 @@ INSTANTIATE_TEST_SUITE_P( web_google_test, Tests_InducedSubgraph_File, ::testing::Combine( - ::testing::Values(InducedSubgraph_Usecase{std::vector{250, 130, 15}, false, true}, + ::testing::Values(InducedSubgraph_Usecase{std::vector{250, 130, 15}, false, false}, + InducedSubgraph_Usecase{std::vector{250, 130, 15}, false, true}, + InducedSubgraph_Usecase{std::vector{250, 130, 15}, true, false}, InducedSubgraph_Usecase{std::vector{125, 300, 70}, true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/web-Google.mtx")))); @@ -288,7 +307,9 @@ INSTANTIATE_TEST_SUITE_P( ljournal_2008_test, Tests_InducedSubgraph_File, ::testing::Combine( - ::testing::Values(InducedSubgraph_Usecase{std::vector{300, 20, 400}, false, true}, + ::testing::Values(InducedSubgraph_Usecase{std::vector{9130, 1200, 300}, false, false}, + InducedSubgraph_Usecase{std::vector{9130, 1200, 300}, false, true}, + InducedSubgraph_Usecase{std::vector{9130, 1200, 300}, true, false}, InducedSubgraph_Usecase{std::vector{9130, 1200, 300}, true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx")))); @@ -296,8 +317,10 @@ INSTANTIATE_TEST_SUITE_P( webbase_1M_test, Tests_InducedSubgraph_File, ::testing::Combine( - ::testing::Values(InducedSubgraph_Usecase{std::vector{700}, false, true}, - InducedSubgraph_Usecase{std::vector{500}, true, true}), + ::testing::Values(InducedSubgraph_Usecase{std::vector{700}, false, false}, + InducedSubgraph_Usecase{std::vector{700}, false, true}, + InducedSubgraph_Usecase{std::vector{700}, true, false}, + InducedSubgraph_Usecase{std::vector{700}, true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx")))); CUGRAPH_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/structure/mg_induced_subgraph_test.cu b/cpp/tests/structure/mg_induced_subgraph_test.cu index 2ed909b9955..9f8ef1debcf 100644 --- a/cpp/tests/structure/mg_induced_subgraph_test.cu +++ b/cpp/tests/structure/mg_induced_subgraph_test.cu @@ -19,6 +19,7 @@ #include "utilities/conversion_utilities.hpp" #include "utilities/device_comm_wrapper.hpp" #include "utilities/mg_utilities.hpp" +#include "utilities/property_generator_utilities.hpp" #include "utilities/test_graphs.hpp" #include @@ -43,6 +44,8 @@ struct InducedSubgraph_Usecase { std::vector subgraph_sizes{}; bool test_weighted{false}; + + bool edge_masking{false}; bool check_correctness{false}; }; @@ -89,6 +92,13 @@ class Tests_MGInducedSubgraph auto mg_edge_weight_view = mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt; + std::optional> edge_mask{std::nullopt}; + if (induced_subgraph_usecase.edge_masking) { + edge_mask = cugraph::test::generate::edge_property( + *handle_, mg_graph_view, 2); + mg_graph_view.attach_edge_mask((*edge_mask).view()); + } + int my_rank = handle_->get_comms().get_rank(); // Construct random subgraph vertex lists @@ -295,12 +305,19 @@ INSTANTIATE_TEST_SUITE_P( karate_test, Tests_MGInducedSubgraph_File, ::testing::Combine( - ::testing::Values(InducedSubgraph_Usecase{std::vector{0}, false, true}, + ::testing::Values(InducedSubgraph_Usecase{std::vector{0}, false, false}, + InducedSubgraph_Usecase{std::vector{0}, false, true}, + InducedSubgraph_Usecase{std::vector{1}, false, false}, InducedSubgraph_Usecase{std::vector{1}, false, true}, + InducedSubgraph_Usecase{std::vector{10}, false, false}, InducedSubgraph_Usecase{std::vector{10}, false, true}, + InducedSubgraph_Usecase{std::vector{34}, false, false}, InducedSubgraph_Usecase{std::vector{34}, false, true}, + InducedSubgraph_Usecase{std::vector{10, 0, 5}, false, false}, InducedSubgraph_Usecase{std::vector{10, 0, 5}, false, true}, + InducedSubgraph_Usecase{std::vector{9, 3, 10}, false, false}, InducedSubgraph_Usecase{std::vector{9, 3, 10}, false, true}, + InducedSubgraph_Usecase{std::vector{5, 12, 13}, true, false}, InducedSubgraph_Usecase{std::vector{5, 12, 13}, true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); @@ -308,7 +325,9 @@ INSTANTIATE_TEST_SUITE_P( web_google_test, Tests_MGInducedSubgraph_File, ::testing::Combine( - ::testing::Values(InducedSubgraph_Usecase{std::vector{250, 130, 15}, false, true}, + ::testing::Values(InducedSubgraph_Usecase{std::vector{250, 130, 15}, false, false}, + InducedSubgraph_Usecase{std::vector{250, 130, 15}, false, true}, + InducedSubgraph_Usecase{std::vector{125, 300, 70}, true, false}, InducedSubgraph_Usecase{std::vector{125, 300, 70}, true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/web-Google.mtx")))); @@ -316,7 +335,9 @@ INSTANTIATE_TEST_SUITE_P( ljournal_2008_test, Tests_MGInducedSubgraph_File, ::testing::Combine( - ::testing::Values(InducedSubgraph_Usecase{std::vector{300, 20, 400}, false, true}, + ::testing::Values(InducedSubgraph_Usecase{std::vector{300, 20, 400}, false, false}, + InducedSubgraph_Usecase{std::vector{300, 20, 400}, false, true}, + InducedSubgraph_Usecase{std::vector{9130, 1200, 300}, true, false}, InducedSubgraph_Usecase{std::vector{9130, 1200, 300}, true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx")))); @@ -324,7 +345,9 @@ INSTANTIATE_TEST_SUITE_P( webbase_1M_test, Tests_MGInducedSubgraph_File, ::testing::Combine( - ::testing::Values(InducedSubgraph_Usecase{std::vector{700}, false, true}, + ::testing::Values(InducedSubgraph_Usecase{std::vector{700}, false, false}, + InducedSubgraph_Usecase{std::vector{700}, false, true}, + InducedSubgraph_Usecase{std::vector{500}, true, false}, InducedSubgraph_Usecase{std::vector{500}, true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx")))); diff --git a/cpp/tests/traversal/bfs_test.cpp b/cpp/tests/traversal/bfs_test.cpp index fda80f1c191..8d3cdb3d24b 100644 --- a/cpp/tests/traversal/bfs_test.cpp +++ b/cpp/tests/traversal/bfs_test.cpp @@ -206,10 +206,12 @@ class Tests_BFS : public ::testing::TestWithParam d_unrenumbered_distances(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_distances) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_distances); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_distances); rmm::device_uvector d_unrenumbered_predecessors(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_predecessors) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_predecessors); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_predecessors); h_cugraph_distances = cugraph::test::to_host(handle, d_unrenumbered_distances); h_cugraph_predecessors = cugraph::test::to_host(handle, d_unrenumbered_predecessors); } else { diff --git a/cpp/tests/traversal/sssp_test.cpp b/cpp/tests/traversal/sssp_test.cpp index ee236e72cdc..3eff1a8e106 100644 --- a/cpp/tests/traversal/sssp_test.cpp +++ b/cpp/tests/traversal/sssp_test.cpp @@ -206,10 +206,12 @@ class Tests_SSSP : public ::testing::TestWithParam d_unrenumbered_distances(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_distances) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_distances); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_distances); rmm::device_uvector d_unrenumbered_predecessors(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_predecessors) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_predecessors); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_predecessors); h_cugraph_distances = cugraph::test::to_host(handle, d_unrenumbered_distances); h_cugraph_predecessors = cugraph::test::to_host(handle, d_unrenumbered_predecessors); diff --git a/cpp/tests/utilities/conversion_utilities_impl.cuh b/cpp/tests/utilities/conversion_utilities_impl.cuh index 748a5731b89..9f8fdcf6ed9 100644 --- a/cpp/tests/utilities/conversion_utilities_impl.cuh +++ b/cpp/tests/utilities/conversion_utilities_impl.cuh @@ -430,7 +430,8 @@ mg_vertex_property_values_to_sg_vertex_property_values( static_cast((*sg_renumber_map).size())); } - std::tie(sg_vertices, sg_values) = cugraph::test::sort_by_key(handle, sg_vertices, sg_values); + std::tie(sg_vertices, sg_values) = cugraph::test::sort_by_key( + handle, std::move(sg_vertices), std::move(sg_values)); if (mg_vertices) { return std::make_tuple(std::move(sg_vertices), std::move(sg_values)); diff --git a/cpp/tests/utilities/thrust_wrapper.cu b/cpp/tests/utilities/thrust_wrapper.cu index 93bb8a04e87..8d26ac1f2fe 100644 --- a/cpp/tests/utilities/thrust_wrapper.cu +++ b/cpp/tests/utilities/thrust_wrapper.cu @@ -16,10 +16,6 @@ #include "utilities/thrust_wrapper.hpp" -#include - -#include - #include #include @@ -31,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -38,12 +35,12 @@ namespace cugraph { namespace test { -template -value_buffer_type sort(raft::handle_t const& handle, value_buffer_type const& values) +template +cugraph::dataframe_buffer_type_t sort( + raft::handle_t const& handle, cugraph::dataframe_buffer_type_t const& values) { - auto sorted_values = - cugraph::allocate_dataframe_buffer>( - values.size(), handle.get_stream()); + auto sorted_values = cugraph::allocate_dataframe_buffer( + cugraph::size_dataframe_buffer(values), handle.get_stream()); thrust::copy(handle.get_thrust_policy(), cugraph::get_dataframe_buffer_begin(values), @@ -57,76 +54,85 @@ value_buffer_type sort(raft::handle_t const& handle, value_buffer_type const& va return sorted_values; } -template -std::tuple sort(raft::handle_t const& handle, - value_buffer_type const& first, - value_buffer_type const& second) +template rmm::device_uvector sort(raft::handle_t const& handle, + rmm::device_uvector const& values); + +template rmm::device_uvector sort(raft::handle_t const& handle, + rmm::device_uvector const& values); + +template +cugraph::dataframe_buffer_type_t sort(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t&& values) +{ + auto sorted_values = std::move(values); + + thrust::sort(handle.get_thrust_policy(), + cugraph::get_dataframe_buffer_begin(sorted_values), + cugraph::get_dataframe_buffer_end(sorted_values)); + + return sorted_values; +} + +template rmm::device_uvector sort(raft::handle_t const& handle, + rmm::device_uvector&& values); + +template rmm::device_uvector sort(raft::handle_t const& handle, + rmm::device_uvector&& values); + +template +std::tuple, cugraph::dataframe_buffer_type_t> +sort(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t const& first, + cugraph::dataframe_buffer_type_t const& second) { auto sorted_first = - cugraph::allocate_dataframe_buffer>( - first.size(), handle.get_stream()); + cugraph::allocate_dataframe_buffer(size_dataframe_buffer(first), handle.get_stream()); auto sorted_second = - cugraph::allocate_dataframe_buffer>( - first.size(), handle.get_stream()); - - auto execution_policy = handle.get_thrust_policy(); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(first), - cugraph::get_dataframe_buffer_end(first), - cugraph::get_dataframe_buffer_begin(sorted_first)); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(second), - cugraph::get_dataframe_buffer_end(second), - cugraph::get_dataframe_buffer_begin(sorted_second)); + cugraph::allocate_dataframe_buffer(size_dataframe_buffer(first), handle.get_stream()); + + auto input_first = thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(first), + cugraph::get_dataframe_buffer_begin(second)); + auto output_first = thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_first), + cugraph::get_dataframe_buffer_begin(sorted_second)); + thrust::copy(handle.get_thrust_policy(), + input_first, + input_first + size_dataframe_buffer(first), + output_first); thrust::sort( - execution_policy, - thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_first), - cugraph::get_dataframe_buffer_begin(sorted_second)), - thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_first) + first.size(), - cugraph::get_dataframe_buffer_begin(sorted_second) + first.size())); + handle.get_thrust_policy(), output_first, output_first + size_dataframe_buffer(sorted_first)); return std::make_tuple(std::move(sorted_first), std::move(sorted_second)); } -template rmm::device_uvector sort(raft::handle_t const& handle, - rmm::device_uvector const& values); - -template rmm::device_uvector sort(raft::handle_t const& handle, - rmm::device_uvector const& values); - -template std::tuple, rmm::device_uvector> sort( +template std::tuple, rmm::device_uvector> sort( raft::handle_t const& handle, rmm::device_uvector const& first, rmm::device_uvector const& second); -template std::tuple, rmm::device_uvector> sort( +template std::tuple, rmm::device_uvector> sort( raft::handle_t const& handle, rmm::device_uvector const& first, rmm::device_uvector const& second); -template -std::tuple sort_by_key(raft::handle_t const& handle, - key_buffer_type const& keys, - value_buffer_type const& values) +template +std::tuple, cugraph::dataframe_buffer_type_t> +sort_by_key(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t const& keys, + cugraph::dataframe_buffer_type_t const& values) { auto sorted_keys = - cugraph::allocate_dataframe_buffer>( - keys.size(), handle.get_stream()); + cugraph::allocate_dataframe_buffer(size_dataframe_buffer(keys), handle.get_stream()); auto sorted_values = - cugraph::allocate_dataframe_buffer>( - keys.size(), handle.get_stream()); - - auto execution_policy = handle.get_thrust_policy(); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(keys), - cugraph::get_dataframe_buffer_end(keys), - cugraph::get_dataframe_buffer_begin(sorted_keys)); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(values), - cugraph::get_dataframe_buffer_end(values), - cugraph::get_dataframe_buffer_begin(sorted_values)); + cugraph::allocate_dataframe_buffer(size_dataframe_buffer(keys), handle.get_stream()); - thrust::sort_by_key(execution_policy, + auto input_first = thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(keys), + cugraph::get_dataframe_buffer_begin(values)); + thrust::copy(handle.get_thrust_policy(), + input_first, + input_first + size_dataframe_buffer(keys), + thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_keys), + cugraph::get_dataframe_buffer_begin(sorted_values))); + thrust::sort_by_key(handle.get_thrust_policy(), cugraph::get_dataframe_buffer_begin(sorted_keys), cugraph::get_dataframe_buffer_end(sorted_keys), cugraph::get_dataframe_buffer_begin(sorted_values)); @@ -134,93 +140,179 @@ std::tuple sort_by_key(raft::handle_t const& return std::make_tuple(std::move(sorted_keys), std::move(sorted_values)); } -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); - -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); - -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); - -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); - -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); -template std::tuple, rmm::device_uvector> sort_by_key( +template std::tuple, + std::tuple, rmm::device_uvector>> +sort_by_key>( raft::handle_t const& handle, rmm::device_uvector const& keys, - rmm::device_uvector const& values); + std::tuple, rmm::device_uvector> const& values); -template std::tuple, rmm::device_uvector> sort_by_key( +template std::tuple, + std::tuple, rmm::device_uvector>> +sort_by_key>( raft::handle_t const& handle, rmm::device_uvector const& keys, - rmm::device_uvector const& values); + std::tuple, rmm::device_uvector> const& values); -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); +template +std::tuple, cugraph::dataframe_buffer_type_t> +sort_by_key(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t&& keys, + cugraph::dataframe_buffer_type_t&& values) +{ + auto sorted_keys = std::move(keys); + auto sorted_values = std::move(values); -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); + thrust::sort_by_key(handle.get_thrust_policy(), + cugraph::get_dataframe_buffer_begin(sorted_keys), + cugraph::get_dataframe_buffer_end(sorted_keys), + cugraph::get_dataframe_buffer_begin(sorted_values)); -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); + return std::make_tuple(std::move(sorted_keys), std::move(sorted_values)); +} -template -std::tuple sort_by_key( - raft::handle_t const& handle, - key_buffer_type const& first, - key_buffer_type const& second, - value_buffer_type const& values) +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template +std::tuple, + cugraph::dataframe_buffer_type_t, + cugraph::dataframe_buffer_type_t> +sort_by_key(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t const& first, + cugraph::dataframe_buffer_type_t const& second, + cugraph::dataframe_buffer_type_t const& values) { - auto sorted_first = - cugraph::allocate_dataframe_buffer>( - first.size(), handle.get_stream()); - auto sorted_second = - cugraph::allocate_dataframe_buffer>( - first.size(), handle.get_stream()); - auto sorted_values = - cugraph::allocate_dataframe_buffer>( - first.size(), handle.get_stream()); - - auto execution_policy = handle.get_thrust_policy(); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(first), - cugraph::get_dataframe_buffer_end(first), - cugraph::get_dataframe_buffer_begin(sorted_first)); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(second), - cugraph::get_dataframe_buffer_end(second), - cugraph::get_dataframe_buffer_begin(sorted_second)); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(values), - cugraph::get_dataframe_buffer_end(values), - cugraph::get_dataframe_buffer_begin(sorted_values)); - thrust::sort_by_key( - execution_policy, + auto sorted_first = cugraph::allocate_dataframe_buffer( + cugraph::size_dataframe_buffer(first), handle.get_stream()); + auto sorted_second = cugraph::allocate_dataframe_buffer( + cugraph::size_dataframe_buffer(first), handle.get_stream()); + auto sorted_values = cugraph::allocate_dataframe_buffer( + cugraph::size_dataframe_buffer(first), handle.get_stream()); + + auto input_first = thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(first), + cugraph::get_dataframe_buffer_begin(second), + cugraph::get_dataframe_buffer_begin(values)); + thrust::copy(handle.get_thrust_policy(), + input_first, + input_first + size_dataframe_buffer(first), + thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_first), + cugraph::get_dataframe_buffer_begin(sorted_second), + cugraph::get_dataframe_buffer_begin(sorted_values))); + auto sorted_key_first = thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_first), - cugraph::get_dataframe_buffer_begin(sorted_second)), - thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_first) + first.size(), - cugraph::get_dataframe_buffer_begin(sorted_second) + first.size()), - cugraph::get_dataframe_buffer_begin(sorted_values)); + cugraph::get_dataframe_buffer_begin(sorted_second)); + thrust::sort_by_key(handle.get_thrust_policy(), + sorted_key_first, + sorted_key_first + cugraph::size_dataframe_buffer(sorted_first), + cugraph::get_dataframe_buffer_begin(sorted_values)); return std::make_tuple( std::move(sorted_first), std::move(sorted_second), std::move(sorted_values)); @@ -228,43 +320,109 @@ std::tuple sort_by_key( template std:: tuple, rmm::device_uvector, rmm::device_uvector> - sort_by_key(raft::handle_t const& handle, - rmm::device_uvector const& first, - rmm::device_uvector const& second, - rmm::device_uvector const& values); + sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& first, + rmm::device_uvector const& second, + rmm::device_uvector const& values); template std:: tuple, rmm::device_uvector, rmm::device_uvector> - sort_by_key(raft::handle_t const& handle, - rmm::device_uvector const& first, - rmm::device_uvector const& second, - rmm::device_uvector const& values); + sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& first, + rmm::device_uvector const& second, + rmm::device_uvector const& values); template std:: tuple, rmm::device_uvector, rmm::device_uvector> - sort_by_key(raft::handle_t const& handle, - rmm::device_uvector const& first, - rmm::device_uvector const& second, - rmm::device_uvector const& values); + sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& first, + rmm::device_uvector const& second, + rmm::device_uvector const& values); template std:: tuple, rmm::device_uvector, rmm::device_uvector> - sort_by_key(raft::handle_t const& handle, - rmm::device_uvector const& first, - rmm::device_uvector const& second, - rmm::device_uvector const& values); + sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& first, + rmm::device_uvector const& second, + rmm::device_uvector const& values); + +template +cugraph::dataframe_buffer_type_t unique(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t&& values) +{ + auto last = thrust::unique(handle.get_thrust_policy(), + cugraph::get_dataframe_buffer_begin(values), + cugraph::get_dataframe_buffer_end(values)); + cugraph::resize_dataframe_buffer( + values, + thrust::distance(cugraph::get_dataframe_buffer_begin(values), last), + handle.get_stream()); + cugraph::shrink_to_fit_dataframe_buffer(values, handle.get_stream()); + + return std::move(values); +} -template std::tuple, - std::tuple, rmm::device_uvector>> -sort_by_key(raft::handle_t const& handle, - rmm::device_uvector const& keys, - std::tuple, rmm::device_uvector> const& values); +template rmm::device_uvector unique(raft::handle_t const& handle, + rmm::device_uvector&& values); -template std::tuple, - std::tuple, rmm::device_uvector>> -sort_by_key(raft::handle_t const& handle, - rmm::device_uvector const& keys, - std::tuple, rmm::device_uvector> const& values); +template rmm::device_uvector unique(raft::handle_t const& handle, + rmm::device_uvector&& values); + +template +cugraph::dataframe_buffer_type_t sequence(raft::handle_t const& handle, + size_t length, + size_t repeat_count, + value_t init) +{ + auto values = cugraph::allocate_dataframe_buffer(length, handle.get_stream()); + if (repeat_count == 1) { + thrust::sequence(handle.get_thrust_policy(), values.begin(), values.end(), init); + } else { + thrust::tabulate(handle.get_thrust_policy(), + values.begin(), + values.end(), + [repeat_count, init] __device__(size_t i) { + return init + static_cast(i / repeat_count); + }); + } + + return values; +} + +template rmm::device_uvector sequence(raft::handle_t const& handle, + size_t length, + size_t repeat_count, + int32_t init); + +template rmm::device_uvector sequence(raft::handle_t const& handle, + size_t length, + size_t repeat_count, + int64_t init); + +template +cugraph::dataframe_buffer_type_t modulo_sequence(raft::handle_t const& handle, + size_t length, + value_t modulo, + value_t init) +{ + auto values = cugraph::allocate_dataframe_buffer(length, handle.get_stream()); + thrust::tabulate( + handle.get_thrust_policy(), values.begin(), values.end(), [modulo, init] __device__(size_t i) { + return static_cast((init + i) % modulo); + }); + + return values; +} + +template rmm::device_uvector modulo_sequence(raft::handle_t const& handle, + size_t length, + int32_t modulo, + int32_t init); + +template rmm::device_uvector modulo_sequence(raft::handle_t const& handle, + size_t length, + int64_t modulo, + int64_t init); template vertex_t max_element(raft::handle_t const& handle, raft::device_span vertices) diff --git a/cpp/tests/utilities/thrust_wrapper.hpp b/cpp/tests/utilities/thrust_wrapper.hpp index c4b87126f50..cd8bc33308f 100644 --- a/cpp/tests/utilities/thrust_wrapper.hpp +++ b/cpp/tests/utilities/thrust_wrapper.hpp @@ -15,6 +15,8 @@ */ #pragma once +#include + #include #include @@ -26,25 +28,57 @@ namespace cugraph { namespace test { -template -value_buffer_type sort(raft::handle_t const& handle, value_buffer_type const& values); - -template -std::tuple sort(raft::handle_t const& handle, - value_buffer_type const& first, - value_buffer_type const& second); - -template -std::tuple sort_by_key(raft::handle_t const& handle, - key_buffer_type const& keys, - value_buffer_type const& values); - -template -std::tuple sort_by_key( - raft::handle_t const& handle, - key_buffer_type const& first, - key_buffer_type const& second, - value_buffer_type const& values); +template +cugraph::dataframe_buffer_type_t sort( + raft::handle_t const& handle, cugraph::dataframe_buffer_type_t const& values); + +template +cugraph::dataframe_buffer_type_t sort(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t&& values); + +template +std::tuple, cugraph::dataframe_buffer_type_t> +sort(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t const& first, + cugraph::dataframe_buffer_type_t const& second); + +template +std::tuple, cugraph::dataframe_buffer_type_t> +sort_by_key(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t const& keys, + cugraph::dataframe_buffer_type_t const& values); + +template +std::tuple, cugraph::dataframe_buffer_type_t> +sort_by_key(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t&& keys, + cugraph::dataframe_buffer_type_t&& values); + +template +std::tuple, + cugraph::dataframe_buffer_type_t, + cugraph::dataframe_buffer_type_t> +sort_by_key(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t const& first, + cugraph::dataframe_buffer_type_t const& second, + cugraph::dataframe_buffer_type_t const& values); + +template +cugraph::dataframe_buffer_type_t unique( + raft::handle_t const& handle, cugraph::dataframe_buffer_type_t&& values); + +template +cugraph::dataframe_buffer_type_t sequence(raft::handle_t const& handle, + size_t length, + size_t repeat_count, + value_t init); + +// return (init + i) % modulo, where i = [0, length) +template +cugraph::dataframe_buffer_type_t modulo_sequence(raft::handle_t const& handle, + size_t length, + value_t modulo, + value_t init); template vertex_t max_element(raft::handle_t const& handle, raft::device_span vertices);