Skip to content

Commit ba4890a

Browse files
ManasviGoyaliannajpivarski
authored
feat: add reduce kernels (#3136)
* feat: add tree reduction implementation of argmin and argmax * feat: add awkward_ListOffsetArray_reduce_local_outoffsets_64 kernel * test: integration tests for cuda * test: some more integration tests for cuda * feat: add awkward_reduce_count_64 kernel * fix: indexing and indentation * feat: add awkward_reduce_countnonzero kernel * feat: add reduce sum, min and max kernels * feat: add reduce prod and sum_int_bool * feat: add sum_bool and prod_bool kernels * fix: use cpt.assert_allclose * test: reducer integration tests * fix: typr conversion * fix: use atomic to avoid race conditions * fix: remove unnessary variable * fix: minor fixes * fix: all reducer for atomics * fix: missing template * fix: remove complex * fix: atomicMin() for float 32 and indentation * fix: pass correct dtype of identity * fix: remove combinations test * fix: manage resources and disable failing test * fix: uncomment fixed test for slicing * fix: correctly interpret typetracer array for cuda backend * fix: tests-spec error for bool * fix: check for the backend of head * Update dev/generate-tests.py --------- Co-authored-by: Ianna Osborne <[email protected]> Co-authored-by: Jim Pivarski <[email protected]>
1 parent db6cece commit ba4890a

31 files changed

+2674
-437
lines changed

dev/generate-kernel-signatures.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
"awkward_ListOffsetArray_drop_none_indexes",
109109
"awkward_ListOffsetArray_reduce_local_nextparents_64",
110110
"awkward_ListOffsetArray_reduce_nonlocal_maxcount_offsetscopy_64",
111+
"awkward_ListOffsetArray_reduce_local_outoffsets_64",
111112
"awkward_UnionArray_flatten_length",
112113
"awkward_UnionArray_flatten_combine",
113114
"awkward_UnionArray_nestedfill_tags_index",
@@ -123,6 +124,7 @@
123124
"awkward_reduce_sum_int32_bool_64",
124125
"awkward_reduce_sum_int64_bool_64",
125126
"awkward_reduce_sum_bool",
127+
"awkward_reduce_prod",
126128
"awkward_reduce_prod_bool",
127129
"awkward_reduce_countnonzero",
128130
"awkward_sorting_ranges",
@@ -381,6 +383,8 @@ def kernel_signatures_cuda_py(specification):
381383
from awkward._connect.cuda import fetch_specialization
382384
from awkward._connect.cuda import import_cupy
383385
386+
import math
387+
384388
cupy = import_cupy("Awkward Arrays with CUDA")
385389
"""
386390
)

dev/generate-tests.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,6 @@ def genspectests(specdict):
424424
425425
"""
426426
)
427-
f.write("import pytest\nimport kernels\n\n")
428427
f.write("import pytest\nimport numpy as np\nimport kernels\n\n")
429428
num = 1
430429
if spec.tests == []:
@@ -894,6 +893,7 @@ def gencpuunittests(specdict):
894893
"awkward_ListOffsetArray_drop_none_indexes",
895894
"awkward_ListOffsetArray_reduce_local_nextparents_64",
896895
"awkward_ListOffsetArray_reduce_nonlocal_maxcount_offsetscopy_64",
896+
"awkward_ListOffsetArray_reduce_local_outoffsets_64",
897897
"awkward_UnionArray_flatten_length",
898898
"awkward_UnionArray_flatten_combine",
899899
"awkward_UnionArray_nestedfill_tags_index",
@@ -909,6 +909,7 @@ def gencpuunittests(specdict):
909909
"awkward_reduce_sum_int32_bool_64",
910910
"awkward_reduce_sum_int64_bool_64",
911911
"awkward_reduce_sum_bool",
912+
"awkward_reduce_prod",
912913
"awkward_reduce_prod_bool",
913914
"awkward_reduce_countnonzero",
914915
"awkward_sorting_ranges",
@@ -959,6 +960,8 @@ def gencudakerneltests(specdict):
959960

