-
Notifications
You must be signed in to change notification settings - Fork 8
Open
Description
Code :
import ppgs
import torch
audio_file = 'test.wav'
audio = ppgs.load.audio(audio_file)
audio_batch = audio.repeat(2, 1).unsqueeze(dim=1)
gpu = 0
latent = ppgs.from_audio(audio_batch, ppgs.SAMPLE_RATE, gpu=gpu)
print(f"Latent shape: {latent.shape}")
Error:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[2], [line 12](vscode-notebook-cell:?execution_count=2&line=12)
[8](vscode-notebook-cell:?execution_count=2&line=8) audio_batch = audio.repeat(2, 1).unsqueeze(dim=1)
[10](vscode-notebook-cell:?execution_count=2&line=10) gpu = 0
---> [12](vscode-notebook-cell:?execution_count=2&line=12) latent = ppgs.from_audio(audio_batch, ppgs.SAMPLE_RATE, gpu=gpu)
[14](vscode-notebook-cell:?execution_count=2&line=14) print(f"Latent shape: {latent.shape}")
File ~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:63, in from_audio(audio, sample_rate, representation, checkpoint, gpu, legacy_mode)
[60](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:60) length = torch.tensor([features.shape[-1]], dtype=torch.long)
[62](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:62) # Infer
---> [63](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:63) return from_features(
[64](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:64) features=features,
[65](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:65) lengths=length,
[66](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:66) representation=representation,
[67](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:67) checkpoint=checkpoint,
[68](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:68) gpu=gpu,
[69](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:69) legacy_mode=legacy_mode)
File ~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:122, in from_features(features, lengths, representation, checkpoint, gpu, softmax, legacy_mode)
[119](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:119) features = from_features.frontend(features.to(device))
[121](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:121) # Infer
--> [122](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:122) return infer(
[123](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:123) features=features.to(device),
[124](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/ppgs/core.py:124) lengths=lengths.to(device),
...
[1638](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/torch/__init__.py:1638) raise TypeError("message must be a callable")
[1640](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/torch/__init__.py:1640) message_evaluated = str(message())
-> [1642](https://vscode-remote+ssh-002dremote-002bbeta4090.vscode-resource.vscode-cdn.net/home/zheqid/workspace/musimple2/musimple/~/miniconda3/envs/torchcfm/lib/python3.10/site-packages/torch/__init__.py:1642) raise error_type(message_evaluated)
AssertionError: Expected key_padded_mask.shape[0] to be 2, but got 1
Env :
ppgs = 0.0.9
torch = 2.7.0
python = 3.10.16
Metadata
Metadata
Assignees
Labels
No labels