Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ivector training in pytorch model #3969

Merged
merged 4 commits into from
Mar 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions egs/aishell/s10/chain/egs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,16 @@ def get_egs_dataloader(egs_dir_or_scp,
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=local_rank, shuffle=True)
dataloader = DataLoader(dataset,
batch_size=batch_size,
collate_fn=collate_fn,
sampler=sampler)
batch_size=batch_size,
collate_fn=collate_fn,
sampler=sampler)
else:
base_sampler = torch.utils.data.RandomSampler(dataset)
sampler = torch.utils.data.BatchSampler(base_sampler, batch_size, False)
dataloader = DataLoader(dataset,
batch_sampler=sampler,
collate_fn=collate_fn)
base_sampler = torch.utils.data.RandomSampler(dataset)
sampler = torch.utils.data.BatchSampler(
base_sampler, batch_size, False)
dataloader = DataLoader(dataset,
batch_sampler=sampler,
collate_fn=collate_fn)
return dataloader


Expand Down Expand Up @@ -146,18 +147,21 @@ def __call__(self, batch):

batch_size = supervision.num_sequences

frames_per_sequence = (supervision.frames_per_sequence * \
self.frame_subsampling_factor) + \
self.egs_left_context + self.egs_right_context
frames_per_sequence = (supervision.frames_per_sequence *
self.frame_subsampling_factor) + \
self.egs_left_context + self.egs_right_context

# TODO(fangjun): support ivector
assert len(eg.inputs) == 1
assert eg.inputs[0].name == 'input'

_feats = kaldi.FloatMatrix()
eg.inputs[0].features.GetMatrix(_feats)
feats = _feats.numpy()

if len(eg.inputs) > 1:
_ivectors = kaldi.FloatMatrix()
eg.inputs[1].features.GetMatrix(_ivectors)
ivectors = _ivectors.numpy()

assert feats.shape[0] == batch_size * frames_per_sequence

feat_list = []
Expand All @@ -173,6 +177,11 @@ def __call__(self, batch):
end_index -= 1 # remove the rightmost frame added for frame shift
feat = feats[start_index:end_index:, :]
feat = splice_feats(feat)
if len(eg.inputs) > 1:
repeat_ivector = torch.from_numpy(
ivectors[i]).repeat(feat.shape[0], 1)
feat = torch.cat(
(torch.from_numpy(feat), repeat_ivector), dim=1).numpy()
feat_list.append(feat)

batched_feat = np.stack(feat_list, axis=0)
Expand All @@ -182,7 +191,11 @@ def __call__(self, batch):
# the first -2 is from extra left/right context
# the second -2 is from lda feats splicing
assert batched_feat.shape[1] == frames_per_sequence - 4
assert batched_feat.shape[2] == feats.shape[-1] * 3
if len(eg.inputs) > 1:
assert batched_feat.shape[2] == feats.shape[-1] * \
3 + ivectors.shape[-1]
else:
assert batched_feat.shape[2] == feats.shape[-1] * 3

torch_feat = torch.from_numpy(batched_feat).float()
feature_list.append(torch_feat)
Expand Down Expand Up @@ -222,8 +235,8 @@ def _test_nnet_chain_example_dataset():
for b in dataloader:
key_list, feature_list, supervision_list = b
assert feature_list[0].shape == (128, 204, 129) \
or feature_list[0].shape == (128, 144, 129) \
or feature_list[0].shape == (128, 165, 129)
or feature_list[0].shape == (128, 144, 129) \
or feature_list[0].shape == (128, 165, 129)
assert supervision_list[0].weight == 1
supervision_list[0].num_sequences == 128 # minibach size is 128

Expand Down
88 changes: 73 additions & 15 deletions egs/aishell/s10/chain/feat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Apache 2.0

import os

import math
import numpy as np
import torch

Expand All @@ -22,13 +22,18 @@
def get_feat_dataloader(feats_scp,
model_left_context,
model_right_context,
frames_per_chunk=51,
ivector_scp=None,
ivector_period=10,
batch_size=16,
num_workers=10):
dataset = FeatDataset(feats_scp=feats_scp)
dataset = FeatDataset(feats_scp=feats_scp, ivector_scp=ivector_scp)

collate_fn = FeatDatasetCollateFunc(model_left_context=model_left_context,
model_right_context=model_right_context,
frame_subsampling_factor=3)
frame_subsampling_factor=3,
frames_per_chunk=frames_per_chunk,
ivector_period=ivector_period)

