66from itertools import zip_longest
77
88import torch
9- from torch import nn , Tensor
9+ from torch import nn , Tensor , arange , cat
1010from torch .nn import Module , ModuleList
1111import torch .nn .functional as F
1212import torch .distributed as dist
1313from vector_quantize_pytorch .vector_quantize_pytorch import VectorQuantize
1414
1515from einops import rearrange , repeat , reduce , pack , unpack
1616
17+ import einx
1718from einx import get_at
1819
1920# helper functions
@@ -36,6 +37,47 @@ def unique(arr):
3637def round_up_multiple (num , mult ):
3738 return ceil (num / mult ) * mult
3839
40+ # tensor helpers
41+
42+ def pad_at_dim (
43+ t ,
44+ pad : tuple [int , int ],
45+ dim = - 1 ,
46+ value = 0.
47+ ):
48+ if pad == (0 , 0 ):
49+ return t
50+
51+ dims_from_right = (- dim - 1 ) if dim < 0 else (t .ndim - dim - 1 )
52+ zeros = ((0 , 0 ) * dims_from_right )
53+ return F .pad (t , (* zeros , * pad ), value = value )
54+
55+ def pack_one (t , pattern ):
56+ packed , packed_shape = pack ([t ], pattern )
57+
58+ def inverse (out , inv_pattern = None ):
59+ inv_pattern = default (inv_pattern , pattern )
60+ return first (unpack (out , packed_shape , inv_pattern ))
61+
62+ return packed , inverse
63+
64+ def batch_select (t , indices , pattern = None ):
65+
66+ if exists (pattern ):
67+ indices = rearrange (indices , '... k -> (...) k' )
68+ t , inv_pack = pack_one (t , pattern )
69+
70+
71+ batch_indices = arange (t .shape [0 ], device = t .device )
72+ batch_indices = rearrange (batch_indices , 'b -> b 1' )
73+
74+ out = t [batch_indices , indices ]
75+
76+ if exists (pattern ):
77+ out = inv_pack (out )
78+
79+ return out
80+
3981# distributed helpers
4082
4183def is_distributed ():
@@ -128,6 +170,8 @@ def __init__(
128170 accept_image_fmap = False ,
129171 implicit_neural_codebook = False , # QINCo from https://arxiv.org/abs/2401.14732
130172 mlp_kwargs : dict = dict (),
173+ beam_size = None ,
174+ eval_beam_size = None ,
131175 ** vq_kwargs
132176 ):
133177 super ().__init__ ()
@@ -183,6 +227,15 @@ def __init__(
183227 self .quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
184228 self .quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
185229
230+ # determine whether is using ema update
231+
232+ self .vq_is_ema_updating = first (self .layers ).ema_update
233+
234+ # beam size
235+
236+ self .beam_size = default (beam_size , eval_beam_size )
237+ self .eval_beam_size = eval_beam_size
238+
186239 # setting up the MLPs for implicit neural codebooks
187240
188241 self .mlps = None
@@ -295,19 +348,29 @@ def forward(
295348 return_all_codes = False ,
296349 sample_codebook_temp = None ,
297350 freeze_codebook = False ,
351+ beam_size = None ,
298352 rand_quantize_dropout_fixed_seed = None
299353 ):
300- num_quant , quant_dropout_multiple_of , return_loss , device = self .num_quantizers , self .quantize_dropout_multiple_of , exists (indices ), x .device
354+
355+ # variables
356+
357+ input_shape , num_quant , quant_dropout_multiple_of , return_loss , device = x .shape , self .num_quantizers , self .quantize_dropout_multiple_of , exists (indices ), x .device
358+
359+ beam_size = default (beam_size , self .beam_size if self .training else self .eval_beam_size )
360+
361+ is_beam_search = exists (beam_size ) and beam_size > 1
362+
363+ # projecting in
301364
302365 x = self .project_in (x )
303366
304367 assert not (self .accept_image_fmap and exists (indices ))
305368
306- quantized_out = 0.
369+ quantized_out = torch . zeros_like ( x )
307370 residual = x
308371
309- all_losses = []
310- all_indices = []
372+ all_losses = torch . empty (( 0 ,), device = device , dtype = torch . float32 )
373+ all_indices = torch . empty (( * input_shape [: - 1 ], 0 ), device = device , dtype = torch . long )
311374
312375 if isinstance (indices , list ):
313376 indices = torch .stack (indices )
@@ -319,7 +382,6 @@ def forward(
319382 should_quantize_dropout = self .training and self .quantize_dropout and not return_loss
320383
321384 # sample a layer index at which to dropout further residual quantization
322- # also prepare null indices and loss
323385
324386 if should_quantize_dropout :
325387
@@ -335,9 +397,24 @@ def forward(
335397 if quant_dropout_multiple_of != 1 :
336398 rand_quantize_dropout_index = round_up_multiple (rand_quantize_dropout_index + 1 , quant_dropout_multiple_of ) - 1
337399
338- null_indices_shape = (x .shape [0 ], * x .shape [- 2 :]) if self .accept_image_fmap else tuple (x .shape [:2 ])
339- null_indices = torch .full (null_indices_shape , - 1. , device = device , dtype = torch .long )
340- null_loss = torch .full ((1 ,), 0. , device = device , dtype = x .dtype )
400+ # save all inputs across layers, for use during expiration at end under shared codebook setting, or ema update during beam search
401+
402+ all_residuals = torch .empty ((* input_shape [:- 1 ], 0 , input_shape [- 1 ]), dtype = residual .dtype , device = device )
403+
404+ # maybe prepare beam search
405+
406+ if is_beam_search :
407+ prec_dims = x .shape [:- 1 ]
408+
409+ search_scores = torch .zeros ((* prec_dims , 1 ), device = device , dtype = x .dtype )
410+
411+ residual = rearrange (residual , '... d -> ... 1 d' )
412+ quantized_out = rearrange (quantized_out , '... d -> ... 1 d' )
413+
414+ all_residuals = rearrange (all_residuals , '... l d -> ... 1 l d' )
415+ all_indices = rearrange (all_indices , '... l -> ... 1 l' )
416+
417+ all_losses = all_losses .reshape (* input_shape [:- 1 ], 1 , 0 )
341418
342419 # setup the mlps for implicit neural codebook
343420
@@ -346,17 +423,13 @@ def forward(
346423 if self .implicit_neural_codebook :
347424 maybe_code_transforms = (None , * self .mlps )
348425
349- # save all inputs across layers, for use during expiration at end under shared codebook setting
350-
351- all_residuals = []
352-
353426 # go through the layers
354427
355428 for quantizer_index , (vq , maybe_mlp ) in enumerate (zip (self .layers , maybe_code_transforms )):
356429
357430 if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index :
358- all_indices . append ( null_indices )
359- all_losses . append ( null_loss )
431+ all_indices = pad_at_dim ( all_indices , ( 0 , 1 ), value = - 1 , dim = - 1 )
432+ all_losses = pad_at_dim ( all_losses , ( 0 , 1 ), value = 0 , dim = - 1 )
360433 continue
361434
362435 layer_indices = None
@@ -368,9 +441,9 @@ def forward(
368441 if exists (maybe_mlp ):
369442 maybe_mlp = partial (maybe_mlp , condition = quantized_out )
370443
371- # save for expiration
444+ # save the residual input for maybe expiration as well as ema update after beam search
372445
373- all_residuals . append ( residual )
446+ all_residuals = cat (( all_residuals , rearrange ( residual , '... d -> ... 1 d' )), dim = - 2 )
374447
375448 # vector quantize forward
376449
@@ -380,29 +453,111 @@ def forward(
380453 indices = layer_indices ,
381454 sample_codebook_temp = sample_codebook_temp ,
382455 freeze_codebook = freeze_codebook ,
383- codebook_transform_fn = maybe_mlp
456+ codebook_transform_fn = maybe_mlp ,
457+ topk = beam_size if is_beam_search else None
384458 )
385459
386- residual = residual - quantized .detach ()
387- quantized_out = quantized_out + quantized
460+ # cross entropy loss for some old paper
388461
389462 if return_loss :
390- ce_loss = rest [ 0 ]
463+ ce_loss = first ( rest )
391464 ce_losses .append (ce_loss )
392465 continue
393466
394467 embed_indices , loss = rest
395468
396- all_indices .append (embed_indices )
397- all_losses .append (loss )
469+ # handle expanding first residual if doing beam search
470+
471+ if is_beam_search :
472+
473+ search_scores = einx .add ('... j, ... j k -> ... (j k)' , search_scores , - loss )
474+
475+ residual = rearrange (residual , '... j d -> ... j 1 d' )
476+ quantized_out = rearrange (quantized_out , '... j d -> ... j 1 d' )
477+
478+ all_residuals = repeat (all_residuals , '... j l d -> ... (j k) l d' , k = beam_size )
479+
480+ # core residual vq logic
481+
482+ residual = residual - quantized .detach ()
483+ quantized_out = quantized_out + quantized
484+
485+ # handle sort and topk beams
486+
487+ if is_beam_search :
488+ residual = rearrange (residual , '... j k d -> ... (j k) d' )
489+ quantized_out = rearrange (quantized_out , '... j k d -> ... (j k) d' )
490+
491+ # broadcat the indices
492+
493+ all_indices = repeat (all_indices , '... j l -> ... j k l' , k = embed_indices .shape [- 1 ])
494+ embed_indices = rearrange (embed_indices , '... j k -> ... j k 1' )
495+
496+ all_indices = cat ((all_indices , embed_indices ), dim = - 1 )
497+ all_indices = rearrange (all_indices , '... j k l -> ... (j k) l' )
498+
499+ # broadcat the losses
500+
501+ all_losses = repeat (all_losses , '... j l -> ... j k l' , k = loss .shape [- 1 ])
502+ loss = rearrange (loss , '... -> ... 1' )
503+
504+ all_losses = cat ((all_losses , loss ), dim = - 1 )
505+ all_losses = rearrange (all_losses , '... j k l -> ... (j k) l' )
506+
507+ # handle sort and selection of highest beam size
508+
509+ if search_scores .shape [- 1 ] > beam_size :
510+ search_scores , select_indices = search_scores .topk (beam_size , dim = - 1 )
511+
512+ residual = batch_select (residual , select_indices , '* k d' )
513+ quantized_out = batch_select (quantized_out , select_indices , '* k d' )
514+
515+ all_indices = batch_select (all_indices , select_indices , '* k l' )
516+ all_losses = batch_select (all_losses , select_indices , '* k l' )
517+
518+ all_residuals = batch_select (all_residuals , select_indices , '* k l d' )
519+ else :
520+ # aggregate indices and losses
521+
522+ all_indices = cat ((all_indices , rearrange (embed_indices , '... -> ... 1' )), dim = - 1 )
523+
524+ all_losses = cat ((all_losses , rearrange (loss , '... -> ... 1' )), dim = - 1 )
525+
526+ # handle beam search
527+
528+ if is_beam_search :
529+ top_index = search_scores .argmax (dim = - 1 , keepdim = True )
530+
531+ quantized_out = batch_select (quantized_out , top_index , '* k d' )
532+ all_indices = batch_select (all_indices , top_index , '* k l' )
533+ all_losses = batch_select (all_losses , top_index , '* k l' )
534+ all_residuals = batch_select (all_residuals , top_index , '* k l d' )
535+
536+ quantized_out , all_indices , all_losses , all_residuals = [t [..., 0 , :] for t in (quantized_out , all_indices , all_losses , all_residuals )]
537+
538+ # handle commit loss, which should be the average
539+
540+ if exists (mask ):
541+ all_losses = einx .where ('..., ... l,' , mask , all_losses , 0. )
542+ all_losses = reduce (all_losses , '... l -> l' , 'sum' ) / mask .sum (dim = - 1 ).clamp_min (1e-4 )
543+ else :
544+ all_losses = reduce (all_losses , '... l -> l' , 'mean' )
545+
546+ # handle updating ema
547+
548+ if self .vq_is_ema_updating :
549+ for vq , layer_input , indices in zip (self .layers , all_residuals .unbind (dim = - 2 ), all_indices .unbind (dim = - 1 )): # in the case of quantize dropout, zip will terminate with the shorter sequence, which should be all_residuals
550+ vq .update_ema_indices (layer_input , indices , mask = mask )
398551
399552 # if shared codebook, update ema only at end
400553
401- if self .training and self .shared_codebook :
554+ if self .training and self .shared_codebook and not is_beam_search :
402555 shared_layer = first (self .layers )
403556 shared_layer ._codebook .update_ema ()
404557 shared_layer .update_in_place_optimizer ()
405- shared_layer .expire_codes_ (torch .cat (all_residuals , dim = - 2 ))
558+
559+ all_codes_for_expire = rearrange (all_residuals , '... n l d -> ... (n l) d' )
560+ shared_layer .expire_codes_ (all_codes_for_expire )
406561
407562 # project out, if needed
408563
@@ -415,8 +570,6 @@ def forward(
415570
416571 # stack all losses and indices
417572
418- all_losses , all_indices = map (partial (torch .stack , dim = - 1 ), (all_losses , all_indices ))
419-
420573 ret = (quantized_out , all_indices , all_losses )
421574
422575 if return_all_codes :
0 commit comments