Skip to content

Commit 7377ea2

Browse files
PatriceVignolafacebook-github-bot
authored andcommitted
Make kv_cache tests device agnostic (#3961)
Summary: X-link: facebookresearch/FBGEMM#1044 The kv_cache tests are currently hardcoded to use CUDA, but we want to be able to hook other devices to it. Differential Revision: D72370780
1 parent 82fbe91 commit 7377ea2

File tree

1 file changed

+101
-44
lines changed

1 file changed

+101
-44
lines changed

fbgemm_gpu/experimental/gen_ai/test/kv_cache/kv_cache_test.py

Lines changed: 101 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class LogicalDtype(Enum):
4040

4141

4242
def _get_varseq_batch_seqpos(
43-
seqlens_q: List[int], seqlens_kv: List[int]
43+
seqlens_q: List[int], seqlens_kv: List[int], device: torch.device
4444
) -> Tuple[torch.Tensor, torch.Tensor]:
4545
"""
4646
varseq_batch[i] is batch index of query i
@@ -49,7 +49,11 @@ def _get_varseq_batch_seqpos(
4949

5050
varseq_batch = torch.cat(
5151
[
52-
torch.as_tensor([i for _ in range(len_q)], dtype=torch.int, device="cuda")
52+
torch.as_tensor(
53+
[i for _ in range(len_q)],
54+
dtype=torch.int,
55+
device=device,
56+
)
5357
for i, len_q in enumerate(seqlens_q)
5458
]
5559
)
@@ -58,7 +62,7 @@ def _get_varseq_batch_seqpos(
5862
torch.as_tensor(
5963
[len_kv - len_q + t for t in range(len_q)],
6064
dtype=torch.int,
61-
device="cuda",
65+
device=device,
6266
)
6367
for len_q, len_kv in zip(seqlens_q, seqlens_kv)
6468
]
@@ -67,6 +71,22 @@ def _get_varseq_batch_seqpos(
6771

6872

6973
class KVCacheTests(unittest.TestCase):
74+
@classmethod
75+
def setUpClass(cls) -> None:
76+
super().setUpClass()
77+
device = torch.accelerator.current_accelerator()
78+
assert device is not None
79+
cls.device = device
80+
81+
# Perform a dummy compilation to test if inductor is supported
82+
try:
83+
torch.compile(torch.abs, backend="inductor")(
84+
torch.tensor(0, device=cls.device)
85+
)
86+
cls.compile_backend = "inductor"
87+
except torch._dynamo.exc.BackendCompilerFailed:
88+
cls.compile_backend = "eager"
89+
7090
@settings(deadline=None)
7191
@given(
7292
num_groups=st.sampled_from([1, 2, 4, 8]),
@@ -88,26 +108,34 @@ def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None:
88108
# PROMPT_T = 1024
89109

90110
xq = (
91-
torch.randn(size=(B * T, N_H_L, D_H), dtype=torch.bfloat16, device="cuda")
111+
torch.randn(
112+
size=(B * T, N_H_L, D_H), dtype=torch.bfloat16, device=self.device
113+
)
92114
* 0.01
93115
)
94116
xk = (
95-
torch.randn(size=(B * T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda")
117+
torch.randn(
118+
size=(B * T, N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device
119+
)
96120
* 0.01
97121
)
98122
xv = (
99-
torch.randn(size=(B * T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda")
123+
torch.randn(
124+
size=(B * T, N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device
125+
)
100126
* 0.01
101127
)
102128
varseq_seqpos = torch.cat(
103129
[
104-
torch.as_tensor(list(range(T)), dtype=torch.int, device="cuda")
130+
torch.as_tensor(list(range(T)), dtype=torch.int, device=self.device)
105131
for b in range(B)
106132
]
107133
)
108134
varseq_batch = torch.cat(
109135
[
110-
torch.as_tensor([b for _ in range(T)], dtype=torch.int, device="cuda")
136+
torch.as_tensor(
137+
[b for _ in range(T)], dtype=torch.int, device=self.device
138+
)
111139
for b in range(B)
112140
]
113141
)
@@ -118,19 +146,21 @@ def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None:
118146
kv_seqlen=[T for _ in range(B)],
119147
)
120148
)
121-
attn_bias.k_seqinfo.to(torch.device("cuda"))
149+
attn_bias.k_seqinfo.to(self.device)
122150
assert attn_bias.k_seqinfo.seqlen.shape == (B,)
123151
assert attn_bias.k_seqinfo.seqlen.tolist() == [T for _ in range(B)]
124152

125153
theta = 10000.0
126154
cache_k_bf16 = torch.zeros(
127-
size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda"
155+
size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device
128156
)
129157
cache_v_bf16 = torch.zeros(
130-
size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda"
158+
size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device
131159
)
132160

133-
xq_out_bf16 = torch.compile(torch.ops.fbgemm.rope_qkv_varseq_prefill)(
161+
xq_out_bf16 = torch.compile(
162+
torch.ops.fbgemm.rope_qkv_varseq_prefill, backend=self.compile_backend
163+
)(
134164
xq,
135165
xk,
136166
xv,
@@ -145,14 +175,16 @@ def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None:
145175
cache_k_int4 = torch.zeros(
146176
size=(B, MAX_T, N_KVH_L, int(D_H // 2) + qparam_offset),
147177
dtype=torch.uint8,
148-
device="cuda",
178+
device=self.device,
149179
)
150180
cache_v_int4 = torch.zeros(
151181
size=(B, MAX_T, N_KVH_L, int(D_H // 2) + qparam_offset),
152182
dtype=torch.uint8,
153-
device="cuda",
183+
device=self.device,
154184
)
155-
xq_out = torch.compile(torch.ops.fbgemm.rope_qkv_varseq_prefill)(
185+
xq_out = torch.compile(
186+
torch.ops.fbgemm.rope_qkv_varseq_prefill, backend=self.compile_backend
187+
)(
156188
xq,
157189
xk,
158190
xv,
@@ -166,7 +198,9 @@ def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None:
166198
)
167199
torch.testing.assert_close(xq_out_bf16, xq_out)
168200

169-
dequantized_cache = torch.compile(torch.ops.fbgemm.dequantize_int4_cache)(
201+
dequantized_cache = torch.compile(
202+
torch.ops.fbgemm.dequantize_int4_cache, backend=self.compile_backend
203+
)(
170204
cache_k_int4,
171205
cache_v_int4,
172206
attn_bias.k_seqinfo.seqlen,
@@ -205,7 +239,8 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
205239
xq = (
206240
torch.cat(
207241
[
208-
torch.randn(N_H_L, D_H, dtype=torch.bfloat16, device="cuda") * (i)
242+
torch.randn(N_H_L, D_H, dtype=torch.bfloat16, device=self.device)
243+
* (i)
209244
for i in range(B * T)
210245
]
211246
)
@@ -215,14 +250,14 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
215250
xk_rows = [
216251
scale_step
217252
* (i + 1)
218-
* torch.randn(size=(N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda")
253+
* torch.randn(size=(N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device)
219254
+ i * shift_step
220255
for i in range(B * T)
221256
]
222257
xv_rows = [
223258
scale_step
224259
* (i + 1)
225-
* torch.randn(size=(N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda")
260+
* torch.randn(size=(N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device)
226261
+ i * shift_step
227262
for i in range(B * T)
228263
]
@@ -232,13 +267,15 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
232267
xv = (torch.cat(xv_rows)).view(B * T, N_KVH_L, D_H)
233268
varseq_seqpos = torch.cat(
234269
[
235-
torch.as_tensor(list(range(T)), dtype=torch.int, device="cuda")
270+
torch.as_tensor(list(range(T)), dtype=torch.int, device=self.device)
236271
for b in range(B)
237272
]
238273
)
239274
varseq_batch = torch.cat(
240275
[
241-
torch.as_tensor([b for _ in range(T)], dtype=torch.int, device="cuda")
276+
torch.as_tensor(
277+
[b for _ in range(T)], dtype=torch.int, device=self.device
278+
)
242279
for b in range(B)
243280
]
244281
)
@@ -249,19 +286,21 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
249286
kv_seqlen=[T for _ in range(B)],
250287
)
251288
)
252-
attn_bias.k_seqinfo.to(torch.device("cuda"))
289+
attn_bias.k_seqinfo.to(self.device)
253290
assert attn_bias.k_seqinfo.seqlen.shape == (B,)
254291
assert attn_bias.k_seqinfo.seqlen.tolist() == [T for _ in range(B)]
255292

256293
theta = 10000.0
257294
cache_k_bf16 = torch.zeros(
258-
size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda"
295+
size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device
259296
)
260297
cache_v_bf16 = torch.zeros(
261-
size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda"
298+
size=(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device
262299
)
263300

264-
xq_out_bf16 = torch.compile(torch.ops.fbgemm.rope_qkv_varseq_prefill)(
301+
xq_out_bf16 = torch.compile(
302+
torch.ops.fbgemm.rope_qkv_varseq_prefill, backend=self.compile_backend
303+
)(
265304
xq,
266305
xk,
267306
xv,
@@ -276,14 +315,16 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
276315
cache_k_fp8 = torch.zeros(
277316
size=(B, MAX_T, N_KVH_L, int(D_H) + qparam_offset),
278317
dtype=torch.uint8,
279-
device="cuda",
318+
device=self.device,
280319
)
281320
cache_v_fp8 = torch.zeros(
282321
size=(B, MAX_T, N_KVH_L, int(D_H) + qparam_offset),
283322
dtype=torch.uint8,
284-
device="cuda",
323+
device=self.device,
285324
)
286-
xq_out = torch.compile(torch.ops.fbgemm.rope_qkv_varseq_prefill)(
325+
xq_out = torch.compile(
326+
torch.ops.fbgemm.rope_qkv_varseq_prefill, backend=self.compile_backend
327+
)(
287328
xq,
288329
xk,
289330
xv,
@@ -296,7 +337,9 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
296337
)
297338
torch.testing.assert_close(xq_out_bf16, xq_out)
298339

299-
dequantized_cache = torch.compile(torch.ops.fbgemm.dequantize_fp8_cache)(
340+
dequantized_cache = torch.compile(
341+
torch.ops.fbgemm.dequantize_fp8_cache, backend=self.compile_backend
342+
)(
300343
cache_k_fp8,
301344
cache_v_fp8,
302345
attn_bias.k_seqinfo.seqlen,
@@ -339,12 +382,12 @@ def test_positional_encoding_with_paged_attention(
339382
kv_seqlens = torch.randint(low=0, high=MAX_T, size=(B,)).tolist()
340383
q_seqlens = kv_seqlens if prefill else [1 for _ in range(B)]
341384
seq_positions = torch.tensor(
342-
[x - 1 for x in kv_seqlens], device="cuda", dtype=torch.int32
385+
[x - 1 for x in kv_seqlens], device=self.device, dtype=torch.int32
343386
)
344387
total_length_q = sum(q_seqlens)
345388

346389
cache_k = torch.randn(
347-
(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda"
390+
(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device
348391
)
349392
cache_v = torch.randn_like(cache_k)
350393

@@ -363,7 +406,7 @@ def test_positional_encoding_with_paged_attention(
363406
N_H_L + 2 * N_KVH_L,
364407
D_H,
365408
dtype=torch.bfloat16,
366-
device="cuda",
409+
device=self.device,
367410
)
368411
xq = xqkv[:, :N_H_L, :]
369412
# This clone is to avoid a weirdness in torch.compile:
@@ -394,15 +437,20 @@ def test_positional_encoding_with_paged_attention(
394437
assert cache_v.shape == (B, MAX_T, N_KVH_L, D_H)
395438

396439
if prefill:
397-
seqpos_args = _get_varseq_batch_seqpos(q_seqlens, kv_seqlens)
440+
seqpos_args = _get_varseq_batch_seqpos(q_seqlens, kv_seqlens, self.device)
398441
else:
399442
seqpos_args = (seq_positions,)
400443

401444
if rope_theta is not None:
402445
func = (
403-
torch.compile(torch.ops.fbgemm.rope_qkv_varseq_prefill)
446+
torch.compile(
447+
torch.ops.fbgemm.rope_qkv_varseq_prefill,
448+
backend=self.compile_backend,
449+
)
404450
if prefill
405-
else torch.compile(torch.ops.fbgemm.rope_qkv_decoding)
451+
else torch.compile(
452+
torch.ops.fbgemm.rope_qkv_decoding, backend=self.compile_backend
453+
)
406454
)
407455
xq_out_ref = func(
408456
xq,
@@ -428,9 +476,14 @@ def test_positional_encoding_with_paged_attention(
428476
)
429477
else:
430478
func = (
431-
torch.compile(torch.ops.fbgemm.xpos_qkv_varseq_prefill)
479+
torch.compile(
480+
torch.ops.fbgemm.xpos_qkv_varseq_prefill,
481+
backend=self.compile_backend,
482+
)
432483
if prefill
433-
else torch.compile(torch.ops.fbgemm.xpos_qkv_decoding)
484+
else torch.compile(
485+
torch.ops.fbgemm.xpos_qkv_decoding, backend=self.compile_backend
486+
)
434487
)
435488
xq_out_ref = func(
436489
xq,
@@ -510,12 +563,12 @@ def test_rope_positional_encoding_only(
510563
kv_seqlens = torch.randint(low=0, high=MAX_T, size=(B,)).tolist()
511564
q_seqlens = kv_seqlens if prefill else [1 for _ in range(B)]
512565
seq_positions = torch.tensor(
513-
[x - 1 for x in kv_seqlens], device="cuda", dtype=torch.int32
566+
[x - 1 for x in kv_seqlens], device=self.device, dtype=torch.int32
514567
)
515568
total_length_q = sum(q_seqlens)
516569

517570
cache_k = torch.randn(
518-
(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda"
571+
(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device=self.device
519572
)
520573
cache_v = torch.randn_like(cache_k)
521574

@@ -524,7 +577,7 @@ def test_rope_positional_encoding_only(
524577
N_H_L + 2 * N_KVH_L,
525578
D_H,
526579
dtype=torch.bfloat16,
527-
device="cuda",
580+
device=self.device,
528581
)
529582
xq = xqkv[:, :N_H_L, :]
530583
xk = xqkv[:, N_H_L : N_H_L + N_KVH_L, :].clone()
@@ -542,14 +595,18 @@ def test_rope_positional_encoding_only(
542595
assert cache_v.shape == (B, MAX_T, N_KVH_L, D_H)
543596

544597
if prefill:
545-
seqpos_args = _get_varseq_batch_seqpos(q_seqlens, kv_seqlens)
598+
seqpos_args = _get_varseq_batch_seqpos(q_seqlens, kv_seqlens, self.device)
546599
else:
547600
seqpos_args = (seq_positions,)
548601

549602
func = (
550-
torch.compile(torch.ops.fbgemm.rope_qkv_varseq_prefill)
603+
torch.compile(
604+
torch.ops.fbgemm.rope_qkv_varseq_prefill, backend=self.compile_backend
605+
)
551606
if prefill
552-
else torch.compile(torch.ops.fbgemm.rope_qkv_decoding)
607+
else torch.compile(
608+
torch.ops.fbgemm.rope_qkv_decoding, backend=self.compile_backend
609+
)
553610
)
554611
xq_out = func(
555612
xq,
@@ -569,7 +626,7 @@ def test_rope_positional_encoding_only(
569626
kv_seqlen=kv_seqlens,
570627
)
571628
)
572-
attn_bias.k_seqinfo.to(torch.device("cuda"))
629+
attn_bias.k_seqinfo.to(self.device)
573630
xq = xq.view(1, xq.shape[0], N_H_L, D_H)
574631
xk = xk.view(1, xk.shape[0], N_KVH_L, D_H)
575632
xv = xv.view(1, xv.shape[0], N_KVH_L, D_H)

0 commit comments

Comments
 (0)