Skip to content

AssertionError in "ppgs.from_audio" with Batched Audio Input #20

@dzq84

Description

@dzq84

Code :

import ppgs
import torch

audio_file = 'test.wav'

audio = ppgs.load.audio(audio_file)  

audio_batch = audio.repeat(2, 1).unsqueeze(dim=1)

gpu = 0

latent = ppgs.from_audio(audio_batch, ppgs.SAMPLE_RATE, gpu=gpu)

print(f"Latent shape: {latent.shape}")

Error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[2], [line 12](vscode-notebook-cell:?execution_count=2&line=12)
      [8](vscode-notebook-cell:?execution_count=2&line=8) audio_batch = audio.repeat(2, 1).unsqueeze(dim=1)
     [10](vscode-notebook-cell:?execution_count=2&line=10) gpu = 0
---> [12](vscode-notebook-cell:?execution_count=2&line=12) latent = ppgs.from_audio(audio_batch, ppgs.SAMPLE_RATE, gpu=gpu)
     [14](vscode-notebook-cell:?execution_count=2&line=14) print(f"Latent shape: {latent.shape}")

File ~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:63, in from_audio(audio, sample_rate, representation, checkpoint, gpu, legacy_mode)
     [60](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:60) length = torch.tensor([features.shape[-1]], dtype=torch.long)
     [62](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:62) # Infer
---> [63](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:63) return from_features(
     [64](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:64)     features=features,
     [65](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:65)     lengths=length,
     [66](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:66)     representation=representation,
     [67](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:67)     checkpoint=checkpoint,
     [68](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:68)     gpu=gpu,
     [69](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:69)     legacy_mode=legacy_mode)

File ~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:122, in from_features(features, lengths, representation, checkpoint, gpu, softmax, legacy_mode)
    [119](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:119)         features = from_features.frontend(features.to(device))
    [121](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:121) # Infer
--> [122](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:122) return infer(
    [123](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:123)     features=features.to(device),
    [124](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:124)     lengths=lengths.to(device),
...
   [1638](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/torch/__init__.py:1638)         raise TypeError("message must be a callable")
   [1640](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/torch/__init__.py:1640)     message_evaluated = str(message())
-> [1642](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/torch/__init__.py:1642) raise error_type(message_evaluated)

AssertionError: Expected key_padded_mask.shape[0] to be 2, but got 1

Env :
ppgs = 0.0.9
torch = 2.7.0
python = 3.10.16

Metadata

Metadata

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions