5
5
from torch import nn , einsum , Tensor
6
6
import torch .nn .functional as F
7
7
8
- from rotary_embedding_torch import RotaryEmbedding
9
8
from beartype import beartype
10
9
11
10
from collections import namedtuple
@@ -205,38 +204,39 @@ def speculative_decoding(
205
204
206
205
# do a bunch of slicing and align everything to the right, including kv caches
207
206
208
- max_num_rejected = num_rejected .amax ()
209
- seq_arange = torch .arange (out .shape [- 1 ], device = device , dtype = torch .long )
210
- seq_offset_indices = seq_arange + (max_num_rejected - num_rejected )[..., None ]
211
-
212
207
seq_lens -= num_rejected
213
208
max_seq_len = seq_lens .amax ()
209
+ curr_len = out .shape [- 1 ]
214
210
215
- if batch > 1 :
216
- out = F .pad (out , (0 , max_num_rejected ), value = pad_id )
217
- out = out [batch_range , seq_offset_indices ]
211
+ seq_arange = torch .arange (max_seq_len , device = device , dtype = torch .long ) + (curr_len - max_seq_len )
212
+ seq_offset_indices = seq_arange - num_rejected [..., None ]
218
213
219
- cache = tuple ( F . pad ( t , ( 0 , 0 , 0 , max_num_rejected ), value = pad_id ) for t in cache )
220
- small_cache = tuple ( F . pad ( t , ( 0 , 0 , 0 , max_num_rejected ), value = pad_id ) for t in small_cache )
214
+ cached_kv , _ = cache
215
+ small_cached_kv , _ = small_cache
221
216
222
- cache = tuple ( rearrange ( t , 'b ... n d -> b n ... d' ) for t in cache )
223
- small_cache = tuple ( rearrange ( t , 'b ... n d -> b n ... d' ) for t in small_cache )
217
+ if batch > 1 :
218
+ small_cached_kv = F . pad ( small_cached_kv , ( 0 , 0 , 0 , 1 ) )
224
219
225
- cache = tuple (t [batch_range , seq_offset_indices ] for t in cache )
226
- small_cache = tuple (t [batch_range , seq_offset_indices ] for t in small_cache )
220
+ out = out [batch_range , seq_offset_indices ]
227
221
228
- cache = tuple ( rearrange (t , 'b n ... d -> b ... n d' ) for t in cache )
229
- small_cache = tuple ( rearrange (t , 'b n ... d -> b ... n d' ) for t in small_cache )
222
+ cached_kv = rearrange (cached_kv , 'b ... n d -> b n ... d' )
223
+ small_cached_kv = rearrange (small_cached_kv , 'b ... n d -> b n ... d' )
230
224
231
- if out .shape [- 1 ] > max_seq_len :
232
- left_index = out .shape [- 1 ] - max_seq_len
233
- out = out [:, left_index :]
234
- cache = tuple (t [..., left_index :, :] for t in cache )
235
- small_cache = tuple (t [..., left_index :, :] for t in small_cache )
225
+ cached_kv = cached_kv [batch_range , seq_offset_indices ]
226
+ small_cached_kv = small_cached_kv [batch_range , seq_offset_indices ]
236
227
228
+ cached_kv = rearrange (cached_kv , 'b n ... d -> b ... n d' )
229
+ small_cached_kv = rearrange (small_cached_kv , 'b n ... d -> b ... n d' )
230
+
231
+ small_cached_kv = small_cached_kv [..., :- 1 , :]
237
232
else :
238
- # if batch size of 1, just slice to be equal to the lone int in seq_lens
239
- out = out [..., :seq_lens .item ()]
233
+ # if batch size of 1, just slice to max_seq_len
234
+ out = out [..., :max_seq_len ]
235
+ cached_kv = cached_kv [..., :max_seq_len , :]
236
+ small_cached_kv = small_cached_kv [..., :max_seq_len , :]
237
+
238
+ cache = (cached_kv , None )
239
+ small_cache = (small_cached_kv , None )
240
240
241
241
# sample the additional token, one of the tricks in the paper to better bound the worst case
242
242
@@ -364,37 +364,38 @@ def speculative_decoding_with_same_model(
364
364
365
365
# do a bunch of slicing and align everything to the right, including kv caches
366
366
367
- max_num_rejected = num_rejected .amax ()
368
- seq_arange = torch .arange (out .shape [- 1 ], device = device , dtype = torch .long )
369
- seq_offset_indices = seq_arange + (max_num_rejected - num_rejected )[..., None ]
370
-
371
367
seq_lens -= num_rejected
372
368
max_seq_len = seq_lens .amax ()
369
+ curr_len = out .shape [- 1 ]
370
+
371
+ seq_arange = torch .arange (max_seq_len , device = device , dtype = torch .long ) + (curr_len - max_seq_len )
372
+ seq_offset_indices = seq_arange - num_rejected [..., None ]
373
+
374
+ cached_kv , _ = cache
375
+ small_cached_kv , _ = small_cache
373
376
374
377
if batch > 1 :
375
- out = F .pad (out , (0 , max_num_rejected ), value = pad_id )
378
+ small_cached_kv = F .pad (small_cached_kv , (0 , 0 , 0 , 1 ) )
376
379
out = out [batch_range , seq_offset_indices ]
377
380
378
- cache = tuple (F .pad (t , (0 , 0 , 0 , max_num_rejected ), value = pad_id ) for t in cache )
379
- small_cache = tuple (F .pad (t , (0 , 0 , 0 , max_num_rejected ), value = pad_id ) for t in small_cache )
380
-
381
- cache = tuple (rearrange (t , 'b ... n d -> b n ... d' ) for t in cache )
382
- small_cache = tuple (rearrange (t , 'b ... n d -> b n ... d' ) for t in small_cache )
381
+ cached_kv = rearrange (cached_kv , 'b ... n d -> b n ... d' )
382
+ small_cached_kv = rearrange (small_cached_kv , 'b ... n d -> b n ... d' )
383
383
384
- cache = tuple ( t [batch_range , seq_offset_indices ] for t in cache )
385
- small_cache = tuple ( t [batch_range , seq_offset_indices ] for t in small_cache )
384
+ cached_kv = cached_kv [batch_range , seq_offset_indices ]
385
+ small_cached_kv = small_cached_kv [batch_range , seq_offset_indices ]
386
386
387
- cache = tuple ( rearrange (t , 'b n ... d -> b ... n d' ) for t in cache )
388
- small_cache = tuple ( rearrange (t , 'b n ... d -> b ... n d' ) for t in small_cache )
387
+ cached_kv = rearrange (cached_kv , 'b n ... d -> b ... n d' )
388
+ small_cached_kv = rearrange (small_cached_kv , 'b n ... d -> b ... n d' )
389
389
390
- if out .shape [- 1 ] > max_seq_len :
391
- left_index = out .shape [- 1 ] - max_seq_len
392
- out = out [:, left_index :]
393
- cache = tuple (t [..., left_index :, :] for t in cache )
394
- small_cache = tuple (t [..., left_index :, :] for t in small_cache )
390
+ small_cached_kv [..., :- 1 , :]
395
391
else :
396
- # if batch size of 1, just slice to be equal to the lone int in seq_lens
397
- out = out [..., :seq_lens .item ()]
392
+ # if batch size of 1, just slice to max_seq_len
393
+ out = out [..., :max_seq_len ]
394
+ cached_kv = cached_kv [..., :max_seq_len , :]
395
+ small_cached_kv = small_cached_kv [..., :max_seq_len , :]
396
+
397
+ cache = (cached_kv , None )
398
+ small_cache = (small_cached_kv , None )
398
399
399
400
# sample the additional token, one of the tricks in the paper to better bound the worst case
400
401
@@ -414,17 +415,6 @@ def speculative_decoding_with_same_model(
414
415
415
416
return out [..., prompt_seq_len :], total_accepted / num_steps
416
417
417
- # norm
418
-
419
- class RMSNorm (Module ):
420
- def __init__ (self , dim ):
421
- super ().__init__ ()
422
- self .scale = dim ** 0.5
423
- self .gamma = nn .Parameter (torch .ones (dim ))
424
-
425
- def forward (self , x ):
426
- return F .normalize (x , dim = - 1 ) * self .scale * self .gamma
427
-
428
418
# attention and feedforward
429
419
430
420
class CausalAttention (Module ):
@@ -440,7 +430,7 @@ def __init__(
440
430
self .heads = heads
441
431
dim_inner = dim_head * heads
442
432
443
- self .norm = RMSNorm (dim )
433
+ self .norm = nn . RMSNorm (dim )
444
434
445
435
self .to_qkv = nn .Linear (dim , dim_inner * 3 , bias = False )
446
436
self .to_out = nn .Linear (dim_inner , dim , bias = False )
@@ -492,7 +482,7 @@ def forward(
492
482
def FeedForward (dim , mult = 4 ):
493
483
dim_inner = dim * mult
494
484
return nn .Sequential (
495
- RMSNorm (dim ),
485
+ nn . RMSNorm (dim ),
496
486
nn .Linear (dim , dim_inner ),
497
487
nn .GELU (),
498
488
nn .Linear (dim_inner , dim )
@@ -529,7 +519,7 @@ def __init__(
529
519
]))
530
520
531
521
self .to_logits = nn .Sequential (
532
- RMSNorm (dim ),
522
+ nn . RMSNorm (dim ),
533
523
nn .Linear (dim , num_tokens , bias = False )
534
524
)
535
525
0 commit comments