@@ -40,7 +40,7 @@ class LogicalDtype(Enum):
4040
4141
4242def _get_varseq_batch_seqpos (
43- seqlens_q : List [int ], seqlens_kv : List [int ]
43+ seqlens_q : List [int ], seqlens_kv : List [int ], device : torch . device
4444) -> Tuple [torch .Tensor , torch .Tensor ]:
4545 """
4646 varseq_batch[i] is batch index of query i
@@ -49,7 +49,11 @@ def _get_varseq_batch_seqpos(
4949
5050 varseq_batch = torch .cat (
5151 [
52- torch .as_tensor ([i for _ in range (len_q )], dtype = torch .int , device = "cuda" )
52+ torch .as_tensor (
53+ [i for _ in range (len_q )],
54+ dtype = torch .int ,
55+ device = device ,
56+ )
5357 for i , len_q in enumerate (seqlens_q )
5458 ]
5559 )
@@ -58,7 +62,7 @@ def _get_varseq_batch_seqpos(
5862 torch .as_tensor (
5963 [len_kv - len_q + t for t in range (len_q )],
6064 dtype = torch .int ,
61- device = "cuda" ,
65+ device = device ,
6266 )
6367 for len_q , len_kv in zip (seqlens_q , seqlens_kv )
6468 ]
@@ -67,6 +71,22 @@ def _get_varseq_batch_seqpos(
6771
6872
6973class KVCacheTests (unittest .TestCase ):
74+ @classmethod
75+ def setUpClass (cls ) -> None :
76+ super ().setUpClass ()
77+ device = torch .accelerator .current_accelerator ()
78+ assert device is not None
79+ cls .device = device
80+
81+ # Perform a dummy compilation to test if inductor is supported
82+ try :
83+ torch .compile (torch .abs , backend = "inductor" )(
84+ torch .tensor (0 , device = cls .device )
85+ )
86+ cls .compile_backend = "inductor"
87+ except torch ._dynamo .exc .BackendCompilerFailed :
88+ cls .compile_backend = "eager"
89+
7090 @settings (deadline = None )
7191 @given (
7292 num_groups = st .sampled_from ([1 , 2 , 4 , 8 ]),
@@ -88,26 +108,34 @@ def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None:
88108 # PROMPT_T = 1024
89109
90110 xq = (
91- torch .randn (size = (B * T , N_H_L , D_H ), dtype = torch .bfloat16 , device = "cuda" )
111+ torch .randn (
112+ size = (B * T , N_H_L , D_H ), dtype = torch .bfloat16 , device = self .device
113+ )
92114 * 0.01
93115 )
94116 xk = (
95- torch .randn (size = (B * T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = "cuda" )
117+ torch .randn (
118+ size = (B * T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = self .device
119+ )
96120 * 0.01
97121 )
98122 xv = (
99- torch .randn (size = (B * T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = "cuda" )
123+ torch .randn (
124+ size = (B * T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = self .device
125+ )
100126 * 0.01
101127 )
102128 varseq_seqpos = torch .cat (
103129 [
104- torch .as_tensor (list (range (T )), dtype = torch .int , device = "cuda" )
130+ torch .as_tensor (list (range (T )), dtype = torch .int , device = self . device )
105131 for b in range (B )
106132 ]
107133 )
108134 varseq_batch = torch .cat (
109135 [
110- torch .as_tensor ([b for _ in range (T )], dtype = torch .int , device = "cuda" )
136+ torch .as_tensor (
137+ [b for _ in range (T )], dtype = torch .int , device = self .device
138+ )
111139 for b in range (B )
112140 ]
113141 )
@@ -118,19 +146,21 @@ def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None:
118146 kv_seqlen = [T for _ in range (B )],
119147 )
120148 )
121- attn_bias .k_seqinfo .to (torch .device ( "cuda" ) )
149+ attn_bias .k_seqinfo .to (self .device )
122150 assert attn_bias .k_seqinfo .seqlen .shape == (B ,)
123151 assert attn_bias .k_seqinfo .seqlen .tolist () == [T for _ in range (B )]
124152
125153 theta = 10000.0
126154 cache_k_bf16 = torch .zeros (
127- size = (B , MAX_T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = "cuda"
155+ size = (B , MAX_T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = self . device
128156 )
129157 cache_v_bf16 = torch .zeros (
130- size = (B , MAX_T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = "cuda"
158+ size = (B , MAX_T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = self . device
131159 )
132160
133- xq_out_bf16 = torch .compile (torch .ops .fbgemm .rope_qkv_varseq_prefill )(
161+ xq_out_bf16 = torch .compile (
162+ torch .ops .fbgemm .rope_qkv_varseq_prefill , backend = self .compile_backend
163+ )(
134164 xq ,
135165 xk ,
136166 xv ,
@@ -145,14 +175,16 @@ def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None:
145175 cache_k_int4 = torch .zeros (
146176 size = (B , MAX_T , N_KVH_L , int (D_H // 2 ) + qparam_offset ),
147177 dtype = torch .uint8 ,
148- device = "cuda" ,
178+ device = self . device ,
149179 )
150180 cache_v_int4 = torch .zeros (
151181 size = (B , MAX_T , N_KVH_L , int (D_H // 2 ) + qparam_offset ),
152182 dtype = torch .uint8 ,
153- device = "cuda" ,
183+ device = self . device ,
154184 )
155- xq_out = torch .compile (torch .ops .fbgemm .rope_qkv_varseq_prefill )(
185+ xq_out = torch .compile (
186+ torch .ops .fbgemm .rope_qkv_varseq_prefill , backend = self .compile_backend
187+ )(
156188 xq ,
157189 xk ,
158190 xv ,
@@ -166,7 +198,9 @@ def test_int4_kv_cache(self, num_groups: int, MAX_T: int, N_KVH_L: int) -> None:
166198 )
167199 torch .testing .assert_close (xq_out_bf16 , xq_out )
168200
169- dequantized_cache = torch .compile (torch .ops .fbgemm .dequantize_int4_cache )(
201+ dequantized_cache = torch .compile (
202+ torch .ops .fbgemm .dequantize_int4_cache , backend = self .compile_backend
203+ )(
170204 cache_k_int4 ,
171205 cache_v_int4 ,
172206 attn_bias .k_seqinfo .seqlen ,
@@ -205,7 +239,8 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
205239 xq = (
206240 torch .cat (
207241 [
208- torch .randn (N_H_L , D_H , dtype = torch .bfloat16 , device = "cuda" ) * (i )
242+ torch .randn (N_H_L , D_H , dtype = torch .bfloat16 , device = self .device )
243+ * (i )
209244 for i in range (B * T )
210245 ]
211246 )
@@ -215,14 +250,14 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
215250 xk_rows = [
216251 scale_step
217252 * (i + 1 )
218- * torch .randn (size = (N_KVH_L , D_H ), dtype = torch .bfloat16 , device = "cuda" )
253+ * torch .randn (size = (N_KVH_L , D_H ), dtype = torch .bfloat16 , device = self . device )
219254 + i * shift_step
220255 for i in range (B * T )
221256 ]
222257 xv_rows = [
223258 scale_step
224259 * (i + 1 )
225- * torch .randn (size = (N_KVH_L , D_H ), dtype = torch .bfloat16 , device = "cuda" )
260+ * torch .randn (size = (N_KVH_L , D_H ), dtype = torch .bfloat16 , device = self . device )
226261 + i * shift_step
227262 for i in range (B * T )
228263 ]
@@ -232,13 +267,15 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
232267 xv = (torch .cat (xv_rows )).view (B * T , N_KVH_L , D_H )
233268 varseq_seqpos = torch .cat (
234269 [
235- torch .as_tensor (list (range (T )), dtype = torch .int , device = "cuda" )
270+ torch .as_tensor (list (range (T )), dtype = torch .int , device = self . device )
236271 for b in range (B )
237272 ]
238273 )
239274 varseq_batch = torch .cat (
240275 [
241- torch .as_tensor ([b for _ in range (T )], dtype = torch .int , device = "cuda" )
276+ torch .as_tensor (
277+ [b for _ in range (T )], dtype = torch .int , device = self .device
278+ )
242279 for b in range (B )
243280 ]
244281 )
@@ -249,19 +286,21 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
249286 kv_seqlen = [T for _ in range (B )],
250287 )
251288 )
252- attn_bias .k_seqinfo .to (torch .device ( "cuda" ) )
289+ attn_bias .k_seqinfo .to (self .device )
253290 assert attn_bias .k_seqinfo .seqlen .shape == (B ,)
254291 assert attn_bias .k_seqinfo .seqlen .tolist () == [T for _ in range (B )]
255292
256293 theta = 10000.0
257294 cache_k_bf16 = torch .zeros (
258- size = (B , MAX_T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = "cuda"
295+ size = (B , MAX_T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = self . device
259296 )
260297 cache_v_bf16 = torch .zeros (
261- size = (B , MAX_T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = "cuda"
298+ size = (B , MAX_T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = self . device
262299 )
263300
264- xq_out_bf16 = torch .compile (torch .ops .fbgemm .rope_qkv_varseq_prefill )(
301+ xq_out_bf16 = torch .compile (
302+ torch .ops .fbgemm .rope_qkv_varseq_prefill , backend = self .compile_backend
303+ )(
265304 xq ,
266305 xk ,
267306 xv ,
@@ -276,14 +315,16 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
276315 cache_k_fp8 = torch .zeros (
277316 size = (B , MAX_T , N_KVH_L , int (D_H ) + qparam_offset ),
278317 dtype = torch .uint8 ,
279- device = "cuda" ,
318+ device = self . device ,
280319 )
281320 cache_v_fp8 = torch .zeros (
282321 size = (B , MAX_T , N_KVH_L , int (D_H ) + qparam_offset ),
283322 dtype = torch .uint8 ,
284- device = "cuda" ,
323+ device = self . device ,
285324 )
286- xq_out = torch .compile (torch .ops .fbgemm .rope_qkv_varseq_prefill )(
325+ xq_out = torch .compile (
326+ torch .ops .fbgemm .rope_qkv_varseq_prefill , backend = self .compile_backend
327+ )(
287328 xq ,
288329 xk ,
289330 xv ,
@@ -296,7 +337,9 @@ def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None:
296337 )
297338 torch .testing .assert_close (xq_out_bf16 , xq_out )
298339
299- dequantized_cache = torch .compile (torch .ops .fbgemm .dequantize_fp8_cache )(
340+ dequantized_cache = torch .compile (
341+ torch .ops .fbgemm .dequantize_fp8_cache , backend = self .compile_backend
342+ )(
300343 cache_k_fp8 ,
301344 cache_v_fp8 ,
302345 attn_bias .k_seqinfo .seqlen ,
@@ -339,12 +382,12 @@ def test_positional_encoding_with_paged_attention(
339382 kv_seqlens = torch .randint (low = 0 , high = MAX_T , size = (B ,)).tolist ()
340383 q_seqlens = kv_seqlens if prefill else [1 for _ in range (B )]
341384 seq_positions = torch .tensor (
342- [x - 1 for x in kv_seqlens ], device = "cuda" , dtype = torch .int32
385+ [x - 1 for x in kv_seqlens ], device = self . device , dtype = torch .int32
343386 )
344387 total_length_q = sum (q_seqlens )
345388
346389 cache_k = torch .randn (
347- (B , MAX_T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = "cuda"
390+ (B , MAX_T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = self . device
348391 )
349392 cache_v = torch .randn_like (cache_k )
350393
@@ -363,7 +406,7 @@ def test_positional_encoding_with_paged_attention(
363406 N_H_L + 2 * N_KVH_L ,
364407 D_H ,
365408 dtype = torch .bfloat16 ,
366- device = "cuda" ,
409+ device = self . device ,
367410 )
368411 xq = xqkv [:, :N_H_L , :]
369412 # This clone is to avoid a weirdness in torch.compile:
@@ -394,15 +437,20 @@ def test_positional_encoding_with_paged_attention(
394437 assert cache_v .shape == (B , MAX_T , N_KVH_L , D_H )
395438
396439 if prefill :
397- seqpos_args = _get_varseq_batch_seqpos (q_seqlens , kv_seqlens )
440+ seqpos_args = _get_varseq_batch_seqpos (q_seqlens , kv_seqlens , self . device )
398441 else :
399442 seqpos_args = (seq_positions ,)
400443
401444 if rope_theta is not None :
402445 func = (
403- torch .compile (torch .ops .fbgemm .rope_qkv_varseq_prefill )
446+ torch .compile (
447+ torch .ops .fbgemm .rope_qkv_varseq_prefill ,
448+ backend = self .compile_backend ,
449+ )
404450 if prefill
405- else torch .compile (torch .ops .fbgemm .rope_qkv_decoding )
451+ else torch .compile (
452+ torch .ops .fbgemm .rope_qkv_decoding , backend = self .compile_backend
453+ )
406454 )
407455 xq_out_ref = func (
408456 xq ,
@@ -428,9 +476,14 @@ def test_positional_encoding_with_paged_attention(
428476 )
429477 else :
430478 func = (
431- torch .compile (torch .ops .fbgemm .xpos_qkv_varseq_prefill )
479+ torch .compile (
480+ torch .ops .fbgemm .xpos_qkv_varseq_prefill ,
481+ backend = self .compile_backend ,
482+ )
432483 if prefill
433- else torch .compile (torch .ops .fbgemm .xpos_qkv_decoding )
484+ else torch .compile (
485+ torch .ops .fbgemm .xpos_qkv_decoding , backend = self .compile_backend
486+ )
434487 )
435488 xq_out_ref = func (
436489 xq ,
@@ -510,12 +563,12 @@ def test_rope_positional_encoding_only(
510563 kv_seqlens = torch .randint (low = 0 , high = MAX_T , size = (B ,)).tolist ()
511564 q_seqlens = kv_seqlens if prefill else [1 for _ in range (B )]
512565 seq_positions = torch .tensor (
513- [x - 1 for x in kv_seqlens ], device = "cuda" , dtype = torch .int32
566+ [x - 1 for x in kv_seqlens ], device = self . device , dtype = torch .int32
514567 )
515568 total_length_q = sum (q_seqlens )
516569
517570 cache_k = torch .randn (
518- (B , MAX_T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = "cuda"
571+ (B , MAX_T , N_KVH_L , D_H ), dtype = torch .bfloat16 , device = self . device
519572 )
520573 cache_v = torch .randn_like (cache_k )
521574
@@ -524,7 +577,7 @@ def test_rope_positional_encoding_only(
524577 N_H_L + 2 * N_KVH_L ,
525578 D_H ,
526579 dtype = torch .bfloat16 ,
527- device = "cuda" ,
580+ device = self . device ,
528581 )
529582 xq = xqkv [:, :N_H_L , :]
530583 xk = xqkv [:, N_H_L : N_H_L + N_KVH_L , :].clone ()
@@ -542,14 +595,18 @@ def test_rope_positional_encoding_only(
542595 assert cache_v .shape == (B , MAX_T , N_KVH_L , D_H )
543596
544597 if prefill :
545- seqpos_args = _get_varseq_batch_seqpos (q_seqlens , kv_seqlens )
598+ seqpos_args = _get_varseq_batch_seqpos (q_seqlens , kv_seqlens , self . device )
546599 else :
547600 seqpos_args = (seq_positions ,)
548601
549602 func = (
550- torch .compile (torch .ops .fbgemm .rope_qkv_varseq_prefill )
603+ torch .compile (
604+ torch .ops .fbgemm .rope_qkv_varseq_prefill , backend = self .compile_backend
605+ )
551606 if prefill
552- else torch .compile (torch .ops .fbgemm .rope_qkv_decoding )
607+ else torch .compile (
608+ torch .ops .fbgemm .rope_qkv_decoding , backend = self .compile_backend
609+ )
553610 )
554611 xq_out = func (
555612 xq ,
@@ -569,7 +626,7 @@ def test_rope_positional_encoding_only(
569626 kv_seqlen = kv_seqlens ,
570627 )
571628 )
572- attn_bias .k_seqinfo .to (torch .device ( "cuda" ) )
629+ attn_bias .k_seqinfo .to (self .device )
573630 xq = xq .view (1 , xq .shape [0 ], N_H_L , D_H )
574631 xk = xk .view (1 , xk .shape [0 ], N_KVH_L , D_H )
575632 xv = xv .view (1 , xv .shape [0 ], N_KVH_L , D_H )
0 commit comments