@@ -209,20 +209,27 @@ void decompress_edge_partition_to_fill_edgelist_majors(
209
209
}
210
210
}
211
211
212
- template <typename vertex_t , typename edge_t , typename weight_t , bool multi_gpu>
212
+ template <typename vertex_t ,
213
+ typename edge_t ,
214
+ typename weight_t ,
215
+ typename edge_type_t ,
216
+ bool multi_gpu>
213
217
void decompress_edge_partition_to_edgelist (
214
218
raft::handle_t const & handle,
215
219
edge_partition_device_view_t <vertex_t , edge_t , multi_gpu> edge_partition,
216
220
std::optional<edge_partition_edge_property_device_view_t <edge_t , weight_t const *>>
217
221
edge_partition_weight_view,
218
222
std::optional<edge_partition_edge_property_device_view_t <edge_t , edge_t const *>>
219
223
edge_partition_id_view,
224
+ std::optional<edge_partition_edge_property_device_view_t <edge_t , edge_type_t const *>>
225
+ edge_partition_type_view,
220
226
std::optional<edge_partition_edge_property_device_view_t <edge_t , uint32_t const *, bool >>
221
227
edge_partition_mask_view,
222
228
raft::device_span<vertex_t > edgelist_majors /* [OUT] */ ,
223
229
raft::device_span<vertex_t > edgelist_minors /* [OUT] */ ,
224
230
std::optional<raft::device_span<weight_t >> edgelist_weights /* [OUT] */ ,
225
231
std::optional<raft::device_span<edge_t >> edgelist_ids /* [OUT] */ ,
232
+ std::optional<raft::device_span<edge_type_t >> edgelist_types /* [OUT] */ ,
226
233
std::optional<std::vector<vertex_t >> const & segment_offsets)
227
234
{
228
235
auto number_of_edges = edge_partition.number_of_edges ();
@@ -271,6 +278,22 @@ void decompress_edge_partition_to_edgelist(
271
278
(*edgelist_ids).begin ());
272
279
}
273
280
}
281
+
282
+ if (edge_partition_type_view) {
283
+ assert (edgelist_types.has_value ());
284
+ if (edge_partition_mask_view) {
285
+ copy_if_mask_set (handle,
286
+ (*edge_partition_type_view).value_first (),
287
+ (*edge_partition_type_view).value_first () + number_of_edges,
288
+ (*edge_partition_mask_view).value_first (),
289
+ (*edgelist_types).begin ());
290
+ } else {
291
+ thrust::copy (handle.get_thrust_policy (),
292
+ (*edge_partition_type_view).value_first (),
293
+ (*edge_partition_type_view).value_first () + number_of_edges,
294
+ (*edgelist_types).begin ());
295
+ }
296
+ }
274
297
}
275
298
276
299
} // namespace detail
0 commit comments