Skip to content

Commit

Permalink
Bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
G-Cornett committed Jul 31, 2024
1 parent a55d88c commit 597530b
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions cpp/src/sampling/random_walks_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ struct biased_sample_edges_op_t {
}
};

template <typename vertex_t, typename bias_t>
template <typename vertex_t, typename bias_t, typename weight_t>
struct node2vec_random_walk_e_bias_op_t {
bias_t p_;
bias_t q_;
Expand All @@ -106,7 +106,7 @@ struct node2vec_random_walk_e_bias_op_t {
raft::device_span<vertex_t const> prev_vertices_{};

// Unweighted Bias Operator
template <typename W = bias_t>
template <typename W = weight_t>
__device__ std::enable_if_t<std::is_same_v<W, void>, bias_t> operator()(
thrust::tuple<vertex_t, vertex_t> tagged_src,
vertex_t dst,
Expand Down Expand Up @@ -143,7 +143,7 @@ struct node2vec_random_walk_e_bias_op_t {
}

// Weighted Biase Operator
template <typename W = bias_t>
template <typename W = weight_t>
__device__ std::enable_if_t<!std::is_same_v<W, void>, bias_t> operator()(
thrust::tuple<vertex_t, vertex_t> tagged_src,
vertex_t dst,
Expand Down Expand Up @@ -302,7 +302,7 @@ struct biased_selector {

// Create data structs for results
rmm::device_uvector<vertex_t> minors(0, handle.get_stream());
rmm::device_uvector<weight_t> weights(0, handle.get_stream());
rmm::device_uvector<weight_t> weights(0, handle.get_stream());

auto vertex_weight_sum = compute_out_weight_sums(handle, graph_view, *edge_weight_view);
edge_src_property_t<GraphViewType, weight_t> edge_src_out_weight_sums(handle, graph_view);
Expand Down Expand Up @@ -330,7 +330,6 @@ struct biased_selector {

// Return results
return std::make_tuple(std::move(minors), std::move(weights));

}
};

Expand Down Expand Up @@ -395,7 +394,7 @@ struct node2vec_selector {
cugraph::edge_src_dummy_property_t{}.view(),
cugraph::edge_dst_dummy_property_t{}.view(),
*edge_weight_view,
node2vec_random_walk_e_bias_op_t<vertex_t, weight_t>{
node2vec_random_walk_e_bias_op_t<vertex_t, weight_t, weight_t>{
p_,
q_,
raft::device_span<size_t const>(intersection_offsets.data(), intersection_offsets.size()),
Expand Down Expand Up @@ -423,7 +422,7 @@ struct node2vec_selector {
cugraph::edge_src_dummy_property_t{}.view(),
cugraph::edge_dst_dummy_property_t{}.view(),
cugraph::edge_dummy_property_t{}.view(),
node2vec_random_walk_e_bias_op_t<vertex_t, weight_t>{
node2vec_random_walk_e_bias_op_t<vertex_t, weight_t, void>{
p_,
q_,
raft::device_span<size_t const>(intersection_offsets.data(), intersection_offsets.size()),
Expand Down Expand Up @@ -572,18 +571,23 @@ random_walk_impl(raft::handle_t const& handle,
// Sort for nbr_intersection, must sort all together
if (previous_vertices) {
if constexpr (multi_gpu){
thrust::sort_by_key(handle.get_thrust_policy(),
current_vertices.begin(),
current_vertices.end(),
thrust::make_zip_iterator(current_position.begin(),
current_gpu.begin(),
(*previous_vertices).begin()));
thrust::sort(handle.get_thrust_policy(),
thrust::make_zip_iterator(current_vertices.begin(),
(*previous_vertices).begin(),
current_position.begin(),
current_gpu.begin()),
thrust::make_zip_iterator(current_vertices.end(),
(*previous_vertices).end(),
current_position.end(),
current_gpu.end()));
} else {
thrust::sort_by_key(handle.get_thrust_policy(),
current_vertices.begin(),
current_vertices.end(),
thrust::make_zip_iterator(current_position.begin(),
(*previous_vertices).begin()));
thrust::sort(handle.get_thrust_policy(),
thrust::make_zip_iterator(current_vertices.begin(),
(*previous_vertices).begin(),
current_position.begin()),
thrust::make_zip_iterator(current_vertices.end(),
(*previous_vertices).end(),
current_position.end()));
}
}

Expand Down Expand Up @@ -683,7 +687,7 @@ random_walk_impl(raft::handle_t const& handle,
auto input_iter = thrust::make_zip_iterator(
current_vertices.begin(), new_weights->begin(), current_position.begin());

auto compacted_length = thrust::distance(
compacted_length = thrust::distance(
input_iter,
thrust::remove_if(handle.get_thrust_policy(),
input_iter,
Expand Down

0 comments on commit 597530b

Please sign in to comment.