Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions configs/vocos-matcha.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# pytorch_lightning==1.8.6
seed_everything: 4444

data:
class_path: vocos.dataset.VocosDataModule
init_args:
train_params:
filelist_path: ???
sampling_rate: 22050
num_samples: 16384
batch_size: 16
num_workers: 8

val_params:
filelist_path: ???
sampling_rate: 22050
num_samples: 48384
batch_size: 16
num_workers: 8

model:
class_path: vocos.experiment.VocosExp
init_args:
sample_rate: 22050
initial_learning_rate: 5e-4
mel_loss_coeff: 45
mrd_loss_coeff: 0.1
num_warmup_steps: 0 # Optimizers warmup steps
pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration

# automatic evaluation
evaluate_utmos: true
evaluate_pesq: true
evaluate_periodicty: true

feature_extractor:
class_path: vocos.feature_extractors.MelSpectrogramFeatures
init_args:
sample_rate: 22050
n_fft: 1024
hop_length: 256
n_mels: 80
padding: same
f_min: 0
f_max: 8000
norm: "slaney"
mel_scale: "slaney"


backbone:
class_path: vocos.models.VocosBackbone
init_args:
input_channels: 80
dim: 512
intermediate_dim: 1536
num_layers: 8

head:
class_path: vocos.heads.ISTFTHead
init_args:
dim: 512
n_fft: 1024
hop_length: 256
padding: same


trainer:
logger:
class_path: pytorch_lightning.loggers.TensorBoardLogger
init_args:
save_dir: /mnt/netapp1/Proxecto_NOS/bsc/speech/TTS/outputs/logs/vocos
callbacks:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
- class_path: pytorch_lightning.callbacks.ModelSummary
init_args:
max_depth: 2
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f}
save_top_k: 3
save_last: true
- class_path: vocos.helpers.GradNormCallback

# Lightning calculates max_steps across all optimizer steps (rather than number of batches)
# This equals to 1M steps per generator and 1M per discriminator
max_steps: 2000000
# You might want to limit val batches when evaluating all the metrics, as they are time-consuming
limit_val_batches: 100
accelerator: gpu
strategy: ddp
devices: [0]
log_every_n_steps: 100
15 changes: 14 additions & 1 deletion vocos/feature_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,16 @@ def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:


class MelSpectrogramFeatures(FeatureExtractor):
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"):
def __init__(self,
sample_rate=24000,
n_fft=1024,
hop_length=256,
n_mels=100,
padding="center",
f_min=0, # to match matcha :X
f_max=8000,
norm="slaney",
mel_scale="slaney"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
Expand All @@ -38,6 +47,10 @@ def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, pa
n_mels=n_mels,
center=padding == "center",
power=1,
f_min=f_min, # to match matcha :X
f_max=f_max,
norm=norm,
mel_scale=mel_scale
)

def forward(self, audio, **kwargs):
Expand Down
22 changes: 19 additions & 3 deletions vocos/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,28 @@ class MelSpecReconstructionLoss(nn.Module):
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
"""

def __init__(
self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100,
def __init__(self,
sample_rate: int = 22050,
n_fft: int = 1024,
hop_length: int = 256,
n_mels: int = 80,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is n_mels in loss.py here meant to have the default changed to 80? In feature_extractors.py it remains at 100, presumably the default in loss.py was also meant to stay at 100 and only be adjusted by the vocos-matcha.yaml?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, we should keep n_mels to 100 in loss.py. Also, in feature_extractors.py the defaults should be

f_max=None
norm=None,
mel_scale="htk"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you happen to have any reference on the decision between 80 and 100 n_mels?

I understand 80 has been quite common so many models are trained with that as a result, but for the actual decision originally I am curious?

  • Is 80 intended to be sufficient for speech specifically?
  • I came across a paper recently that cited 96 as a minimum for covering not only speech, but also music and general sound effects.

With 80 and 96, these are multiples of 8 which I'm familiar with being preferential compute (at least traditionally, just like games used for textures - although that'd tend to be more like powers of 2, thus 64 vs 128). Perhaps Vocos just rounded that up to 100 🤔 I'm not sure if that'd actually regress somewhere vs 96 😅

f_min: int = 0,
f_max: int = 8000,
norm: str = "slaney",
mel_scale: str = "slaney",
):
super().__init__()
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1,
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
center=True,
power=1,
f_min=f_min,
f_max=f_max,
norm=norm,
mel_scale=mel_scale
)

def forward(self, y_hat, y) -> torch.Tensor:
Expand Down