Skip to content

Commit 914e012

Browse files
fix: runtime for ak.combinations with cuda backend (#3502)
* optimize awkward_ListArray_combinations_length.cu * optimize awkward_ListArray_combinations.cu * update awkward_ListArray_combinations.cu * update awkward_ListArray_combinations_length.cu * cleanup * style: pre-commit fixes * Update awkward_ListArray_combinations.cu --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d688c0d commit 914e012

File tree

2 files changed

+177
-111
lines changed

2 files changed

+177
-111
lines changed

src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_combinations.cu

Lines changed: 127 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,21 @@
55
// (tocarry, toindex, fromindex, n, replacement, starts, stops, length, invocation_index, err_code) = args
66
// scan_in_array_offsets = cupy.zeros(length + 1, dtype=cupy.int64)
77
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_a", tocarry[0].dtype, toindex.dtype, fromindex.dtype, starts.dtype, stops.dtype]))(grid, block, (tocarry, toindex, fromindex, n, replacement, starts, stops, length, scan_in_array_offsets, invocation_index, err_code))
8-
// scan_in_array_offsets = cupy.cumsum(scan_in_array_offsets)
9-
// scan_in_array_parents = cupy.zeros(int(scan_in_array_offsets[length]), dtype=cupy.int64)
10-
// scan_in_array_local_indices = cupy.zeros(int(scan_in_array_offsets[length]), dtype=cupy.int64)
11-
// for i in range(1, length + 1):
12-
// scan_in_array_parents[scan_in_array_offsets[i - 1]:scan_in_array_offsets[i]] = i - 1
13-
// if int(scan_in_array_offsets[length]) < 1024:
14-
// block_size = int(scan_in_array_offsets[length])
15-
// else:
16-
// block_size = 1024
17-
// if block_size > 0:
18-
// grid_size = math.floor((int(scan_in_array_offsets[length]) + block_size - 1) / block_size)
19-
// else:
20-
// grid_size = 1
21-
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_b", tocarry[0].dtype, toindex.dtype, fromindex.dtype, starts.dtype, stops.dtype]))((grid_size,), (block_size,), (tocarry, toindex, fromindex, n, replacement, starts, stops, length, scan_in_array_offsets, scan_in_array_parents, scan_in_array_local_indices, invocation_index, err_code))
8+
// cupy.cumsum(scan_in_array_offsets, out = scan_in_array_offsets)
9+
// totallen=int(scan_in_array_offsets[length])
10+
// if totallen == 0:
11+
// return # Nothing to do if no combinations, skip the rest
12+
// block_size = min(1024, totallen)
13+
// grid_size = (totallen + block_size - 1)//block_size
14+
// scan_in_array_parents = cupy.zeros(totallen, dtype=cupy.int64)
15+
// scan_in_array_local_indices = cupy.zeros(totallen, dtype=cupy.int64)
16+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_b", tocarry[0].dtype, toindex.dtype, fromindex.dtype, starts.dtype, stops.dtype]))((grid_size,), (block_size,), (tocarry, toindex, fromindex, n, replacement, starts, stops, length, scan_in_array_offsets, scan_in_array_parents, invocation_index, err_code))
2217
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_c", tocarry[0].dtype, toindex.dtype, fromindex.dtype, starts.dtype, stops.dtype]))((grid_size,), (block_size,), (tocarry, toindex, fromindex, n, replacement, starts, stops, length, scan_in_array_offsets, scan_in_array_parents, scan_in_array_local_indices, invocation_index, err_code))
18+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_d", tocarry[0].dtype, toindex.dtype, fromindex.dtype, starts.dtype, stops.dtype]))((grid_size,), (block_size,), (tocarry, toindex, fromindex, n, replacement, starts, stops, length, scan_in_array_offsets, scan_in_array_parents, scan_in_array_local_indices, invocation_index, err_code))
2319
// out["awkward_ListArray_combinations_a", {dtype_specializations}] = None
2420
// out["awkward_ListArray_combinations_b", {dtype_specializations}] = None
2521
// out["awkward_ListArray_combinations_c", {dtype_specializations}] = None
22+
// out["awkward_ListArray_combinations_d", {dtype_specializations}] = None
2623
// END PYTHON
2724

2825
enum class LISTARRAY_COMBINATIONS_ERRORS {
@@ -43,25 +40,61 @@ awkward_ListArray_combinations_a(
4340
int64_t* scan_in_array_offsets,
4441
uint64_t invocation_index,
4542
uint64_t* err_code) {
46-
if (err_code[0] == NO_ERROR) {
47-
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
48-
if (thread_id < length) {
49-
if (n != 2) {
50-
RAISE_ERROR(LISTARRAY_COMBINATIONS_ERRORS::N_NOT_IMPLEMENTED)
51-
}
52-
int64_t counts = stops[thread_id] - starts[thread_id];
53-
if (replacement) {
54-
scan_in_array_offsets[thread_id + 1] = counts * (counts + 1) / 2;
55-
} else {
56-
scan_in_array_offsets[thread_id + 1] = counts * (counts - 1) / 2;
57-
}
43+
if (err_code[0] != NO_ERROR) {
44+
return;
45+
}
46+
47+
// For now only n==2 supported
48+
if (n != 2) {
49+
if (threadIdx.x == 0 && blockIdx.x == 0) {
50+
RAISE_ERROR(LISTARRAY_COMBINATIONS_ERRORS::N_NOT_IMPLEMENTED)
5851
}
52+
return;
53+
}
54+
55+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
56+
57+
// Grid-stride loop for general scalability
58+
for (int64_t i = thread_id; i < length; i += gridDim.x * blockDim.x) {
59+
V start = starts[i];
60+
W stop = stops[i];
61+
int64_t counts = stop - start;
62+
int64_t result = replacement
63+
? counts * (counts + 1) / 2
64+
: counts * (counts - 1) / 2;
65+
scan_in_array_offsets[i + 1] = result;
5966
}
6067
}
6168

6269
template <typename T, typename C, typename U, typename V, typename W>
6370
__global__ void
6471
awkward_ListArray_combinations_b(
72+
T** tocarry,
73+
C* toindex,
74+
U* fromindex,
75+
int64_t n,
76+
bool replacement,
77+
const V* starts,
78+
const W* stops,
79+
int64_t length,
80+
const int64_t* __restrict__ scan_in_array_offsets,
81+
int64_t* __restrict__ scan_in_array_parents,
82+
uint64_t invocation_index,
83+
uint64_t* err_code) {
84+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
85+
if (thread_id >= length) return;
86+
87+
int64_t start = scan_in_array_offsets[thread_id];
88+
int64_t stop = scan_in_array_offsets[thread_id + 1];
89+
90+
for (int64_t i = start; i < stop; i++) {
91+
scan_in_array_parents[i] = thread_id;
92+
}
93+
}
94+
95+
template <typename T, typename C, typename U, typename V, typename W>
96+
__global__ void
97+
awkward_ListArray_combinations_c(
6598
T** tocarry,
6699
C* toindex,
67100
U* fromindex,
@@ -75,22 +108,31 @@ awkward_ListArray_combinations_b(
75108
int64_t* scan_in_array_local_indices,
76109
uint64_t invocation_index,
77110
uint64_t* err_code) {
78-
if (err_code[0] == NO_ERROR) {
79-
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
80-
int64_t offsetslength = scan_in_array_offsets[length];
81-
82-
if (thread_id < offsetslength) {
83-
if (n != 2) {
84-
RAISE_ERROR(LISTARRAY_COMBINATIONS_ERRORS::N_NOT_IMPLEMENTED)
85-
}
86-
scan_in_array_local_indices[thread_id] = thread_id - scan_in_array_offsets[scan_in_array_parents[thread_id]];
111+
if (err_code[0] != NO_ERROR) {
112+
return;
113+
}
114+
115+
// For now only n==2 supported
116+
if (n != 2) {
117+
if (threadIdx.x == 0 && blockIdx.x == 0) {
118+
RAISE_ERROR(LISTARRAY_COMBINATIONS_ERRORS::N_NOT_IMPLEMENTED)
87119
}
120+
return;
121+
}
122+
123+
int64_t offsetslength = scan_in_array_offsets[length];
124+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
125+
126+
// Grid-stride loop
127+
for (int64_t i = thread_id; i < offsetslength; i += gridDim.x * blockDim.x) {
128+
int64_t parent_offset = scan_in_array_offsets[scan_in_array_parents[i]];
129+
scan_in_array_local_indices[i] = i - parent_offset;
88130
}
89131
}
90132

91133
template <typename T, typename C, typename U, typename V, typename W>
92134
__global__ void
93-
awkward_ListArray_combinations_c(
135+
awkward_ListArray_combinations_d(
94136
T** tocarry,
95137
C* toindex,
96138
U* fromindex,
@@ -104,38 +146,55 @@ awkward_ListArray_combinations_c(
104146
int64_t* scan_in_array_local_indices,
105147
uint64_t invocation_index,
106148
uint64_t* err_code) {
107-
if (err_code[0] == NO_ERROR) {
108-
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
109-
int64_t offsetslength = scan_in_array_offsets[length];
110-
int64_t i = 0;
111-
int64_t j = 0;
112-
113-
if (thread_id < offsetslength) {
114-
if (n != 2) {
115-
RAISE_ERROR(LISTARRAY_COMBINATIONS_ERRORS::N_NOT_IMPLEMENTED)
116-
}
117-
118-
int64_t n = stops[scan_in_array_parents[thread_id]] - starts[scan_in_array_parents[thread_id]];
119-
120-
if (replacement) {
121-
int64_t b = 2 * n + 1;
122-
float discriminant = sqrtf(b * b - 8 * scan_in_array_local_indices[thread_id]);
123-
i = (int64_t)((b - discriminant) / 2);
124-
j = scan_in_array_local_indices[thread_id] + i * (i - b + 2) / 2;
125-
} else {
126-
int64_t b = 2 * n - 1;
127-
float discriminant = sqrtf(b * b - 8 * scan_in_array_local_indices[thread_id]);
128-
i = (int64_t)((b - discriminant) / 2);
129-
j = scan_in_array_local_indices[thread_id] + i * (i - b + 2) / 2 + 1;
130-
}
131-
132-
i += starts[scan_in_array_parents[thread_id]];
133-
j += starts[scan_in_array_parents[thread_id]];
134-
135-
tocarry[0][thread_id] = i;
136-
tocarry[1][thread_id] = j;
137-
toindex[0] = offsetslength;
138-
toindex[1] = offsetslength;
149+
if (err_code[0] != NO_ERROR) {
150+
return;
151+
}
152+
153+
// For now only n==2 supported
154+
if (n != 2) {
155+
if (threadIdx.x == 0 && blockIdx.x == 0) {
156+
RAISE_ERROR(LISTARRAY_COMBINATIONS_ERRORS::N_NOT_IMPLEMENTED)
139157
}
158+
return;
159+
}
160+
161+
int64_t offsetslength = scan_in_array_offsets[length];
162+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
163+
164+
// Grid-stride loop
165+
for (int64_t idx = thread_id; idx < offsetslength; idx += gridDim.x * blockDim.x) {
166+
167+
int64_t parent = scan_in_array_parents[idx];
168+
V start = starts[parent];
169+
W stop = stops[parent];
170+
int64_t count = stop - start;
171+
int64_t local_index = scan_in_array_local_indices[idx];
172+
173+
float discriminant;
174+
int64_t i, j;
175+
176+
if (replacement) {
177+
int64_t b = 2 * count + 1;
178+
discriminant = sqrtf(float(b * b - 8 * local_index));
179+
i = (int64_t)((b - discriminant) / 2.0f);
180+
j = local_index + i * (i - b + 2) / 2;
181+
} else {
182+
int64_t b = 2 * count - 1;
183+
discriminant = sqrtf(float(b * b - 8 * local_index));
184+
i = (int64_t)((b - discriminant) / 2.0f);
185+
j = local_index + i * (i - b + 2) / 2 + 1;
186+
}
187+
188+
i += start;
189+
j += start;
190+
191+
tocarry[0][idx] = i;
192+
tocarry[1][idx] = j;
193+
}
194+
195+
// Set toindex[0] and [1] only once per kernel call (thread 0 of block 0)
196+
if (threadIdx.x == 0 && blockIdx.x == 0) {
197+
toindex[0] = offsetslength;
198+
toindex[1] = offsetslength;
140199
}
141200
}

src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_combinations_length.cu

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
// BEGIN PYTHON
44
// def f(grid, block, args):
55
// (totallen, tooffsets, n, replacement, starts, stops, length, invocation_index, err_code) = args
6-
// scan_in_array_totallen = cupy.zeros(length, dtype=cupy.int64)
7-
// scan_in_array_tooffsets = cupy.zeros(length, dtype=cupy.int64)
8-
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_length_a", totallen.dtype, tooffsets.dtype, starts.dtype, stops.dtype]))(grid, block, (totallen, tooffsets, n, replacement, starts, stops, length, scan_in_array_totallen, scan_in_array_tooffsets, invocation_index, err_code))
9-
// scan_in_array_totallen = cupy.cumsum(scan_in_array_totallen)
10-
// scan_in_array_tooffsets = cupy.cumsum(scan_in_array_tooffsets)
11-
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_length_b", totallen.dtype, tooffsets.dtype, starts.dtype, stops.dtype]))(grid, block, (totallen, tooffsets, n, replacement, starts, stops, length, scan_in_array_totallen, scan_in_array_tooffsets, invocation_index, err_code))
6+
// scan_out = cupy.zeros(length, dtype=cupy.int64)
7+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_length_a", totallen.dtype, tooffsets.dtype, starts.dtype, stops.dtype]))(grid, block, (totallen, tooffsets, n, replacement, starts, stops, length, scan_out, invocation_index, err_code))
8+
// cupy.cumsum(scan_out, out=scan_out)
9+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_length_b", totallen.dtype, tooffsets.dtype, starts.dtype, stops.dtype]))(grid, block, (totallen, tooffsets, n, replacement, starts, stops, length, scan_out, invocation_index, err_code))
1210
// out["awkward_ListArray_combinations_length_a", {dtype_specializations}] = None
1311
// out["awkward_ListArray_combinations_length_b", {dtype_specializations}] = None
1412
// END PYTHON
@@ -23,40 +21,42 @@ awkward_ListArray_combinations_length_a(
2321
const U* starts,
2422
const V* stops,
2523
int64_t length,
26-
int64_t* scan_in_array_totallen,
27-
int64_t* scan_in_array_tooffsets,
24+
int64_t* scan_out,
2825
uint64_t invocation_index,
2926
uint64_t* err_code) {
30-
if (err_code[0] == NO_ERROR) {
31-
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
27+
if (err_code[0] != NO_ERROR) {
28+
return;
29+
}
30+
31+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
32+
if (thread_id >= length) {
33+
return;
34+
}
35+
36+
int64_t size = stops[thread_id] - starts[thread_id];
37+
int64_t combinationslen = 0;
38+
39+
if (replacement) {
40+
size += (n - 1);
41+
}
42+
43+
if (n > size) {
44+
combinationslen = 0;
45+
}
46+
else if (n == size) {
47+
combinationslen = 1;
48+
}
49+
else {
50+
// Choose the smaller of n and size - n for fewer multiplications
51+
int64_t k = (n * 2 > size) ? (size - n) : n;
3252

33-
if (thread_id < length) {
34-
int64_t size = (int64_t)(stops[thread_id] - starts[thread_id]);
35-
if (replacement) {
36-
size += (n - 1);
37-
}
38-
int64_t thisn = n;
39-
int64_t combinationslen;
40-
if (thisn > size) {
41-
combinationslen = 0;
42-
}
43-
else if (thisn == size) {
44-
combinationslen = 1;
45-
}
46-
else {
47-
if (thisn * 2 > size) {
48-
thisn = size - thisn;
49-
}
50-
combinationslen = size;
51-
for (int64_t j = 2 + threadIdx.y; j <= thisn; j += blockDim.y) {
52-
combinationslen *= (size - j + 1);
53-
combinationslen /= j;
54-
}
55-
}
56-
scan_in_array_totallen[thread_id] = combinationslen;
57-
scan_in_array_tooffsets[thread_id] = combinationslen;
53+
combinationslen = 1;
54+
for (int64_t j = 1; j <= k; ++j) {
55+
combinationslen = (combinationslen * (size - j + 1)) / j;
5856
}
5957
}
58+
59+
scan_out[thread_id] = combinationslen;
6060
}
6161

6262
template <typename T, typename C, typename U, typename V>
@@ -69,17 +69,24 @@ awkward_ListArray_combinations_length_b(
6969
const U* starts,
7070
const V* stops,
7171
int64_t length,
72-
int64_t* scan_in_array_totallen,
73-
int64_t* scan_in_array_tooffsets,
72+
int64_t* scan_out,
7473
uint64_t invocation_index,
7574
uint64_t* err_code) {
76-
if (err_code[0] == NO_ERROR) {
77-
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
78-
*totallen = length > 0 ? scan_in_array_totallen[length - 1] : 0;
75+
76+
if (err_code[0] != NO_ERROR) {
77+
return;
78+
}
79+
80+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
81+
82+
// Let a single thread handle totallen and tooffsets[0]
83+
if (thread_id == 0) {
84+
*totallen = (length > 0) ? scan_out[length - 1] : 0;
7985
tooffsets[0] = 0;
86+
}
8087

81-
if (thread_id < length) {
82-
tooffsets[thread_id + 1] = scan_in_array_tooffsets[thread_id];
83-
}
88+
// Copy scan_out values into tooffsets (shifted by 1)
89+
if (thread_id < length) {
90+
tooffsets[thread_id + 1] = scan_out[thread_id];
8491
}
8592
}

0 commit comments

Comments
 (0)