Skip to content

Commit 5e8b036

Browse files
committed
the extra norm in projecting to prophet model dimensions hurt for some reason
1 parent 2315a8a commit 5e8b036

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'speculative-decoding',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.0',
6+
version = '0.1.1',
77
license='MIT',
88
description = 'Speculative Decoding',
99
author = 'Phil Wang',

speculative_decoding/speculative_decoding_with_prophet.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -301,15 +301,17 @@ def __init__(
301301
model: Decoder,
302302
prophet: Decoder,
303303
prophet_train_length = 8, # should be greater than spec decoding gamma, as main model cache embedding is one step behind
304-
detach_model_embed_for_prophet = False
304+
detach_model_embed_for_prophet = False,
305+
num_leading_start_tokens = 1
305306
):
306307
super().__init__()
307308
self.model = model
308309
self.prophet = prophet
309310

310311
model_prophet_same_dim = model.dim == prophet.dim
311-
self.to_prophet_start_token = nn.Identity() if model_prophet_same_dim else nn.Sequential(RMSNorm(model.dim), nn.Linear(model.dim, prophet.dim, bias = False))
312+
self.to_prophet_start_token = nn.Identity() if model_prophet_same_dim else nn.Linear(model.dim, prophet.dim, bias = False)
312313

314+
self.num_leading_start_tokens = num_leading_start_tokens
313315
self.prophet_train_length = prophet_train_length
314316
self.detach_model_embed_for_prophet = detach_model_embed_for_prophet
315317

train_prophet.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,15 @@ def inner(*args, **kwargs):
7676

7777
prophet = Decoder(
7878
num_tokens = 256,
79-
dim = 512,
79+
dim = 256,
8080
depth = 2
8181
)
8282

8383
model_and_prophet = ModelWithProphetWrapper(
8484
model,
8585
prophet,
8686
prophet_train_length = GAMMA + 2,
87+
num_leading_start_tokens = 1,
8788
detach_model_embed_for_prophet = False # train end to end, shouldn't hurt (although benefits is dubious) given ProphetNet paper - of course, trying to get to the bottom of the benefits in spec decoding setting here
8889
).to(device)
8990

0 commit comments

Comments
 (0)