Skip to content

Commit bf32309

Browse files
committed
improv
1 parent bf1e07b commit bf32309

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.27.5"
3+
version = "1.27.6"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_readme.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@ def test_vq_mask():
7474
@pytest.mark.parametrize('implicit_neural_codebook, use_cosine_sim', ((True, False), (False, True), (False, False)))
7575
@pytest.mark.parametrize('train', (True, False))
7676
@pytest.mark.parametrize('shared_codebook', (True, False))
77+
@pytest.mark.parametrize('quant_grad_frac', (0., 0.1))
7778
def test_residual_vq(
7879
implicit_neural_codebook,
7980
use_cosine_sim,
8081
train,
81-
shared_codebook
82+
shared_codebook,
83+
quant_grad_frac
8284
):
8385
from vector_quantize_pytorch import ResidualVQ
8486

@@ -88,7 +90,8 @@ def test_residual_vq(
8890
codebook_size = 128,
8991
implicit_neural_codebook = implicit_neural_codebook,
9092
use_cosine_sim = use_cosine_sim,
91-
shared_codebook = shared_codebook
93+
shared_codebook = shared_codebook,
94+
quant_grad_frac = quant_grad_frac
9295
)
9396

9497
x = torch.randn(1, 256, 32)

vector_quantize_pytorch/residual_vq.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ def unique(arr):
3838
def round_up_multiple(num, mult):
3939
return ceil(num / mult) * mult
4040

41+
def frac_gradient(t, frac):
42+
detached = t.detach()
43+
44+
if frac <= 0:
45+
return detached
46+
47+
return frac * t + (1. - frac) * detached
48+
4149
# tensor helpers
4250

4351
def pad_at_dim(
@@ -174,6 +182,7 @@ def __init__(
174182
beam_size = None,
175183
eval_beam_size = None,
176184
beam_score_quantizer_weights: list[float] | None = None,
185+
quant_grad_frac = 0.,
177186
**vq_kwargs
178187
):
179188
super().__init__()
@@ -233,6 +242,10 @@ def __init__(
233242

234243
self.vq_is_ema_updating = first(self.layers).ema_update
235244

245+
# gradient related - how much gradients to allow up the residual path (previous layers influence the layer above)
246+
247+
self.quant_grad_frac = quant_grad_frac
248+
236249
# beam size
237250

238251
assert not (exists(eval_beam_size) and not exists(beam_size))
@@ -493,7 +506,7 @@ def forward(
493506

494507
# core residual vq logic
495508

496-
residual = residual - quantized.detach()
509+
residual = residual - frac_gradient(quantized, self.quant_grad_frac)
497510
quantized_out = quantized_out + quantized
498511

499512
# handle sort and topk beams

0 commit comments

Comments
 (0)