Skip to content

Commit 06abaff

Browse files
alex2awesomeAlex Spangherloubbrad
authored
synthetic data and baselines (#19)
* added codecs compression to augmentation * updated * added soundfonts to gitignore * updated for synthetic data creation --------- Co-authored-by: Alex Spangher <[email protected]> Co-authored-by: Louis <[email protected]>
1 parent d6fea7f commit 06abaff

8 files changed

+981
-1
lines changed

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
*.xml
1616
*.html
1717
*.htm
18+
*.sf2
1819

1920
.idea/
20-
21+
notebooks/scratch
22+
baselines/hft_transformer/model_files/
2123

2224
# Byte-compiled / optimized / DLL files
2325
__pycache__/

amt/audio.py

+17
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def __init__(
197197
bandpass_ratio: float = 0.1,
198198
distort_ratio: float = 0.15,
199199
reduce_ratio: float = 0.01,
200+
codecs_ratio: float = 0.01,
200201
spec_aug_ratio: float = 0.5,
201202
):
202203
super().__init__()
@@ -219,6 +220,7 @@ def __init__(
219220
self.distort_ratio = distort_ratio
220221
self.reduce_ratio = reduce_ratio
221222
self.spec_aug_ratio = spec_aug_ratio
223+
self.codecs_ratio = codecs_ratio
222224
self.reduction_resample_rate = 6000 # Hardcoded?
223225

224226
# Audio aug
@@ -397,6 +399,20 @@ def distortion_aug_cpu(self, wav: torch.Tensor):
397399

398400
return wav
399401

402+
def apply_codec(self, wav: torch.tensor):
403+
"""
404+
Apply different audio codecs to the audio.
405+
"""
406+
format_encoder_pairs = [
407+
("wav", "pcm_mulaw"),
408+
("g722", None),
409+
("ogg", "vorbis")
410+
]
411+
for format, encoder in format_encoder_pairs:
412+
encoder = torchaudio.io.AudioEffector(format=format, encoder=encoder)
413+
if random.random() < self.codecs_ratio:
414+
wav = encoder.apply(wav, self.sample_rate)
415+
400416
def shift_spec(self, specs: torch.Tensor, shift: int):
401417
if shift == 0:
402418
return specs
@@ -429,6 +445,7 @@ def aug_wav(self, wav: torch.Tensor):
429445
# Noise
430446
if random.random() < self.noise_ratio:
431447
wav = self.apply_noise(wav)
448+
432449
if random.random() < self.applause_ratio:
433450
wav = self.apply_applause(wav)
434451

amt/data.py

+75
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,81 @@
1111
from amt.tokenizer import AmtTokenizer
1212
from amt.config import load_config
1313
from amt.audio import pad_or_trim
14+
from midi2audio import FluidSynth
15+
import random
16+
17+
18+
class SyntheticMidiHandler:
19+
def __init__(self, soundfont_path: str, soundfont_prob_dict: dict = None, num_wavs_per_midi: int = 1):
20+
"""
21+
File to load MIDI files and convert them to audio.
22+
23+
Parameters
24+
----------
25+
soundfont_path : str
26+
Path to the directory containing soundfont files.
27+
soundfont_prob_dict : dict, optional
28+
Dictionary containing the probability of using a soundfont file.
29+
The keys are the soundfont file names and the values are the
30+
probability of using the soundfont file. If none is given, then
31+
a uniform distribution is used.
32+
num_wavs_per_midi : int, optional
33+
Number of audio files to generate per MIDI file.
34+
"""
35+
36+
self.soundfont_path = soundfont_path
37+
self.soundfont_prob_dict = soundfont_prob_dict
38+
self.num_wavs_per_midi = num_wavs_per_midi
39+
40+
self.fs_objs = self._load_soundfonts()
41+
self.soundfont_cumul_prob_dict = self._get_cumulative_prob_dict()
42+
43+
def _load_soundfonts(self):
44+
"""Loads the soundfonts into fluidsynth objects."""
45+
fs_files = os.listdir(self.soundfont_path)
46+
fs_objs = {}
47+
for fs_file in fs_files:
48+
fs_objs[fs_file] = FluidSynth(fs_file)
49+
return fs_objs
50+
51+
def _get_cumulative_prob_dict(self):
52+
"""Returns a dictionary with the cumulative probabilities of the soundfonts.
53+
Used for sampling the soundfonts.
54+
"""
55+
if self.soundfont_prob_dict is None:
56+
self.soundfont_prob_dict = {k: 1 / len(self.fs_objs) for k in self.fs_objs.keys()}
57+
self.soundfont_prob_dict = {k: v / sum(self.soundfont_prob_dict.values())
58+
for k, v in self.soundfont_prob_dict.items()}
59+
cumul_prob_dict = {}
60+
cumul_prob = 0
61+
for k, v in self.soundfont_prob_dict.items():
62+
cumul_prob_dict[k] = (cumul_prob, cumul_prob + v)
63+
cumul_prob += v
64+
return cumul_prob_dict
65+
66+
def _sample_soundfont(self):
67+
"""Samples a soundfont file."""
68+
rand_num = random.random()
69+
for k, (v_s, v_e) in self.soundfont_cumul_prob_dict.items():
70+
if (rand_num >= v_s) and (rand_num < v_e):
71+
return self.fs_objs[k]
72+
73+
def get_wav(self, midi_path: str, save_path: str):
74+
"""
75+
Converts a MIDI file to audio.
76+
77+
Parameters
78+
----------
79+
midi_path : str
80+
Path to the MIDI file.
81+
save_path : str
82+
Path to save the audio file.
83+
"""
84+
for i in range(self.num_wavs_per_midi):
85+
soundfont = self._sample_soundfont()
86+
if self.num_wavs_per_midi > 1:
87+
save_path = save_path[:-4] + f"_{i}.wav"
88+
soundfont.midi_to_audio(midi_path, save_path)
1489

1590

1691
def get_wav_mid_segments(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
import argparse
3+
import time
4+
import torch
5+
import piano_transcription_inference
6+
import glob
7+
8+
9+
def transcribe_piano(mp3s_dir, midis_dir, begin_index=None, end_index=None):
10+
"""Transcribe piano solo mp3s to midi files."""
11+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
12+
os.makedirs(midis_dir, exist_ok=True)
13+
14+
# Transcriptor
15+
transcriptor = piano_transcription_inference.PianoTranscription(device=device)
16+
17+
transcribe_time = time.time()
18+
for n, mp3_path in enumerate(glob.glob(os.path.join(mp3s_dir, '*.mp3'))[begin_index:end_index]):
19+
print(n, mp3_path)
20+
midi_file = os.path.basename(mp3_path).replace('.mp3', '.midi')
21+
midi_path = os.path.join(midis_dir, midi_file)
22+
if os.path.exists(midi_path):
23+
continue
24+
25+
(audio, _) = (
26+
piano_transcription_inference
27+
.load_audio(mp3_path, sr=piano_transcription_inference.sample_rate, mono=True)
28+
)
29+
30+
try:
31+
# Transcribe
32+
transcribed_dict = transcriptor.transcribe(audio, midi_path)
33+
print(transcribed_dict)
34+
except:
35+
print('Failed for this audio!')
36+
37+
print('Time: {:.3f} s'.format(time.time() - transcribe_time))
38+
39+
40+
if __name__ == '__main__':
41+
parser = argparse.ArgumentParser(description='Example of parser. ')
42+
parser.add_argument('--mp3s_dir', type=str, required=True, help='')
43+
parser.add_argument('--midis_dir', type=str, required=True, help='')
44+
parser.add_argument(
45+
'--begin_index', type=int, required=False,
46+
help='File num., of an ordered list of files, to start transcribing from.', default=None
47+
)
48+
parser.add_argument(
49+
'--end_index', type=int, required=False, default=None,
50+
help='File num., of an ordered list of files, to end transcription.'
51+
)
52+
53+
# Parse arguments
54+
args = parser.parse_args()
55+
transcribe_piano(
56+
mp3s_dir=args.mp3s_dir,
57+
midis_dir=args.midis_dir,
58+
begin_index=args.begin_index,
59+
end_index=args.end_index
60+
)
61+
62+
"""
63+
python transcribe_new_files.py \
64+
transcribe_piano \
65+
--mp3s_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data \
66+
--midis_dir /mnt/data10/spangher/aira-dl/hFT-Transformer/evaluation/glenn-gould-bach-data/kong-model
67+
"""

0 commit comments

Comments
 (0)