-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathAudioReader.py
110 lines (92 loc) · 3.16 KB
/
AudioReader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torchaudio
import torch
def handle_scp(scp_path):
'''
Read scp file script
input:
scp_path: .scp file's file path
output:
scp_dict: {'key':'wave file path'}
'''
scp_dict = dict()
line = 0
lines = open(scp_path, 'r').readlines()
for l in lines:
scp_parts = l.strip().split()
line += 1
if len(scp_parts) != 2:
raise RuntimeError("For {}, format error in line[{:d}]: {}".format(
scp_path, line, scp_parts))
if len(scp_parts) == 2:
key, value = scp_parts
if key in scp_dict:
raise ValueError("Duplicated key \'{0}\' exists in {1}".format(
key, scp_path))
scp_dict[key] = value
return scp_dict
def read_wav(fname, return_rate=False):
'''
Read wavfile using Pytorch audio
input:
fname: wav file path
return_rate: Whether to return the sampling rate
output:
src: output tensor of size C x L
L is the number of audio frames
C is the number of channels.
sr: sample rate
'''
src, sr = torchaudio.load(fname, channels_first=True)
if return_rate:
return src.squeeze(), sr
else:
return src.squeeze()
def write_wav(fname, src, sample_rate):
'''
Write wav file
input:
fname: wav file path
src: frames of audio
sample_rate: An integer which is the sample rate of the audio
output:
None
'''
torchaudio.save(fname, src, sample_rate)
class AudioReader(object):
'''
Class that reads Wav format files
Input as a different scp file address
Output a matrix of wav files in all scp files.
'''
def __init__(self, scp_path, sample_rate=8000):
super(AudioReader, self).__init__()
self.sample_rate = sample_rate
self.index_dict = handle_scp(scp_path)
self.keys = list(self.index_dict.keys())
def _load(self, key):
src, sr = read_wav(self.index_dict[key], return_rate=True)
if self.sample_rate is not None and sr != self.sample_rate:
raise RuntimeError('SampleRate mismatch: {:d} vs {:d}'.format(
sr, self.sample_rate))
return src
def __len__(self):
return len(self.keys)
def __iter__(self):
for key in self.keys:
yield key, self._load(key)
def __getitem__(self, index):
if type(index) not in [int, str]:
raise IndexError('Unsupported index type: {}'.format(type(index)))
if type(index) == int:
num_uttrs = len(self.keys)
if num_uttrs < index and index < 0:
raise KeyError('Interger index out of range, {:d} vs {:d}'.format(
index, num_uttrs))
index = self.keys[index]
if index not in self.index_dict:
raise KeyError("Missing utterance {}!".format(index))
return self._load(index)
if __name__ == "__main__":
r = AudioReader('/home/likai/data1/create_scp/cv_s2.scp')
index = 0
print(r[1])