forked from noetits/ICE-Talk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare_attention_guides.py
62 lines (43 loc) · 1.8 KB
/
prepare_attention_guides.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# -*- coding: utf-8 -*-
#!/usr/bin/env python2
from __future__ import print_function
from utils import get_attention_guide
import os
from data_load import load_data
import numpy as np
import tqdm
from concurrent.futures import ProcessPoolExecutor
from argparse import ArgumentParser
from libutil import basename, save_floats_as_8bit, safe_makedir
from configuration import load_config
def proc(fpath, text_length, hp):
base = basename(fpath)
melfile = hp.coarse_audio_dir + os.path.sep + base + '.npy'
attfile = hp.attention_guide_dir + os.path.sep + base # without '.npy'
if not os.path.isfile(melfile):
print('file %s not found'%(melfile))
return
speech_length = np.load(melfile).shape[0]
att = get_attention_guide(text_length, speech_length, g=hp.g)
save_floats_as_8bit(att, attfile)
def main_work():
#################################################
# ============= Process command line ============
a = ArgumentParser()
a.add_argument('-c', dest='config', required=True, type=str)
a.add_argument('-ncores', default=1, type=int, help='Number of cores for parallel processing')
opts = a.parse_args()
# ===============================================
hp = load_config(opts.config)
assert hp.attention_guide_dir
dataset = load_data(hp)
fpaths, text_lengths = dataset['fpaths'], dataset['text_lengths']
assert os.path.exists(hp.coarse_audio_dir)
safe_makedir(hp.attention_guide_dir)
executor = ProcessPoolExecutor(max_workers=opts.ncores)
futures = []
for (fpath, text_length) in zip(fpaths, text_lengths):
futures.append(executor.submit(proc, fpath, text_length, hp))
proc_list = [future.result() for future in tqdm.tqdm(futures)]
if __name__=="__main__":
main_work()