dataloader = DataLoader(dataset,
batch_size=batch_size,
Expand All @@ -55,21 +60,40 @@ def _add_model_left_right_context(x, left_context, right_context):

class FeatDataset(Dataset):

def __init__(self, feats_scp):
def __init__(self, feats_scp, ivector_scp=None):
assert os.path.isfile(feats_scp)
if ivector_scp:
assert os.path.isfile(ivector_scp)

self.feats_scp = feats_scp

# items is a list of [key, rxfilename]
items = list()
# items is a dict of [uttid, feat_rxfilename, None]
# or [uttid, feat_rxfilename, ivector_rxfilename] if ivector_scp is not None
items = dict()

with open(feats_scp, 'r') as f:
for line in f:
split = line.split()
assert len(split) == 2
items.append(split)

self.items = items
uttid, rxfilename = split
assert uttid not in items
items[uttid] = [uttid, rxfilename, None]
self.ivector_scp = None
if ivector_scp:
self.ivector_scp = ivector_scp
expected_count = len(items)
n = 0
with open(ivector_scp, 'r') as f:
for line in f:
uttid_rxfilename = line.split()
assert len(uttid_rxfilename) == 2
uttid, rxfilename = uttid_rxfilename
assert uttid in items
items[uttid][-1] = rxfilename
n += 1
assert n == expected_count

self.items = list(items.values())

self.num_items = len(self.items)

Expand All @@ -81,6 +105,8 @@ def __getitem__(self, i):

def __str__(self):
s = 'feats scp: {}\n'.format(self.feats_scp)
if self.ivector_scp:
s += 'ivector_scp scp: {}\n'.format(self.ivector_scp)
s += 'num utt: {}\n'.format(self.num_items)
return s

Expand All @@ -90,26 +116,37 @@ class FeatDatasetCollateFunc:
def __init__(self,
model_left_context,
model_right_context,
frame_subsampling_factor=3):
frame_subsampling_factor=3,
frames_per_chunk=51,
ivector_period=10):
'''
We need `frame_subsampling_factor` because we want to know
the number of output frames of different waves in the same batch
'''
self.model_left_context = model_left_context
self.model_right_context = model_right_context
self.frame_subsampling_factor = frame_subsampling_factor
self.frames_per_chunk = frames_per_chunk
self.ivector_period = ivector_period

def __call__(self, batch):
'''
batch is a list of [key, rxfilename]
'''
key_list = []
feat_list = []
ivector_list = []
ivector_len_list = []
output_len_list = []
subsampled_frames_per_chunk = (self.frames_per_chunk //
self.frame_subsampling_factor)
for b in batch:
key, rxfilename = b
key, rxfilename, ivector_rxfilename = b
key_list.append(key)
feat = kaldi.read_mat(rxfilename).numpy()
if ivector_rxfilename:
ivector = kaldi.read_mat(
ivector_rxfilename).numpy() # L // 10 * C
feat_len = feat.shape[0]
output_len = (feat_len + self.frame_subsampling_factor -
1) // self.frame_subsampling_factor
Expand All @@ -118,12 +155,33 @@ def __call__(self, batch):
feat = _add_model_left_right_context(feat, self.model_left_context,
self.model_right_context)
feat = splice_feats(feat)
feat_list.append(feat)
# no need to sort the feat by length

# the user should sort utterances by length offline
# to avoid unnecessary padding
# now we split feat to chunk, then we can do decode by chunk
input_num_frames = (feat.shape[0] + 2
- self.model_left_context - self.model_right_context)
for i in range(0, output_len, subsampled_frames_per_chunk):
# input len:418 -> output len:140 -> output chunk:[0, 17, 34, 51, 68, 85, 102, 119, 136]
first_output = i * self.frame_subsampling_factor
last_output = min(input_num_frames,
first_output + (subsampled_frames_per_chunk-1) * self.frame_subsampling_factor)
first_input = first_output
last_input = last_output + self.model_left_context + self.model_right_context
input_x = feat[first_input:last_input+1, :]
if ivector_rxfilename:
ivector_index = (
first_output + last_output) // 2 // self.ivector_period
input_ivector = ivector[ivector_index, :].reshape(1, -1)
feat_list.append(np.concatenate((input_x,
np.repeat(input_ivector, input_x.shape[0], axis=0)),
axis=-1))
else:
feat_list.append(input_x)

padded_feat = pad_sequence(
[torch.from_numpy(feat).float() for feat in feat_list],
batch_first=True)

assert sum([math.ceil(l / subsampled_frames_per_chunk) for l in output_len_list]) \
== padded_feat.shape[0]

return key_list, padded_feat, output_len_list
15 changes: 12 additions & 3 deletions egs/aishell/s10/chain/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
import sys
import math

