Skip to content

Commit a5561c3

Browse files
freewymdanpovey
authored andcommitted
[src,scripts] Simplify model combination: do simple average over last n models (kaldi-asr#2067)
1 parent 48656c3 commit a5561c3

17 files changed

+246
-1757
lines changed

egs/wsj/s5/steps/info/chain_dir_info.pl

+3
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ sub get_combine_info {
137137
if (m/Combining nnets, objective function changed from (\S+) to (\S+)/) {
138138
close(F);
139139
return sprintf(" combine=%.3f->%.3f", $1, $2);
140+
} elsif (m/Combining (\S+) nnets, objective function changed from (\S+) to (\S+)/) {
141+
close(F);
142+
return sprintf(" combine=%.3f->%.3f (over %d)", $2, $3, $1);
140143
}
141144
}
142145
}

egs/wsj/s5/steps/info/nnet3_dir_info.pl

+3
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ sub get_combine_info {
137137
if (m/Combining nnets, objective function changed from (\S+) to (\S+)/) {
138138
close(F);
139139
return sprintf(" combine=%.2f->%.2f", $1, $2);
140+
} elsif (m/Combining (\S+) nnets, objective function changed from (\S+) to (\S+)/) {
141+
close(F);
142+
return sprintf(" combine=%.2f->%.2f (over %d)", $2, $3, $1);
140143
}
141144
}
142145
}

