-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathresample_folder.py
More file actions
64 lines (46 loc) · 2.19 KB
/
resample_folder.py
File metadata and controls
64 lines (46 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import os
import time
from pathlib import Path
import typing as tp
import torch
import torchaudio
from tqdm import tqdm
def get_filelist(folder: tp.Union[str, os.PathLike], extensions: tp.Optional[tp.List[str]] = None) -> tp.List[str]:
if extensions is None:
extensions = ['.mp3', '.wav', '.flac', '.aac', '.ogg', '.m4a']
extensions = set(ext.lower() for ext in extensions)
path = Path(folder)
if not path.is_dir():
raise ValueError(f"The provided directory '{folder}' is not a valid directory.")
return [str(file) for file in path.rglob('*') if file.suffix.lower() in extensions]
if __name__ == "__main__":
INPUT_FOLDER = "exp_recon/test-clean"
TARGET_SR = 16000
# TARGET_SR = 24000
OUTPUT_FOLDER = f"{INPUT_FOLDER}_{TARGET_SR}"
parser = argparse.ArgumentParser(description="Resample audio files in a folder to a specified sample rate.")
parser.add_argument('--device', type=str, default='cpu', help='Device to run resampling on')
parser.add_argument('--input_folder', type=str, default=INPUT_FOLDER)
parser.add_argument('--output_folder', type=str, default=OUTPUT_FOLDER)
parser.add_argument('--target_sr', type=int, default=TARGET_SR, help='Target sample rate for resampling')
args = parser.parse_args()
device = args.device
target_sr = args.target_sr
os.makedirs(args.output_folder, exist_ok=True)
filelist = get_filelist(args.input_folder)
print(f"*** Resampling all audio files to {target_sr} Hz ***")
start_time = time.time()
with torch.no_grad():
for file_path in tqdm(filelist):
wav, sr = torchaudio.load(file_path)
wav = wav.to(device)
if sr != target_sr:
resampler = torchaudio.transforms.Resample(sr, target_sr).to(device)
wav = resampler(wav)
rel_path = os.path.relpath(file_path, start=args.input_folder)
output_path = os.path.join(args.output_folder, rel_path)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
torchaudio.save(output_path, wav.cpu(), target_sr)
end_time = time.time()
print(f"Total time taken: {end_time - start_time:.2f} seconds.")