import torch
from torch.utils.dlpack import to_dlpack
Expand Down Expand Up @@ -38,6 +39,7 @@ def main():
model = get_chain_model(
feat_dim=args.feat_dim,
output_dim=args.output_dim,
ivector_dim=args.ivector_dim,
lda_mat_filename=args.lda_mat_filename,
hidden_dim=args.hidden_dim,
bottleneck_dim=args.bottleneck_dim,
Expand All @@ -64,22 +66,29 @@ def main():

dataloader = get_feat_dataloader(
feats_scp=args.feats_scp,
ivector_scp=args.ivector_scp,
model_left_context=args.model_left_context,
model_right_context=args.model_right_context,
batch_size=32)

batch_size=32,
num_workers=10)
subsampling_factor = 3
subsampled_frames_per_chunk = args.frames_per_chunk // subsampling_factor
for batch_idx, batch in enumerate(dataloader):
key_list, padded_feat, output_len_list = batch
padded_feat = padded_feat.to(device)
with torch.no_grad():
nnet_output, _ = model(padded_feat)

num = len(key_list)
first = 0
for i in range(num):
key = key_list[i]
output_len = output_len_list[i]
value = nnet_output[i, :output_len, :]
target_len = math.ceil(output_len / subsampled_frames_per_chunk)
result = nnet_output[first:first + target_len, :, :].split(1, 0)
value = torch.cat(result, dim=1)[0, :output_len, :]
value = value.cpu()
first += target_len

m = kaldi.SubMatrixFromDLPack(to_dlpack(value))
m = Matrix(m)
Expand Down
10 changes: 7 additions & 3 deletions egs/aishell/s10/chain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

def get_chain_model(feat_dim,
output_dim,
ivector_dim,
hidden_dim,
bottleneck_dim,
prefinal_bottleneck_dim,
Expand All @@ -25,6 +26,7 @@ def get_chain_model(feat_dim,
lda_mat_filename=None):
model = ChainModel(feat_dim=feat_dim,
output_dim=output_dim,
ivector_dim=ivector_dim,
lda_mat_filename=lda_mat_filename,
hidden_dim=hidden_dim,
bottleneck_dim=bottleneck_dim,
Expand Down Expand Up @@ -82,6 +84,7 @@ class ChainModel(nn.Module):
def __init__(self,
feat_dim,
output_dim,
ivector_dim=0,
lda_mat_filename=None,
hidden_dim=1024,
bottleneck_dim=128,
Expand All @@ -97,8 +100,9 @@ def __init__(self,
assert len(kernel_size_list) == len(subsampling_factor_list)
num_layers = len(kernel_size_list)

input_dim = feat_dim * 3 + ivector_dim
# tdnn1_affine requires [N, T, C]
self.tdnn1_affine = nn.Linear(in_features=feat_dim * 3,
self.tdnn1_affine = nn.Linear(in_features=input_dim,
out_features=hidden_dim)

# tdnn1_batchnorm requires [N, C, T]
Expand Down Expand Up @@ -142,11 +146,11 @@ def __init__(self,
if lda_mat_filename:
logging.info('Use LDA from {}'.format(lda_mat_filename))
self.lda_A, self.lda_b = load_lda_mat(lda_mat_filename)
assert feat_dim * 3 == self.lda_A.shape[0]
assert input_dim == self.lda_A.shape[0]
self.has_LDA = True
else:
logging.info('replace LDA with BatchNorm')
self.input_batch_norm = nn.BatchNorm1d(num_features=feat_dim * 3,
self.input_batch_norm = nn.BatchNorm1d(num_features=input_dim,
affine=False)
self.has_LDA = False

Expand Down
26 changes: 25 additions & 1 deletion egs/aishell/s10/chain/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@ def _set_inference_args(parser):
dest='feats_scp',
help='feats.scp filename, required for inference',
type=str)

parser.add_argument('--frames-per-chunk',
dest='frames_per_chunk',
help='frames per chunk',
type=int,
default=51)

parser.add_argument('--ivector-scp',
dest='ivector_scp',
help='ivector.scp filename, required for ivector inference',
type=str)

parser.add_argument('--ivector-period',
dest='ivector_period',
help='ivector period',
type=int,
default=10)

parser.add_argument('--model-left-context',
dest='model_left_context',
Expand Down Expand Up @@ -228,10 +245,17 @@ def get_args():

parser.add_argument('--feat-dim',
dest='feat_dim',
help='nn input dimension',
help='nn input 0 dimension',
required=True,
type=int)

parser.add_argument('--ivector-dim',
dest='ivector_dim',
help='nn input 1 dimension',
required=False,
default=0,
type=int)

parser.add_argument('--output-dim',
dest='output_dim',
help='nn output dimension',
Expand Down
Loading