|
| 1 | +// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE |
| 2 | + |
| 3 | +// BEGIN PYTHON |
| 4 | +// def f(grid, block, args): |
| 5 | +// (tooffsets, tocarry, slicestarts, slicestops, sliceouterlen, sliceindex, sliceinnerlen, fromstarts, fromstops, contentlen, invocation_index, err_code) = args |
| 6 | +// scan_in_array = cupy.zeros(sliceouterlen + 1, dtype=cupy.int64) |
| 7 | +// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_apply_a", tooffsets.dtype, tocarry.dtype, slicestarts.dtype, slicestops.dtype, sliceindex.dtype, fromstarts.dtype, fromstops.dtype]))(grid, block, (tooffsets, tocarry, slicestarts, slicestops, sliceouterlen, sliceindex, sliceinnerlen, fromstarts, fromstops, contentlen, scan_in_array, invocation_index, err_code)) |
| 8 | +// scan_in_array = cupy.cumsum(scan_in_array) |
| 9 | +// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_apply_b", tooffsets.dtype, tocarry.dtype, slicestarts.dtype, slicestops.dtype, sliceindex.dtype, fromstarts.dtype, fromstops.dtype]))(grid, block, (tooffsets, tocarry, slicestarts, slicestops, sliceouterlen, sliceindex, sliceinnerlen, fromstarts, fromstops, contentlen, scan_in_array, invocation_index, err_code)) |
| 10 | +// out["awkward_ListArray_getitem_jagged_apply_a", {dtype_specializations}] = None |
| 11 | +// out["awkward_ListArray_getitem_jagged_apply_b", {dtype_specializations}] = None |
| 12 | +// END PYTHON |
| 13 | + |
| 14 | +enum class LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS { |
| 15 | + JAG_STOP_LT_START, // message: "jagged slice's stops[i] < starts[i]" |
| 16 | + OFF_GET_CON, // message: "jagged slice's offsets extend beyond its content" |
| 17 | + STOP_LT_START, // message: "stops[i] < starts[i]" |
| 18 | + STOP_GET_LEN, // message: "stops[i] > len(content)" |
| 19 | + IND_OUT_OF_RANGE, // message: "index out of range" |
| 20 | +}; |
| 21 | + |
| 22 | +template <typename T, typename C, typename U, typename V, typename W, typename X, typename Y> |
| 23 | +__global__ void |
| 24 | +awkward_ListArray_getitem_jagged_apply_a( |
| 25 | + T* tooffsets, |
| 26 | + C* tocarry, |
| 27 | + const U* slicestarts, |
| 28 | + const V* slicestops, |
| 29 | + int64_t sliceouterlen, |
| 30 | + const W* sliceindex, |
| 31 | + int64_t sliceinnerlen, |
| 32 | + const X* fromstarts, |
| 33 | + const Y* fromstops, |
| 34 | + int64_t contentlen, |
| 35 | + int64_t* scan_in_array, |
| 36 | + uint64_t invocation_index, |
| 37 | + uint64_t* err_code) { |
| 38 | + if (err_code[0] == NO_ERROR) { |
| 39 | + int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; |
| 40 | + scan_in_array[0] = 0; |
| 41 | + |
| 42 | + if (thread_id < sliceouterlen) { |
| 43 | + U slicestart = slicestarts[thread_id]; |
| 44 | + V slicestop = slicestops[thread_id]; |
| 45 | + |
| 46 | + if (slicestart != slicestop) { |
| 47 | + if (slicestop < slicestart) { |
| 48 | + RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::JAG_STOP_LT_START) |
| 49 | + } |
| 50 | + if (slicestop > sliceinnerlen) { |
| 51 | + RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::OFF_GET_CON) |
| 52 | + } |
| 53 | + int64_t start = (int64_t)fromstarts[thread_id]; |
| 54 | + int64_t stop = (int64_t)fromstops[thread_id]; |
| 55 | + if (stop < start) { |
| 56 | + RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_LT_START) |
| 57 | + } |
| 58 | + if (start != stop && stop > contentlen) { |
| 59 | + RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_GET_LEN) |
| 60 | + } |
| 61 | + scan_in_array[thread_id + 1] = slicestop - slicestart; |
| 62 | + } |
| 63 | + } |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +template <typename T, typename C, typename U, typename V, typename W, typename X, typename Y> |
| 68 | +__global__ void |
| 69 | +awkward_ListArray_getitem_jagged_apply_b( |
| 70 | + T* tooffsets, |
| 71 | + C* tocarry, |
| 72 | + const U* slicestarts, |
| 73 | + const V* slicestops, |
| 74 | + int64_t sliceouterlen, |
| 75 | + const W* sliceindex, |
| 76 | + int64_t sliceinnerlen, |
| 77 | + const X* fromstarts, |
| 78 | + const Y* fromstops, |
| 79 | + int64_t contentlen, |
| 80 | + int64_t* scan_in_array, |
| 81 | + uint64_t invocation_index, |
| 82 | + uint64_t* err_code) { |
| 83 | + if (err_code[0] == NO_ERROR) { |
| 84 | + int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; |
| 85 | + |
| 86 | + if (thread_id < sliceouterlen) { |
| 87 | + U slicestart = slicestarts[thread_id]; |
| 88 | + V slicestop = slicestops[thread_id]; |
| 89 | + tooffsets[thread_id] = (T)(scan_in_array[thread_id]); |
| 90 | + if (slicestart != slicestop) { |
| 91 | + if (slicestop < slicestart) { |
| 92 | + RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::JAG_STOP_LT_START) |
| 93 | + } |
| 94 | + if (slicestop > sliceinnerlen) { |
| 95 | + RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::OFF_GET_CON) |
| 96 | + } |
| 97 | + int64_t start = (int64_t)fromstarts[thread_id]; |
| 98 | + int64_t stop = (int64_t)fromstops[thread_id]; |
| 99 | + if (stop < start) { |
| 100 | + RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_LT_START) |
| 101 | + } |
| 102 | + if (start != stop && stop > contentlen) { |
| 103 | + RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_GET_LEN) |
| 104 | + } |
| 105 | + int64_t count = stop - start; |
| 106 | + for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) { |
| 107 | + int64_t index = (int64_t) sliceindex[j]; |
| 108 | + if (index < -count || index > count) { |
| 109 | + RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::IND_OUT_OF_RANGE) |
| 110 | + } |
| 111 | + if (index < 0) { |
| 112 | + index += count; |
| 113 | + } |
| 114 | + tocarry[scan_in_array[thread_id] + j - slicestart] = start + index; |
| 115 | + } |
| 116 | + } |
| 117 | + } |
| 118 | + tooffsets[sliceouterlen] = scan_in_array[sliceouterlen]; |
| 119 | + } |
| 120 | +} |
0 commit comments