Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ivector training in pytorch model #3969

Merged
merged 4 commits into from
Mar 3, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions egs/aishell/s10/chain/egs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,17 @@ def __call__(self, batch):
self.egs_left_context + self.egs_right_context

# TODO(fangjun): support ivector
assert len(eg.inputs) == 1
assert eg.inputs[0].name == 'input'

_feats = kaldi.FloatMatrix()
eg.inputs[0].features.GetMatrix(_feats)
feats = _feats.numpy()

if len(eg.inputs) > 1:
_ivectors = kaldi.FloatMatrix()
eg.inputs[1].features.GetMatrix(_ivectors)
ivectors = _ivectors.numpy()

assert feats.shape[0] == batch_size * frames_per_sequence

feat_list = []
Expand All @@ -173,6 +177,9 @@ def __call__(self, batch):
end_index -= 1 # remove the rightmost frame added for frame shift
feat = feats[start_index:end_index:, :]
feat = splice_feats(feat)
if len(eg.inputs) > 1:
repeat_ivector = torch.from_numpy(ivectors[i]).repeat(feat.shape[0], 1)
feat = torch.cat((torch.from_numpy(feat), repeat_ivector), dim=1).numpy()
feat_list.append(feat)

batched_feat = np.stack(feat_list, axis=0)
Expand All @@ -182,7 +189,10 @@ def __call__(self, batch):
# the first -2 is from extra left/right context
# the second -2 is from lda feats splicing
assert batched_feat.shape[1] == frames_per_sequence - 4
assert batched_feat.shape[2] == feats.shape[-1] * 3
if len(eg.inputs) > 1:
assert batched_feat.shape[2] == feats.shape[-1] * 3 + ivectors.shape[-1]
else:
assert batched_feat.shape[2] == feats.shape[-1] * 3

torch_feat = torch.from_numpy(batched_feat).float()
feature_list.append(torch_feat)
Expand Down
82 changes: 67 additions & 15 deletions egs/aishell/s10/chain/feat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Apache 2.0

import os

import math
import numpy as np
import torch

Expand All @@ -22,13 +22,18 @@
def get_feat_dataloader(feats_scp,
model_left_context,
model_right_context,
frames_per_chunk=51,
ivector_scp=None,
ivector_period=10,
batch_size=16,
num_workers=10):
dataset = FeatDataset(feats_scp=feats_scp)
dataset = FeatDataset(feats_scp=feats_scp, ivector_scp=ivector_scp)

collate_fn = FeatDatasetCollateFunc(model_left_context=model_left_context,
model_right_context=model_right_context,
frame_subsampling_factor=3)
frame_subsampling_factor=3,
frames_per_chunk=frames_per_chunk,
ivector_period=ivector_period)

dataloader = DataLoader(dataset,
batch_size=batch_size,
Expand All @@ -55,21 +60,36 @@ def _add_model_left_right_context(x, left_context, right_context):

class FeatDataset(Dataset):

def __init__(self, feats_scp):
def __init__(self, feats_scp, ivector_scp=None):
assert os.path.isfile(feats_scp)

self.feats_scp = feats_scp

# items is a list of [key, rxfilename]
items = list()
# items is a dict of {key: [key, rxfilename, ivec]}
items = dict()

with open(feats_scp, 'r') as f:
for line in f:
split = line.split()
assert len(split) == 2
items.append(split)

self.items = items
uttid, rxfilename =split
assert uttid not in items
items[uttid] = [uttid, rxfilename, None]
if ivector_scp:
self.ivector_scp = ivector_scp
expected_count = len(items)
n = 0
with open(ivector_scp, 'r') as f:
for line in f:
uttid_rxfilename = line.split()
assert len(uttid_rxfilename) == 2
uttid, rxfilename = uttid_rxfilename
assert uttid in items
items[uttid][-1] = rxfilename
n += 1
assert n == expected_count

self.items = list(items.values())

self.num_items = len(self.items)

Expand All @@ -81,6 +101,8 @@ def __getitem__(self, i):

def __str__(self):
s = 'feats scp: {}\n'.format(self.feats_scp)
if self.ivector_scp:
s += 'ivector_scp scp: {}\n'.format(self.ivector_scp)
s += 'num utt: {}\n'.format(self.num_items)
return s

Expand All @@ -90,26 +112,36 @@ class FeatDatasetCollateFunc:
def __init__(self,
model_left_context,
model_right_context,
frame_subsampling_factor=3):
frame_subsampling_factor=3,
frames_per_chunk=51,
ivector_period=10):
'''
We need `frame_subsampling_factor` because we want to know
the number of output frames of different waves in the same batch
'''
self.model_left_context = model_left_context
self.model_right_context = model_right_context
self.frame_subsampling_factor = frame_subsampling_factor
self.frames_per_chunk = frames_per_chunk
self.ivector_period = ivector_period

