Skip to content

Commit becf35d

Browse files
committed
address #3
1 parent 3194f53 commit becf35d

File tree

6 files changed

+78
-102
lines changed

6 files changed

+78
-102
lines changed

setup.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'speculative-decoding',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.4',
6+
version = '0.2.0',
77
license='MIT',
88
description = 'Speculative Decoding',
99
author = 'Phil Wang',
@@ -18,8 +18,8 @@
1818
],
1919
install_requires=[
2020
'beartype',
21-
'einops>=0.6.1',
22-
'torch>=1.12',
21+
'einops>=0.8.0',
22+
'torch>=2.4',
2323
],
2424
classifiers=[
2525
'Development Status :: 4 - Beta',

speculative_decoding/speculative_decoding.py

+48-58
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch import nn, einsum, Tensor
66
import torch.nn.functional as F
77

8-
from rotary_embedding_torch import RotaryEmbedding
98
from beartype import beartype
109

1110
from collections import namedtuple
@@ -205,38 +204,39 @@ def speculative_decoding(
205204

206205
# do a bunch of slicing and align everything to the right, including kv caches
207206

208-
max_num_rejected = num_rejected.amax()
209-
seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long)
210-
seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None]
211-
212207
seq_lens -= num_rejected
213208
max_seq_len = seq_lens.amax()
209+
curr_len = out.shape[-1]
214210

215-
if batch > 1:
216-
out = F.pad(out, (0, max_num_rejected), value = pad_id)
217-
out = out[batch_range, seq_offset_indices]
211+
seq_arange = torch.arange(max_seq_len, device = device, dtype = torch.long) + (curr_len - max_seq_len)
212+
seq_offset_indices = seq_arange - num_rejected[..., None]
218213

219-
cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache)
220-
small_cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in small_cache)
214+
cached_kv, _ = cache
215+
small_cached_kv, _ = small_cache
221216

222-
cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache)
223-
small_cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in small_cache)
217+
if batch > 1:
218+
small_cached_kv = F.pad(small_cached_kv, (0, 0, 0, 1))
224219

225-
cache = tuple(t[batch_range, seq_offset_indices] for t in cache)
226-
small_cache = tuple(t[batch_range, seq_offset_indices] for t in small_cache)
220+
out = out[batch_range, seq_offset_indices]
227221

228-
cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache)
229-
small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache)
222+
cached_kv = rearrange(cached_kv, 'b ... n d -> b n ... d')
223+
small_cached_kv = rearrange(small_cached_kv, 'b ... n d -> b n ... d')
230224

231-
if out.shape[-1] > max_seq_len:
232-
left_index = out.shape[-1] - max_seq_len
233-
out = out[:, left_index:]
234-
cache = tuple(t[..., left_index:, :] for t in cache)
235-
small_cache = tuple(t[..., left_index:, :] for t in small_cache)
225+
cached_kv = cached_kv[batch_range, seq_offset_indices]
226+
small_cached_kv = small_cached_kv[batch_range, seq_offset_indices]
236227

228+
cached_kv = rearrange(cached_kv, 'b n ... d -> b ... n d')
229+
small_cached_kv = rearrange(small_cached_kv, 'b n ... d -> b ... n d')
230+
231+
small_cached_kv = small_cached_kv[..., :-1, :]
237232
else:
238-
# if batch size of 1, just slice to be equal to the lone int in seq_lens
239-
out = out[..., :seq_lens.item()]
233+
# if batch size of 1, just slice to max_seq_len
234+
out = out[..., :max_seq_len]
235+
cached_kv = cached_kv[..., :max_seq_len, :]
236+
small_cached_kv = small_cached_kv[..., :max_seq_len, :]
237+
238+
cache = (cached_kv, None)
239+
small_cache = (small_cached_kv, None)
240240

241241
# sample the additional token, one of the tricks in the paper to better bound the worst case
242242

