-
Notifications
You must be signed in to change notification settings - Fork 1
/
datasets.py
124 lines (92 loc) · 3.9 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import os
from torch.utils.data import Dataset
from pathlib import Path
import torchaudio
class VivosDataset(Dataset):
def __init__(self, root: str = "", subset: str = "train", n_fft: int = 200):
super().__init__()
self.root = root
self.subset = subset
assert self.subset in ["train", "test"], "subset not found"
path = os.path.join(self.root, self.subset)
waves_path = os.path.join(path, "waves")
transcript_path = os.path.join(path, "prompts.txt")
# walker oof
self.walker = list(Path(waves_path).glob("*/*"))
with open(transcript_path, "r", encoding="utf-8") as f:
transcripts = f.read().strip().split("\n")
transcripts = [line.split(" ", 1) for line in transcripts]
filenames = [i[0] for i in transcripts]
trans = [i[1] for i in transcripts]
self.transcripts = dict(zip(filenames, trans))
self.feature_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft)
def __len__(self):
return len(self.walker)
def __getitem__(self, idx):
filepath = str(self.walker[idx])
filename = filepath.rsplit(os.sep, 1)[-1].split(".")[0]
wave, sr = torchaudio.load(filepath)
specs = self.feature_transform(wave) # channel, feature, time
specs = specs.permute(0, 2, 1) # channel, time, feature
specs = specs.squeeze() # time, feature
trans = self.transcripts[filename].lower()
return specs, trans
class ComposeDataset(Dataset):
"""
this dataset aim to load:
- vivos
- vin big data
- vietnamese podcasts
"""
def __init__(
self,
vivos_root: str = "",
vivos_subset: str = "train",
vlsp_root: str = "",
podcasts_root: str = "",
n_fft: int = 400,
):
super().__init__()
self.feature_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft)
self.walker = self.init_vivos(vivos_root, vivos_subset)
if vivos_subset == "train":
self.walker += self.init_vlsp(vlsp_root)
def init_vivos(self, root, subset):
assert subset in ["train", "test"], "subset not found"
path = os.path.join(root, subset)
waves_path = os.path.join(path, "waves")
transcript_path = os.path.join(path, "prompts.txt")
# walker oof
walker = list(Path(waves_path).glob("*/*"))
with open(transcript_path, "r", encoding="utf-8") as f:
transcripts = f.read().strip().split("\n")
transcripts = [line.split(" ", 1) for line in transcripts]
filenames = [i[0] for i in transcripts]
trans = [i[1] for i in transcripts]
transcripts = dict(zip(filenames, trans))
def load_el_from_path(filepath):
filename = filepath.name.split(".")[0]
trans = transcripts[filename].lower()
return (filepath, trans)
walker = [load_el_from_path(filepath) for filepath in walker]
return walker
def init_vlsp(self, root):
walker = list(Path(root).glob("*.wav"))
def load_el_from_path(filepath):
filename = filepath.name.split(".")[0] + ".txt"
with open(Path(root) / filename, "r", encoding="utf-8") as f:
trans = f.read().strip().lower()
trans = trans.replace("<unk>", "").strip()
return filepath, trans
walker = [load_el_from_path(filepath) for filepath in walker]
return walker
def __len__(self):
return len(self.walker)
def __getitem__(self, idx):
filepath, trans = self.walker[idx]
wave, sr = torchaudio.load(filepath)
specs = self.feature_transform(wave) # channel, feature, time
specs = specs.permute(0, 2, 1) # channel, time, feature
specs = specs.squeeze() # time, feature
return specs, trans