5
5
// (tocarry, toindex, fromindex, n, replacement, starts, stops, length, invocation_index, err_code) = args
6
6
// scan_in_array_offsets = cupy.zeros(length + 1, dtype=cupy.int64)
7
7
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_a", tocarry[0].dtype, toindex.dtype, fromindex.dtype, starts.dtype, stops.dtype]))(grid, block, (tocarry, toindex, fromindex, n, replacement, starts, stops, length, scan_in_array_offsets, invocation_index, err_code))
8
- // scan_in_array_offsets = cupy.cumsum(scan_in_array_offsets)
9
- // scan_in_array_parents = cupy.zeros(int(scan_in_array_offsets[length]), dtype=cupy.int64)
10
- // scan_in_array_local_indices = cupy.zeros(int(scan_in_array_offsets[length]), dtype=cupy.int64)
11
- // for i in range(1, length + 1):
12
- // scan_in_array_parents[scan_in_array_offsets[i - 1]:scan_in_array_offsets[i]] = i - 1
13
- // if int(scan_in_array_offsets[length]) < 1024:
14
- // block_size = int(scan_in_array_offsets[length])
15
- // else:
16
- // block_size = 1024
17
- // if block_size > 0:
18
- // grid_size = math.floor((int(scan_in_array_offsets[length]) + block_size - 1) / block_size)
19
- // else:
20
- // grid_size = 1
21
- // cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_b", tocarry[0].dtype, toindex.dtype, fromindex.dtype, starts.dtype, stops.dtype]))((grid_size,), (block_size,), (tocarry, toindex, fromindex, n, replacement, starts, stops, length, scan_in_array_offsets, scan_in_array_parents, scan_in_array_local_indices, invocation_index, err_code))
8
+ // cupy.cumsum(scan_in_array_offsets, out = scan_in_array_offsets)
9
+ // totallen=int(scan_in_array_offsets[length])
10
+ // if totallen == 0:
11
+ // return # Nothing to do if no combinations, skip the rest
12
+ // block_size = min(1024, totallen)
13
+ // grid_size = (totallen + block_size - 1)//block_size
14
+ // scan_in_array_parents = cupy.zeros(totallen, dtype=cupy.int64)
15
+ // scan_in_array_local_indices = cupy.zeros(totallen, dtype=cupy.int64)
16
+ // cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_b", tocarry[0].dtype, toindex.dtype, fromindex.dtype, starts.dtype, stops.dtype]))((grid_size,), (block_size,), (tocarry, toindex, fromindex, n, replacement, starts, stops, length, scan_in_array_offsets, scan_in_array_parents, invocation_index, err_code))
22
17
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_c", tocarry[0].dtype, toindex.dtype, fromindex.dtype, starts.dtype, stops.dtype]))((grid_size,), (block_size,), (tocarry, toindex, fromindex, n, replacement, starts, stops, length, scan_in_array_offsets, scan_in_array_parents, scan_in_array_local_indices, invocation_index, err_code))
18
+ // cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_combinations_d", tocarry[0].dtype, toindex.dtype, fromindex.dtype, starts.dtype, stops.dtype]))((grid_size,), (block_size,), (tocarry, toindex, fromindex, n, replacement, starts, stops, length, scan_in_array_offsets, scan_in_array_parents, scan_in_array_local_indices, invocation_index, err_code))
23
19
// out["awkward_ListArray_combinations_a", {dtype_specializations}] = None
24
20
// out["awkward_ListArray_combinations_b", {dtype_specializations}] = None
25
21
// out["awkward_ListArray_combinations_c", {dtype_specializations}] = None
22
+ // out["awkward_ListArray_combinations_d", {dtype_specializations}] = None
26
23
// END PYTHON
27
24
28
25
enum class LISTARRAY_COMBINATIONS_ERRORS {
@@ -43,25 +40,61 @@ awkward_ListArray_combinations_a(
43
40
int64_t * scan_in_array_offsets,
44
41
uint64_t invocation_index,
45
42
uint64_t * err_code) {
46
- if (err_code[0 ] == NO_ERROR) {
47
- int64_t thread_id = blockIdx .x * blockDim .x + threadIdx .x ;
48
- if (thread_id < length) {
49
- if (n != 2 ) {
50
- RAISE_ERROR (LISTARRAY_COMBINATIONS_ERRORS::N_NOT_IMPLEMENTED)
51
- }
52
- int64_t counts = stops[thread_id] - starts[thread_id];
53
- if (replacement) {
54
- scan_in_array_offsets[thread_id + 1 ] = counts * (counts + 1 ) / 2 ;
55
- } else {
56
- scan_in_array_offsets[thread_id + 1 ] = counts * (counts - 1 ) / 2 ;
57
- }
43
+ if (err_code[0 ] != NO_ERROR) {
44
+ return ;
45
+ }
46
+
47
+ // For now only n==2 supported
48
+ if (n != 2 ) {
49
+ if (threadIdx .x == 0 && blockIdx .x == 0 ) {
50
+ RAISE_ERROR (LISTARRAY_COMBINATIONS_ERRORS::N_NOT_IMPLEMENTED)
58
51
}
52
+ return ;
53
+ }
54
+
55
+ int64_t thread_id = blockIdx .x * blockDim .x + threadIdx .x ;
56
+
57
+ // Grid-stride loop for general scalability
58
+ for (int64_t i = thread_id; i < length; i += gridDim .x * blockDim .x ) {
59
+ V start = starts[i];
60
+ W stop = stops[i];
61
+ int64_t counts = stop - start;
62
+ int64_t result = replacement
63
+ ? counts * (counts + 1 ) / 2
64
+ : counts * (counts - 1 ) / 2 ;
65
+ scan_in_array_offsets[i + 1 ] = result;
59
66
}
60
67
}
61
68
62
69
template <typename T, typename C, typename U, typename V, typename W>
63
70
__global__ void
64
71
awkward_ListArray_combinations_b (
72
+ T** tocarry,
73
+ C* toindex,
74
+ U* fromindex,
75
+ int64_t n,
76
+ bool replacement,
77
+ const V* starts,
78
+ const W* stops,
79
+ int64_t length,
80
+ const int64_t * __restrict__ scan_in_array_offsets,
81
+ int64_t * __restrict__ scan_in_array_parents,
82
+ uint64_t invocation_index,
83
+ uint64_t * err_code) {
84
+ int64_t thread_id = blockIdx .x * blockDim .x + threadIdx .x ;
85
+ if (thread_id >= length) return ;
86
+
87
+ int64_t start = scan_in_array_offsets[thread_id];
88
+ int64_t stop = scan_in_array_offsets[thread_id + 1 ];
89
+
90
+ for (int64_t i = start; i < stop; i++) {
91
+ scan_in_array_parents[i] = thread_id;
92
+ }
93
+ }
94
+
95
+ template <typename T, typename C, typename U, typename V, typename W>
96
+ __global__ void
97
+ awkward_ListArray_combinations_c (
65
98
T** tocarry,
66
99
C* toindex,
67
100
U* fromindex,
@@ -75,22 +108,31 @@ awkward_ListArray_combinations_b(
75
108
int64_t * scan_in_array_local_indices,
76
109
uint64_t invocation_index,
77
110
uint64_t * err_code) {
78
- if (err_code[0 ] == NO_ERROR) {
79
- int64_t thread_id = blockIdx .x * blockDim .x + threadIdx .x ;
80
- int64_t offsetslength = scan_in_array_offsets[length];
81
-
82
- if (thread_id < offsetslength) {
83
- if (n != 2 ) {
84
- RAISE_ERROR (LISTARRAY_COMBINATIONS_ERRORS::N_NOT_IMPLEMENTED)
85
- }
86
- scan_in_array_local_indices[thread_id] = thread_id - scan_in_array_offsets[scan_in_array_parents[thread_id]];
111
+ if (err_code[0 ] != NO_ERROR) {
112
+ return ;
113
+ }
114
+
115
+ // For now only n==2 supported
116
+ if (n != 2 ) {
117
+ if (threadIdx .x == 0 && blockIdx .x == 0 ) {
118
+ RAISE_ERROR (LISTARRAY_COMBINATIONS_ERRORS::N_NOT_IMPLEMENTED)
87
119
}
120
+ return ;
121
+ }
122
+
123
+ int64_t offsetslength = scan_in_array_offsets[length];
124
+ int64_t thread_id = blockIdx .x * blockDim .x + threadIdx .x ;
125
+
126
+ // Grid-stride loop
127
+ for (int64_t i = thread_id; i < offsetslength; i += gridDim .x * blockDim .x ) {
128
+ int64_t parent_offset = scan_in_array_offsets[scan_in_array_parents[i]];
129
+ scan_in_array_local_indices[i] = i - parent_offset;
88
130
}
89
131
}
90
132
91
133
template <typename T, typename C, typename U, typename V, typename W>
92
134
__global__ void
93
- awkward_ListArray_combinations_c (
135
+ awkward_ListArray_combinations_d (
94
136
T** tocarry,
95
137
C* toindex,
96
138
U* fromindex,
@@ -104,38 +146,55 @@ awkward_ListArray_combinations_c(
104
146
int64_t * scan_in_array_local_indices,
105
147
uint64_t invocation_index,
106
148
uint64_t * err_code) {
107
- if (err_code[0 ] == NO_ERROR) {
108
- int64_t thread_id = blockIdx .x * blockDim .x + threadIdx .x ;
109
- int64_t offsetslength = scan_in_array_offsets[length];
110
- int64_t i = 0 ;
111
- int64_t j = 0 ;
112
-
113
- if (thread_id < offsetslength) {
114
- if (n != 2 ) {
115
- RAISE_ERROR (LISTARRAY_COMBINATIONS_ERRORS::N_NOT_IMPLEMENTED)
116
- }
117
-
118
- int64_t n = stops[scan_in_array_parents[thread_id]] - starts[scan_in_array_parents[thread_id]];
119
-
120
- if (replacement) {
121
- int64_t b = 2 * n + 1 ;
122
- float discriminant = sqrtf (b * b - 8 * scan_in_array_local_indices[thread_id]);
123
- i = (int64_t )((b - discriminant) / 2 );
124
- j = scan_in_array_local_indices[thread_id] + i * (i - b + 2 ) / 2 ;
125
- } else {
126
- int64_t b = 2 * n - 1 ;
127
- float discriminant = sqrtf (b * b - 8 * scan_in_array_local_indices[thread_id]);
128
- i = (int64_t )((b - discriminant) / 2 );
129
- j = scan_in_array_local_indices[thread_id] + i * (i - b + 2 ) / 2 + 1 ;
130
- }
131
-
132
- i += starts[scan_in_array_parents[thread_id]];
133
- j += starts[scan_in_array_parents[thread_id]];
134
-
135
- tocarry[0 ][thread_id] = i;
136
- tocarry[1 ][thread_id] = j;
137
- toindex[0 ] = offsetslength;
138
- toindex[1 ] = offsetslength;
149
+ if (err_code[0 ] != NO_ERROR) {
150
+ return ;
151
+ }
152
+
153
+ // For now only n==2 supported
154
+ if (n != 2 ) {
155
+ if (threadIdx .x == 0 && blockIdx .x == 0 ) {
156
+ RAISE_ERROR (LISTARRAY_COMBINATIONS_ERRORS::N_NOT_IMPLEMENTED)
139
157
}
158
+ return ;
159
+ }
160
+
161
+ int64_t offsetslength = scan_in_array_offsets[length];
162
+ int64_t thread_id = blockIdx .x * blockDim .x + threadIdx .x ;
163
+
164
+ // Grid-stride loop
165
+ for (int64_t idx = thread_id; idx < offsetslength; idx += gridDim .x * blockDim .x ) {
166
+
167
+ int64_t parent = scan_in_array_parents[idx];
168
+ V start = starts[parent];
169
+ W stop = stops[parent];
170
+ int64_t count = stop - start;
171
+ int64_t local_index = scan_in_array_local_indices[idx];
172
+
173
+ float discriminant;
174
+ int64_t i, j;
175
+
176
+ if (replacement) {
177
+ int64_t b = 2 * count + 1 ;
178
+ discriminant = sqrtf (float (b * b - 8 * local_index));
179
+ i = (int64_t )((b - discriminant) / 2 .0f );
180
+ j = local_index + i * (i - b + 2 ) / 2 ;
181
+ } else {
182
+ int64_t b = 2 * count - 1 ;
183
+ discriminant = sqrtf (float (b * b - 8 * local_index));
184
+ i = (int64_t )((b - discriminant) / 2 .0f );
185
+ j = local_index + i * (i - b + 2 ) / 2 + 1 ;
186
+ }
187
+
188
+ i += start;
189
+ j += start;
190
+
191
+ tocarry[0 ][idx] = i;
192
+ tocarry[1 ][idx] = j;
193
+ }
194
+
195
+ // Set toindex[0] and [1] only once per kernel call (thread 0 of block 0)
196
+ if (threadIdx .x == 0 && blockIdx .x == 0 ) {
197
+ toindex[0 ] = offsetslength;
198
+ toindex[1 ] = offsetslength;
140
199
}
141
200
}
0 commit comments