Skip to content

Commit bf1e07b

Browse files
committed
quick fix
1 parent 0b8fbad commit bf1e07b

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.27.4"
3+
version = "1.27.5"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_beam.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ def test_beam_search():
5151
dim = 256,
5252
num_quantizers = 8, # specify number of quantizers
5353
codebook_size = 1024, # codebook size
54-
quantize_dropout = True
54+
quantize_dropout = True,
55+
beam_size = 2,
56+
eval_beam_size = 3
5557
)
5658

5759
x = torch.randn(1, 1024, 256)
5860

5961
for _ in range(5):
60-
quantized, indices, commit_loss = residual_vq(x, beam_size = 3)
62+
quantized, indices, commit_loss = residual_vq(x)
6163

6264
assert quantized.shape == (1, 1024, 256)
6365
assert indices.shape == (1, 1024, 8)

vector_quantize_pytorch/residual_vq.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,10 @@ def __init__(
235235

236236
# beam size
237237

238-
self.beam_size = default(beam_size, eval_beam_size)
239-
self.eval_beam_size = eval_beam_size
238+
assert not (exists(eval_beam_size) and not exists(beam_size))
239+
240+
self.beam_size = beam_size
241+
self.eval_beam_size = default(eval_beam_size, beam_size)
240242

241243
# able to assign a different weight for the scoring at each quantizer layer
242244

0 commit comments

Comments
 (0)