Skip to content

Commit 19b1200

Browse files
committed
finish debugging... restructure code to match if structure in nx to make it easier to compare/debug, disabled python test again since nx release is not out
1 parent b50c349 commit 19b1200

File tree

3 files changed

+42
-49
lines changed

3 files changed

+42
-49
lines changed

cpp/src/centrality/betweenness_centrality_impl.cuh

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -540,65 +540,54 @@ rmm::device_uvector<weight_t> betweenness_centrality(
540540
do_expensive_check);
541541
}
542542

543-
std::optional<weight_t> scale_factor{std::nullopt};
543+
std::optional<weight_t> scale_nonsource{std::nullopt};
544544
std::optional<weight_t> scale_source{std::nullopt};
545545

546-
if (normalized) {
547-
if (include_endpoints) {
548-
if (graph_view.number_of_vertices() >= 2) {
549-
scale_factor = static_cast<weight_t>(
550-
std::min(static_cast<vertex_t>(num_sources), graph_view.number_of_vertices()) *
551-
(graph_view.number_of_vertices() - 1));
552-
scale_source = scale_factor;
553-
}
554-
} else if (static_cast<edge_t>(num_sources) == graph_view.number_of_vertices()) {
555-
scale_factor = static_cast<weight_t>(
556-
std::min(static_cast<vertex_t>(num_sources), graph_view.number_of_vertices() - 1) *
557-
(graph_view.number_of_vertices() - 2));
558-
scale_source = scale_factor;
559-
} else if (graph_view.number_of_vertices() > 2) {
560-
scale_factor = static_cast<weight_t>(
561-
std::min(static_cast<vertex_t>(num_sources), graph_view.number_of_vertices() - 1) *
562-
(graph_view.number_of_vertices() - 2));
563-
scale_source = static_cast<weight_t>(
564-
(std::min(static_cast<vertex_t>(num_sources), graph_view.number_of_vertices() - 1) - 1) *
565-
(graph_view.number_of_vertices() - 2));
546+
if ((static_cast<edge_t>(num_sources) == graph_view.number_of_vertices()) || include_endpoints) {
547+
if (normalized) {
548+
scale_nonsource = static_cast<weight_t>(num_sources * (graph_view.number_of_vertices() - 1));
549+
} else if (graph_view.is_symmetric()) {
550+
scale_nonsource = static_cast<weight_t>(num_sources * 2) /
551+
static_cast<weight_t>(graph_view.number_of_vertices());
552+
} else {
553+
scale_nonsource =
554+
static_cast<weight_t>(num_sources) / static_cast<weight_t>(graph_view.number_of_vertices());
555+
}
556+
557+
scale_source = scale_nonsource;
558+
} else if (normalized) {
559+
scale_nonsource = static_cast<weight_t>(num_sources) * (graph_view.number_of_vertices() - 1);
560+
scale_source = static_cast<weight_t>(num_sources - 1) * (graph_view.number_of_vertices() - 1);
561+
} else {
562+
scale_nonsource = static_cast<weight_t>(num_sources) / graph_view.number_of_vertices();
563+
scale_source = static_cast<weight_t>(num_sources - 1) / graph_view.number_of_vertices();
564+
565+
if (graph_view.is_symmetric()) {
566+
*scale_nonsource *= 2;
567+
*scale_source *= 2;
566568
}
567-
} else if (static_cast<edge_t>(num_sources) < graph_view.number_of_vertices()) {
568-
if ((graph_view.number_of_vertices() > 1) && (num_sources > 0))
569-
scale_factor =
570-
(graph_view.is_symmetric() ? weight_t{2} : weight_t{1}) *
571-
static_cast<weight_t>(num_sources) /
572-
(include_endpoints ? static_cast<weight_t>(graph_view.number_of_vertices())
573-
: static_cast<weight_t>(graph_view.number_of_vertices() - 1));
574-
scale_source = (graph_view.is_symmetric() ? weight_t{2} : weight_t{1}) *
575-
static_cast<weight_t>(num_sources) /
576-
(include_endpoints ? static_cast<weight_t>(graph_view.number_of_vertices())
577-
: static_cast<weight_t>(graph_view.number_of_vertices() - 1));
578-
} else if (graph_view.is_symmetric()) {
579-
scale_factor = weight_t{2};
580-
scale_source = weight_t{2};
581569
}
582570

583-
if (scale_factor) {
571+
if (scale_nonsource) {
584572
auto iter = thrust::make_zip_iterator(
585573
thrust::make_counting_iterator(graph_view.local_vertex_partition_range_first()),
586574
centralities.begin());
587575

588-
std::cout << "sf = " << *scale_factor << ", ss = " << *scale_source << std::endl;
589-
590576
thrust::transform(
591577
handle.get_thrust_policy(),
592578
iter,
593579
iter + centralities.size(),
594580
centralities.begin(),
595-
[sf = *scale_factor, ssf = *scale_source, vertices_begin, vertices_end] __device__(auto t) {
581+
[nonsource = *scale_nonsource,
582+
source = *scale_source,
583+
vertices_begin,
584+
vertices_end] __device__(auto t) {
596585
vertex_t v = thrust::get<0>(t);
597586
weight_t centrality = thrust::get<1>(t);
598587

599588
return (thrust::find(thrust::seq, vertices_begin, vertices_end, v) == vertices_end)
600-
? centrality / sf
601-
: centrality / ssf;
589+
? centrality / nonsource
590+
: centrality / source;
602591
});
603592
}
604593

cpp/tests/c_api/betweenness_centrality_test.c

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ int generic_betweenness_centrality_test(vertex_t* h_src,
116116
if (isnan(h_result[h_vertices[i]])) {
117117
TEST_ASSERT(test_ret_value, isnan(h_centralities[i]), "expected NaN, got a non-NaN value");
118118
} else {
119+
if (!nearlyEqual(h_result[h_vertices[i]], h_centralities[i], 0.0001))
120+
printf(" expected: %g, got %g\n", h_result[h_vertices[i]], h_centralities[i]);
121+
119122
TEST_ASSERT(test_ret_value,
120123
nearlyEqual(h_result[h_vertices[i]], h_centralities[i], 0.0001),
121124
"centralities results don't match");
@@ -173,7 +176,7 @@ int test_betweenness_centrality_specific_normalized()
173176
weight_t h_wgt[] = {
174177
0.1f, 2.1f, 1.1f, 5.1f, 3.1f, 4.1f, 7.2f, 3.2f, 0.1f, 2.1f, 1.1f, 5.1f, 3.1f, 4.1f, 7.2f, 3.2f};
175178
vertex_t h_seeds[] = {0, 3};
176-
weight_t h_result[] = {0, 0.395833, 0.16667, 0.16667, 0.0416667, 0.0625};
179+
weight_t h_result[] = {0, 0.316667, 0.133333, 0.133333, 0.0333333, 0.05};
177180

178181
return generic_betweenness_centrality_test(h_src,
179182
h_dst,
@@ -201,7 +204,7 @@ int test_betweenness_centrality_specific_unnormalized()
201204
weight_t h_wgt[] = {
202205
0.1f, 2.1f, 1.1f, 5.1f, 3.1f, 4.1f, 7.2f, 3.2f, 0.1f, 2.1f, 1.1f, 5.1f, 3.1f, 4.1f, 7.2f, 3.2f};
203206
vertex_t h_seeds[] = {0, 3};
204-
weight_t h_result[] = {0, 7.91667, 3.33333, 1.666667, 0.833333, 1.25};
207+
weight_t h_result[] = {0, 9.5, 4, 4, 1, 1.5};
205208

206209
return generic_betweenness_centrality_test(h_src,
207210
h_dst,
@@ -315,18 +318,18 @@ int test_issue_4941()
315318
{TRUE, TRUE, TRUE, 1, {1.0, 1.0, 0.25, 0.25, 0.25}},
316319
{TRUE, TRUE, FALSE, 0, {1.0, 0.4, 0.4, 0.4, 0.4}},
317320
{TRUE, TRUE, FALSE, 1, {1.0, 1.0, 0.25, 0.25, 0.25}},
318-
{TRUE, FALSE, TRUE, 0, {1.0, 0.0, 0.0, 0.0, 0.0}},
319-
{TRUE, FALSE, TRUE, 1, {1.0, NAN, 0.0, 0.0, 0.0}},
320-
{TRUE, FALSE, FALSE, 0, {1.0, 0.0, 0.0, 0.0, 0.0}},
321-
{TRUE, FALSE, FALSE, 1, {1.0, NAN, 0.0, 0.0, 0.0}},
321+
{TRUE, FALSE, TRUE, 0, {0.6, 0.0, 0.0, 0.0, 0.0}},
322+
{TRUE, FALSE, TRUE, 1, {0.75, NAN, 0.0, 0.0, 0.0}},
323+
{TRUE, FALSE, FALSE, 0, {0.6, 0.0, 0.0, 0.0, 0.0}},
324+
{TRUE, FALSE, FALSE, 1, {0.75, NAN, 0.0, 0.0, 0.0}},
322325
{FALSE, TRUE, TRUE, 0, {20.0, 8.0, 8.0, 8.0, 8.0}},
323326
{FALSE, TRUE, TRUE, 1, {20.0, 20.0, 5.0, 5.0, 5.0}},
324327
{FALSE, TRUE, FALSE, 0, {10.0, 4.0, 4.0, 4.0, 4.0}},
325328
{FALSE, TRUE, FALSE, 1, {10.0, 10.0, 2.5, 2.5, 2.5}},
326329
{FALSE, FALSE, TRUE, 0, {12.0, 0.0, 0.0, 0.0, 0.0}},
327-
{FALSE, FALSE, TRUE, 1, {12.0, 0.0, 0.0, 0.0, 0.0}},
330+
{FALSE, FALSE, TRUE, 1, {15, NAN, 0.0, 0.0, 0.0}},
328331
{FALSE, FALSE, FALSE, 0, {6.0, 0.0, 0.0, 0.0, 0.0}},
329-
{FALSE, FALSE, FALSE, 1, {6.0, 0.0, 0.0, 0.0, 0.0}},
332+
{FALSE, FALSE, FALSE, 1, {7.5, NAN, 0.0, 0.0, 0.0}},
330333
};
331334

332335
int test_result = 0;

python/cugraph/cugraph/tests/centrality/test_betweenness_centrality.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def compare_scores(sorted_df, first_key, second_key, epsilon=DEFAULT_EPSILON):
305305
# =============================================================================
306306
# Tests
307307
# =============================================================================
308+
@pytest.mark.skip(reason="https://github.com/networkx/networkx/pull/7908")
308309
@pytest.mark.sg
309310
@pytest.mark.parametrize("graph_file", SMALL_DATASETS)
310311
@pytest.mark.parametrize("directed", [False, True])

0 commit comments

Comments
 (0)