Skip to content

Commit f4b2495

Browse files
authored
[Operator] Add vstack op (#175)
1 parent da86496 commit f4b2495

File tree

5 files changed

+196
-0
lines changed

5 files changed

+196
-0
lines changed

benchmark/test_special_perf.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,21 @@ def cat_kwargs(dtype, batch, size):
270270
kwargs_func=cat_kwargs,
271271
)
272272
bench.run()
273+
274+
275+
def test_perf_vstack():
276+
def vstack_args(dtype, batch, size):
277+
inp1 = torch.randn(size=(batch, size), dtype=dtype, device="cuda")
278+
inp2 = torch.randn(size=(batch + 1, size), dtype=dtype, device="cuda")
279+
inp3 = torch.randn(size=(batch + 2, size), dtype=dtype, device="cuda")
280+
return [[inp1, inp2, inp3]]
281+
282+
bench = Benchmark(
283+
op_name="vstack",
284+
torch_op=torch.vstack,
285+
arg_func=vstack_args,
286+
dtypes=FLOAT_DTYPES,
287+
batch=(512),
288+
sizes=SIZES,
289+
)
290+
bench.run()

src/flag_gems/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def enable(lib=aten_lib):
153153
lib.impl("hstack", hstack, "CUDA")
154154
lib.impl("cat", cat, "CUDA")
155155
lib.impl("repeat_interleave.self_int", repeat_interleave_self_int, "CUDA")
156+
lib.impl("vstack", vstack, "CUDA")
156157

157158

158159
class use_gems:

src/flag_gems/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
from .unique import _unique2
101101
from .var_mean import var_mean
102102
from .vector_norm import vector_norm
103+
from .vstack import vstack
103104
from .where import where_scalar_other, where_scalar_self, where_self
104105
from .zeros import zeros
105106
from .zeros_like import zeros_like
@@ -237,4 +238,5 @@
237238
"hstack",
238239
"cat",
239240
"repeat_interleave_self_int",
241+
"vstack",
240242
]