egs/wsj/s5/steps/libs/nnet3/train/chain_objf/acoustic_model.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def compute_progress(dir, iter, run_opts):
492492
def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_str,
493493
egs_dir, leaky_hmm_coefficient, l2_regularize,
494494
xent_regularize, run_opts,
495-
sum_to_one_penalty=0.0):
495+
max_objective_evaluations=30):
496496
""" Function to do model combination
497497
498498
In the nnet3 setup, the logic
@@ -505,9 +505,6 @@ def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_st
505505

506506
models_to_combine.add(num_iters)
507507

508-
# TODO: if it turns out the sum-to-one-penalty code is not useful,
509-
# remove support for it.
510-
511508
for iter in sorted(models_to_combine):
512509
model_file = '{0}/{1}.mdl'.format(dir, iter)
513510
if os.path.exists(model_file):
@@ -528,12 +525,9 @@ def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_st
528525

529526
common_lib.execute_command(
530527
"""{command} {combine_queue_opt} {dir}/log/combine.log \
531-
nnet3-chain-combine --num-iters={opt_iters} \
528+
nnet3-chain-combine \
529+
--max-objective-evaluations={max_objective_evaluations} \
532530
--l2-regularize={l2} --leaky-hmm-coefficient={leaky} \
533-
--separate-weights-per-component={separate_weights} \
534-
--enforce-sum-to-one={hard_enforce} \
535-
--sum-to-one-penalty={penalty} \
536-
--enforce-positive-weights=true \
537531
--verbose=3 {dir}/den.fst {raw_models} \
538532
"ark,bg:nnet3-chain-copy-egs ark:{egs_dir}/combine.cegs ark:- | \
539533
nnet3-chain-merge-egs --minibatch-size={num_chunk_per_mb} \
@@ -542,12 +536,9 @@ def combine_models(dir, num_iters, models_to_combine, num_chunk_per_minibatch_st
542536
{dir}/final.mdl""".format(
543537
command=run_opts.command,
544538
combine_queue_opt=run_opts.combine_queue_opt,
545-
opt_iters=(20 if sum_to_one_penalty <= 0 else 80),
546-
separate_weights=(sum_to_one_penalty > 0),
539+
max_objective_evaluations=max_objective_evaluations,
547540
l2=l2_regularize, leaky=leaky_hmm_coefficient,
548541
dir=dir, raw_models=" ".join(raw_model_strings),
549-
hard_enforce=(sum_to_one_penalty <= 0),
550-
penalty=sum_to_one_penalty,
551542
num_chunk_per_mb=num_chunk_per_minibatch_str,
552543
num_iters=num_iters,
553544
egs_dir=egs_dir))

egs/wsj/s5/steps/libs/nnet3/train/common.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,16 @@ def __init__(self,
852852
the final model combination stage. These
853853
models will themselves be averages of
854854
iteration-number ranges""")
855+
self.parser.add_argument("--trainer.optimization.max-objective-evaluations",
856+
"--trainer.max-objective-evaluations",
857+
type=int, dest='max_objective_evaluations',
858+
default=30,
859+
help="""The maximum number of objective
860+
evaluations in order to figure out the
861+
best number of models to combine. It helps to
862+
speedup if the number of models provided to the
863+
model combination binary is quite large (e.g.
864+
several hundred).""")
855865
self.parser.add_argument("--trainer.optimization.do-final-combination",
856866
dest='do_final_combination', type=str,
857867
action=common_lib.StrToBoolAction,
@@ -861,9 +871,7 @@ def __init__(self,
861871
last-numbered model as the final.mdl).""")
862872
self.parser.add_argument("--trainer.optimization.combine-sum-to-one-penalty",
863873
type=float, dest='combine_sum_to_one_penalty', default=0.0,
864-
help="""If > 0, activates 'soft' enforcement of the
865-
sum-to-one penalty in combination (may be helpful
866-
if using dropout). E.g. 1.0e-03.""")
874+
help="""This option is deprecated and does nothing.""")
867875
self.parser.add_argument("--trainer.optimization.momentum", type=float,
868876
dest='momentum', default=0.0,
869877
help="""Momentum used in update computation.

egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir,
452452
minibatch_size_str,
453453
run_opts,
454454
chunk_width=None, get_raw_nnet_from_am=True,
455-
sum_to_one_penalty=0.0,
455+
max_objective_evaluations=30,
456456
use_multitask_egs=False,
457457
compute_per_dim_accuracy=False):
458458
""" Function to do model combination
@@ -501,10 +501,8 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir,
501501
use_multitask_egs=use_multitask_egs)
502502
common_lib.execute_command(
503503
"""{command} {combine_queue_opt} {dir}/log/combine.log \
504-
nnet3-combine --num-iters=80 \
505-
--enforce-sum-to-one={hard_enforce} \
506-
--sum-to-one-penalty={penalty} \
507-
--enforce-positive-weights=true \
504+
nnet3-combine \
505+
--max-objective-evaluations={max_objective_evaluations} \
508506
--verbose=3 {raw_models} \
509507
"ark,bg:nnet3-copy-egs {multitask_egs_opts} \
510508
{egs_rspecifier} ark:- | \
@@ -513,9 +511,8 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir,
513511
""".format(command=run_opts.command,
514512
combine_queue_opt=run_opts.combine_queue_opt,
515513
dir=dir, raw_models=" ".join(raw_model_strings),
514+
max_objective_evaluations=max_objective_evaluations,
516515
egs_rspecifier=egs_rspecifier,
517-
hard_enforce=(sum_to_one_penalty <= 0),
518-
penalty=sum_to_one_penalty,
519516
mbsize=minibatch_size_str,
520517
out_model=out_model,
521518
multitask_egs_opts=multitask_egs_opts))

egs/wsj/s5/steps/nnet3/chain/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def train(args, run_opts):
554554
l2_regularize=args.l2_regularize,
555555
xent_regularize=args.xent_regularize,
556556
run_opts=run_opts,
557-
sum_to_one_penalty=args.combine_sum_to_one_penalty)
557+
max_objective_evaluations=args.max_objective_evaluations)
558558
else:
559559
logger.info("Copying the last-numbered model to final.mdl")
560560
common_lib.force_symlink("{0}.mdl".format(num_iters),

egs/wsj/s5/steps/nnet3/train_dnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def train(args, run_opts):
364364
models_to_combine=models_to_combine,
365365
egs_dir=egs_dir,
366366
minibatch_size_str=args.minibatch_size, run_opts=run_opts,
367-
sum_to_one_penalty=args.combine_sum_to_one_penalty)
367+
max_objective_evaluations=args.max_objective_evaluations)
368368

369369
if args.stage <= num_iters + 1:
370370
logger.info("Getting average posterior for purposes of "

egs/wsj/s5/steps/nnet3/train_raw_dnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def train(args, run_opts):
398398
models_to_combine=models_to_combine, egs_dir=egs_dir,
399399
minibatch_size_str=args.minibatch_size, run_opts=run_opts,
400400
get_raw_nnet_from_am=False,
401-
sum_to_one_penalty=args.combine_sum_to_one_penalty,
401+
max_objective_evaluations=args.max_objective_evaluations,
402402
use_multitask_egs=use_multitask_egs)
403403
else:
404404
common_lib.force_symlink("{0}.raw".format(num_iters),

egs/wsj/s5/steps/nnet3/train_raw_rnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def train(args, run_opts):
475475
run_opts=run_opts, chunk_width=args.chunk_width,
476476
get_raw_nnet_from_am=False,
477477
compute_per_dim_accuracy=args.compute_per_dim_accuracy,
478-
sum_to_one_penalty=args.combine_sum_to_one_penalty)
478+
max_objective_evaluations=args.max_objective_evaluations)
479479
else:
480480
common_lib.force_symlink("{0}.raw".format(num_iters),
481481
"{0}/final.raw".format(args.dir))

egs/wsj/s5/steps/nnet3/train_rnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def train(args, run_opts):
451451
run_opts=run_opts,
452452
minibatch_size_str=args.num_chunk_per_minibatch,
453453
chunk_width=args.chunk_width,
454-
sum_to_one_penalty=args.combine_sum_to_one_penalty,
454+
max_objective_evaluations=args.max_objective_evaluations,
455455
compute_per_dim_accuracy=args.compute_per_dim_accuracy)
456456

457457
if args.stage <= num_iters + 1:

src/chainbin/nnet3-chain-combine.cc

+109-23
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// chainbin/nnet3-chain-combine.cc
22

33
// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4+
// 2017 Yiming Wang
45

56
// See ../../COPYING for clarification regarding multiple authors
67
//
@@ -19,7 +20,65 @@
1920

2021
#include "base/kaldi-common.h"
2122
#include "util/common-utils.h"
22-
#include "nnet3/nnet-chain-combine.h"
23+
#include "nnet3/nnet-utils.h"
24+
#include "nnet3/nnet-compute.h"
25+
#include "nnet3/nnet-chain-diagnostics.h"
26+
27+
28+
namespace kaldi {
29+
namespace nnet3 {
30+
31+
// Computes and returns the objective function for the examples in 'egs' given
32+
// the model in 'nnet'. If either of batchnorm/dropout test modes is true, we
33+
// make a copy of 'nnet', set test modes on that and evaluate its objective.
34+
// Note: the object that prob_computer->nnet_ refers to should be 'nnet'.
35+
double ComputeObjf(bool batchnorm_test_mode, bool dropout_test_mode,
36+
const std::vector<NnetChainExample> &egs, const Nnet &nnet,
37+
const chain::ChainTrainingOptions &chain_config,
38+
const fst::StdVectorFst &den_fst,
39+
NnetChainComputeProb *prob_computer) {
40+
if (batchnorm_test_mode || dropout_test_mode) {
41+
Nnet nnet_copy(nnet);
42+
if (batchnorm_test_mode)
43+
SetBatchnormTestMode(true, &nnet_copy);
44+
if (dropout_test_mode)
45+
SetDropoutTestMode(true, &nnet_copy);
46+
NnetComputeProbOptions compute_prob_opts;
47+
NnetChainComputeProb prob_computer_test(compute_prob_opts, chain_config,
48+
den_fst, nnet_copy);
49+
return ComputeObjf(false, false, egs, nnet_copy,
50+
chain_config, den_fst, &prob_computer_test);
51+
} else {
52+
prob_computer->Reset();
53+
std::vector<NnetChainExample>::const_iterator iter = egs.begin(),
54+
end = egs.end();
55+
for (; iter != end; ++iter)
56+
prob_computer->Compute(*iter);
57+
const ChainObjectiveInfo *objf_info =
58+
prob_computer->GetObjective("output");
59+
if (objf_info == NULL)
60+
KALDI_ERR << "Error getting objective info (unsuitable egs?)";
61+
KALDI_ASSERT(objf_info->tot_weight > 0.0);
62+
// inf/nan tot_objf->return -inf objective.
63+
double tot_objf = objf_info->tot_like + objf_info->tot_l2_term;
64+
if (!(tot_objf == tot_objf && tot_objf - tot_objf == 0))
65+
return -std::numeric_limits<double>::infinity();
66+
// we prefer to deal with normalized objective functions.
67+
return tot_objf / objf_info->tot_weight;
68+
}
69+
}
70+
71+
// Updates moving average over num_models nnets, given the average over
72+
// previous (num_models - 1) nnets, and the new nnet.
73+
void UpdateNnetMovingAverage(int32 num_models,
74+
const Nnet &nnet, Nnet *moving_average_nnet) {
75+
KALDI_ASSERT(NumParameters(nnet) == NumParameters(*moving_average_nnet));
76+
ScaleNnet((num_models - 1.0) / num_models, moving_average_nnet);
77+
AddNnet(nnet, 1.0 / num_models, moving_average_nnet);
78+
}
79+
80+
}
81+
}
2382

2483

2584
int main(int argc, char *argv[]) {
@@ -30,9 +89,11 @@ int main(int argc, char *argv[]) {
3089
typedef kaldi::int64 int64;
3190

3291
const char *usage =
33-
"Using a subset of training or held-out nnet3+chain examples, compute an\n"
34-
"optimal combination of anumber of nnet3 neural nets by maximizing the\n"
35-
"'chain' objective function. See documentation of options for more details.\n"
92+
"Using a subset of training or held-out nnet3+chain examples, compute\n"
93+
"the average over the first n nnet models where we maximize the\n"
94+
"'chain' objective function for n. Note that the order of models has\n"
95+
"been reversed before feeding into this binary. So we are actually\n"
96+
"combining last n models.\n"
3697
"Inputs and outputs are nnet3 raw nnets.\n"
3798
"\n"
3899
"Usage: nnet3-chain-combine [options] <den-fst> <raw-nnet-in1> <raw-nnet-in2> ... <raw-nnet-inN> <chain-examples-in> <raw-nnet-out>\n"
@@ -41,23 +102,28 @@ int main(int argc, char *argv[]) {
41102
" nnet3-combine den.fst 35.raw 36.raw 37.raw 38.raw ark:valid.cegs final.raw\n";
42103

43104
bool binary_write = true;
105+
int32 max_objective_evaluations = 30;
44106
bool batchnorm_test_mode = false,
45107
dropout_test_mode = true;
46108
std::string use_gpu = "yes";
47-
NnetCombineConfig combine_config;
48109
chain::ChainTrainingOptions chain_config;
49110

50111
ParseOptions po(usage);
51112
po.Register("binary", &binary_write, "Write output in binary mode");
113+
po.Register("max-objective-evaluations", &max_objective_evaluations, "The "
114+
"maximum number of objective evaluations in order to figure "
115+
"out the best number of models to combine. It helps to speedup "
116+
"if the number of models provided to this binary is quite "
117+
"large (e.g. several hundred).");
52118
po.Register("use-gpu", &use_gpu,
53119
"yes|no|optional|wait, only has effect if compiled with CUDA");
54120
po.Register("batchnorm-test-mode", &batchnorm_test_mode,
55-
"If true, set test-mode to true on any BatchNormComponents.");
121+
"If true, set test-mode to true on any BatchNormComponents "
122+
"while evaluating objectives.");
56123
po.Register("dropout-test-mode", &dropout_test_mode,
57124
"If true, set test-mode to true on any DropoutComponents and "
58-
"DropoutMaskComponents.");
125+
"DropoutMaskComponents while evaluating objectives.");
59126

60-
combine_config.Register(&po);
61127
chain_config.Register(&po);
62128

63129
po.Read(argc, argv);
@@ -83,11 +149,10 @@ int main(int argc, char *argv[]) {
83149

84150
Nnet nnet;
85151
ReadKaldiObject(raw_nnet_rxfilename, &nnet);
86-
87-
if (batchnorm_test_mode)
88-
SetBatchnormTestMode(true, &nnet);
89-
if (dropout_test_mode)
90-
SetDropoutTestMode(true, &nnet);
152+
Nnet moving_average_nnet(nnet), best_nnet(nnet);
153+
NnetComputeProbOptions compute_prob_opts;
154+
NnetChainComputeProb prob_computer(compute_prob_opts, chain_config,
155+
den_fst, moving_average_nnet);
91156

92157
std::vector<NnetChainExample> egs;
93158
egs.reserve(10000); // reserve a lot of space to minimize the chance of
@@ -102,29 +167,50 @@ int main(int argc, char *argv[]) {
102167
KALDI_ASSERT(!egs.empty());
103168
}
104169

170+
// first evaluates the objective using the last model.
171+
int32 best_num_to_combine = 1;
172+
double
173+
init_objf = ComputeObjf(batchnorm_test_mode, dropout_test_mode,
174+
egs, moving_average_nnet, chain_config, den_fst, &prob_computer),
175+
best_objf = init_objf;
176+
KALDI_LOG << "objective function using the last model is " << init_objf;
105177

106178
int32 num_nnets = po.NumArgs() - 3;
107-
NnetChainCombiner combiner(combine_config, chain_config,
108-
num_nnets, egs, den_fst, nnet);
109-
179+
// then each time before we re-evaluate the objective function, we will add
180+
// num_to_add models to the moving average.
181+
int32 num_to_add = (num_nnets + max_objective_evaluations - 1) /
182+
max_objective_evaluations;
110183
for (int32 n = 1; n < num_nnets; n++) {
111184
std::string this_nnet_rxfilename = po.GetArg(n + 2);
112185
ReadKaldiObject(this_nnet_rxfilename, &nnet);
113-
combiner.AcceptNnet(nnet);
186+
// updates the moving average
187+
UpdateNnetMovingAverage(n + 1, nnet, &moving_average_nnet);
188+
// evaluates the objective everytime after adding num_to_add model or
189+
// all the models to the moving average.
190+
if ((n - 1) % num_to_add == num_to_add - 1 || n == num_nnets - 1) {
191+
double objf = ComputeObjf(batchnorm_test_mode, dropout_test_mode,
192+
egs, moving_average_nnet, chain_config, den_fst, &prob_computer);
193+
KALDI_LOG << "Combining last " << n + 1
194+
<< " models, objective function is " << objf;
195+
if (objf > best_objf) {
196+
best_objf = objf;
197+
best_nnet = moving_average_nnet;
198+
best_num_to_combine = n + 1;
199+
}
200+
}
114201
}
202+
KALDI_LOG << "Combining " << best_num_to_combine
203+
<< " nnets, objective function changed from " << init_objf
204+
<< " to " << best_objf;
115205

116-
combiner.Combine();
117-
118-
nnet = combiner.GetNnet();
119206
if (HasBatchnorm(nnet))
120-
RecomputeStats(egs, chain_config, den_fst, &nnet);
207+
RecomputeStats(egs, chain_config, den_fst, &best_nnet);
121208

122209
#if HAVE_CUDA==1
123210
CuDevice::Instantiate().PrintProfile();
124211
#endif
125212

126-
WriteKaldiObject(nnet, nnet_wxfilename, binary_write);
127-
213+
WriteKaldiObject(best_nnet, nnet_wxfilename, binary_write);
128214
KALDI_LOG << "Finished combining neural nets, wrote model to "
129215
<< nnet_wxfilename;
130216
} catch(const std::exception &e) {

src/nnet3/Makefile

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \
2222
nnet-example.o nnet-nnet.o nnet-compile-utils.o \
2323
nnet-utils.o nnet-compute.o nnet-test-utils.o nnet-analyze.o \
2424
nnet-example-utils.o nnet-training.o \
25-
nnet-diagnostics.o nnet-combine.o nnet-am-decodable-simple.o \
25+
nnet-diagnostics.o nnet-am-decodable-simple.o \
2626
nnet-optimize-utils.o nnet-chain-example.o \
27-
nnet-chain-training.o nnet-chain-diagnostics.o nnet-chain-combine.o \
27+
nnet-chain-training.o nnet-chain-diagnostics.o \
2828
discriminative-supervision.o nnet-discriminative-example.o \
2929
nnet-discriminative-diagnostics.o \
3030
discriminative-training.o nnet-discriminative-training.o \

0 commit comments

Comments
 (0)