|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <cub/block/block_scan.cuh> |
| 10 | +#include "common.cuh" |
| 11 | + |
| 12 | +static constexpr uint32_t kMaxThreads = 1024; |
| 13 | + |
| 14 | +namespace fbgemm_gpu { |
| 15 | + |
| 16 | +C10_ALWAYS_INLINE uint32_t next_power_of_2(uint32_t n) { |
| 17 | + n--; |
| 18 | + n |= n >> 1; |
| 19 | + n |= n >> 2; |
| 20 | + n |= n >> 4; |
| 21 | + n |= n >> 8; |
| 22 | + n |= n >> 16; |
| 23 | + n++; |
| 24 | + return n; |
| 25 | +} |
| 26 | + |
| 27 | +template < |
| 28 | + typename val_t, |
| 29 | + typename = std::enable_if_t<std::is_integral<val_t>::value>> |
| 30 | +struct BlockPrefixCallbackOp { |
| 31 | + val_t running_total; |
| 32 | + |
| 33 | + __device__ BlockPrefixCallbackOp(val_t running_total) |
| 34 | + : running_total(running_total) {} |
| 35 | + |
| 36 | + __device__ val_t operator()(val_t block_aggregate) { |
| 37 | + val_t old_prefix = running_total; |
| 38 | + running_total += block_aggregate; |
| 39 | + return old_prefix; |
| 40 | + } |
| 41 | +}; |
| 42 | + |
| 43 | +template < |
| 44 | + typename val_t, |
| 45 | + uint32_t nthreads_per_block, |
| 46 | + typename = std::enable_if_t<std::is_integral<val_t>::value>> |
| 47 | +__global__ __launch_bounds__(kMaxThreads) void _batched_complete_cumsum_kernel( |
| 48 | + const at::PackedTensorAccessor64<val_t, 2, at::RestrictPtrTraits> values, |
| 49 | + const uint32_t len, |
| 50 | + const uint32_t items_per_thread, |
| 51 | + at::PackedTensorAccessor64<val_t, 2, at::RestrictPtrTraits> out) { |
| 52 | + using BlockScan = cub::BlockScan<val_t, nthreads_per_block>; |
| 53 | + __shared__ typename BlockScan::TempStorage temp_storage; |
| 54 | + |
| 55 | + BlockPrefixCallbackOp<val_t> prefix_op(0); |
| 56 | + if (threadIdx.x == 0) { |
| 57 | + out[blockIdx.x][0] = 0; |
| 58 | + } |
| 59 | + |
| 60 | + for (uint32_t offset = 0; offset < items_per_thread; offset++) { |
| 61 | + uint32_t i = offset * nthreads_per_block + threadIdx.x; |
| 62 | + val_t data = 0; |
| 63 | + if (i < len) { |
| 64 | + data = (val_t)values[blockIdx.x][i]; |
| 65 | + } |
| 66 | + BlockScan(temp_storage).InclusiveSum(data, data, prefix_op); |
| 67 | + cub::CTA_SYNC(); |
| 68 | + if (i < len) { |
| 69 | + out[blockIdx.x][i + 1] = data; |
| 70 | + } |
| 71 | + } |
| 72 | +} |
| 73 | + |
| 74 | +at::Tensor asynchronous_batched_complete_cumsum_gpu(const at::Tensor& values) { |
| 75 | + at::cuda::OptionalCUDAGuard device_guard; |
| 76 | + device_guard.set_index(values.get_device()); |
| 77 | + |
| 78 | + TORCH_CHECK(values.dim() == 2, "values of batched_complete_cumsum must be 2") |
| 79 | + TORCH_CHECK( |
| 80 | + values.size(0) <= UINT32_MAX, |
| 81 | + "values.size(0) must be no higher than UINT32_MAX") |
| 82 | + TORCH_CHECK( |
| 83 | + values.size(1) <= UINT32_MAX, |
| 84 | + "values.size(1) must be no higher than UINT32_MAX") |
| 85 | + |
| 86 | + const uint32_t B = values.size(0); |
| 87 | + const uint32_t len = values.size(1); |
| 88 | + const uint32_t nthreads_per_block = |
| 89 | + min(max(next_power_of_2(len), 64), kMaxThreads); |
| 90 | + const uint32_t items_per_thread = div_round_up(len, nthreads_per_block); |
| 91 | + |
| 92 | + auto cumsum = at::empty({B, len + 1}, values.options()); |
| 93 | + |
| 94 | + AT_DISPATCH_INTEGRAL_TYPES( |
| 95 | + values.scalar_type(), "batched_complete_cumsum_cuda_input1", [&] { |
| 96 | + using val_t = scalar_t; |
| 97 | + if (nthreads_per_block == 64) { |
| 98 | + _batched_complete_cumsum_kernel<val_t, 64> |
| 99 | + <<<B, 64, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 100 | + values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(), |
| 101 | + len, |
| 102 | + items_per_thread, |
| 103 | + cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>()); |
| 104 | + } else if (nthreads_per_block == 128) { |
| 105 | + _batched_complete_cumsum_kernel<val_t, 128> |
| 106 | + <<<B, 128, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 107 | + values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(), |
| 108 | + len, |
| 109 | + items_per_thread, |
| 110 | + cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>()); |
| 111 | + } else if (nthreads_per_block == 256) { |
| 112 | + _batched_complete_cumsum_kernel<val_t, 256> |
| 113 | + <<<B, 256, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 114 | + values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(), |
| 115 | + len, |
| 116 | + items_per_thread, |
| 117 | + cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>()); |
| 118 | + } else if (nthreads_per_block == 512) { |
| 119 | + _batched_complete_cumsum_kernel<val_t, 512> |
| 120 | + <<<B, 512, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 121 | + values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(), |
| 122 | + len, |
| 123 | + items_per_thread, |
| 124 | + cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>()); |
| 125 | + } else { |
| 126 | + _batched_complete_cumsum_kernel<val_t, 1024> |
| 127 | + <<<B, 1024, 0, at::cuda::getCurrentCUDAStream()>>>( |
| 128 | + values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(), |
| 129 | + len, |
| 130 | + items_per_thread, |
| 131 | + cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>()); |
| 132 | + } |
| 133 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 134 | + }); |
| 135 | + |
| 136 | + return cumsum; |
| 137 | +} |
| 138 | + |
| 139 | +} // namespace fbgemm_gpu |
| 140 | + |
| 141 | +FBGEMM_OP_DISPATCH( |
| 142 | + CUDA, |
| 143 | + "asynchronous_batched_complete_cumsum", |
| 144 | + fbgemm_gpu::asynchronous_batched_complete_cumsum_gpu); |
0 commit comments