forked from seujung/WaveNet-gluon
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
92 lines (80 loc) · 3.43 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os, sys
import numpy as np
import mxnet as mx
from mxnet import gluon, autograd, nd
from mxnet.gluon import nn,utils
import mxnet.ndarray as F
from tqdm import trange
from models import *
from utils import *
from data_loader import load_wav, data_generation, data_generation_sample
# set gpu count
def setting_ctx(use_gpu):
if (use_gpu):
ctx = mx.gpu()
else :
ctx = mx.cpu()
return ctx
class Train(object):
def __init__(self, config):
##setting hyper-parameters
self.batch_size = config.batch_size
self.epochs = config.epochs
self.mu = config.mu
self.n_residue = config.n_residue
self.n_skip = config.n_skip
self.dilation_depth = config.dilation_depth
self.n_repeat = config.n_repeat
self.seq_size = config.seq_size
self.use_gpu = config.use_gpu
self.input = config.input
self.ctx = setting_ctx(self.use_gpu)
self.build_model()
def build_model(self):
self.net = WaveNet(mu=self.mu, n_residue=self.n_residue, n_skip=self.n_skip, dilation_depth=self.dilation_depth, n_repeat=self.n_repeat)
#parameter initialization
self.net.collect_params().initialize(ctx=self.ctx)
#set optimizer
self.trainer = gluon.Trainer(self.net.collect_params(),optimizer='adam',optimizer_params={'learning_rate':0.01 })
self.loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
def save_model(self,epoch,current_loss):
filename = 'models/best_perf_epoch_'+str(epoch)+"_loss_"+str(current_loss)
self.net.save_params(filename)
def train(self):
fs, data = load_wav(self.input)
g = data_generation(data,fs,mu=self.mu, seq_size=self.seq_size,ctx=self.ctx)
loss_save = []
best_loss = sys.maxsize
for epoch in trange(self.epochs):
loss = 0.0
for _ in range(self.batch_size):
batch = next(g)
x = batch[:-1]
with autograd.record():
logits = self.net(x)
sz = logits.shape[0]
loss = loss + self.loss_fn(logits, batch[-sz:])
loss.backward()
self.trainer.step(1,ignore_stale_grad=True)
loss_save.append(nd.sum(loss).asscalar()/self.batch_size)
#save the best model
current_loss = nd.sum(loss).asscalar()/self.batch_size
if best_loss > current_loss:
print('epoch {}, loss {}'.format(epoch, nd.sum(loss).asscalar()/self.batch_size))
self.save_model(epoch,current_loss)
best_loss = current_loss
def generate_slow(self, x, models, dilation_depth, n_repeat, ctx, n=100):
dilations = [2**i for i in range(dilation_depth)] * n_repeat
res = list(x.asnumpy())
for _ in trange(n):
x = nd.array(res[-sum(dilations)-1:],ctx=ctx)
y = models(x)
res.append(y.argmax(1).asnumpy()[-1])
return res
def generation(self):
fs, data = load_wav('parametric-2.wav')
initial_data = data_generation_sample(data,fs,mu=self.mu, seq_size=3000,ctx=self.ctx)
gen_rst = self.generate_slow(initial_data[0:3000],self.net,dilation_depth=10,n_repeat=2,n=2000,ctx=self.ctx)
gen_wav = np.array(gen_rst)
gen_wav = decode_mu_law(gen_wav, 128)
np.save("wav.npy",gen_wav)