1
1
// chainbin/nnet3-chain-combine.cc
2
2
3
3
// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey)
4
+ // 2017 Yiming Wang
4
5
5
6
// See ../../COPYING for clarification regarding multiple authors
6
7
//
19
20
20
21
#include " base/kaldi-common.h"
21
22
#include " util/common-utils.h"
22
- #include " nnet3/nnet-chain-combine.h"
23
+ #include " nnet3/nnet-utils.h"
24
+ #include " nnet3/nnet-compute.h"
25
+ #include " nnet3/nnet-chain-diagnostics.h"
26
+
27
+
28
+ namespace kaldi {
29
+ namespace nnet3 {
30
+
31
+ // Computes and returns the objective function for the examples in 'egs' given
32
+ // the model in 'nnet'. If either of batchnorm/dropout test modes is true, we
33
+ // make a copy of 'nnet', set test modes on that and evaluate its objective.
34
+ // Note: the object that prob_computer->nnet_ refers to should be 'nnet'.
35
+ double ComputeObjf (bool batchnorm_test_mode, bool dropout_test_mode,
36
+ const std::vector<NnetChainExample> &egs, const Nnet &nnet,
37
+ const chain::ChainTrainingOptions &chain_config,
38
+ const fst::StdVectorFst &den_fst,
39
+ NnetChainComputeProb *prob_computer) {
40
+ if (batchnorm_test_mode || dropout_test_mode) {
41
+ Nnet nnet_copy (nnet);
42
+ if (batchnorm_test_mode)
43
+ SetBatchnormTestMode (true , &nnet_copy);
44
+ if (dropout_test_mode)
45
+ SetDropoutTestMode (true , &nnet_copy);
46
+ NnetComputeProbOptions compute_prob_opts;
47
+ NnetChainComputeProb prob_computer_test (compute_prob_opts, chain_config,
48
+ den_fst, nnet_copy);
49
+ return ComputeObjf (false , false , egs, nnet_copy,
50
+ chain_config, den_fst, &prob_computer_test);
51
+ } else {
52
+ prob_computer->Reset ();
53
+ std::vector<NnetChainExample>::const_iterator iter = egs.begin (),
54
+ end = egs.end ();
55
+ for (; iter != end; ++iter)
56
+ prob_computer->Compute (*iter);
57
+ const ChainObjectiveInfo *objf_info =
58
+ prob_computer->GetObjective (" output" );
59
+ if (objf_info == NULL )
60
+ KALDI_ERR << " Error getting objective info (unsuitable egs?)" ;
61
+ KALDI_ASSERT (objf_info->tot_weight > 0.0 );
62
+ // inf/nan tot_objf->return -inf objective.
63
+ double tot_objf = objf_info->tot_like + objf_info->tot_l2_term ;
64
+ if (!(tot_objf == tot_objf && tot_objf - tot_objf == 0 ))
65
+ return -std::numeric_limits<double >::infinity ();
66
+ // we prefer to deal with normalized objective functions.
67
+ return tot_objf / objf_info->tot_weight ;
68
+ }
69
+ }
70
+
71
+ // Updates moving average over num_models nnets, given the average over
72
+ // previous (num_models - 1) nnets, and the new nnet.
73
+ void UpdateNnetMovingAverage (int32 num_models,
74
+ const Nnet &nnet, Nnet *moving_average_nnet) {
75
+ KALDI_ASSERT (NumParameters (nnet) == NumParameters (*moving_average_nnet));
76
+ ScaleNnet ((num_models - 1.0 ) / num_models, moving_average_nnet);
77
+ AddNnet (nnet, 1.0 / num_models, moving_average_nnet);
78
+ }
79
+
80
+ }
81
+ }
23
82
24
83
25
84
int main (int argc, char *argv[]) {
@@ -30,9 +89,11 @@ int main(int argc, char *argv[]) {
30
89
typedef kaldi::int64 int64;
31
90
32
91
const char *usage =
33
- " Using a subset of training or held-out nnet3+chain examples, compute an\n "
34
- " optimal combination of anumber of nnet3 neural nets by maximizing the\n "
35
- " 'chain' objective function. See documentation of options for more details.\n "
92
+ " Using a subset of training or held-out nnet3+chain examples, compute\n "
93
+ " the average over the first n nnet models where we maximize the\n "
94
+ " 'chain' objective function for n. Note that the order of models has\n "
95
+ " been reversed before feeding into this binary. So we are actually\n "
96
+ " combining last n models.\n "
36
97
" Inputs and outputs are nnet3 raw nnets.\n "
37
98
" \n "
38
99
" Usage: nnet3-chain-combine [options] <den-fst> <raw-nnet-in1> <raw-nnet-in2> ... <raw-nnet-inN> <chain-examples-in> <raw-nnet-out>\n "
@@ -41,23 +102,28 @@ int main(int argc, char *argv[]) {
41
102
" nnet3-combine den.fst 35.raw 36.raw 37.raw 38.raw ark:valid.cegs final.raw\n " ;
42
103
43
104
bool binary_write = true ;
105
+ int32 max_objective_evaluations = 30 ;
44
106
bool batchnorm_test_mode = false ,
45
107
dropout_test_mode = true ;
46
108
std::string use_gpu = " yes" ;
47
- NnetCombineConfig combine_config;
48
109
chain::ChainTrainingOptions chain_config;
49
110
50
111
ParseOptions po (usage);
51
112
po.Register (" binary" , &binary_write, " Write output in binary mode" );
113
+ po.Register (" max-objective-evaluations" , &max_objective_evaluations, " The "
114
+ " maximum number of objective evaluations in order to figure "
115
+ " out the best number of models to combine. It helps to speedup "
116
+ " if the number of models provided to this binary is quite "
117
+ " large (e.g. several hundred)." );
52
118
po.Register (" use-gpu" , &use_gpu,
53
119
" yes|no|optional|wait, only has effect if compiled with CUDA" );
54
120
po.Register (" batchnorm-test-mode" , &batchnorm_test_mode,
55
- " If true, set test-mode to true on any BatchNormComponents." );
121
+ " If true, set test-mode to true on any BatchNormComponents "
122
+ " while evaluating objectives." );
56
123
po.Register (" dropout-test-mode" , &dropout_test_mode,
57
124
" If true, set test-mode to true on any DropoutComponents and "
58
- " DropoutMaskComponents." );
125
+ " DropoutMaskComponents while evaluating objectives ." );
59
126
60
- combine_config.Register (&po);
61
127
chain_config.Register (&po);
62
128
63
129
po.Read (argc, argv);
@@ -83,11 +149,10 @@ int main(int argc, char *argv[]) {
83
149
84
150
Nnet nnet;
85
151
ReadKaldiObject (raw_nnet_rxfilename, &nnet);
86
-
87
- if (batchnorm_test_mode)
88
- SetBatchnormTestMode (true , &nnet);
89
- if (dropout_test_mode)
90
- SetDropoutTestMode (true , &nnet);
152
+ Nnet moving_average_nnet (nnet), best_nnet (nnet);
153
+ NnetComputeProbOptions compute_prob_opts;
154
+ NnetChainComputeProb prob_computer (compute_prob_opts, chain_config,
155
+ den_fst, moving_average_nnet);
91
156
92
157
std::vector<NnetChainExample> egs;
93
158
egs.reserve (10000 ); // reserve a lot of space to minimize the chance of
@@ -102,29 +167,50 @@ int main(int argc, char *argv[]) {
102
167
KALDI_ASSERT (!egs.empty ());
103
168
}
104
169
170
+ // first evaluates the objective using the last model.
171
+ int32 best_num_to_combine = 1 ;
172
+ double
173
+ init_objf = ComputeObjf (batchnorm_test_mode, dropout_test_mode,
174
+ egs, moving_average_nnet, chain_config, den_fst, &prob_computer),
175
+ best_objf = init_objf;
176
+ KALDI_LOG << " objective function using the last model is " << init_objf;
105
177
106
178
int32 num_nnets = po.NumArgs () - 3 ;
107
- NnetChainCombiner combiner (combine_config, chain_config,
108
- num_nnets, egs, den_fst, nnet);
109
-
179
+ // then each time before we re-evaluate the objective function, we will add
180
+ // num_to_add models to the moving average.
181
+ int32 num_to_add = (num_nnets + max_objective_evaluations - 1 ) /
182
+ max_objective_evaluations;
110
183
for (int32 n = 1 ; n < num_nnets; n++) {
111
184
std::string this_nnet_rxfilename = po.GetArg (n + 2 );
112
185
ReadKaldiObject (this_nnet_rxfilename, &nnet);
113
- combiner.AcceptNnet (nnet);
186
+ // updates the moving average
187
+ UpdateNnetMovingAverage (n + 1 , nnet, &moving_average_nnet);
188
+ // evaluates the objective everytime after adding num_to_add model or
189
+ // all the models to the moving average.
190
+ if ((n - 1 ) % num_to_add == num_to_add - 1 || n == num_nnets - 1 ) {
191
+ double objf = ComputeObjf (batchnorm_test_mode, dropout_test_mode,
192
+ egs, moving_average_nnet, chain_config, den_fst, &prob_computer);
193
+ KALDI_LOG << " Combining last " << n + 1
194
+ << " models, objective function is " << objf;
195
+ if (objf > best_objf) {
196
+ best_objf = objf;
197
+ best_nnet = moving_average_nnet;
198
+ best_num_to_combine = n + 1 ;
199
+ }
200
+ }
114
201
}
202
+ KALDI_LOG << " Combining " << best_num_to_combine
203
+ << " nnets, objective function changed from " << init_objf
204
+ << " to " << best_objf;
115
205
116
- combiner.Combine ();
117
-
118
- nnet = combiner.GetNnet ();
119
206
if (HasBatchnorm (nnet))
120
- RecomputeStats (egs, chain_config, den_fst, &nnet );
207
+ RecomputeStats (egs, chain_config, den_fst, &best_nnet );
121
208
122
209
#if HAVE_CUDA==1
123
210
CuDevice::Instantiate ().PrintProfile ();
124
211
#endif
125
212
126
- WriteKaldiObject (nnet, nnet_wxfilename, binary_write);
127
-
213
+ WriteKaldiObject (best_nnet, nnet_wxfilename, binary_write);
128
214
KALDI_LOG << " Finished combining neural nets, wrote model to "
129
215
<< nnet_wxfilename;
130
216
} catch (const std::exception &e) {
0 commit comments