Skip to content

Commit 8b36a03

Browse files
committed
first pass through residual vq beam search
1 parent 6fb8fb5 commit 8b36a03

File tree

4 files changed

+206
-29
lines changed

4 files changed

+206
-29
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.26.0"
3+
version = "1.27.0"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_manual_ema.py renamed to tests/test_beam.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,23 @@ def test_topk_and_manual_ema_update():
4242
assert torch.allclose(vq1._codebook.cluster_size, vq2._codebook.cluster_size)
4343
assert torch.allclose(vq1._codebook.embed_avg, vq2._codebook.embed_avg)
4444
assert torch.allclose(vq1.codebook, vq2.codebook)
45+
46+
def test_beam_search():
47+
import torch
48+
from vector_quantize_pytorch import ResidualVQ
49+
50+
residual_vq = ResidualVQ(
51+
dim = 256,
52+
num_quantizers = 8, # specify number of quantizers
53+
codebook_size = 1024, # codebook size
54+
quantize_dropout = True
55+
)
56+
57+
x = torch.randn(1, 1024, 256)
58+
59+
for _ in range(5):
60+
quantized, indices, commit_loss = residual_vq(x, beam_size = 3)
61+
62+
assert quantized.shape == (1, 1024, 256)
63+
assert indices.shape == (1, 1024, 8)
64+
assert commit_loss.shape == (8,)

vector_quantize_pytorch/residual_vq.py

Lines changed: 180 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
from itertools import zip_longest
77

88
import torch
9-
from torch import nn, Tensor
9+
from torch import nn, Tensor, arange, cat
1010
from torch.nn import Module, ModuleList
1111
import torch.nn.functional as F
1212
import torch.distributed as dist
1313
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
1414

1515
from einops import rearrange, repeat, reduce, pack, unpack
1616

17+
import einx
1718
from einx import get_at
1819

1920
# helper functions
@@ -36,6 +37,47 @@ def unique(arr):
3637
def round_up_multiple(num, mult):
3738
return ceil(num / mult) * mult
3839

40+
# tensor helpers
41+
42+
def pad_at_dim(
43+
t,
44+
pad: tuple[int, int],
45+
dim = -1,
46+
value = 0.
47+
):
48+
if pad == (0, 0):
49+
return t
50+
51+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
52+
zeros = ((0, 0) * dims_from_right)
53+
return F.pad(t, (*zeros, *pad), value = value)
54+
55+
def pack_one(t, pattern):
56+
packed, packed_shape = pack([t], pattern)
57+
58+
def inverse(out, inv_pattern = None):
59+
inv_pattern = default(inv_pattern, pattern)
60+
return first(unpack(out, packed_shape, inv_pattern))
61+
62+
return packed, inverse
63+
64+
def batch_select(t, indices, pattern = None):
65+
66+
if exists(pattern):
67+
indices = rearrange(indices, '... k -> (...) k')
68+
t, inv_pack = pack_one(t, pattern)
69+
70+
71+
batch_indices = arange(t.shape[0], device = t.device)
72+
batch_indices = rearrange(batch_indices, 'b -> b 1')
73+
74+
out = t[batch_indices, indices]
75+
76+
if exists(pattern):
77+
out = inv_pack(out)
78+
79+
return out
80+
3981
# distributed helpers
4082

