16
16
17
17
#include " c_api/abstract_functor.hpp"
18
18
#include " c_api/graph.hpp"
19
+ #include " c_api/random.hpp"
19
20
#include " c_api/resource_handle.hpp"
20
21
#include " c_api/utils.hpp"
21
22
@@ -153,10 +154,11 @@ namespace {
153
154
154
155
struct uniform_random_walks_functor : public cugraph ::c_api::abstract_functor {
155
156
raft::handle_t const & handle_;
157
+ // FIXME: rng_state_ should be passed as a parameter
158
+ cugraph::c_api::cugraph_rng_state_t * rng_state_{nullptr };
156
159
cugraph::c_api::cugraph_graph_t * graph_{nullptr };
157
160
cugraph::c_api::cugraph_type_erased_device_array_view_t const * start_vertices_{nullptr };
158
161
size_t max_length_{0 };
159
- size_t seed_{0 };
160
162
cugraph::c_api::cugraph_random_walk_result_t * result_{nullptr };
161
163
162
164
uniform_random_walks_functor (cugraph_resource_handle_t const * handle,
@@ -222,13 +224,17 @@ struct uniform_random_walks_functor : public cugraph::c_api::abstract_functor {
222
224
graph_view.local_vertex_partition_range_last (),
223
225
false );
224
226
227
+ // FIXME: remove once rng_state passed as parameter
228
+ rng_state_ = reinterpret_cast <cugraph::c_api::cugraph_rng_state_t *>(
229
+ new cugraph::c_api::cugraph_rng_state_t {raft::random ::RngState{0 }});
230
+
225
231
auto [paths, weights] = cugraph::uniform_random_walks (
226
232
handle_,
233
+ rng_state_->rng_state_ ,
227
234
graph_view,
228
235
(edge_weights != nullptr ) ? std::make_optional (edge_weights->view ()) : std::nullopt,
229
236
raft::device_span<vertex_t const >{start_vertices.data (), start_vertices.size ()},
230
- max_length_,
231
- seed_);
237
+ max_length_);
232
238
233
239
//
234
240
// Need to unrenumber the vertices in the resulting paths
@@ -255,11 +261,12 @@ struct uniform_random_walks_functor : public cugraph::c_api::abstract_functor {
255
261
256
262
struct biased_random_walks_functor : public cugraph ::c_api::abstract_functor {
257
263
raft::handle_t const & handle_;
264
+ // FIXME: rng_state_ should be passed as a parameter
265
+ cugraph::c_api::cugraph_rng_state_t * rng_state_{nullptr };
258
266
cugraph::c_api::cugraph_graph_t * graph_{nullptr };
259
267
cugraph::c_api::cugraph_type_erased_device_array_view_t const * start_vertices_{nullptr };
260
268
size_t max_length_{0 };
261
269
cugraph::c_api::cugraph_random_walk_result_t * result_{nullptr };
262
- uint64_t seed_{0 };
263
270
264
271
biased_random_walks_functor (cugraph_resource_handle_t const * handle,
265
272
cugraph_graph_t * graph,
@@ -326,13 +333,17 @@ struct biased_random_walks_functor : public cugraph::c_api::abstract_functor {
326
333
graph_view.local_vertex_partition_range_last (),
327
334
false );
328
335
336
+ // FIXME: remove once rng_state passed as parameter
337
+ rng_state_ = reinterpret_cast <cugraph::c_api::cugraph_rng_state_t *>(
338
+ new cugraph::c_api::cugraph_rng_state_t {raft::random ::RngState{0 }});
339
+
329
340
auto [paths, weights] = cugraph::biased_random_walks (
330
341
handle_,
342
+ rng_state_->rng_state_ ,
331
343
graph_view,
332
344
edge_weights->view (),
333
345
raft::device_span<vertex_t const >{start_vertices.data (), start_vertices.size ()},
334
- max_length_,
335
- seed_);
346
+ max_length_);
336
347
337
348
//
338
349
// Need to unrenumber the vertices in the resulting paths
@@ -354,12 +365,13 @@ struct biased_random_walks_functor : public cugraph::c_api::abstract_functor {
354
365
355
366
struct node2vec_random_walks_functor : public cugraph ::c_api::abstract_functor {
356
367
raft::handle_t const & handle_;
368
+ // FIXME: rng_state_ should be passed as a parameter
369
+ cugraph::c_api::cugraph_rng_state_t * rng_state_{nullptr };
357
370
cugraph::c_api::cugraph_graph_t * graph_{nullptr };
358
371
cugraph::c_api::cugraph_type_erased_device_array_view_t const * start_vertices_{nullptr };
359
372
size_t max_length_{0 };
360
373
double p_{0 };
361
374
double q_{0 };
362
- uint64_t seed_{0 };
363
375
cugraph::c_api::cugraph_random_walk_result_t * result_{nullptr };
364
376
365
377
node2vec_random_walks_functor (cugraph_resource_handle_t const * handle,
@@ -431,15 +443,19 @@ struct node2vec_random_walks_functor : public cugraph::c_api::abstract_functor {
431
443
graph_view.local_vertex_partition_range_last (),
432
444
false );
433
445
446
+ // FIXME: remove once rng_state passed as parameter
447
+ rng_state_ = reinterpret_cast <cugraph::c_api::cugraph_rng_state_t *>(
448
+ new cugraph::c_api::cugraph_rng_state_t {raft::random ::RngState{0 }});
449
+
434
450
auto [paths, weights] = cugraph::node2vec_random_walks (
435
451
handle_,
452
+ rng_state_->rng_state_ ,
436
453
graph_view,
437
454
(edge_weights != nullptr ) ? std::make_optional (edge_weights->view ()) : std::nullopt,
438
455
raft::device_span<vertex_t const >{start_vertices.data (), start_vertices.size ()},
439
456
max_length_,
440
457
static_cast <weight_t >(p_),
441
- static_cast <weight_t >(q_),
442
- seed_);
458
+ static_cast <weight_t >(q_));
443
459
444
460
// FIXME: Need to fix invalid_vtx issue here. We can't unrenumber max_vertex_id+1
445
461
// properly...
0 commit comments