src/flag_gems/ops/vstack.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import logging
2+
3+
import torch
4+
import triton
5+
import triton.language as tl
6+
7+
from ..utils import libentry
8+
9+
10+
@libentry()
11+
@triton.autotune(
12+
configs=[
13+
triton.Config({"BLOCK_SIZE": k}, num_warps=w)
14+
for w in [4, 8, 16, 32]
15+
for k in [512, 1024, 2048, 4096]
16+
],
17+
key=[
18+
"max_tile_elems",
19+
],
20+
)
21+
@triton.jit
22+
def vstack_kernel(
23+
itensor_ptr0,
24+
itensor_ptr1,
25+
itensor_ptr2,
26+
itensor_ptr3,
27+
output_ptr,
28+
local_row0,
29+
local_row1,
30+
local_row2,
31+
local_row3,
32+
exc_row_offset0,
33+
exc_row_offset1,
34+
exc_row_offset2,
35+
exc_row_offset3,
36+
total_row_offset,
37+
row_stride,
38+
max_tile_elems,
39+
BLOCK_SIZE: tl.constexpr,
40+
):
41+
pid_x = tl.program_id(axis=0)
42+
tensor_idx = tl.program_id(axis=1)
43+
col_idx = tl.arange(0, BLOCK_SIZE)
44+
45+
intensor_ptr = tl.where(tensor_idx == 0, itensor_ptr0, itensor_ptr1)
46+
intensor_ptr = tl.where(tensor_idx == 2, itensor_ptr2, intensor_ptr)
47+
intensor_ptr = tl.where(tensor_idx == 3, itensor_ptr3, intensor_ptr)
48+
base_exc_row_idx = tl.where(tensor_idx == 0, exc_row_offset0, exc_row_offset1)
49+
base_exc_row_idx = tl.where(tensor_idx == 2, exc_row_offset2, base_exc_row_idx)
50+
base_exc_row_idx = tl.where(tensor_idx == 3, exc_row_offset3, base_exc_row_idx)
51+
local_row = tl.where(tensor_idx == 0, local_row0, local_row1)
52+
local_row = tl.where(tensor_idx == 2, local_row2, local_row)
53+
local_row = tl.where(tensor_idx == 3, local_row3, local_row)
54+
55+
end_idx = local_row * row_stride.to(tl.int64)
56+
idx = (pid_x * BLOCK_SIZE + col_idx).to(tl.int64)
57+
offset_mask = idx < end_idx
58+
in_offset = intensor_ptr + idx
59+
row_stride_offset = (total_row_offset + base_exc_row_idx) * row_stride.to(tl.int64)
60+
out_offset = output_ptr + row_stride_offset + idx
61+
out = tl.load(in_offset, mask=offset_mask)
62+
tl.store(out_offset, out, mask=offset_mask)
63+
64+
65+
def vstack(tensors: list[torch.Tensor]):
66+
logging.debug("GEMS VSTACK")
67+
68+
tensors = torch.atleast_2d(tensors)
69+
num_tensors = len(tensors)
70+
assert num_tensors > 0
71+
72+
# Ensure all tensors are on the same device and have the same dtype
73+
device = tensors[0].device
74+
dtype = tensors[0].dtype
75+
for tensor in tensors:
76+
assert (
77+
tensor.device == device
78+
and tensor.dtype == dtype
79+
and tensors[0].shape[1:] == tensor.shape[1:]
80+
)
81+
82+
c_tensors = [t.contiguous() for t in tensors]
83+
# Calculate the output shape
84+
total_rows = sum(tensor.shape[0] for tensor in c_tensors)
85+
output_shape = list(c_tensors[0].shape)
86+
output_shape[0] = total_rows
87+
output = torch.empty(output_shape, device=device, dtype=dtype)
88+
row_stride = c_tensors[0].stride(0)
89+
90+
outer_iters = triton.cdiv(num_tensors, 4)
91+
total_row_offset = 0
92+
for i in range(outer_iters):
93+
max_rows = 1
94+
itensors = []
95+
exclusive_row = []
96+
local_row = []
97+
array_row_offset = 0
98+
scheduled_num_tensors = 0
99+
for j in range(4):
100+
tensor_idx = i * 4 + j
101+
if tensor_idx < num_tensors:
102+
scheduled_num_tensors += 1
103+
itensors.append(c_tensors[tensor_idx])
104+
local_row.append(c_tensors[tensor_idx].shape[0])
105+
exclusive_row.append(array_row_offset)
106+
array_row_offset += c_tensors[tensor_idx].shape[0]
107+
max_rows = max(max_rows, c_tensors[tensor_idx].shape[0])
108+
else:
109+
empty_tensor = torch.empty(
110+
0, dtype=c_tensors[0].dtype, device=c_tensors[0].device
111+
)
112+
itensors.append(empty_tensor)
113+
local_row.append(local_row[-1])
114+
exclusive_row.append(exclusive_row[-1])
115+
max_tile_elems = max_rows * row_stride
116+
grid = lambda META: (
117+
triton.cdiv(max_tile_elems, META["BLOCK_SIZE"]),
118+
scheduled_num_tensors,
119+
)
120+
# Launch the kernel
121+
with torch.cuda.device(c_tensors[0].device):
122+
vstack_kernel[grid](
123+
itensors[0],
124+
itensors[1],
125+
itensors[2],
126+
itensors[3],
127+
output,
128+
local_row[0],
129+
local_row[1],
130+
local_row[2],
131+
local_row[3],
132+
exclusive_row[0],
133+
exclusive_row[1],
134+
exclusive_row[2],
135+
exclusive_row[3],
136+
total_row_offset,
137+
row_stride,
138+
max_tile_elems,
139+
)
140+
total_row_offset += array_row_offset
141+
return output

tests/test_special_ops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,37 @@ def test_accuracy_cat(shape, dim, dtype):
652652
with flag_gems.use_gems():
653653
res_out = torch.cat(inp, dim)
654654
gems_assert_equal(res_out, ref_out)
655+
656+
657+
VSTACK_SHAPES = [
658+
[(3,), (3,)],
659+
[(3, 33), (7, 33)],
660+
[(13, 3, 333), (17, 3, 333), (7, 3, 333)],
661+
[
662+
(13, 3, 64, 5, 2),
663+
(16, 3, 64, 5, 2),
664+
(7, 3, 64, 5, 2),
665+
(4, 3, 64, 5, 2),
666+
(1, 3, 64, 5, 2),
667+
],
668+
]
669+
670+
671+
@pytest.mark.parametrize("shape", VSTACK_SHAPES)
672+
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES)
673+
def test_accuracy_vstack(shape, dtype):
674+
if dtype in FLOAT_DTYPES:
675+
inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape]
676+
else:
677+
inp = [
678+
torch.randint(low=0, high=0x7FFF, size=s, dtype=dtype, device="cuda").to(
679+
dtype
680+
)
681+
for s in shape
682+
]
683+
ref_inp = [to_reference(_) for _ in inp]
684+
ref_out = torch.vstack(ref_inp)
685+
686+
with flag_gems.use_gems():
687+
res_out = torch.vstack(inp)
688+
gems_assert_equal(res_out, ref_out)

0 commit comments

Comments
 (0)