Skip to content

Commit ef1e851

Browse files
authored
feat: add some misc CUDA kernels (#3141)
* feat: add awkward_NumpyArray_subrange_equal and awkward_NumpyArray_subrange_equal_bool kernel * fix: grid-stride loop * fix: awkward_ListOffsetArray_rpad_axis1 * feat: add awkward_UnionArray_regular_index.cu * test: rearrange and add tests
1 parent 5de7b35 commit ef1e851

16 files changed

+707
-507
lines changed

dev/generate-kernel-signatures.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
"awkward_ListArray_getitem_next_range_counts",
6868
"awkward_ListArray_rpad_and_clip_length_axis1",
6969
"awkward_ListArray_rpad_axis1",
70+
"awkward_UnionArray_regular_index",
7071
"awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64",
7172
"awkward_ListArray_getitem_next_range_spreadadvanced",
7273
"awkward_ListArray_localindex",
@@ -84,6 +85,8 @@
8485
"awkward_Content_getitem_next_missing_jagged_getmaskstartstop",
8586
"awkward_index_rpad_and_clip_axis0",
8687
"awkward_index_rpad_and_clip_axis1",
88+
"awkward_NumpyArray_subrange_equal",
89+
"awkward_NumpyArray_subrange_equal_bool",
8790
"awkward_IndexedArray_flatten_nextcarry",
8891
"awkward_IndexedArray_flatten_none2empty",
8992
"awkward_IndexedArray_getitem_nextcarry",

dev/generate-tests.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,7 @@ def gencpuunittests(specdict):
852852
"awkward_ListArray_getitem_next_range_counts",
853853
"awkward_ListArray_rpad_and_clip_length_axis1",
854854
"awkward_ListArray_rpad_axis1",
855+
"awkward_UnionArray_regular_index",
855856
"awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64",
856857
"awkward_ListArray_getitem_next_range_spreadadvanced",
857858
"awkward_ListArray_localindex",
@@ -869,6 +870,8 @@ def gencpuunittests(specdict):
869870
"awkward_Content_getitem_next_missing_jagged_getmaskstartstop",
870871
"awkward_index_rpad_and_clip_axis0",
871872
"awkward_index_rpad_and_clip_axis1",
873+
"awkward_NumpyArray_subrange_equal",
874+
"awkward_NumpyArray_subrange_equal_bool",
872875
"awkward_IndexedArray_flatten_nextcarry",
873876
"awkward_IndexedArray_flatten_none2empty",
874877
"awkward_IndexedArray_getitem_nextcarry",

kernel-specification.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ kernels:
543543
k = k + 1
544544
automatic-tests: true
545545

546-
- name: awkward_IndexedArray_local_preparenext
546+
- name: awkward_IndexedArray_local_preparenext_64
547547
specializations:
548548
- name: awkward_IndexedArray_local_preparenext_64
549549
args:
@@ -555,7 +555,7 @@ kernels:
555555
- {name: nextlen, type: "Const[int64_t]", dir: in, role: default}
556556
description: null
557557
definition: |
558-
def awkward_IndexedArray_local_preparenext(
558+
def awkward_IndexedArray_local_preparenext_64(
559559
tocarry, starts, parents, parentslength, nextparents, nextlen
560560
):
561561
j = 0

kernel-test-data.json

Lines changed: 137 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10677,7 +10677,7 @@
1067710677
},
1067810678
{
1067910679
"name": "awkward_ListOffsetArray_rpad_axis1",
10680-
"status": false,
10680+
"status": true,
1068110681
"tests": [
1068210682
{
1068310683
"error": false,
@@ -15608,7 +15608,7 @@
1560815608
},
1560915609
{
1561015610
"name": "awkward_UnionArray_regular_index",
15611-
"status": false,
15611+
"status": true,
1561215612
"tests": [
1561315613
{
1561415614
"error": false,
@@ -16614,7 +16614,7 @@
1661416614
]
1661516615
},
1661616616
{
16617-
"name": "awkward_IndexedArray_local_preparenext",
16617+
"name": "awkward_IndexedArray_local_preparenext_64",
1661816618
"status": true,
1661916619
"tests": [
1662016620
{
@@ -16639,7 +16639,7 @@
1663916639
"nextparents": [0, 0, 0, 0, 1, 1, 1],
1664016640
"parentslength": 11,
1664116641
"parents": [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
16642-
"starts": [0, 2, 5]
16642+
"starts": [0, 6]
1664316643
},
1664416644
"outputs": {
1664516645
"tocarry": [0, 1, 2, 3, -1, -1, 4, 5, 6, -1, -1]
@@ -19436,7 +19436,7 @@
1943619436
},
1943719437
{
1943819438
"name": "awkward_ListOffsetArray_reduce_nonlocal_nextshifts_64",
19439-
"status": false,
19439+
"status": true,
1944019440
"tests": [
1944119441
{
1944219442
"error": false,
@@ -26387,12 +26387,38 @@
2638726387
"tmpptr": [0, 2, 2, 3, 5],
2638826388
"fromstarts": [0, 2, 3, 3],
2638926389
"fromstops": [2, 3, 3, 5],
26390-
"length": 4
26390+
"length": 5
2639126391
},
2639226392
"outputs": {
2639326393
"toequal": [0]
2639426394
}
2639526395
},
26396+
{
26397+
"error": false,
26398+
"message": "",
26399+
"inputs": {
26400+
"tmpptr": [0, 2, 2, 0, 2],
26401+
"fromstarts": [0, 2, 3, 3],
26402+
"fromstops": [2, 3, 3, 5],
26403+
"length": 5
26404+
},
26405+
"outputs": {
26406+
"toequal": [1]
26407+
}
26408+
},
26409+
{
26410+
"error": false,
26411+
"message": "",
26412+
"inputs": {
26413+
"tmpptr": [0, 0, 0, 0, 0],
26414+
"fromstarts": [0, 2, 3, 3],
26415+
"fromstops": [2, 3, 3, 5],
26416+
"length": 5
26417+
},
26418+
"outputs": {
26419+
"toequal": [1]
26420+
}
26421+
},
2639626422
{
2639726423
"error": false,
2639826424
"message": "",
@@ -26406,6 +26432,32 @@
2640626432
"toequal": [1]
2640726433
}
2640826434
},
26435+
{
26436+
"error": false,
26437+
"message": "",
26438+
"inputs": {
26439+
"tmpptr": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
26440+
"fromstarts": [0, 2, 4, 6, 8, 10],
26441+
"fromstops": [2, 4, 6, 8, 10, 12],
26442+
"length": 6
26443+
},
26444+
"outputs": {
26445+
"toequal": [1]
26446+
}
26447+
},
26448+
{
26449+
"error": false,
26450+
"message": "",
26451+
"inputs": {
26452+
"tmpptr": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 2],
26453+
"fromstarts": [0, 2, 4, 6, 8, 10],
26454+
"fromstops": [2, 4, 6, 8, 10, 12],
26455+
"length": 6
26456+
},
26457+
"outputs": {
26458+
"toequal": [1]
26459+
}
26460+
},
2640926461
{
2641026462
"error": false,
2641126463
"message": "",
@@ -26418,6 +26470,19 @@
2641826470
"outputs": {
2641926471
"toequal": [1]
2642026472
}
26473+
},
26474+
{
26475+
"error": false,
26476+
"message": "",
26477+
"inputs": {
26478+
"tmpptr": [1, 2, 3, 4, 5, 6],
26479+
"fromstarts": [2, 2, 2, 2, 2, 2],
26480+
"fromstops": [4, 4, 4, 4, 4, 4],
26481+
"length": 6
26482+
},
26483+
"outputs": {
26484+
"toequal": [1]
26485+
}
2642126486
}
2642226487
]
2642326488
},
@@ -26458,12 +26523,38 @@
2645826523
"tmpptr": [0, 2, 2, 3, 5],
2645926524
"fromstarts": [0, 2, 3, 3],
2646026525
"fromstops": [2, 3, 3, 5],
26461-
"length": 4
26526+
"length": 5
2646226527
},
2646326528
"outputs": {
2646426529
"toequal": [0]
2646526530
}
2646626531
},
26532+
{
26533+
"error": false,
26534+
"message": "",
26535+
"inputs": {
26536+
"tmpptr": [0, 2, 2, 0, 2],
26537+
"fromstarts": [0, 2, 3, 3],
26538+
"fromstops": [2, 3, 3, 5],
26539+
"length": 5
26540+
},
26541+
"outputs": {
26542+
"toequal": [1]
26543+
}
26544+
},
26545+
{
26546+
"error": false,
26547+
"message": "",
26548+
"inputs": {
26549+
"tmpptr": [0, 0, 0, 0, 0],
26550+
"fromstarts": [0, 2, 3, 3],
26551+
"fromstops": [2, 3, 3, 5],
26552+
"length": 5
26553+
},
26554+
"outputs": {
26555+
"toequal": [1]
26556+
}
26557+
},
2646726558
{
2646826559
"error": false,
2646926560
"message": "",
@@ -26477,6 +26568,32 @@
2647726568
"toequal": [1]
2647826569
}
2647926570
},
26571+
{
26572+
"error": false,
26573+
"message": "",
26574+
"inputs": {
26575+
"tmpptr": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
26576+
"fromstarts": [0, 2, 4, 6, 8, 10],
26577+
"fromstops": [2, 4, 6, 8, 10, 12],
26578+
"length": 6
26579+
},
26580+
"outputs": {
26581+
"toequal": [1]
26582+
}
26583+
},
26584+
{
26585+
"error": false,
26586+
"message": "",
26587+
"inputs": {
26588+
"tmpptr": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 2],
26589+
"fromstarts": [0, 2, 4, 6, 8, 10],
26590+
"fromstops": [2, 4, 6, 8, 10, 12],
26591+
"length": 6
26592+
},
26593+
"outputs": {
26594+
"toequal": [1]
26595+
}
26596+
},
2648026597
{
2648126598
"error": false,
2648226599
"message": "",
@@ -26489,6 +26606,19 @@
2648926606
"outputs": {
2649026607
"toequal": [1]
2649126608
}
26609+
},
26610+
{
26611+
"error": false,
26612+
"message": "",
26613+
"inputs": {
26614+
"tmpptr": [1, 2, 3, 4, 5, 6],
26615+
"fromstarts": [2, 2, 2, 2, 2, 2],
26616+
"fromstops": [4, 4, 4, 4, 4, 4],
26617+
"length": 6
26618+
},
26619+
"outputs": {
26620+
"toequal": [1]
26621+
}
2649226622
}
2649326623
]
2649426624
},

