Skip to content

Commit 3cef662

Browse files
LinjianMafacebook-github-bot
authored andcommitted
Move batched_complete_cumsum op to FBGEMM (#4036)
Summary: Pull Request resolved: #4036 X-link: facebookresearch/FBGEMM#1121 X-link: meta-recsys/generative-recommenders#271 `asynchronous_batched_complete_cumsum` implements batched complete_cumsum in a single kernel to avoid writing for loop in python. Move the kernel from `generative_recommenders` folder to FBGEMM so that it can be used in different models. The kernel has not been used in production yet so the change is safe. Reviewed By: sryap, zhaozhul, q10 Differential Revision: D72906062 fbshipit-source-id: 53a4a5c4e0a68449141cefe55c4f66d289589f09
1 parent ed83720 commit 3cef662

File tree

4 files changed

+238
-0
lines changed

4 files changed

+238
-0
lines changed

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ at::Tensor asynchronous_complete_cumsum_gpu(const at::Tensor& t_in);
3838
///@ingroup sparse-data-cuda
3939
at::Tensor asynchronous_inclusive_cumsum_gpu(const at::Tensor& t_in);
4040

41+
///@ingroup sparse-data-cuda
42+
at::Tensor asynchronous_batched_complete_cumsum_gpu(const at::Tensor& t_in);
43+
4144
///@ingroup sparse-data-cpu
4245
at::Tensor asynchronous_exclusive_cumsum_cpu(const at::Tensor& t_in);
4346

@@ -49,6 +52,9 @@ void asynchronous_exclusive_cumsum_cpu_out(
4952
///@ingroup sparse-data-cpu
5053
at::Tensor asynchronous_complete_cumsum_cpu(const at::Tensor& t_in);
5154

55+
///@ingroup sparse-data-cuda
56+
at::Tensor asynchronous_batched_complete_cumsum_cpu(const at::Tensor& t_in);
57+
5258
///@ingroup sparse-data-cpu
5359
at::Tensor asynchronous_complete_cumsum_cpu_out(
5460
at::Tensor& t_out,
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include "fbgemm_gpu/sparse_ops.h"
11+
#include "fbgemm_gpu/utils/dispatch_macros.h"
12+
13+
namespace fbgemm_gpu {
14+
15+
at::Tensor asynchronous_batched_complete_cumsum_cpu(const at::Tensor& values) {
16+
auto B = values.size(0);
17+
auto len = values.size(1);
18+
auto output = at::empty({B, len + 1}, values.options());
19+
const at::Tensor index = at::range(0, len, at::kLong).cpu();
20+
for (auto i : c10::irange(B)) {
21+
at::Tensor t = output[i];
22+
at::index_put_(
23+
t, {index}, fbgemm_gpu::asynchronous_complete_cumsum_cpu(values[i]));
24+
}
25+
return output;
26+
}
27+
28+
at::Tensor asynchronous_batched_complete_cumsum_meta(const at::Tensor& values) {
29+
auto B = values.sym_size(0);
30+
auto len = values.sym_size(1);
31+
auto output = at::native::empty_meta_symint(
32+
{B, len + 1},
33+
/*dtype=*/::std::make_optional(values.scalar_type()),
34+
/*layout=*/::std::make_optional(values.layout()),
35+
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
36+
/*pin_memory=*/::std::nullopt);
37+
return output;
38+
}
39+
40+
} // namespace fbgemm_gpu
41+
42+
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
43+
m.def("asynchronous_batched_complete_cumsum(Tensor values) -> Tensor");
44+
}
45+
46+
TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
47+
m.impl(
48+
"asynchronous_batched_complete_cumsum",
49+
fbgemm_gpu::asynchronous_batched_complete_cumsum_cpu);
50+
}
51+
52+
TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
53+
m.impl(
54+
"asynchronous_batched_complete_cumsum",
55+
fbgemm_gpu::asynchronous_batched_complete_cumsum_meta);
56+
}
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cub/block/block_scan.cuh>
10+
#include "common.cuh"
11+
12+
static constexpr uint32_t kMaxThreads = 1024;
13+
14+
namespace fbgemm_gpu {
15+
16+
C10_ALWAYS_INLINE uint32_t next_power_of_2(uint32_t n) {
17+
n--;
18+
n |= n >> 1;
19+
n |= n >> 2;
20+
n |= n >> 4;
21+
n |= n >> 8;
22+
n |= n >> 16;
23+
n++;
24+
return n;
25+
}
26+
27+
template <
28+
typename val_t,
29+
typename = std::enable_if_t<std::is_integral<val_t>::value>>
30+
struct BlockPrefixCallbackOp {
31+
val_t running_total;
32+
33+
__device__ BlockPrefixCallbackOp(val_t running_total)
34+
: running_total(running_total) {}
35+
36+
__device__ val_t operator()(val_t block_aggregate) {
37+
val_t old_prefix = running_total;
38+
running_total += block_aggregate;
39+
return old_prefix;
40+
}
41+
};
42+
43+
template <
44+
typename val_t,
45+
uint32_t nthreads_per_block,
46+
typename = std::enable_if_t<std::is_integral<val_t>::value>>
47+
__global__ __launch_bounds__(kMaxThreads) void _batched_complete_cumsum_kernel(
48+
const at::PackedTensorAccessor64<val_t, 2, at::RestrictPtrTraits> values,
49+
const uint32_t len,
50+
const uint32_t items_per_thread,
51+
at::PackedTensorAccessor64<val_t, 2, at::RestrictPtrTraits> out) {
52+
using BlockScan = cub::BlockScan<val_t, nthreads_per_block>;
53+
__shared__ typename BlockScan::TempStorage temp_storage;
54+
55+
BlockPrefixCallbackOp<val_t> prefix_op(0);
56+
if (threadIdx.x == 0) {
57+
out[blockIdx.x][0] = 0;
58+
}
59+
60+
for (uint32_t offset = 0; offset < items_per_thread; offset++) {
61+
uint32_t i = offset * nthreads_per_block + threadIdx.x;
62+
val_t data = 0;
63+
if (i < len) {
64+
data = (val_t)values[blockIdx.x][i];
65+
}
66+
BlockScan(temp_storage).InclusiveSum(data, data, prefix_op);
67+
cub::CTA_SYNC();
68+
if (i < len) {
69+
out[blockIdx.x][i + 1] = data;
70+
}
71+
}
72+
}
73+
74+
at::Tensor asynchronous_batched_complete_cumsum_gpu(const at::Tensor& values) {
75+
at::cuda::OptionalCUDAGuard device_guard;
76+
device_guard.set_index(values.get_device());
77+
78+
TORCH_CHECK(values.dim() == 2, "values of batched_complete_cumsum must be 2")
79+
TORCH_CHECK(
80+
values.size(0) <= UINT32_MAX,
81+
"values.size(0) must be no higher than UINT32_MAX")
82+
TORCH_CHECK(
83+
values.size(1) <= UINT32_MAX,
84+
"values.size(1) must be no higher than UINT32_MAX")
85+
86+
const uint32_t B = values.size(0);
87+
const uint32_t len = values.size(1);
88+
const uint32_t nthreads_per_block =
89+
min(max(next_power_of_2(len), 64), kMaxThreads);
90+
const uint32_t items_per_thread = div_round_up(len, nthreads_per_block);
91+
92+
auto cumsum = at::empty({B, len + 1}, values.options());
93+
94+
AT_DISPATCH_INTEGRAL_TYPES(
95+
values.scalar_type(), "batched_complete_cumsum_cuda_input1", [&] {
96+
using val_t = scalar_t;
97+
if (nthreads_per_block == 64) {
98+
_batched_complete_cumsum_kernel<val_t, 64>
99+
<<<B, 64, 0, at::cuda::getCurrentCUDAStream()>>>(
100+
values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(),
101+
len,
102+
items_per_thread,
103+
cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>());
104+
} else if (nthreads_per_block == 128) {
105+
_batched_complete_cumsum_kernel<val_t, 128>
106+
<<<B, 128, 0, at::cuda::getCurrentCUDAStream()>>>(
107+
values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(),
108+
len,
109+
items_per_thread,
110+
cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>());
111+
} else if (nthreads_per_block == 256) {
112+
_batched_complete_cumsum_kernel<val_t, 256>
113+
<<<B, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
114+
values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(),
115+
len,
116+
items_per_thread,
117+
cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>());
118+
} else if (nthreads_per_block == 512) {
119+
_batched_complete_cumsum_kernel<val_t, 512>
120+
<<<B, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
121+
values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(),
122+
len,
123+
items_per_thread,
124+
cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>());
125+
} else {
126+
_batched_complete_cumsum_kernel<val_t, 1024>
127+
<<<B, 1024, 0, at::cuda::getCurrentCUDAStream()>>>(
128+
values.packed_accessor64<val_t, 2, at::RestrictPtrTraits>(),
129+
len,
130+
items_per_thread,
131+
cumsum.packed_accessor64<val_t, 2, at::RestrictPtrTraits>());
132+
}
133+
C10_CUDA_KERNEL_LAUNCH_CHECK();
134+
});
135+
136+
return cumsum;
137+
}
138+
139+
} // namespace fbgemm_gpu
140+
141+
FBGEMM_OP_DISPATCH(
142+
CUDA,
143+
"asynchronous_batched_complete_cumsum",
144+
fbgemm_gpu::asynchronous_batched_complete_cumsum_gpu);

