Skip to content

Commit 811b5c2

Browse files
committed
adding adjus term to error for rounding issues
1 parent 58d959f commit 811b5c2

File tree

1 file changed

+23
-87
lines changed

1 file changed

+23
-87
lines changed

src/gmm.cpp

Lines changed: 23 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -213,70 +213,32 @@ kmeans_model train_model(uint32_t n, arma::mat data) {
213213

214214
//function used for production
215215
gaussian_mixture_model retrain_model(uint32_t n, arma::mat data, std::vector<variant> variants, uint32_t lower_n, double var_floor){
216-
216+
double initial_covariance = 0.005;
217217
gaussian_mixture_model gmodel;
218218
gmodel.n = n;
219-
220-
arma::mat centroids;
221219
arma::mat initial_means(1, n, arma::fill::zeros);
222220

223221
arma::mat cov (1, n, arma::fill::zeros);
224222
std::vector<double> total_distances;
225223
std::vector<std::vector<double>> all_centroids;
224+
225+
//run a kmeans to seed the GMM
226+
kmeans_model initial_model = train_model(n, data);
226227

227-
for(uint32_t j=0; j < 15; j++){
228-
//std::cerr << "iteration j " << j << std::endl;
229-
bool status2 = arma::kmeans(centroids, data, n, arma::random_spread, 10, false);
230-
if(!status2) continue;
231-
double total_dist = 0;
232-
std::vector<std::vector<double>> clusters(n);
233-
for(auto point : data){
234-
//using std::min_element to find the closest element
235-
auto closest_it = std::min_element(centroids.begin(), centroids.end(),
236-
[point](double a, double b) {
237-
return std::abs(a - point) < std::abs(b - point);
238-
});
239-
240-
uint32_t index = std::distance(centroids.begin(), closest_it);
241-
//std::cerr << point << " " << centroids[index] << std::endl;
242-
clusters[index].push_back(point);
243-
total_dist += std::abs(point-centroids[index]);
244-
}
245-
std::vector<double> tmp;
246-
for(auto c : centroids){
247-
//std::cerr << c << " ";
248-
tmp.push_back((double)c);
249-
}
250-
//std::cerr << "\n";
251-
all_centroids.push_back(tmp);
252-
//std::cerr << total_dist << std::endl;
253-
total_distances.push_back(total_dist);
228+
for(uint32_t c=0; c < initial_model.means.size(); c++){
229+
initial_means.col(c) = (double)initial_model.means[c];
230+
cov.col(c) = initial_covariance;
231+
//std::cerr << initial_model.means[c] << std::endl;
254232
}
255233

256-
uint32_t i=0;
257-
auto min_it = std::min_element(total_distances.begin(), total_distances.end());
258-
uint32_t index = std::distance(total_distances.begin(), min_it);
259-
std::vector<double> centroid_vec = all_centroids[index];
260-
261-
for(uint32_t c=0; c < centroid_vec.size(); c++){
262-
initial_means.col(i) = (double)centroid_vec[c];
263-
std::cerr << "initial k means " << centroid_vec[c] << std::endl;
264-
cov.col(i) = 0.005;
265-
++i;
266-
}
267-
//original had this a 0.001, then 0.8, then 0.01
268234
arma::gmm_diag model;
269235
model.reset(1, n);
270236
model.set_means(initial_means);
271237
model.set_dcovs(cov);
272238
bool status = model.learn(data, n, arma::eucl_dist, arma::keep_existing, 1, 10, var_floor, false);
273239
if(!status){
274-
std::cerr << "model failed to converge" << std::endl;
275-
}
276-
uint32_t c = 0;
277-
for(auto m : model.means){
278-
std::cerr << "retrain means " << m << std::endl;
279-
c++;
240+
std::cerr << "GMM failed to converge" << std::endl;
241+
exit(1);
280242
}
281243
std::vector<double> means;
282244
std::vector<double> hefts;
@@ -303,24 +265,11 @@ gaussian_mixture_model retrain_model(uint32_t n, arma::mat data, std::vector<var
303265
prob_matrix.push_back(tmp);
304266
}
305267

306-
// Compute Log-Likelihood
307-
arma::rowvec log_probs = model.log_p(data);
308-
double log_likelihood = arma::accu(log_probs);
309-
310-
// Compute number of parameters (k)
311-
int d = data.n_rows; // Number of dimensions (features)
312-
int k = n * (d + d * d); // Number of estimated parameters
313-
314-
// Compute AIC
315-
double AIC = 2 * k - 2 * log_likelihood;
316-
317-
318268
gmodel.dcovs = dcovs;
319269
gmodel.prob_matrix = prob_matrix;
320270
gmodel.means = means;
321271
gmodel.hefts = hefts;
322272
gmodel.model = model;
323-
gmodel.aic = AIC;
324273
return(gmodel);
325274
}
326275

@@ -849,7 +798,7 @@ std::vector<uint32_t> find_deletion_positions(std::string filename, uint32_t dep
849798
return(deletion_positions);
850799
}
851800

852-
void parse_internal_variants(std::string filename, std::vector<variant> &variants, uint32_t depth_cutoff, float lower_bound, float upper_bound, std::vector<uint32_t> deletion_positions, uint32_t round_val, double quality_threshold){
801+
void parse_internal_variants(std::string filename, std::vector<variant> &variants, uint32_t depth_cutoff, float lower_bound, float upper_bound, std::vector<uint32_t> deletion_positions, uint32_t round_val, uint8_t quality_threshold){
853802
/*
854803
* Parses the variants file produced internally by reading bam file line-by-line.
855804
*/
@@ -952,44 +901,37 @@ void parse_internal_variants(std::string filename, std::vector<variant> &variant
952901
}
953902
}
954903

955-
std::vector<variant> gmm_model(std::string prefix, std::string output_prefix){
904+
std::vector<variant> gmm_model(std::string prefix, std::string output_prefix, uint32_t min_depth, uint8_t min_qual){
956905
uint32_t n=8;
957-
uint32_t depth_cutoff = 10;
958-
float quality_threshold = 20;
959906
uint32_t round_val = 4;
960-
961907
bool development_mode=true;
962-
double error_rate = cluster_error(prefix, quality_threshold, depth_cutoff);
963-
964-
float lower_bound = 1-error_rate;
965-
float upper_bound = error_rate;
908+
double error_rate = cluster_error(prefix, min_qual, min_depth);
909+
//add these adjusters because of rounding errors
910+
float lower_bound = 1-error_rate+0.0001;
911+
float upper_bound = error_rate-0.0001;
966912
std::cerr << "lower " << lower_bound << " upper " << upper_bound << std::endl;
967913
std::vector<variant> base_variants;
968-
std::vector<uint32_t> deletion_positions = find_deletion_positions(prefix, depth_cutoff, lower_bound, upper_bound, round_val);
914+
std::vector<uint32_t> deletion_positions = find_deletion_positions(prefix, min_depth, lower_bound, upper_bound, round_val);
969915

970-
parse_internal_variants(prefix, base_variants, depth_cutoff, lower_bound, upper_bound, deletion_positions, round_val, quality_threshold);
916+
parse_internal_variants(prefix, base_variants, min_depth, lower_bound, upper_bound, deletion_positions, round_val, min_qual);
971917

972918
//if ivar 1.0 is in use, calculate the frequency of the reference
973919
if(base_variants[0].version_1_var){
974-
calculate_reference_frequency(base_variants, prefix, depth_cutoff, lower_bound, upper_bound, deletion_positions);
920+
calculate_reference_frequency(base_variants, prefix, min_depth, lower_bound, upper_bound, deletion_positions);
975921
}
976922
std::string filename = prefix + ".txt";
977923

978924
//this whole things needs to be reconfigured
979925
uint32_t useful_var=0;
980-
std::vector<double> all_var;
981926
std::vector<variant> variants;
982927
std::vector<uint32_t> count_pos;
983928

984929
for(uint32_t i=0; i < base_variants.size(); i++){
985-
if(!base_variants[i].amplicon_flux && !base_variants[i].depth_flag && !base_variants[i].outside_freq_range && !base_variants[i].qual_flag && !base_variants[i].del_flag && !base_variants[i].amplicon_masked && !base_variants[i].primer_masked){
930+
if(!base_variants[i].amplicon_flux && !base_variants[i].depth_flag && !base_variants[i].outside_freq_range && !base_variants[i].qual_flag && !base_variants[i].amplicon_masked && !base_variants[i].primer_masked){
986931
useful_var++;
987932
variants.push_back(base_variants[i]);
988-
//all_var.push_back(base_variants[i].freq);
989-
//TESTLINES
990-
all_var.push_back(base_variants[i].gapped_freq);
991933
count_pos.push_back(base_variants[i].position);
992-
std::cerr << base_variants[i].freq << " " << base_variants[i].position << " " << base_variants[i].nuc << " " << base_variants[i].depth << " " << base_variants[i].gapped_freq << std::endl;
934+
//std::cerr << base_variants[i].freq << " " << base_variants[i].position << " " << base_variants[i].nuc << " " << base_variants[i].depth << " " << base_variants[i].gapped_freq << std::endl;
993935
}
994936
}
995937
std::cerr << "useful var " << useful_var << std::endl;
@@ -1005,9 +947,6 @@ std::vector<variant> gmm_model(std::string prefix, std::string output_prefix){
1005947
//(rows, cols) where each columns is a sample
1006948
uint32_t count=0;
1007949
for(uint32_t i = 0; i < variants.size(); i++){
1008-
//check if variant should be filtered for first pass model
1009-
//double tmp = static_cast<double>(variants[i].freq);
1010-
//TESTLINES
1011950
double tmp = static_cast<double>(variants[i].gapped_freq);
1012951
data.col(count) = tmp;
1013952
count += 1;
@@ -1063,9 +1002,6 @@ std::vector<variant> gmm_model(std::string prefix, std::string output_prefix){
10631002
std::cerr << "\n";
10641003
tmp_mads.push_back(mad);
10651004
float ratio = (float)useful_var / (float) n; //originally data.size()
1066-
/*for(auto d : data){
1067-
std::cerr << d << std::endl;
1068-
}*/
10691005
std::cerr << "mean " << mean << " mad " << mad << " cluster size " << data.size() << " ratio " << ratio << std::endl;
10701006
if(data.size() > 5){
10711007
if(mad <= 0.10){
@@ -1146,9 +1082,9 @@ std::vector<variant> gmm_model(std::string prefix, std::string output_prefix){
11461082
//look at all variants despite other parameters
11471083
base_variants.clear();
11481084
variants.clear();
1149-
parse_internal_variants(prefix, base_variants, depth_cutoff, lower_bound, upper_bound, deletion_positions, round_val, quality_threshold);
1085+
parse_internal_variants(prefix, base_variants, min_depth, lower_bound, upper_bound, deletion_positions, round_val, min_qual);
11501086
if(base_variants[0].version_1_var){
1151-
calculate_reference_frequency(base_variants, prefix, depth_cutoff, lower_bound, upper_bound, deletion_positions);
1087+
calculate_reference_frequency(base_variants, prefix, min_depth, lower_bound, upper_bound, deletion_positions);
11521088
}
11531089
if(!variants[0].version_1_var){
11541090
calculate_gapped_frequency(base_variants, upper_bound, lower_bound);

0 commit comments

Comments
 (0)