Skip to content

Commit 81085be

Browse files
authored
feat: add slicing CUDA kernels (#3140)
* feat add awkward_ListArray_getitem_jagged_apply kernel * fix: remove print statements * feat: add awkward_ListArray_getitem_jagged_shrink kernel * test: cuda integration tests * test: more slicing integration tests * fix: ndarray error for cupy array shape * fix: remove unused variable
1 parent 0b9f6f4 commit 81085be

9 files changed

+2433
-3
lines changed

dev/generate-kernel-signatures.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,12 @@
5353
"awkward_missing_repeat",
5454
"awkward_RegularArray_getitem_jagged_expand",
5555
"awkward_ListArray_combinations_length",
56+
"awkward_ListArray_getitem_jagged_apply",
5657
"awkward_ListArray_getitem_jagged_carrylen",
5758
"awkward_ListArray_getitem_jagged_descend",
5859
"awkward_ListArray_getitem_jagged_expand",
5960
"awkward_ListArray_getitem_jagged_numvalid",
61+
"awkward_ListArray_getitem_jagged_shrink",
6062
"awkward_ListArray_getitem_next_array_advanced",
6163
"awkward_ListArray_getitem_next_array",
6264
"awkward_ListArray_getitem_next_at",

dev/generate-tests.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,10 +838,12 @@ def gencpuunittests(specdict):
838838
"awkward_missing_repeat",
839839
"awkward_RegularArray_getitem_jagged_expand",
840840
"awkward_ListArray_combinations_length",
841+
"awkward_ListArray_getitem_jagged_apply",
841842
"awkward_ListArray_getitem_jagged_carrylen",
842843
"awkward_ListArray_getitem_jagged_descend",
843844
"awkward_ListArray_getitem_jagged_expand",
844845
"awkward_ListArray_getitem_jagged_numvalid",
846+
"awkward_ListArray_getitem_jagged_shrink",
845847
"awkward_ListArray_getitem_next_array_advanced",
846848
"awkward_ListArray_getitem_next_array",
847849
"awkward_ListArray_getitem_next_at",

src/awkward/_connect/cuda/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,11 @@ def fetch_template_specializations(kernel_dict):
9696
"awkward_ListArray_broadcast_tooffsets",
9797
"awkward_ListArray_combinations_length",
9898
"awkward_ListArray_compact_offsets",
99+
"awkward_ListArray_getitem_jagged_apply",
99100
"awkward_ListArray_getitem_jagged_carrylen",
100101
"awkward_ListArray_getitem_jagged_descend",
101102
"awkward_ListArray_getitem_jagged_numvalid",
103+
"awkward_ListArray_getitem_jagged_shrink",
102104
"awkward_ListArray_getitem_next_range",
103105
"awkward_ListArray_getitem_next_range_carrylength",
104106
"awkward_ListArray_min_range",
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
2+
3+
// BEGIN PYTHON
4+
// def f(grid, block, args):
5+
// (tooffsets, tocarry, slicestarts, slicestops, sliceouterlen, sliceindex, sliceinnerlen, fromstarts, fromstops, contentlen, invocation_index, err_code) = args
6+
// scan_in_array = cupy.zeros(sliceouterlen + 1, dtype=cupy.int64)
7+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_apply_a", tooffsets.dtype, tocarry.dtype, slicestarts.dtype, slicestops.dtype, sliceindex.dtype, fromstarts.dtype, fromstops.dtype]))(grid, block, (tooffsets, tocarry, slicestarts, slicestops, sliceouterlen, sliceindex, sliceinnerlen, fromstarts, fromstops, contentlen, scan_in_array, invocation_index, err_code))
8+
// scan_in_array = cupy.cumsum(scan_in_array)
9+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_apply_b", tooffsets.dtype, tocarry.dtype, slicestarts.dtype, slicestops.dtype, sliceindex.dtype, fromstarts.dtype, fromstops.dtype]))(grid, block, (tooffsets, tocarry, slicestarts, slicestops, sliceouterlen, sliceindex, sliceinnerlen, fromstarts, fromstops, contentlen, scan_in_array, invocation_index, err_code))
10+
// out["awkward_ListArray_getitem_jagged_apply_a", {dtype_specializations}] = None
11+
// out["awkward_ListArray_getitem_jagged_apply_b", {dtype_specializations}] = None
12+
// END PYTHON
13+
14+
enum class LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS {
15+
JAG_STOP_LT_START, // message: "jagged slice's stops[i] < starts[i]"
16+
OFF_GET_CON, // message: "jagged slice's offsets extend beyond its content"
17+
STOP_LT_START, // message: "stops[i] < starts[i]"
18+
STOP_GET_LEN, // message: "stops[i] > len(content)"
19+
IND_OUT_OF_RANGE, // message: "index out of range"
20+
};
21+
22+
template <typename T, typename C, typename U, typename V, typename W, typename X, typename Y>
23+
__global__ void
24+
awkward_ListArray_getitem_jagged_apply_a(
25+
T* tooffsets,
26+
C* tocarry,
27+
const U* slicestarts,
28+
const V* slicestops,
29+
int64_t sliceouterlen,
30+
const W* sliceindex,
31+
int64_t sliceinnerlen,
32+
const X* fromstarts,
33+
const Y* fromstops,
34+
int64_t contentlen,
35+
int64_t* scan_in_array,
36+
uint64_t invocation_index,
37+
uint64_t* err_code) {
38+
if (err_code[0] == NO_ERROR) {
39+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
40+
scan_in_array[0] = 0;
41+
42+
if (thread_id < sliceouterlen) {
43+
U slicestart = slicestarts[thread_id];
44+
V slicestop = slicestops[thread_id];
45+
46+
if (slicestart != slicestop) {
47+
if (slicestop < slicestart) {
48+
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::JAG_STOP_LT_START)
49+
}
50+
if (slicestop > sliceinnerlen) {
51+
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::OFF_GET_CON)
52+
}
53+
int64_t start = (int64_t)fromstarts[thread_id];
54+
int64_t stop = (int64_t)fromstops[thread_id];
55+
if (stop < start) {
56+
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_LT_START)
57+
}
58+
if (start != stop && stop > contentlen) {
59+
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_GET_LEN)
60+
}
61+
scan_in_array[thread_id + 1] = slicestop - slicestart;
62+
}
63+
}
64+
}
65+
}
66+
67+
template <typename T, typename C, typename U, typename V, typename W, typename X, typename Y>
68+
__global__ void
69+
awkward_ListArray_getitem_jagged_apply_b(
70+
T* tooffsets,
71+
C* tocarry,
72+
const U* slicestarts,
73+
const V* slicestops,
74+
int64_t sliceouterlen,
75+
const W* sliceindex,
76+
int64_t sliceinnerlen,
77+
const X* fromstarts,
78+
const Y* fromstops,
79+
int64_t contentlen,
80+
int64_t* scan_in_array,
81+
uint64_t invocation_index,
82+
uint64_t* err_code) {
83+
if (err_code[0] == NO_ERROR) {
84+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
85+
86+
if (thread_id < sliceouterlen) {
87+
U slicestart = slicestarts[thread_id];
88+
V slicestop = slicestops[thread_id];
89+
tooffsets[thread_id] = (T)(scan_in_array[thread_id]);
90+
if (slicestart != slicestop) {
91+
if (slicestop < slicestart) {
92+
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::JAG_STOP_LT_START)
93+
}
94+
if (slicestop > sliceinnerlen) {
95+
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::OFF_GET_CON)
96+
}
97+
int64_t start = (int64_t)fromstarts[thread_id];
98+
int64_t stop = (int64_t)fromstops[thread_id];
99+
if (stop < start) {
100+
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_LT_START)
101+
}
102+
if (start != stop && stop > contentlen) {
103+
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_GET_LEN)
104+
}
105+
int64_t count = stop - start;
106+
for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) {
107+
int64_t index = (int64_t) sliceindex[j];
108+
if (index < -count || index > count) {
109+
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::IND_OUT_OF_RANGE)
110+
}
111+
if (index < 0) {
112+
index += count;
113+
}
114+
tocarry[scan_in_array[thread_id] + j - slicestart] = start + index;
115+
}
116+
}
117+
}
118+
tooffsets[sliceouterlen] = scan_in_array[sliceouterlen];
119+
}
120+
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
2+
3+
// BEGIN PYTHON
4+
// def f(grid, block, args):
5+
// (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, invocation_index, err_code) = args
6+
// if length > 0 and length < int(slicestops[length - 1]):
7+
// len_array = int(slicestops[length - 1])
8+
// else:
9+
// len_array = length
10+
// scan_in_array_k = cupy.zeros(len_array, dtype=cupy.int64)
11+
// scan_in_array_tosmalloffsets = cupy.zeros(length + 1, dtype=cupy.int64)
12+
// scan_in_array_tolargeoffsets = cupy.zeros(length + 1, dtype=cupy.int64)
13+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_shrink_a", tocarry.dtype, tosmalloffsets.dtype, tolargeoffsets.dtype, slicestarts.dtype, slicestops.dtype, missing.dtype]))(grid, block, (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, scan_in_array_k, scan_in_array_tosmalloffsets, scan_in_array_tolargeoffsets, invocation_index, err_code))
14+
// scan_in_array_k = cupy.cumsum(scan_in_array_k)
15+
// scan_in_array_tosmalloffsets = cupy.cumsum(scan_in_array_tosmalloffsets)
16+
// scan_in_array_tolargeoffsets = cupy.cumsum(scan_in_array_tolargeoffsets)
17+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_shrink_b", tocarry.dtype, tosmalloffsets.dtype, tolargeoffsets.dtype, slicestarts.dtype, slicestops.dtype, missing.dtype]))(grid, block, (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, scan_in_array_k, scan_in_array_tosmalloffsets, scan_in_array_tolargeoffsets, invocation_index, err_code))
18+
// out["awkward_ListArray_getitem_jagged_shrink_a", {dtype_specializations}] = None
19+
// out["awkward_ListArray_getitem_jagged_shrink_b", {dtype_specializations}] = None
20+
// END PYTHON
21+
22+
template <typename T, typename C, typename U, typename V, typename W, typename X>
23+
__global__ void
24+
awkward_ListArray_getitem_jagged_shrink_a(
25+
T* tocarry,
26+
C* tosmalloffsets,
27+
U* tolargeoffsets,
28+
const V* slicestarts,
29+
const W* slicestops,
30+
int64_t length,
31+
const X* missing,
32+
int64_t* scan_in_array_k,
33+
int64_t* scan_in_array_tosmalloffsets,
34+
int64_t* scan_in_array_tolargeoffsets,
35+
uint64_t invocation_index,
36+
uint64_t* err_code) {
37+
if (err_code[0] == NO_ERROR) {
38+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
39+
if (thread_id < length) {
40+
if (thread_id == 0) {
41+
scan_in_array_tosmalloffsets[0] = slicestarts[0];
42+
scan_in_array_tolargeoffsets[0] = slicestarts[0];
43+
}
44+
V slicestart = slicestarts[thread_id];
45+
W slicestop = slicestops[thread_id];
46+
if (slicestart != slicestop) {
47+
C smallcount = 0;
48+
for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) {
49+
if (missing[j] >= 0) {
50+
smallcount++;
51+
}
52+
}
53+
scan_in_array_k[thread_id + 1] = smallcount;
54+
scan_in_array_tosmalloffsets[thread_id + 1] = smallcount;
55+
}
56+
scan_in_array_tolargeoffsets[thread_id + 1] = slicestop - slicestart;
57+
}
58+
}
59+
}
60+
61+
template <typename T, typename C, typename U, typename V, typename W, typename X>
62+
__global__ void
63+
awkward_ListArray_getitem_jagged_shrink_b(
64+
T* tocarry,
65+
C* tosmalloffsets,
66+
U* tolargeoffsets,
67+
const V* slicestarts,
68+
const W* slicestops,
69+
int64_t length,
70+
const X* missing,
71+
int64_t* scan_in_array_k,
72+
int64_t* scan_in_array_tosmalloffsets,
73+
int64_t* scan_in_array_tolargeoffsets,
74+
uint64_t invocation_index,
75+
uint64_t* err_code) {
76+
if (err_code[0] == NO_ERROR) {
77+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
78+
if (length == 0) {
79+
tosmalloffsets[0] = 0;
80+
tolargeoffsets[0] = 0;
81+
}
82+
else {
83+
tosmalloffsets[0] = slicestarts[0];
84+
tolargeoffsets[0] = slicestarts[0];
85+
}
86+
if (thread_id < length) {
87+
V slicestart = slicestarts[thread_id];
88+
W slicestop = slicestops[thread_id];
89+
int64_t k = scan_in_array_k[thread_id] - scan_in_array_k[0];
90+
if (slicestart != slicestop) {
91+
for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) {
92+
if (missing[j] >= 0) {
93+
tocarry[k] = j;
94+
k++;
95+
}
96+
}
97+
tosmalloffsets[thread_id + 1] = scan_in_array_tosmalloffsets[thread_id + 1];
98+
}
99+
else {
100+
tosmalloffsets[thread_id + 1] = scan_in_array_tosmalloffsets[thread_id];
101+
}
102+
tolargeoffsets[thread_id + 1] = scan_in_array_tolargeoffsets[thread_id + 1];
103+
}
104+
}
105+
}

src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_next_range_carrylength.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ awkward_ListArray_getitem_next_range_carrylength_a(
2626
uint64_t* err_code) {
2727
if (err_code[0] == NO_ERROR) {
2828
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
29-
const int64_t kMaxInt64 = 9223372036854775806; // 2**63 - 2: see below
30-
const int64_t kSliceNone = kMaxInt64 + 1; // for Slice::none()
3129
if (thread_id < lenstarts) {
3230
int64_t length = fromstops[thread_id] - fromstarts[thread_id];
3331
int64_t regular_start = start;

src/awkward/_slicing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def _normalise_item_bool_to_int(item: Content, backend: Backend) -> Content:
550550

551551
# outindex fits into the lists; non-missing are sequential
552552
outindex = ak.index.Index64(
553-
item_backend.index_nplike.full(nextoffsets.data[-1], -1, dtype=np.int64)
553+
item_backend.index_nplike.full(nextoffsets[-1], -1, dtype=np.int64)
554554
)
555555
outindex.data[~isnegative[expanded]] = item_backend.index_nplike.arange(
556556
nextcontent.shape[0], dtype=np.int64

0 commit comments

Comments
 (0)