@@ -184,18 +184,22 @@ def _convert_unpacked_qkv_to_packed(
184
184
max_seqlen_q = attention_mask .shape [- 1 ]
185
185
max_seqlen_k = attention_mask .shape [- 1 ]
186
186
187
- q_packed = q .masked_select (attention_mask .view (batch_size , - 1 , 1 , 1 )).view (- 1 , q .shape [- 2 ], q .shape [- 1 ])
188
- kv_packed = kv .masked_select (attention_mask .view (batch_size , - 1 , 1 , 1 , 1 )).view (
189
- - 1 , kv .shape [- 3 ], kv .shape [- 2 ], kv .shape [- 1 ]
187
+ q_packed = (
188
+ q .masked_select (attention_mask .view (batch_size , - 1 , 1 , 1 )).view (- 1 , q .shape [- 2 ], q .shape [- 1 ]).unsqueeze (0 )
189
+ )
190
+ kv_packed = (
191
+ kv .masked_select (attention_mask .view (batch_size , - 1 , 1 , 1 , 1 ))
192
+ .view (- 1 , kv .shape [- 3 ], kv .shape [- 2 ], kv .shape [- 1 ])
193
+ .unsqueeze (0 )
190
194
)
191
195
192
196
return q_packed , kv_packed , cu_seqlens_q , cu_seqlens_k , max_seqlen_q , max_seqlen_k
193
197
194
198
def _inference (self , x , inference_params , ** kwargs ): # pylint: disable=W0613
195
199
assert inference_params is not None , "inference_params is required for inference"
196
200
assert self .layer_idx is not None , "Generation requires layer_idx in the constructor"
197
- attention_mask = inference_params .get ( " attention_mask" , None )
198
- sequence_len_offset = inference_params .get ( " sequence_len_offset" , 0 )
201
+ attention_mask = inference_params .attention_mask
202
+ sequence_len_offset = inference_params .sequence_len_offset
199
203
batch_size = x .shape [0 ]
200
204
201
205
# wqkv, output: q, kv
@@ -230,21 +234,21 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613
230
234
q = self .rotary_emb (
231
235
q , offsets = sequence_len_offset , cache_type = "query" , interleaved = self .interleaved
232
236
)
233
- k = kv [:, :, 0 ].squeueze (2 )
237
+ k = kv [:, :, 0 ].squeeze (2 )
234
238
self .rotary_emb (
235
239
k , offsets = 0 , cache_type = "key" , interleaved = self .interleaved , in_place = True
236
240
) # in-place is important
237
241
else :
238
242
if self .rotary_emb_dim > 0 :
239
243
q = self .rotary_emb (q , offsets = 0 , cache_type = "query" , interleaved = self .interleaved )
240
- k = kv [:, :, 0 ].squeueze (2 )
244
+ k = kv [:, :, 0 ].squeeze (2 )
241
245
self .rotary_emb (
242
246
k , offsets = 0 , cache_type = "key" , interleaved = self .interleaved , in_place = True
243
247
) # in-place is important
244
248
else :
245
249
assert self .rotary_emb_dim > 0 , "You should use rotary_emb."
246
250
247
- k , v = kv [:, :, 0 ].squeueze (2 ), kv [:, :, 1 ].squeueze (2 )
251
+ k , v = kv [:, :, 0 ].squeeze (2 ), kv [:, :, 1 ].squeeze (2 )
248
252
249
253
if attention_mask is None :
250
254
q = self .rotary_emb (q , offsets = sequence_len_offset , cache_type = "query" , interleaved = self .interleaved )
@@ -474,27 +478,31 @@ def _convert_unpacked_qkv_to_packed(
474
478
max_seqlen_q = attention_mask .shape [- 1 ]
475
479
max_seqlen_k = attention_mask .shape [- 1 ]
476
480
477
- q_packed = q .masked_select (attention_mask .view (batch_size , - 1 , 1 , 1 )).view (- 1 , q .shape [- 2 ], q .shape [- 1 ])
478
- kv_packed = kv .masked_select (attention_mask .view (batch_size , - 1 , 1 , 1 , 1 )).view (
479
- - 1 , kv .shape [- 3 ], kv .shape [- 2 ], kv .shape [- 1 ]
481
+ q_packed = (
482
+ q .masked_select (attention_mask .view (batch_size , - 1 , 1 , 1 )).view (- 1 , q .shape [- 2 ], q .shape [- 1 ]).unsqueeze (0 )
483
+ )
484
+ kv_packed = (
485
+ kv .masked_select (attention_mask .view (batch_size , - 1 , 1 , 1 , 1 ))
486
+ .view (- 1 , kv .shape [- 3 ], kv .shape [- 2 ], kv .shape [- 1 ])
487
+ .unsqueeze (0 )
480
488
)
481
489
482
490
return q_packed , kv_packed , cu_seqlens_q , cu_seqlens_k , max_seqlen_q , max_seqlen_k
483
491
484
492
def _inference (self , x , inference_params , ** kwargs ): # pylint: disable=W0613
485
493
assert inference_params is not None , "inference_params is required for inference"
486
494
assert self .layer_idx is not None , "Generation requires layer_idx in the constructor"
487
- attention_mask = inference_params .get ( " attention_mask" , None )
488
- sequence_len_offset = inference_params .get ( " sequence_len_offset" , 0 )
489
- window_size = inference_params .get ( " window_size" , None )
495
+ attention_mask = inference_params .attention_mask
496
+ sequence_len_offset = inference_params .sequence_len_offset
497
+ window_size = inference_params .window_size
490
498
491
499
batch_size = x .shape [0 ]
492
500
493
501
# wqkv, output: q, k, v
494
502
if self .enable_qkv_fusion :
495
503
qkv = self .wqkv (x )
496
504
qkv = rearrange (qkv , "b s (h gs d) -> b s h gs d" , gs = self .q_per_kv + 2 , d = self .head_dim )
497
- q , k , v = (qkv [..., : self .q_per_kv , :], qkv [..., - 2 , :]. unsqueeze ( - 2 ) , qkv [..., - 1 , :]. unsqueeze ( - 2 ) )
505
+ q , k , v = (qkv [..., : self .q_per_kv , :], qkv [..., - 2 , :], qkv [..., - 1 , :])
498
506
q = rearrange (q , "b s h gs d -> b s (h gs) d" )
499
507
else :
500
508
q , k , v = self .wq (x ), self .wk (x ), self .wv (x )
0 commit comments