Skip to content

Fix OOB error, BFS C API should validate that the source vertex is a valid vertex #4077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 12, 2024
17 changes: 16 additions & 1 deletion cpp/include/cugraph/detail/utility_wrappers.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -174,5 +174,20 @@ bool is_equal(raft::handle_t const& handle,
raft::device_span<data_t> span1,
raft::device_span<data_t> span2);

/**
* @brief Count the number of times a value appears in a span
*
* @tparam data_t type of data in span
* @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and
* handles to various CUDA libraries) to run graph algorithms.
* @param span The span of data to compare
* @param value The value to count
* @return The count of how many instances of that value occur
*/
template <typename data_t>
size_t count_values(raft::handle_t const& handle,
raft::device_span<data_t const> span,
data_t value);

} // namespace detail
} // namespace cugraph
12 changes: 9 additions & 3 deletions cpp/src/c_api/abstract_functor.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,8 +32,14 @@ struct abstract_functor {

void unsupported()
{
error_code_ = CUGRAPH_UNSUPPORTED_TYPE_COMBINATION;
error_->error_message_ = "Type Dispatcher executing unsupported combination of types";
mark_error(CUGRAPH_UNSUPPORTED_TYPE_COMBINATION,
"Type Dispatcher executing unsupported combination of types");
}

void mark_error(cugraph_error_code_t error_code, std::string const& error_message)
{
error_code_ = error_code;
error_->error_message_ = error_message;
}
};

Expand Down
17 changes: 16 additions & 1 deletion cpp/src/c_api/bfs.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -113,6 +113,21 @@ struct bfs_functor : public abstract_functor {
graph_view.local_vertex_partition_range_last(),
do_expensive_check_);

size_t invalid_count = cugraph::detail::count_values(
handle_,
raft::device_span<vertex_t const>{sources.data(), sources.size()},
cugraph::invalid_vertex_id<vertex_t>::value);

if constexpr (multi_gpu) {
invalid_count = cugraph::host_scalar_allreduce(
handle_.get_comms(), invalid_count, raft::comms::op_t::SUM, handle_.get_stream());
}

if (invalid_count != 0) {
mark_error(CUGRAPH_INVALID_INPUT, "Found invalid vertex in the input sources");
return;
}

cugraph::bfs<vertex_t, edge_t, multi_gpu>(
handle_,
graph_view,
Expand Down
19 changes: 18 additions & 1 deletion cpp/src/detail/utility_wrappers.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,11 +15,13 @@
*/
#include <cugraph/detail/utility_wrappers.hpp>
#include <cugraph/utilities/error.hpp>
#include <cugraph/utilities/host_scalar_comm.hpp>

#include <raft/random/rng.cuh>

#include <rmm/exec_policy.hpp>

#include <thrust/count.h>
#include <thrust/distance.h>
#include <thrust/functional.h>
#include <thrust/iterator/zip_iterator.h>
Expand Down Expand Up @@ -227,5 +229,20 @@ template bool is_equal(raft::handle_t const& handle,
raft::device_span<int64_t const> span1,
raft::device_span<int64_t const> span2);

template <typename data_t>
size_t count_values(raft::handle_t const& handle,
raft::device_span<data_t const> span,
data_t value)
{
return thrust::count(handle.get_thrust_policy(), span.begin(), span.end(), value);
}

template size_t count_values<int32_t>(raft::handle_t const& handle,
raft::device_span<int32_t const> span,
int32_t value);
template size_t count_values<int64_t>(raft::handle_t const& handle,
raft::device_span<int64_t const> span,
int64_t value);

} // namespace detail
} // namespace cugraph