src/awkward/_connect/cuda/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def fetch_template_specializations(kernel_dict):
8585
"awkward_IndexedArray_flatten_none2empty",
8686
"awkward_IndexedArray_getitem_nextcarry",
8787
"awkward_IndexedArray_getitem_nextcarry_outindex",
88-
"awkward_ListArray_getitem_next_range_counts",
8988
"awkward_IndexedArray_index_of_nulls",
9089
"awkward_IndexedArray_ranges_next_64",
9190
"awkward_IndexedArray_ranges_carry_next_64",
@@ -103,19 +102,18 @@ def fetch_template_specializations(kernel_dict):
103102
"awkward_ListArray_getitem_jagged_shrink",
104103
"awkward_ListArray_getitem_next_range",
105104
"awkward_ListArray_getitem_next_range_carrylength",
105+
"awkward_ListArray_getitem_next_range_counts",
106106
"awkward_ListArray_min_range",
107107
"awkward_ListArray_rpad_and_clip_length_axis1",
108108
"awkward_ListArray_rpad_axis1",
109109
"awkward_ListOffsetArray_drop_none_indexes",
110-
"awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64",
111110
"awkward_ListOffsetArray_reduce_nonlocal_maxcount_offsetscopy_64",
111+
"awkward_UnionArray_regular_index",
112+
"awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64",
113+
"awkward_ListOffsetArray_rpad_axis1",
112114
"awkward_ListOffsetArray_rpad_length_axis1",
113115
"awkward_MaskedArray_getitem_next_jagged_project",
114-
"awkward_UnionArray_nestedfill_tags_index",
115116
"awkward_NumpyArray_rearrange_shifted",
116-
"awkward_UnionArray_flatten_length",
117-
"awkward_UnionArray_flatten_combine",
118-
"awkward_UnionArray_project",
119117
"awkward_reduce_count_64",
120118
"awkward_reduce_sum",
121119
"awkward_reduce_sum_int32_bool_64",
@@ -129,6 +127,10 @@ def fetch_template_specializations(kernel_dict):
129127
"awkward_reduce_min",
130128
"awkward_sorting_ranges",
131129
"awkward_sorting_ranges_length",
130+
"awkward_UnionArray_flatten_length",
131+
"awkward_UnionArray_flatten_combine",
132+
"awkward_UnionArray_nestedfill_tags_index",
133+
"awkward_UnionArray_project",
132134
]
133135
template_specializations = []
134136
import re

src/awkward/_connect/cuda/cuda_kernels/awkward_IndexedArray_fill.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ awkward_IndexedArray_fill(
1414
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
1515
if (thread_id < length) {
1616
C fromval = fromindex[thread_id];
17-
toindex[toindexoffset + thread_id] = fromval < 0 ? -1 : (C)(fromval + base);
17+
toindex[toindexoffset + thread_id] = fromval < 0 ? (C)-1 : (C)(fromval + base);
1818
}
1919
}
2020
}

src/awkward/_connect/cuda/cuda_kernels/awkward_IndexedArray_getitem_nextcarry_outindex.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ awkward_IndexedArray_getitem_nextcarry_outindex_b(
6060
RAISE_ERROR(
6161
INDEXEDARRAY_GETITEM_NEXTCARRY_OUTINDEX_ERRORS::IND_OUT_OF_RANGE)
6262
} else if (j < 0) {
63-
toindex[thread_id] = -1;
63+
toindex[thread_id] = (C)-1;
6464
} else {
6565
tocarry[scan_in_array[thread_id] - 1] = j;
6666
toindex[thread_id] = (C)(scan_in_array[thread_id] - 1);

0 commit comments

Comments
 (0)