This repository has been archived by the owner on Aug 23, 2023. It is now read-only.
forked from Kyubyong/dc_tts
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsynthesize.py
87 lines (73 loc) · 3.17 KB
/
synthesize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
'''
By kyubyong park. [email protected].
https://www.github.com/kyubyong/dc_tts
'''
from __future__ import print_function
import argparse
import os
from hyperparams import Hyperparams as hp
import numpy as np
import tensorflow as tf
from train import Graph
from utils import *
from data_load import load_new_data
from scipy.io.wavfile import write
from tqdm import tqdm
def synthesize(filename, outdir):
# Load data
L = load_new_data(filename)
# Load graph
g = Graph(mode="synthesize")
print("Graph loaded")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Restore parameters
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'Text2Mel')
saver1 = tf.train.Saver(var_list=var_list)
model1 = tf.train.latest_checkpoint(hp.logdir + "-1")
saver1.restore(sess, model1)
print("LOADED: Text2Mel Restored from {}".format(model1))
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'SSRN') + \
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'gs')
saver2 = tf.train.Saver(var_list=var_list)
model2 = tf.train.latest_checkpoint(hp.logdir + "-2")
saver2.restore(sess, model2)
print("LOADED: SSRN Restored from {}".format(model2))
# Feed Forward
## mel
Y = np.zeros((len(L), hp.max_T, hp.n_mels), np.float32)
prev_max_attentions = np.zeros((len(L),), np.int32)
for j in tqdm(range(hp.max_T)):
_gs, _Y, _max_attentions, _alignments = \
sess.run([g.global_step, g.Y, g.max_attentions, g.alignments],
{g.L: L,
g.mels: Y,
g.prev_max_attentions: prev_max_attentions})
Y[:, j, :] = _Y[:, j, :]
prev_max_attentions = _max_attentions[:, j]
# Get magnitude
Z = sess.run(g.Z, {g.Y: Y})
# Generate wav files
if not os.path.exists(outdir):
os.makedirs(outdir)
for i, mag in enumerate(Z):
print("Working on file", i+1)
wav = spectrogram2wav(mag)
write(outdir + "/{:02d}.wav".format(i+1), hp.sr, wav)
render_spectrogram(outdir + "/{:02d}.wav".format(i+1), "{:02d}.wav".format(i+1), outdir + "/{:02d}.png".format(i+1))
if __name__ == '__main__':
# argument: 1 or 2. 1 for Text2mel, 2 for SSRN.
parser = argparse.ArgumentParser(description='')
parser.add_argument('-g', '--gpu', dest='gpu', type=int, default=-1, help='specify GPU; default none (-1)')
parser.add_argument('-f', '--file', dest='sentences', type=str, default=hp.test_data, help='test_data, def from hp')
parser.add_argument('-o', '--outdir', dest='outdir', type=str, default=hp.sampledir, help='sampledir, def from hp')
args = parser.parse_args()
# restrict GPU usage here, if using multi-gpu
if args.gpu >= 0:
print("restricting GPU usage to gpu/", args.gpu, "\n")
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
else:
print("restricting to CPU\n")
os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
synthesize(args.sentences, args.outdir)
print("Done")