Skip to content

Commit

Permalink
Fix OOB error, BFS C API should validate that the source vertex is a …
Browse files Browse the repository at this point in the history
…valid vertex (#4077)

* Added error check to be sure that the source vertex is a valid vertex in the graph.
* Updated `nx_cugraph.Graph` class to create PLC graphs using `vertices_array` in order to include isolated vertices. This is now needed since the error check added in this PR prevents NetworkX tests from passing if isolated vertices are treated as invalid, so this change prevents that.
* This resolves the problem that required the test workarounds done [here](#4029 (comment)) in [4029](#4029), so those workarounds have been removed in this PR.

Closes #4067

Authors:
  - Chuck Hastings (https://github.com/ChuckHastings)
  - Rick Ratzel (https://github.com/rlratzel)

Approvers:
  - Seunghwa Kang (https://github.com/seunghwak)
  - Ray Douglass (https://github.com/raydouglass)
  - Erik Welch (https://github.com/eriknw)

URL: #4077
  • Loading branch information
ChuckHastings authored Jan 12, 2024
1 parent c09db10 commit 24d02a5
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 48 deletions.
5 changes: 0 additions & 5 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,6 @@ popd
rapids-logger "pytest networkx using nx-cugraph backend"
pushd python/nx-cugraph
./run_nx_tests.sh
# Individually run tests that are skipped above b/c they may run out of memory
PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestDAG and test_antichains"
PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestMultiDiGraph_DAGLCA and test_all_pairs_lca_pairs_without_lca"
PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestDAGLCA and test_all_pairs_lca_pairs_without_lca"
PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestEfficiency and test_using_ego_graph"
# run_nx_tests.sh outputs coverage data, so check that total coverage is >0.0%
# in case nx-cugraph failed to load but fallback mode allowed the run to pass.
_coverage=$(coverage report|grep "^TOTAL")
Expand Down
17 changes: 16 additions & 1 deletion cpp/include/cugraph/detail/utility_wrappers.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -174,5 +174,20 @@ bool is_equal(raft::handle_t const& handle,
raft::device_span<data_t> span1,
raft::device_span<data_t> span2);

/**
* @brief Count the number of times a value appears in a span
*
* @tparam data_t type of data in span
* @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and
* handles to various CUDA libraries) to run graph algorithms.
* @param span The span of data to compare
* @param value The value to count
* @return The count of how many instances of that value occur
*/
template <typename data_t>
size_t count_values(raft::handle_t const& handle,
raft::device_span<data_t const> span,
data_t value);

} // namespace detail
} // namespace cugraph
12 changes: 9 additions & 3 deletions cpp/src/c_api/abstract_functor.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,8 +32,14 @@ struct abstract_functor {

void unsupported()
{
error_code_ = CUGRAPH_UNSUPPORTED_TYPE_COMBINATION;
error_->error_message_ = "Type Dispatcher executing unsupported combination of types";
mark_error(CUGRAPH_UNSUPPORTED_TYPE_COMBINATION,
"Type Dispatcher executing unsupported combination of types");
}

void mark_error(cugraph_error_code_t error_code, std::string const& error_message)
{
error_code_ = error_code;
error_->error_message_ = error_message;
}
};

Expand Down
17 changes: 16 additions & 1 deletion cpp/src/c_api/bfs.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -113,6 +113,21 @@ struct bfs_functor : public abstract_functor {
graph_view.local_vertex_partition_range_last(),
do_expensive_check_);

size_t invalid_count = cugraph::detail::count_values(
handle_,
raft::device_span<vertex_t const>{sources.data(), sources.size()},
cugraph::invalid_vertex_id<vertex_t>::value);

if constexpr (multi_gpu) {
invalid_count = cugraph::host_scalar_allreduce(
handle_.get_comms(), invalid_count, raft::comms::op_t::SUM, handle_.get_stream());
}

if (invalid_count != 0) {
mark_error(CUGRAPH_INVALID_INPUT, "Found invalid vertex in the input sources");
return;
}

cugraph::bfs<vertex_t, edge_t, multi_gpu>(
handle_,
graph_view,
Expand Down
19 changes: 18 additions & 1 deletion cpp/src/detail/utility_wrappers.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,11 +15,13 @@
*/
#include <cugraph/detail/utility_wrappers.hpp>
#include <cugraph/utilities/error.hpp>
#include <cugraph/utilities/host_scalar_comm.hpp>

#include <raft/random/rng.cuh>

#include <rmm/exec_policy.hpp>

#include <thrust/count.h>
#include <thrust/distance.h>
#include <thrust/functional.h>
#include <thrust/iterator/zip_iterator.h>
Expand Down Expand Up @@ -227,5 +229,20 @@ template bool is_equal(raft::handle_t const& handle,
raft::device_span<int64_t const> span1,
raft::device_span<int64_t const> span2);

template <typename data_t>
size_t count_values(raft::handle_t const& handle,
raft::device_span<data_t const> span,
data_t value)
{
return thrust::count(handle.get_thrust_policy(), span.begin(), span.end(), value);
}

template size_t count_values<int32_t>(raft::handle_t const& handle,
raft::device_span<int32_t const> span,
int32_t value);
template size_t count_values<int64_t>(raft::handle_t const& handle,
raft::device_span<int64_t const> span,
int64_t value);

} // namespace detail
} // namespace cugraph
17 changes: 16 additions & 1 deletion python/nx-cugraph/nx_cugraph/classes/graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -65,6 +65,7 @@ class Graph:
key_to_id: dict[NodeKey, IndexValue] | None
_id_to_key: list[NodeKey] | None
_N: int
_node_ids: cp.ndarray[IndexValue] | None # holds plc.SGGraph.vertices_array data

# Used by graph._get_plc_graph
_plc_type_map: ClassVar[dict[np.dtype, np.dtype]] = {
Expand Down Expand Up @@ -116,6 +117,7 @@ def from_coo(
new_graph.key_to_id = None if key_to_id is None else dict(key_to_id)
new_graph._id_to_key = None if id_to_key is None else list(id_to_key)
new_graph._N = op.index(N) # Ensure N is integral
new_graph._node_ids = None
new_graph.graph = new_graph.graph_attr_dict_factory()
new_graph.graph.update(attr)
size = new_graph.src_indices.size
Expand Down Expand Up @@ -157,6 +159,16 @@ def from_coo(
f"(got {new_graph.dst_indices.dtype.name})."
)
new_graph.dst_indices = dst_indices

# If the graph contains isolates, plc.SGGraph() must be passed a value
# for vertices_array that contains every vertex ID, since the
# src/dst_indices arrays will not contain IDs for isolates. Create this
# only if needed. Like src/dst_indices, the _node_ids array must be
# maintained for the lifetime of the plc.SGGraph
isolates = nxcg.algorithms.isolate._isolates(new_graph)
if len(isolates) > 0:
new_graph._node_ids = cp.arange(new_graph._N, dtype=index_dtype)

return new_graph

@classmethod
Expand Down Expand Up @@ -405,6 +417,7 @@ def clear(self) -> None:
self.src_indices = cp.empty(0, self.src_indices.dtype)
self.dst_indices = cp.empty(0, self.dst_indices.dtype)
self._N = 0
self._node_ids = None
self.key_to_id = None
self._id_to_key = None

Expand Down Expand Up @@ -637,6 +650,7 @@ def _get_plc_graph(
dst_indices = self.dst_indices
if switch_indices:
src_indices, dst_indices = dst_indices, src_indices

return plc.SGGraph(
resource_handle=plc.ResourceHandle(),
graph_properties=plc.GraphProperties(
Expand All @@ -649,6 +663,7 @@ def _get_plc_graph(
store_transposed=store_transposed,
renumber=False,
do_expensive_check=False,
vertices_array=self._node_ids,
)

def _sort_edge_indices(self, primary="src"):
Expand Down
13 changes: 1 addition & 12 deletions python/nx-cugraph/nx_cugraph/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -242,20 +242,9 @@ def key(testpath):
)

too_slow = "Too slow to run"
maybe_oom = "out of memory in CI"
skip = {
key("test_tree_isomorphism.py:test_positive"): too_slow,
key("test_tree_isomorphism.py:test_negative"): too_slow,
key("test_efficiency.py:TestEfficiency.test_using_ego_graph"): maybe_oom,
key("test_dag.py:TestDAG.test_antichains"): maybe_oom,
key(
"test_lowest_common_ancestors.py:"
"TestDAGLCA.test_all_pairs_lca_pairs_without_lca"
): maybe_oom,
key(
"test_lowest_common_ancestors.py:"
"TestMultiDiGraph_DAGLCA.test_all_pairs_lca_pairs_without_lca"
): maybe_oom,
# These repeatedly call `bfs_layers`, which converts the graph every call
key(
"test_vf2pp.py:TestGraphISOVF2pp.test_custom_graph2_different_labels"
Expand Down
Loading

0 comments on commit 24d02a5

Please sign in to comment.