@@ -364,37 +364,38 @@ def speculative_decoding_with_same_model(
364364

365365
# do a bunch of slicing and align everything to the right, including kv caches
366366

367-
max_num_rejected = num_rejected.amax()
368-
seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long)
369-
seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None]
370-
371367
seq_lens -= num_rejected
372368
max_seq_len = seq_lens.amax()
369+
curr_len = out.shape[-1]
370+
371+
seq_arange = torch.arange(max_seq_len, device = device, dtype = torch.long) + (curr_len - max_seq_len)
372+
seq_offset_indices = seq_arange - num_rejected[..., None]
373+
374+
cached_kv, _ = cache
375+
small_cached_kv, _ = small_cache
373376

374377
if batch > 1:
375-
out = F.pad(out, (0, max_num_rejected), value = pad_id)
378+
small_cached_kv = F.pad(small_cached_kv, (0, 0, 0, 1))
376379
out = out[batch_range, seq_offset_indices]
377380

378-
cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache)
379-
small_cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in small_cache)
380-
381-
cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache)
382-
small_cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in small_cache)
381+
cached_kv = rearrange(cached_kv, 'b ... n d -> b n ... d')
382+
small_cached_kv = rearrange(small_cached_kv, 'b ... n d -> b n ... d')
383383

384-
cache = tuple(t[batch_range, seq_offset_indices] for t in cache)
385-
small_cache = tuple(t[batch_range, seq_offset_indices] for t in small_cache)
384+
cached_kv = cached_kv[batch_range, seq_offset_indices]
385+
small_cached_kv = small_cached_kv[batch_range, seq_offset_indices]
386386

387-
cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache)
388-
small_cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in small_cache)
387+
cached_kv = rearrange(cached_kv, 'b n ... d -> b ... n d')
388+
small_cached_kv = rearrange(small_cached_kv, 'b n ... d -> b ... n d')
389389

390-
if out.shape[-1] > max_seq_len:
391-
left_index = out.shape[-1] - max_seq_len
392-
out = out[:, left_index:]
393-
cache = tuple(t[..., left_index:, :] for t in cache)
394-
small_cache = tuple(t[..., left_index:, :] for t in small_cache)
390+
small_cached_kv[..., :-1, :]
395391
else:
396-
# if batch size of 1, just slice to be equal to the lone int in seq_lens
397-
out = out[..., :seq_lens.item()]
392+
# if batch size of 1, just slice to max_seq_len
393+
out = out[..., :max_seq_len]
394+
cached_kv = cached_kv[..., :max_seq_len, :]
395+
small_cached_kv = small_cached_kv[..., :max_seq_len, :]
396+
397+
cache = (cached_kv, None)
398+
small_cache = (small_cached_kv, None)
398399

399400
# sample the additional token, one of the tricks in the paper to better bound the worst case
400401