def __call__(self, batch):
'''
batch is a list of [key, rxfilename]
'''
key_list = []
feat_list = []
ivector_list = []
ivector_len_list = []
output_len_list = []
subsampled_frames_per_chunk = (self.frames_per_chunk //
self.frame_subsampling_factor)
for b in batch:
key, rxfilename = b
key, rxfilename, ivector_rxfilename = b
key_list.append(key)
feat = kaldi.read_mat(rxfilename).numpy()
if ivector_rxfilename:
ivector = kaldi.read_mat(ivector_rxfilename).numpy() # L // 10 * C
feat_len = feat.shape[0]
output_len = (feat_len + self.frame_subsampling_factor -
1) // self.frame_subsampling_factor
Expand All @@ -118,12 +150,32 @@ def __call__(self, batch):
feat = _add_model_left_right_context(feat, self.model_left_context,
self.model_right_context)
feat = splice_feats(feat)
feat_list.append(feat)
# no need to sort the feat by length

# the user should sort utterances by length offline
# to avoid unnecessary padding
# now we split feat to chunk, then we can do decode by chunk
input_num_frames = (feat.shape[0] + 2
- self.model_left_context - self.model_right_context)
for i in range(0, output_len, subsampled_frames_per_chunk):
# input len:418 -> output len:140 -> output chunk:[0, 17, 34, 51, 68, 85, 102, 119, 136]
first_output = i * self.frame_subsampling_factor
last_output = min(input_num_frames, \
first_output + (subsampled_frames_per_chunk-1) * self.frame_subsampling_factor)
first_input = first_output
last_input = last_output + self.model_left_context + self.model_right_context
input_x = feat[first_input:last_input+1, :]
if ivector_rxfilename:
ivector_index = (first_output + last_output) // 2 // self.ivector_period
input_ivector = ivector[ivector_index, :].reshape(1,-1)
feat_list.append(np.concatenate((input_x,
np.repeat(input_ivector, input_x.shape[0], axis=0)),
axis=-1))
else:
feat_list.append(input_x)

padded_feat = pad_sequence(
[torch.from_numpy(feat).float() for feat in feat_list],
batch_first=True)

assert sum([math.ceil(l / subsampled_frames_per_chunk) for l in output_len_list]) \
== padded_feat.shape[0]

return key_list, padded_feat, output_len_list
15 changes: 12 additions & 3 deletions egs/aishell/s10/chain/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
import sys
import math

import torch
from torch.utils.dlpack import to_dlpack
Expand Down Expand Up @@ -38,6 +39,7 @@ def main():
model = get_chain_model(
feat_dim=args.feat_dim,
output_dim=args.output_dim,
ivector_dim=args.ivector_dim,
lda_mat_filename=args.lda_mat_filename,
hidden_dim=args.hidden_dim,
bottleneck_dim=args.bottleneck_dim,
Expand All @@ -64,22 +66,29 @@ def main():

dataloader = get_feat_dataloader(
feats_scp=args.feats_scp,
ivector_scp=args.ivector_scp,
model_left_context=args.model_left_context,
model_right_context=args.model_right_context,
batch_size=32)

batch_size=32,
num_workers=10)
subsampling_factor = 3
subsampled_frames_per_chunk = args.frames_per_chunk // subsampling_factor
for batch_idx, batch in enumerate(dataloader):
key_list, padded_feat, output_len_list = batch
padded_feat = padded_feat.to(device)
with torch.no_grad():
nnet_output, _ = model(padded_feat)