fbgemm_gpu/test/sparse/cumsum_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,38 @@ def test_asynchronous_complete_cumsum_2d(
131131
zc.cpu(),
132132
)
133133

134+
@given(
135+
batch_size=st.integers(10, 1000),
136+
max_len=st.integers(10, 1000),
137+
dtype=st.sampled_from([torch.int32, torch.int64]),
138+
device=cpu_and_maybe_gpu(),
139+
)
140+
@settings(
141+
verbosity=Verbosity.verbose,
142+
max_examples=50,
143+
deadline=None,
144+
)
145+
def test_batched_complete_cumsum(
146+
self,
147+
batch_size: int,
148+
max_len: int,
149+
dtype: torch.dtype,
150+
device: torch.device,
151+
) -> None:
152+
def cumsum_base(values: torch.Tensor) -> torch.Tensor:
153+
out = [
154+
torch.ops.fbgemm.asynchronous_complete_cumsum(values[i])
155+
for i in range(values.shape[0])
156+
]
157+
return torch.stack(out, dim=0)
158+
159+
values = torch.randint(
160+
0, 1000, (batch_size, max_len), device=device, dtype=dtype
161+
)
162+
out = torch.ops.fbgemm.asynchronous_batched_complete_cumsum(values)
163+
out2 = cumsum_base(values)
164+
torch.testing.assert_close(out, out2)
165+
134166

135167
extend_test_class(CumSumTest)
136168

0 commit comments

Comments
 (0)