4183
def is_distributed():
@@ -128,6 +170,8 @@ def __init__(
128170
accept_image_fmap = False,
129171
implicit_neural_codebook = False, # QINCo from https://arxiv.org/abs/2401.14732
130172
mlp_kwargs: dict = dict(),
173+
beam_size = None,
174+
eval_beam_size = None,
131175
**vq_kwargs
132176
):
133177
super().__init__()
@@ -183,6 +227,15 @@ def __init__(
183227
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
184228
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
185229

230+
# determine whether is using ema update
231+
232+
self.vq_is_ema_updating = first(self.layers).ema_update
233+
234+
# beam size
235+
236+
self.beam_size = default(beam_size, eval_beam_size)
237+
self.eval_beam_size = eval_beam_size
238+
186239
# setting up the MLPs for implicit neural codebooks
187240

188241
self.mlps = None
@@ -295,19 +348,29 @@ def forward(
295348
return_all_codes = False,
296349
sample_codebook_temp = None,
297350
freeze_codebook = False,
351+
beam_size = None,
298352
rand_quantize_dropout_fixed_seed = None
299353
):
300-
num_quant, quant_dropout_multiple_of, return_loss, device = self.num_quantizers, self.quantize_dropout_multiple_of, exists(indices), x.device
354+
355+
# variables
356+
357+
input_shape, num_quant, quant_dropout_multiple_of, return_loss, device = x.shape, self.num_quantizers, self.quantize_dropout_multiple_of, exists(indices), x.device
358+
359+
beam_size = default(beam_size, self.beam_size if self.training else self.eval_beam_size)
360+
361+
is_beam_search = exists(beam_size) and beam_size > 1
362+
363+
# projecting in
301364

302365
x = self.project_in(x)
303366

304367
assert not (self.accept_image_fmap and exists(indices))
305368

306-
quantized_out = 0.
369+
quantized_out = torch.zeros_like(x)
307370
residual = x
308371

309-
all_losses = []
310-
all_indices = []
372+
all_losses = torch.empty((0,), device = device, dtype = torch.float32)
373+
all_indices = torch.empty((*input_shape[:-1], 0), device = device, dtype = torch.long)
311374

312375
if isinstance(indices, list):
313376
indices = torch.stack(indices)
@@ -319,7 +382,6 @@ def forward(
319382
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
320383

321384
# sample a layer index at which to dropout further residual quantization
322-
# also prepare null indices and loss
323385

324386
if should_quantize_dropout:
325387

@@ -335,9 +397,24 @@ def forward(
335397
if quant_dropout_multiple_of != 1:
336398
rand_quantize_dropout_index = round_up_multiple(rand_quantize_dropout_index + 1, quant_dropout_multiple_of) - 1
337399

338-
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
339-
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
340-
null_loss = torch.full((1,), 0., device = device, dtype = x.dtype)
400+
# save all inputs across layers, for use during expiration at end under shared codebook setting, or ema update during beam search
401+
402+
all_residuals = torch.empty((*input_shape[:-1], 0, input_shape[-1]), dtype = residual.dtype, device = device)
403+
404+
# maybe prepare beam search
405+
406+
if is_beam_search:
407+
prec_dims = x.shape[:-1]
408+
409+
search_scores = torch.zeros((*prec_dims, 1), device = device, dtype = x.dtype)
410+
411+
residual = rearrange(residual, '... d -> ... 1 d')
412+
quantized_out = rearrange(quantized_out, '... d -> ... 1 d')
413+
414+
all_residuals = rearrange(all_residuals, '... l d -> ... 1 l d')
415+
all_indices = rearrange(all_indices, '... l -> ... 1 l')
416+
417+
all_losses = all_losses.reshape(*input_shape[:-1], 1, 0)
341418

342419
# setup the mlps for implicit neural codebook
343420

@@ -346,17 +423,13 @@ def forward(
346423
if self.implicit_neural_codebook:
347424
maybe_code_transforms = (None, *self.mlps)
348425

349-
# save all inputs across layers, for use during expiration at end under shared codebook setting
350-
351-
all_residuals = []
352-
353426
# go through the layers
354427

355428
for quantizer_index, (vq, maybe_mlp) in enumerate(zip(self.layers, maybe_code_transforms)):
356429

357430
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
358-
all_indices.append(null_indices)
359-
all_losses.append(null_loss)
431+
all_indices = pad_at_dim(all_indices, (0, 1), value = -1, dim = -1)
432+
all_losses = pad_at_dim(all_losses, (0, 1), value = 0, dim = -1)
360433
continue
361434

362435
layer_indices = None
@@ -368,9 +441,9 @@ def forward(
368441
if exists(maybe_mlp):
369442
maybe_mlp = partial(maybe_mlp, condition = quantized_out)
370443

371-
# save for expiration
444+
# save the residual input for maybe expiration as well as ema update after beam search
372445

373-
all_residuals.append(residual)
446+
all_residuals = cat((all_residuals, rearrange(residual, '... d -> ... 1 d')), dim = -2)
374447

375448
# vector quantize forward
376449

@@ -380,29 +453,111 @@ def forward(
380453
indices = layer_indices,
381454
sample_codebook_temp = sample_codebook_temp,
382455
freeze_codebook = freeze_codebook,
383-
codebook_transform_fn = maybe_mlp
456+
codebook_transform_fn = maybe_mlp,
457+
topk = beam_size if is_beam_search else None
384458
)
385459

386-
residual = residual - quantized.detach()
387-
quantized_out = quantized_out + quantized
460+
# cross entropy loss for some old paper
388461

389462
if return_loss:
390-
ce_loss = rest[0]
463+
ce_loss = first(rest)
391464
ce_losses.append(ce_loss)
392465
continue
393466

394467
embed_indices, loss = rest
395468

396-
all_indices.append(embed_indices)
397-
all_losses.append(loss)
469+
# handle expanding first residual if doing beam search
470+
471+
if is_beam_search:
472+
473+
search_scores = einx.add('... j, ... j k -> ... (j k)', search_scores, -loss)
474+
475+
residual = rearrange(residual, '... j d -> ... j 1 d')
476+
quantized_out = rearrange(quantized_out, '... j d -> ... j 1 d')
477+
478+
all_residuals = repeat(all_residuals, '... j l d -> ... (j k) l d', k = beam_size)
479+
480+
# core residual vq logic
481+
482+
residual = residual - quantized.detach()
483+
quantized_out = quantized_out + quantized
484+
485+
# handle sort and topk beams
486+
487+
if is_beam_search:
488+
residual = rearrange(residual, '... j k d -> ... (j k) d')
489+
quantized_out = rearrange(quantized_out, '... j k d -> ... (j k) d')
490+
491+
# broadcat the indices
492+
493+
all_indices = repeat(all_indices, '... j l -> ... j k l', k = embed_indices.shape[-1])
494+
embed_indices = rearrange(embed_indices, '... j k -> ... j k 1')
495+
496+
all_indices = cat((all_indices, embed_indices), dim = -1)
497+
all_indices = rearrange(all_indices, '... j k l -> ... (j k) l')
498+
499+
# broadcat the losses
500+
501+
all_losses = repeat(all_losses, '... j l -> ... j k l', k = loss.shape[-1])
502+
loss = rearrange(loss, '... -> ... 1')
503+
504+
all_losses = cat((all_losses, loss), dim = -1)
505+
all_losses = rearrange(all_losses, '... j k l -> ... (j k) l')
506+
507+
# handle sort and selection of highest beam size
508+
509+
if search_scores.shape[-1] > beam_size:
510+
search_scores, select_indices = search_scores.topk(beam_size, dim = -1)
511+
512+
residual = batch_select(residual, select_indices, '* k d')
513+
quantized_out = batch_select(quantized_out, select_indices, '* k d')
514+
515+
all_indices = batch_select(all_indices, select_indices, '* k l')
516+
all_losses = batch_select(all_losses, select_indices, '* k l')
517+
518+
all_residuals = batch_select(all_residuals, select_indices, '* k l d')
519+
else:
520+
# aggregate indices and losses
521+
522+
all_indices = cat((all_indices, rearrange(embed_indices, '... -> ... 1')), dim = -1)
523+
524+
all_losses = cat((all_losses, rearrange(loss, '... -> ... 1')), dim = -1)
525+
526+
# handle beam search
527+
528+
if is_beam_search:
529+
top_index = search_scores.argmax(dim = -1, keepdim = True)
530+
531+
quantized_out = batch_select(quantized_out, top_index, '* k d')
532+
all_indices = batch_select(all_indices, top_index, '* k l')
533+
all_losses = batch_select(all_losses, top_index, '* k l')
534+
all_residuals = batch_select(all_residuals, top_index, '* k l d')
535+
536+
quantized_out, all_indices, all_losses, all_residuals = [t[..., 0, :] for t in (quantized_out, all_indices, all_losses, all_residuals)]
537+
538+
# handle commit loss, which should be the average
539+
540+
if exists(mask):
541+
all_losses = einx.where('..., ... l,', mask, all_losses, 0.)
542+
all_losses = reduce(all_losses, '... l -> l', 'sum') / mask.sum(dim = -1).clamp_min(1e-4)
543+
else:
544+
all_losses = reduce(all_losses, '... l -> l', 'mean')
545+
546+
# handle updating ema
547+
548+
if self.vq_is_ema_updating:
549+
for vq, layer_input, indices in zip(self.layers, all_residuals.unbind(dim = -2), all_indices.unbind(dim = -1)): # in the case of quantize dropout, zip will terminate with the shorter sequence, which should be all_residuals
550+
vq.update_ema_indices(layer_input, indices, mask = mask)
398551

399552
# if shared codebook, update ema only at end
400553

401-
if self.training and self.shared_codebook:
554+
if self.training and self.shared_codebook and not is_beam_search:
402555
shared_layer = first(self.layers)
403556
shared_layer._codebook.update_ema()
404557
shared_layer.update_in_place_optimizer()
405-
shared_layer.expire_codes_(torch.cat(all_residuals, dim = -2))
558+
559+
all_codes_for_expire = rearrange(all_residuals, '... n l d -> ... (n l) d')
560+
shared_layer.expire_codes_(all_codes_for_expire)
406561

407562
# project out, if needed
408563

@@ -415,8 +570,6 @@ def forward(
415570

416571
# stack all losses and indices
417572

418-
all_losses, all_indices = map(partial(torch.stack, dim = -1), (all_losses, all_indices))
419-
420573
ret = (quantized_out, all_indices, all_losses)
421574

422575
if return_all_codes:

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def forward(
728728
repeated_embed_ind = repeat(embed_ind, 'h b n -> h b n d', d = embed.shape[-1])
729729
quantize = repeated_embed.gather(-2, repeated_embed_ind)
730730

731-
if self.training and ema_update and not freeze_codebook:
731+
if self.training and ema_update and not freeze_codebook and not exists(topk):
732732
self.update_ema_part(flatten, embed_onehot, mask = mask, ema_update_weight = ema_update_weight, accum_ema_update = accum_ema_update)
733733

734734
if needs_codebook_dim:
@@ -912,6 +912,10 @@ def __init__(
912912
# whether to freeze the codebook, can be overridden on forward
913913
self.freeze_codebook = freeze_codebook
914914

915+
@property
916+
def ema_update(self):
917+
return self._codebook.ema_update
918+
915919
@property
916920
def codebook(self):
917921
codebook = self._codebook.embed

0 commit comments

Comments
 (0)