@@ -414,17 +415,6 @@ def speculative_decoding_with_same_model(
414415

415416
return out[..., prompt_seq_len:], total_accepted / num_steps
416417

417-
# norm
418-
419-
class RMSNorm(Module):
420-
def __init__(self, dim):
421-
super().__init__()
422-
self.scale = dim ** 0.5
423-
self.gamma = nn.Parameter(torch.ones(dim))
424-
425-
def forward(self, x):
426-
return F.normalize(x, dim = -1) * self.scale * self.gamma
427-
428418
# attention and feedforward
429419

430420
class CausalAttention(Module):
@@ -440,7 +430,7 @@ def __init__(
440430
self.heads = heads
441431
dim_inner = dim_head * heads
442432

443-
self.norm = RMSNorm(dim)
433+
self.norm = nn.RMSNorm(dim)
444434

445435
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
446436
self.to_out = nn.Linear(dim_inner, dim, bias = False)
@@ -492,7 +482,7 @@ def forward(
492482
def FeedForward(dim, mult = 4):
493483
dim_inner = dim * mult
494484
return nn.Sequential(
495-
RMSNorm(dim),
485+
nn.RMSNorm(dim),
496486
nn.Linear(dim, dim_inner),
497487
nn.GELU(),
498488
nn.Linear(dim_inner, dim)
@@ -529,7 +519,7 @@ def __init__(
529519
]))
530520

531521
self.to_logits = nn.Sequential(
532-
RMSNorm(dim),
522+
nn.RMSNorm(dim),
533523
nn.Linear(dim, num_tokens, bias = False)
534524
)
535525

speculative_decoding/speculative_decoding_with_prophet.py

+17-27
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch import nn, einsum, Tensor
66
import torch.nn.functional as F
77

8-
from rotary_embedding_torch import RotaryEmbedding
98
from beartype import beartype
109

1110
from collections import namedtuple
@@ -92,17 +91,6 @@ def base_decoding(
9291

9392
return out[..., prompt_seq_len:]
9493

95-
# norm
96-
97-
class RMSNorm(Module):
98-
def __init__(self, dim):
99-
super().__init__()
100-
self.scale = dim ** 0.5
101-
self.gamma = nn.Parameter(torch.ones(dim))
102-
103-
def forward(self, x):
104-
return F.normalize(x, dim = -1) * self.scale * self.gamma
105-
10694
# attention and feedforward
10795

10896
class CausalAttention(Module):
@@ -118,7 +106,7 @@ def __init__(
118106
self.heads = heads
119107
dim_inner = dim_head * heads
120108

121-
self.norm = RMSNorm(dim)
109+
self.norm = nn.RMSNorm(dim)
122110

123111
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
124112
self.to_out = nn.Linear(dim_inner, dim, bias = False)
@@ -170,7 +158,7 @@ def forward(
170158
def FeedForward(dim, mult = 4):
171159
dim_inner = dim * mult
172160
return nn.Sequential(
173-
RMSNorm(dim),
161+
nn.RMSNorm(dim),
174162
nn.Linear(dim, dim_inner),
175163
nn.GELU(),
176164
nn.Linear(dim_inner, dim)
@@ -205,7 +193,7 @@ def __init__(
205193
]))
206194

207195
self.to_logits = nn.Sequential(
208-
RMSNorm(dim),
196+
nn.RMSNorm(dim),
209197
nn.Linear(dim, num_tokens, bias = False)
210198
)
211199

@@ -509,25 +497,28 @@ def speculative_decoding_with_prophet_model(
509497
# do a bunch of slicing and align everything to the right, including kv caches
510498

511499
max_num_rejected = num_rejected.amax()
512-
seq_arange = torch.arange(out.shape[-1], device = device, dtype = torch.long)
513-
seq_offset_indices = seq_arange + (max_num_rejected - num_rejected)[..., None]
514500

501+
curr_len = out.shape[-1]
515502
seq_lens -= num_rejected
516503
max_seq_len = seq_lens.amax()
517504

505+
seq_arange = torch.arange(max_seq_len, device = device, dtype = torch.long) + (curr_len - max_seq_len)
506+
507+
seq_offset_indices = seq_arange - num_rejected[..., None]
508+
509+
cached_kv, embed = cache
510+
518511
if batch > 1:
519-
out = F.pad(out, (0, max_num_rejected), value = pad_id)
520512
out = out[batch_range, seq_offset_indices]
521513

522-
cache = tuple(F.pad(t, (0, 0, 0, max_num_rejected), value = pad_id) for t in cache)
523-
cache = tuple(rearrange(t, 'b ... n d -> b n ... d') for t in cache)
524-
cache = tuple(t[batch_range, seq_offset_indices] for t in cache)
525-
cache = tuple(rearrange(t, 'b n ... d -> b ... n d') for t in cache)
514+
cached_kv = rearrange(cached_kv, 'b ... n d -> b n ... d')
515+
cached_kv = cached_kv[batch_range, seq_offset_indices]
516+
cached_kv = rearrange(cached_kv, 'b n ... d -> b ... n d')
517+
else:
518+
out = out[..., :max_seq_len]
519+
cached_kv = cached_kv[..., :max_seq_len, :]
526520

527-
if out.shape[-1] > max_seq_len:
528-
left_index = out.shape[-1] - max_seq_len
529-
out = out[:, left_index:]
530-
cache = tuple(t[..., left_index:, :] for t in cache)
521+
cache = (cached_kv, None)
531522

532523
# sample the additional token, one of the tricks in the paper to better bound the worst case
533524

@@ -536,7 +527,6 @@ def speculative_decoding_with_prophet_model(
536527
out = torch.cat((out, next_token), dim = -1)
537528
seq_lens += 1
538529

539-
_, embeds = cache
540530
next_prophet_start_tokens = to_prophet_start_token(embeds[:, -num_start_tokens:])
541531

542532
# now left align

train.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
LEARNING_RATE = 1e-4
2828
VALIDATE_EVERY = 100
2929
PRIME_LENGTH = 128
30-
GENERATE_EVERY = 500
30+
GENERATE_EVERY = 100
3131
GENERATE_LENGTH = 512
3232
SEQ_LEN = 512
3333
GAMMA = 5
@@ -104,7 +104,7 @@ def __len__(self):
104104
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
105105
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
106106
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
107-
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
107+
val_loader = cycle(DataLoader(val_dataset, batch_size=1))
108108

109109
# optimizer
110110

@@ -126,8 +126,7 @@ def __len__(self):
126126
(loss / GRAD_ACCUM_EVERY).backward()
127127
(small_loss / GRAD_ACCUM_EVERY).backward()
128128

129-
print(f"training loss: {loss.item():.3f}")
130-
print(f"training small loss: {small_loss.item():.3f}")
129+
print(f"loss: {loss.item():.3f}\tsmall loss: {small_loss.item():.3f}")
131130

132131
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
133132
torch.nn.utils.clip_grad_norm_(small_model.parameters(), 0.5)
@@ -144,10 +143,9 @@ def __len__(self):
144143
valid_data = next(val_loader)
145144

146145
loss = model(valid_data, return_loss = True)
147-
print(f"validation loss: {loss.item():.3f}")
148-
149146
small_loss = small_model(valid_data, return_loss = True)
150-
print(f"validation small loss: {small_loss.item():.3f}")
147+
148+
print(f"validation - loss: {loss.item():.3f}\tsmall loss: {small_loss.item():.3f}")
151149

152150
if i % GENERATE_EVERY == 0:
153151
model.eval()
@@ -157,7 +155,8 @@ def __len__(self):
157155
prime = decode_tokens(inp)
158156
print(f"%s \n\n %s", (prime, "*" * 100))
159157

160-
prompt = inp[None, ...]
158+
from einops import repeat
159+
prompt = repeat(inp, '... -> b ...', b = 2)
161160

162161
sampled, base_decode_elapsed = benchmark(base_decoding)(model, prompt, GENERATE_LENGTH)
163162

train_early_exit.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ def __len__(self):
116116

117117
((loss + small_loss * EARLY_EXIT_LOSS_WEIGHT) / GRAD_ACCUM_EVERY).backward()
118118

119-
print(f"training loss: {loss.item():.3f}")
120-
print(f"training small loss: {small_loss.item():.3f}")
119+
print(f"loss: {loss.item():.3f}\tsmall loss: {small_loss.item():.3f}")
121120

122121
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
123122

@@ -130,8 +129,7 @@ def __len__(self):
130129
valid_data = next(val_loader)
131130

132131
loss, small_loss = model(valid_data, return_loss = True)
133-
print(f"validation loss: {loss.item():.3f}")
134-
print(f"validation small loss: {small_loss.item():.3f}")
132+
print(f"validation - loss: {loss.item():.3f}\tsmall loss: {small_loss.item():.3f}")
135133

136134
if i % GENERATE_EVERY == 0:
137135
model.eval()

train_prophet.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,7 @@ def __len__(self):
130130

131131
(total_loss / GRAD_ACCUM_EVERY).backward()
132132

133-
print(f"training loss: {loss.item():.3f}")
134-
print(f"training prophet loss: {prophet_loss.item():.3f}")
133+
print(f"loss: {loss.item():.3f}\tprophet loss: {prophet_loss.item():.3f}")
135134

136135
torch.nn.utils.clip_grad_norm_(model_and_prophet.parameters(), 0.5)
137136

0 commit comments

Comments
 (0)