num = len(key_list)
first = 0
for i in range(num):
key = key_list[i]
output_len = output_len_list[i]
value = nnet_output[i, :output_len, :]
target_len = math.ceil(output_len / subsampled_frames_per_chunk)
result = nnet_output[first:first + target_len, :, :].split(1, 0)
value = torch.cat(result, dim=1)[0, :output_len, :]
value = value.cpu()
first += target_len

m = kaldi.SubMatrixFromDLPack(to_dlpack(value))
m = Matrix(m)
Expand Down
10 changes: 7 additions & 3 deletions egs/aishell/s10/chain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

def get_chain_model(feat_dim,
output_dim,
ivector_dim,
hidden_dim,
bottleneck_dim,
prefinal_bottleneck_dim,
Expand All @@ -25,6 +26,7 @@ def get_chain_model(feat_dim,
lda_mat_filename=None):
model = ChainModel(feat_dim=feat_dim,
output_dim=output_dim,
ivector_dim=ivector_dim,
lda_mat_filename=lda_mat_filename,
hidden_dim=hidden_dim,
bottleneck_dim=bottleneck_dim,
Expand Down Expand Up @@ -82,6 +84,7 @@ class ChainModel(nn.Module):
def __init__(self,
feat_dim,
output_dim,
ivector_dim=0,
lda_mat_filename=None,
hidden_dim=1024,
bottleneck_dim=128,
Expand All @@ -97,8 +100,9 @@ def __init__(self,
assert len(kernel_size_list) == len(subsampling_factor_list)
num_layers = len(kernel_size_list)

input_dim = feat_dim * 3 + ivector_dim
# tdnn1_affine requires [N, T, C]
self.tdnn1_affine = nn.Linear(in_features=feat_dim * 3,
self.tdnn1_affine = nn.Linear(in_features=input_dim,
out_features=hidden_dim)

# tdnn1_batchnorm requires [N, C, T]
Expand Down Expand Up @@ -142,11 +146,11 @@ def __init__(self,
if lda_mat_filename:
logging.info('Use LDA from {}'.format(lda_mat_filename))
self.lda_A, self.lda_b = load_lda_mat(lda_mat_filename)
assert feat_dim * 3 == self.lda_A.shape[0]
assert input_dim == self.lda_A.shape[0]
self.has_LDA = True
else:
logging.info('replace LDA with BatchNorm')
self.input_batch_norm = nn.BatchNorm1d(num_features=feat_dim * 3,
self.input_batch_norm = nn.BatchNorm1d(num_features=input_dim,
affine=False)
self.has_LDA = False

Expand Down
26 changes: 25 additions & 1 deletion egs/aishell/s10/chain/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@ def _set_inference_args(parser):
dest='feats_scp',
help='feats.scp filename, required for inference',
type=str)

parser.add_argument('--frames-per-chunk',
dest='frames_per_chunk',
help='frames per chunk',
type=int,
default=51)

parser.add_argument('--ivector-scp',
dest='ivector_scp',
help='ivector.scp filename, required for ivector inference',
type=str)

parser.add_argument('--ivector-period',
dest='ivector_period',
help='ivector period',
type=int,
default=10)

parser.add_argument('--model-left-context',
dest='model_left_context',
Expand Down Expand Up @@ -228,10 +245,17 @@ def get_args():

parser.add_argument('--feat-dim',
dest='feat_dim',
help='nn input dimension',
help='nn input 0 dimension',
required=True,
type=int)

parser.add_argument('--ivector-dim',
dest='ivector_dim',
help='nn input 1 dimension',
required=False,
default=0,
type=int)

parser.add_argument('--output-dim',
dest='output_dim',
help='nn output dimension',
Expand Down
1 change: 1 addition & 0 deletions egs/aishell/s10/chain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def process_job(learning_rate, local_rank=None):
model = get_chain_model(
feat_dim=args.feat_dim,
output_dim=args.output_dim,
ivector_dim=args.ivector_dim,
lda_mat_filename=args.lda_mat_filename,
hidden_dim=args.hidden_dim,
bottleneck_dim=args.bottleneck_dim,
Expand Down
1 change: 1 addition & 0 deletions egs/aishell/s10/conf/online_cmvn.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# configuration file for apply-cmvn-online, used when invoking online2-wav-nnet3-latgen-faster.
Loading