@@ -213,70 +213,32 @@ kmeans_model train_model(uint32_t n, arma::mat data) {
213
213
214
214
// function used for production
215
215
gaussian_mixture_model retrain_model (uint32_t n, arma::mat data, std::vector<variant> variants, uint32_t lower_n, double var_floor){
216
-
216
+ double initial_covariance = 0.005 ;
217
217
gaussian_mixture_model gmodel;
218
218
gmodel.n = n;
219
-
220
- arma::mat centroids;
221
219
arma::mat initial_means (1 , n, arma::fill::zeros);
222
220
223
221
arma::mat cov (1 , n, arma::fill::zeros);
224
222
std::vector<double > total_distances;
225
223
std::vector<std::vector<double >> all_centroids;
224
+
225
+ // run a kmeans to seed the GMM
226
+ kmeans_model initial_model = train_model (n, data);
226
227
227
- for (uint32_t j=0 ; j < 15 ; j++){
228
- // std::cerr << "iteration j " << j << std::endl;
229
- bool status2 = arma::kmeans (centroids, data, n, arma::random_spread, 10 , false );
230
- if (!status2) continue ;
231
- double total_dist = 0 ;
232
- std::vector<std::vector<double >> clusters (n);
233
- for (auto point : data){
234
- // using std::min_element to find the closest element
235
- auto closest_it = std::min_element (centroids.begin (), centroids.end (),
236
- [point](double a, double b) {
237
- return std::abs (a - point) < std::abs (b - point);
238
- });
239
-
240
- uint32_t index = std::distance (centroids.begin (), closest_it);
241
- // std::cerr << point << " " << centroids[index] << std::endl;
242
- clusters[index].push_back (point);
243
- total_dist += std::abs (point-centroids[index]);
244
- }
245
- std::vector<double > tmp;
246
- for (auto c : centroids){
247
- // std::cerr << c << " ";
248
- tmp.push_back ((double )c);
249
- }
250
- // std::cerr << "\n";
251
- all_centroids.push_back (tmp);
252
- // std::cerr << total_dist << std::endl;
253
- total_distances.push_back (total_dist);
228
+ for (uint32_t c=0 ; c < initial_model.means .size (); c++){
229
+ initial_means.col (c) = (double )initial_model.means [c];
230
+ cov.col (c) = initial_covariance;
231
+ // std::cerr << initial_model.means[c] << std::endl;
254
232
}
255
233
256
- uint32_t i=0 ;
257
- auto min_it = std::min_element (total_distances.begin (), total_distances.end ());
258
- uint32_t index = std::distance (total_distances.begin (), min_it);
259
- std::vector<double > centroid_vec = all_centroids[index];
260
-
261
- for (uint32_t c=0 ; c < centroid_vec.size (); c++){
262
- initial_means.col (i) = (double )centroid_vec[c];
263
- std::cerr << " initial k means " << centroid_vec[c] << std::endl;
264
- cov.col (i) = 0.005 ;
265
- ++i;
266
- }
267
- // original had this a 0.001, then 0.8, then 0.01
268
234
arma::gmm_diag model;
269
235
model.reset (1 , n);
270
236
model.set_means (initial_means);
271
237
model.set_dcovs (cov);
272
238
bool status = model.learn (data, n, arma::eucl_dist, arma::keep_existing, 1 , 10 , var_floor, false );
273
239
if (!status){
274
- std::cerr << " model failed to converge" << std::endl;
275
- }
276
- uint32_t c = 0 ;
277
- for (auto m : model.means ){
278
- std::cerr << " retrain means " << m << std::endl;
279
- c++;
240
+ std::cerr << " GMM failed to converge" << std::endl;
241
+ exit (1 );
280
242
}
281
243
std::vector<double > means;
282
244
std::vector<double > hefts;
@@ -303,24 +265,11 @@ gaussian_mixture_model retrain_model(uint32_t n, arma::mat data, std::vector<var
303
265
prob_matrix.push_back (tmp);
304
266
}
305
267
306
- // Compute Log-Likelihood
307
- arma::rowvec log_probs = model.log_p (data);
308
- double log_likelihood = arma::accu (log_probs);
309
-
310
- // Compute number of parameters (k)
311
- int d = data.n_rows ; // Number of dimensions (features)
312
- int k = n * (d + d * d); // Number of estimated parameters
313
-
314
- // Compute AIC
315
- double AIC = 2 * k - 2 * log_likelihood;
316
-
317
-
318
268
gmodel.dcovs = dcovs;
319
269
gmodel.prob_matrix = prob_matrix;
320
270
gmodel.means = means;
321
271
gmodel.hefts = hefts;
322
272
gmodel.model = model;
323
- gmodel.aic = AIC;
324
273
return (gmodel);
325
274
}
326
275
@@ -849,7 +798,7 @@ std::vector<uint32_t> find_deletion_positions(std::string filename, uint32_t dep
849
798
return (deletion_positions);
850
799
}
851
800
852
- void parse_internal_variants (std::string filename, std::vector<variant> &variants, uint32_t depth_cutoff, float lower_bound, float upper_bound, std::vector<uint32_t > deletion_positions, uint32_t round_val, double quality_threshold){
801
+ void parse_internal_variants (std::string filename, std::vector<variant> &variants, uint32_t depth_cutoff, float lower_bound, float upper_bound, std::vector<uint32_t > deletion_positions, uint32_t round_val, uint8_t quality_threshold){
853
802
/*
854
803
* Parses the variants file produced internally by reading bam file line-by-line.
855
804
*/
@@ -952,44 +901,37 @@ void parse_internal_variants(std::string filename, std::vector<variant> &variant
952
901
}
953
902
}
954
903
955
- std::vector<variant> gmm_model (std::string prefix, std::string output_prefix){
904
+ std::vector<variant> gmm_model (std::string prefix, std::string output_prefix, uint32_t min_depth, uint8_t min_qual ){
956
905
uint32_t n=8 ;
957
- uint32_t depth_cutoff = 10 ;
958
- float quality_threshold = 20 ;
959
906
uint32_t round_val = 4 ;
960
-
961
907
bool development_mode=true ;
962
- double error_rate = cluster_error (prefix, quality_threshold, depth_cutoff );
963
-
964
- float lower_bound = 1 -error_rate;
965
- float upper_bound = error_rate;
908
+ double error_rate = cluster_error (prefix, min_qual, min_depth );
909
+ // add these adjusters because of rounding errors
910
+ float lower_bound = 1 -error_rate+ 0.0001 ;
911
+ float upper_bound = error_rate- 0.0001 ;
966
912
std::cerr << " lower " << lower_bound << " upper " << upper_bound << std::endl;
967
913
std::vector<variant> base_variants;
968
- std::vector<uint32_t > deletion_positions = find_deletion_positions (prefix, depth_cutoff , lower_bound, upper_bound, round_val);
914
+ std::vector<uint32_t > deletion_positions = find_deletion_positions (prefix, min_depth , lower_bound, upper_bound, round_val);
969
915
970
- parse_internal_variants (prefix, base_variants, depth_cutoff , lower_bound, upper_bound, deletion_positions, round_val, quality_threshold );
916
+ parse_internal_variants (prefix, base_variants, min_depth , lower_bound, upper_bound, deletion_positions, round_val, min_qual );
971
917
972
918
// if ivar 1.0 is in use, calculate the frequency of the reference
973
919
if (base_variants[0 ].version_1_var ){
974
- calculate_reference_frequency (base_variants, prefix, depth_cutoff , lower_bound, upper_bound, deletion_positions);
920
+ calculate_reference_frequency (base_variants, prefix, min_depth , lower_bound, upper_bound, deletion_positions);
975
921
}
976
922
std::string filename = prefix + " .txt" ;
977
923
978
924
// this whole things needs to be reconfigured
979
925
uint32_t useful_var=0 ;
980
- std::vector<double > all_var;
981
926
std::vector<variant> variants;
982
927
std::vector<uint32_t > count_pos;
983
928
984
929
for (uint32_t i=0 ; i < base_variants.size (); i++){
985
- if (!base_variants[i].amplicon_flux && !base_variants[i].depth_flag && !base_variants[i].outside_freq_range && !base_variants[i].qual_flag && !base_variants[i].del_flag && !base_variants[i]. amplicon_masked && !base_variants[i].primer_masked ){
930
+ if (!base_variants[i].amplicon_flux && !base_variants[i].depth_flag && !base_variants[i].outside_freq_range && !base_variants[i].qual_flag && !base_variants[i].amplicon_masked && !base_variants[i].primer_masked ){
986
931
useful_var++;
987
932
variants.push_back (base_variants[i]);
988
- // all_var.push_back(base_variants[i].freq);
989
- // TESTLINES
990
- all_var.push_back (base_variants[i].gapped_freq );
991
933
count_pos.push_back (base_variants[i].position );
992
- std::cerr << base_variants[i].freq << " " << base_variants[i].position << " " << base_variants[i].nuc << " " << base_variants[i].depth << " " << base_variants[i].gapped_freq << std::endl;
934
+ // std::cerr << base_variants[i].freq << " " << base_variants[i].position << " " << base_variants[i].nuc << " " << base_variants[i].depth << " " << base_variants[i].gapped_freq << std::endl;
993
935
}
994
936
}
995
937
std::cerr << " useful var " << useful_var << std::endl;
@@ -1005,9 +947,6 @@ std::vector<variant> gmm_model(std::string prefix, std::string output_prefix){
1005
947
// (rows, cols) where each columns is a sample
1006
948
uint32_t count=0 ;
1007
949
for (uint32_t i = 0 ; i < variants.size (); i++){
1008
- // check if variant should be filtered for first pass model
1009
- // double tmp = static_cast<double>(variants[i].freq);
1010
- // TESTLINES
1011
950
double tmp = static_cast <double >(variants[i].gapped_freq );
1012
951
data.col (count) = tmp;
1013
952
count += 1 ;
@@ -1063,9 +1002,6 @@ std::vector<variant> gmm_model(std::string prefix, std::string output_prefix){
1063
1002
std::cerr << " \n " ;
1064
1003
tmp_mads.push_back (mad);
1065
1004
float ratio = (float )useful_var / (float ) n; // originally data.size()
1066
- /* for(auto d : data){
1067
- std::cerr << d << std::endl;
1068
- }*/
1069
1005
std::cerr << " mean " << mean << " mad " << mad << " cluster size " << data.size () << " ratio " << ratio << std::endl;
1070
1006
if (data.size () > 5 ){
1071
1007
if (mad <= 0.10 ){
@@ -1146,9 +1082,9 @@ std::vector<variant> gmm_model(std::string prefix, std::string output_prefix){
1146
1082
// look at all variants despite other parameters
1147
1083
base_variants.clear ();
1148
1084
variants.clear ();
1149
- parse_internal_variants (prefix, base_variants, depth_cutoff , lower_bound, upper_bound, deletion_positions, round_val, quality_threshold );
1085
+ parse_internal_variants (prefix, base_variants, min_depth , lower_bound, upper_bound, deletion_positions, round_val, min_qual );
1150
1086
if (base_variants[0 ].version_1_var ){
1151
- calculate_reference_frequency (base_variants, prefix, depth_cutoff , lower_bound, upper_bound, deletion_positions);
1087
+ calculate_reference_frequency (base_variants, prefix, min_depth , lower_bound, upper_bound, deletion_positions);
1152
1088
}
1153
1089
if (!variants[0 ].version_1_var ){
1154
1090
calculate_gapped_frequency (base_variants, upper_bound, lower_bound);
0 commit comments