960961
f.write(
961962
"import cupy\n"
963+
"import cupy.testing as cpt\n"
964+
"import numpy as np\n"
962965
"import pytest\n\n"
963966
"import awkward as ak\n"
964967
"import awkward._connect.cuda as ak_cu\n"
@@ -1028,7 +1031,7 @@ def gencudakerneltests(specdict):
10281031
if isinstance(val, list):
10291032
f.write(
10301033
" " * 4
1031-
+ f"assert cupy.array_equal({arg}[:len(pytest_{arg})], cupy.array(pytest_{arg}))\n"
1034+
+ f"cpt.assert_allclose({arg}[:len(pytest_{arg})], cupy.array(pytest_{arg}))\n"
10321035
)
10331036
else:
10341037
f.write(" " * 4 + f"assert {arg} == pytest_{arg}\n")
@@ -1088,6 +1091,7 @@ def gencudaunittests(specdict):
10881091
f.write(
10891092
"import re\n"
10901093
"import cupy\n"
1094+
"import cupy.testing as cpt\n"
10911095
"import pytest\n\n"
10921096
"import awkward as ak\n"
10931097
"import awkward._connect.cuda as ak_cu\n"
@@ -1224,7 +1228,7 @@ def gencudaunittests(specdict):
12241228
if isinstance(val, list):
12251229
f.write(
12261230
" " * 4
1227-
+ f"assert cupy.array_equal({arg}[:len(pytest_{arg})], cupy.array(pytest_{arg}))\n"
1231+
+ f"cpt.assert_allclose({arg}[:len(pytest_{arg})], cupy.array(pytest_{arg}))\n"
12281232
)
12291233
else:
12301234
f.write(" " * 4 + f"assert {arg} == pytest_{arg}\n")

kernel-test-data.json

Lines changed: 946 additions & 154 deletions
Large diffs are not rendered by default.

src/awkward/_connect/cuda/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def fetch_template_specializations(kernel_dict):
108108
"awkward_ListArray_rpad_axis1",
109109
"awkward_ListOffsetArray_drop_none_indexes",
110110
"awkward_ListOffsetArray_reduce_nonlocal_maxcount_offsetscopy_64",
111+
"awkward_ListOffsetArray_reduce_local_outoffsets_64",
111112
"awkward_UnionArray_regular_index",
112113
"awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64",
113114
"awkward_ListOffsetArray_rpad_axis1",
@@ -119,6 +120,7 @@ def fetch_template_specializations(kernel_dict):
119120
"awkward_reduce_sum_int32_bool_64",
120121
"awkward_reduce_sum_int64_bool_64",
121122
"awkward_reduce_sum_bool",
123+
"awkward_reduce_prod",
122124
"awkward_reduce_prod_bool",
123125
"awkward_reduce_argmax",
124126
"awkward_reduce_argmin",
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
2+
3+
// BEGIN PYTHON
4+
// def f(grid, block, args):
5+
// (outoffsets, parents, lenparents, outlength, invocation_index, err_code) = args
6+
// if block[0] > 0:
7+
// segment = math.floor((outlength + block[0] - 1) / block[0])
8+
// grid_size = math.floor((lenparents + block[0] - 1) / block[0])
9+
// else:
10+
// grid_size = 1
11+
// temp = cupy.zeros(lenparents, dtype=cupy.int64)
12+
// scan_in_array = cupy.zeros(outlength, dtype=cupy.uint64)
13+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListOffsetArray_reduce_local_outoffsets_64_a", cupy.dtype(outoffsets.dtype).type, parents.dtype]))((grid_size,), block, (outoffsets, parents, lenparents, outlength, scan_in_array, temp, invocation_index, err_code))
14+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListOffsetArray_reduce_local_outoffsets_64_b", cupy.dtype(outoffsets.dtype).type, parents.dtype]))((grid_size,), block, (outoffsets, parents, lenparents, outlength, scan_in_array, temp, invocation_index, err_code))
15+
// scan_in_array = cupy.cumsum(scan_in_array)
16+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListOffsetArray_reduce_local_outoffsets_64_c", cupy.dtype(outoffsets.dtype).type, parents.dtype]))((grid_size,), block, (outoffsets, parents, lenparents, outlength, scan_in_array, temp, invocation_index, err_code))
17+
// out["awkward_ListOffsetArray_reduce_local_outoffsets_64_a", {dtype_specializations}] = None
18+
// out["awkward_ListOffsetArray_reduce_local_outoffsets_64_b", {dtype_specializations}] = None
19+
// out["awkward_ListOffsetArray_reduce_local_outoffsets_64_c", {dtype_specializations}] = None
20+
// END PYTHON
21+
22+
template <typename T, typename C>
23+
__global__ void
24+
awkward_ListOffsetArray_reduce_local_outoffsets_64_a(
25+
T* outoffsets,
26+
const C* parents,
27+
int64_t lenparents,
28+
int64_t outlength,
29+
uint64_t* scan_in_array,
30+
int64_t* temp,
31+
uint64_t invocation_index,
32+
uint64_t* err_code) {
33+
if (err_code[0] == NO_ERROR) {
34+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
35+
36+
if (thread_id < outlength) {
37+
outoffsets[thread_id] = 0;
38+
}
39+
}
40+
}
41+
42+
template <typename T, typename C>
43+
__global__ void
44+
awkward_ListOffsetArray_reduce_local_outoffsets_64_b(
45+
T* outoffsets,
46+
const C* parents,
47+
int64_t lenparents,
48+
int64_t outlength,
49+
uint64_t* scan_in_array,
50+
int64_t* temp,
51+
uint64_t invocation_index,
52+
uint64_t* err_code) {
53+
if (err_code[0] == NO_ERROR) {
54+
int64_t idx = threadIdx.x;
55+
int64_t thread_id = blockIdx.x * blockDim.x + idx;
56+
57+
if (thread_id < lenparents) {
58+
temp[thread_id] = 1;
59+
}
60+
__syncthreads();
61+
62+
for (int64_t stride = 1; stride < blockDim.x; stride *= 2) {
63+
int64_t val = 0;
64+
if (idx >= stride && thread_id < lenparents && parents[thread_id] == parents[thread_id - stride]) {
65+
val = temp[thread_id - stride];
66+
}
67+
__syncthreads();
68+
temp[thread_id] += val;
69+
__syncthreads();
70+
}
71+
72+
if (thread_id < lenparents) {
73+
int64_t parent = parents[thread_id];
74+
if (idx == blockDim.x - 1 || thread_id == lenparents - 1 || parents[thread_id] != parents[thread_id + 1]) {
75+
atomicAdd(&scan_in_array[parent], temp[thread_id]);
76+
}
77+
}
78+
}
79+
}
80+
81+
template <typename T, typename C>
82+
__global__ void
83+
awkward_ListOffsetArray_reduce_local_outoffsets_64_c(
84+
T* outoffsets,
85+
const C* parents,
86+
int64_t lenparents,
87+
int64_t outlength,
88+
uint64_t* scan_in_array,
89+
int64_t* temp,
90+
uint64_t invocation_index,
91+
uint64_t* err_code) {
92+
if (err_code[0] == NO_ERROR) {
93+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
94+
outoffsets[0] = 0;
95+
96+
if (thread_id < outlength) {
97+
outoffsets[thread_id + 1] = (T)(scan_in_array[thread_id]);
98+
}
99+
}
100+
}

src/awkward/_connect/cuda/cuda_kernels/awkward_reduce_argmax.cu

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,18 @@
33
// BEGIN PYTHON
44
// def f(grid, block, args):
55
// (toptr, fromptr, parents, lenparents, outlength, invocation_index, err_code) = args
6-
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_argmax_a", toptr.dtype, fromptr.dtype, parents.dtype]))(grid, block, (toptr, fromptr, parents, lenparents, outlength, invocation_index, err_code))
7-
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_argmax_b", toptr.dtype, fromptr.dtype, parents.dtype]))(grid, block, (toptr, fromptr, parents, lenparents, outlength, invocation_index, err_code))
6+
// if block[0] > 0:
7+
// grid_size = math.floor((lenparents + block[0] - 1) / block[0])
8+
// else:
9+
// grid_size = 1
10+
// atomic_toptr = cupy.array(toptr, dtype=cupy.uint64)
11+
// temp = cupy.zeros(lenparents, dtype=toptr.dtype)
12+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_argmax_a", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((grid_size,), block, (toptr, fromptr, parents, lenparents, outlength, atomic_toptr, temp, invocation_index, err_code))
13+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_argmax_b", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((grid_size,), block, (toptr, fromptr, parents, lenparents, outlength, atomic_toptr, temp, invocation_index, err_code))
14+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_reduce_argmax_c", cupy.dtype(toptr.dtype).type, cupy.dtype(fromptr.dtype).type, parents.dtype]))((grid_size,), block, (toptr, fromptr, parents, lenparents, outlength, atomic_toptr, temp, invocation_index, err_code))
815
// out["awkward_reduce_argmax_a", {dtype_specializations}] = None
916
// out["awkward_reduce_argmax_b", {dtype_specializations}] = None
17+
// out["awkward_reduce_argmax_c", {dtype_specializations}] = None
1018
// END PYTHON
1119

1220
template <typename T, typename C, typename U>
@@ -17,12 +25,15 @@ awkward_reduce_argmax_a(
1725
const U* parents,
1826
int64_t lenparents,
1927
int64_t outlength,
28+
uint64_t* atomic_toptr,
29+
T* temp,
2030
uint64_t invocation_index,
2131
uint64_t* err_code) {
2232
if (err_code[0] == NO_ERROR) {
2333
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
34+
2435
if (thread_id < outlength) {
25-
toptr[thread_id] = -1;
36+
atomic_toptr[thread_id] = -1;
2637
}
2738
}
2839
}
@@ -35,17 +46,57 @@ awkward_reduce_argmax_b(
3546
const U* parents,
3647
int64_t lenparents,
3748
int64_t outlength,
49+
uint64_t* atomic_toptr,
50+
T* temp,
3851
uint64_t invocation_index,
3952
uint64_t* err_code) {
4053
if (err_code[0] == NO_ERROR) {
41-
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
54+
int64_t idx = threadIdx.x;
55+
int64_t thread_id = blockIdx.x * blockDim.x + idx;
56+
57+
if (thread_id < lenparents) {
58+
temp[thread_id] = thread_id;
59+
}
60+
__syncthreads();
61+
62+
for (int64_t stride = 1; stride < blockDim.x; stride *= 2) {
63+
int64_t index = -1;
64+
if (idx >= stride && thread_id < lenparents && parents[thread_id] == parents[thread_id - stride]) {
65+
index = temp[thread_id - stride];
66+
}
67+
if (index != -1 && (temp[thread_id] == -1 || fromptr[index] > fromptr[temp[thread_id]] ||
68+
(fromptr[index] == fromptr[temp[thread_id]] && index < temp[thread_id]))) {
69+
temp[thread_id] = index;
70+
}
71+
__syncthreads();
72+
}
4273

4374
if (thread_id < lenparents) {
4475
int64_t parent = parents[thread_id];
45-
if (toptr[parent] == -1 ||
46-
(fromptr[thread_id] > (fromptr[toptr[parent]]))) {
47-
toptr[parent] = thread_id; // we need the last parent filled, thread random order problem, find max arg at that index
76+
if (idx == blockDim.x - 1 || thread_id == lenparents - 1 || parents[thread_id] != parents[thread_id + 1]) {
77+
atomicExch(&atomic_toptr[parent], temp[thread_id]);
4878
}
4979
}
5080
}
5181
}
82+
83+
template <typename T, typename C, typename U>
84+
__global__ void
85+
awkward_reduce_argmax_c(
86+
T* toptr,
87+
const C* fromptr,
88+
const U* parents,
89+
int64_t lenparents,
90+
int64_t outlength,
91+
uint64_t* atomic_toptr,
92+
T* temp,
93+
uint64_t invocation_index,
94+
uint64_t* err_code) {
95+
if (err_code[0] == NO_ERROR) {
96+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
97+
98+
if (thread_id < outlength) {
99+
toptr[thread_id] = (T)(atomic_toptr[thread_id]);
100+
}
101+
}
102+
}

0 commit comments

Comments
 (0)