@@ -212,6 +212,22 @@ struct all_pairs_similarity_functor : public cugraph::c_api::abstract_functor {
212
212
: std::nullopt,
213
213
topk_ != SIZE_MAX ? std::make_optional (topk_) : std::nullopt);
214
214
215
+ cugraph::unrenumber_int_vertices<vertex_t , multi_gpu>(
216
+ handle_,
217
+ v1.data (),
218
+ v1.size (),
219
+ number_map->data (),
220
+ graph_view.vertex_partition_range_lasts (),
221
+ false );
222
+
223
+ cugraph::unrenumber_int_vertices<vertex_t , multi_gpu>(
224
+ handle_,
225
+ v2.data (),
226
+ v2.size (),
227
+ number_map->data (),
228
+ graph_view.vertex_partition_range_lasts (),
229
+ false );
230
+
215
231
result_ = new cugraph::c_api::cugraph_similarity_result_t {
216
232
new cugraph::c_api::cugraph_type_erased_device_array_t (similarity_coefficients,
217
233
graph_->weight_type_ ),
@@ -274,6 +290,33 @@ struct sorensen_functor {
274
290
}
275
291
};
276
292
293
+ struct cosine_functor {
294
+ template <typename vertex_t , typename edge_t , typename weight_t , bool multi_gpu>
295
+ rmm::device_uvector<weight_t > operator ()(
296
+ raft::handle_t const & handle,
297
+ cugraph::graph_view_t <vertex_t , edge_t , false , multi_gpu> const & graph_view,
298
+ std::optional<cugraph::edge_property_view_t <edge_t , weight_t const *>> edge_weight_view,
299
+ std::tuple<raft::device_span<vertex_t const >, raft::device_span<vertex_t const >> vertex_pairs)
300
+ {
301
+ return cugraph::cosine_similarity_coefficients (
302
+ handle, graph_view, edge_weight_view, vertex_pairs);
303
+ }
304
+
305
+ template <typename vertex_t , typename edge_t , typename weight_t , bool multi_gpu>
306
+ std::tuple<rmm::device_uvector<vertex_t >,
307
+ rmm::device_uvector<vertex_t >,
308
+ rmm::device_uvector<weight_t >>
309
+ operator ()(raft::handle_t const & handle,
310
+ cugraph::graph_view_t <vertex_t , edge_t , false , multi_gpu> const & graph_view,
311
+ std::optional<cugraph::edge_property_view_t <edge_t , weight_t const *>> edge_weight_view,
312
+ std::optional<raft::device_span<vertex_t const >> vertices,
313
+ std::optional<size_t > topk)
314
+ {
315
+ return cugraph::cosine_similarity_all_pairs_coefficients (
316
+ handle, graph_view, edge_weight_view, vertices, topk);
317
+ }
318
+ };
319
+
277
320
struct overlap_functor {
278
321
template <typename vertex_t , typename edge_t , typename weight_t , bool multi_gpu>
279
322
rmm::device_uvector<weight_t > operator ()(
@@ -300,6 +343,33 @@ struct overlap_functor {
300
343
}
301
344
};
302
345
346
+ struct cosine_similarity_functor {
347
+ template <typename vertex_t , typename edge_t , typename weight_t , bool multi_gpu>
348
+ rmm::device_uvector<weight_t > operator ()(
349
+ raft::handle_t const & handle,
350
+ cugraph::graph_view_t <vertex_t , edge_t , false , multi_gpu> const & graph_view,
351
+ std::optional<cugraph::edge_property_view_t <edge_t , weight_t const *>> edge_weight_view,
352
+ std::tuple<raft::device_span<vertex_t const >, raft::device_span<vertex_t const >> vertex_pairs)
353
+ {
354
+ return cugraph::cosine_similarity_coefficients (
355
+ handle, graph_view, edge_weight_view, vertex_pairs);
356
+ }
357
+
358
+ template <typename vertex_t , typename edge_t , typename weight_t , bool multi_gpu>
359
+ std::tuple<rmm::device_uvector<vertex_t >,
360
+ rmm::device_uvector<vertex_t >,
361
+ rmm::device_uvector<weight_t >>
362
+ operator ()(raft::handle_t const & handle,
363
+ cugraph::graph_view_t <vertex_t , edge_t , false , multi_gpu> const & graph_view,
364
+ std::optional<cugraph::edge_property_view_t <edge_t , weight_t const *>> edge_weight_view,
365
+ std::optional<raft::device_span<vertex_t const >> vertices,
366
+ std::optional<size_t > topk)
367
+ {
368
+ return cugraph::cosine_similarity_all_pairs_coefficients (
369
+ handle, graph_view, edge_weight_view, vertices, topk);
370
+ }
371
+ };
372
+
303
373
} // namespace
304
374
305
375
extern " C" cugraph_type_erased_device_array_view_t * cugraph_similarity_result_get_similarity (
@@ -391,6 +461,28 @@ extern "C" cugraph_error_code_t cugraph_overlap_coefficients(
391
461
return cugraph::c_api::run_algorithm (graph, functor, result, error);
392
462
}
393
463
464
+ extern " C" cugraph_error_code_t cugraph_cosine_similarity_coefficients (
465
+ const cugraph_resource_handle_t * handle,
466
+ cugraph_graph_t * graph,
467
+ const cugraph_vertex_pairs_t * vertex_pairs,
468
+ bool_t use_weight,
469
+ bool_t do_expensive_check,
470
+ cugraph_similarity_result_t ** result,
471
+ cugraph_error_t ** error)
472
+ {
473
+ if (use_weight) {
474
+ CAPI_EXPECTS (
475
+ reinterpret_cast <cugraph::c_api::cugraph_graph_t *>(graph)->edge_weights_ != nullptr ,
476
+ CUGRAPH_INVALID_INPUT,
477
+ " use_weight is true but edge weights are not provided." ,
478
+ *error);
479
+ }
480
+ similarity_functor functor (
481
+ handle, graph, vertex_pairs, cosine_similarity_functor{}, use_weight, do_expensive_check);
482
+
483
+ return cugraph::c_api::run_algorithm (graph, functor, result, error);
484
+ }
485
+
394
486
extern " C" cugraph_error_code_t cugraph_all_pairs_jaccard_coefficients (
395
487
const cugraph_resource_handle_t * handle,
396
488
cugraph_graph_t * graph,
@@ -459,3 +551,26 @@ extern "C" cugraph_error_code_t cugraph_all_pairs_overlap_coefficients(
459
551
460
552
return cugraph::c_api::run_algorithm (graph, functor, result, error);
461
553
}
554
+
555
+ extern " C" cugraph_error_code_t cugraph_all_pairs_cosine_similarity_coefficients (
556
+ const cugraph_resource_handle_t * handle,
557
+ cugraph_graph_t * graph,
558
+ const cugraph_type_erased_device_array_view_t * vertices,
559
+ bool_t use_weight,
560
+ size_t topk,
561
+ bool_t do_expensive_check,
562
+ cugraph_similarity_result_t ** result,
563
+ cugraph_error_t ** error)
564
+ {
565
+ if (use_weight) {
566
+ CAPI_EXPECTS (
567
+ reinterpret_cast <cugraph::c_api::cugraph_graph_t *>(graph)->edge_weights_ != nullptr ,
568
+ CUGRAPH_INVALID_INPUT,
569
+ " use_weight is true but edge weights are not provided." ,
570
+ *error);
571
+ }
572
+ all_pairs_similarity_functor functor (
573
+ handle, graph, vertices, overlap_functor{}, use_weight, topk, do_expensive_check);
574
+
575
+ return cugraph::c_api::run_algorithm (graph, functor, result, error);
576
+ }
0 commit comments