44#include " split.hpp"
55#include < omp.h>
66#include " utils.hpp"
7+ #include < sstream>
8+ #include < iomanip>
79
810namespace odgi {
911
@@ -225,16 +227,13 @@ int main_similarity(int argc, char** argv) {
225227 }
226228 path_length += graph.get_length (h);
227229 });
230+
231+ uint32_t path_id = get_path_id (p);
228232#pragma omp critical (bp_count)
229- bp_count[get_path_id (p) ] += path_length;
233+ bp_count[path_id ] += path_length;
230234 }
231235
232236 const bool show_progress = args::get (progress);
233- std::unique_ptr<algorithms::progress_meter::ProgressMeter> progress_meter;
234- if (show_progress) {
235- progress_meter = std::make_unique<algorithms::progress_meter::ProgressMeter>(
236- graph.get_node_count (), " [odgi::similarity] collecting path intersection lengths" );
237- }
238237
239238 ska::flat_hash_map<uint64_t , uint64_t > path_intersection_length;
240239
@@ -245,9 +244,11 @@ int main_similarity(int argc, char** argv) {
245244 }
246245 if (using_delim) {
247246 const uint32_t num_groups = path_groups.size ();
247+ #pragma omp parallel for collapse(2)
248248 for (uint32_t i = 0 ; i < num_groups; ++i) {
249249 for (uint32_t j = 0 ; j < num_groups; ++j) {
250250 // Initialize with 0 intersection. Will be updated later if intersection > 0.
251+ #pragma omp critical
251252 path_intersection_length[encode_pair (i, j)] = 0 ;
252253 }
253254 }
@@ -259,11 +260,14 @@ int main_similarity(int argc, char** argv) {
259260 actual_path_ids.push_back ((uint32_t )as_integer (p));
260261 });
261262
262- // Iterate through all actual path integer IDs collected earlier
263- for (const uint32_t id_i : actual_path_ids) {
264- for (const uint32_t id_j : actual_path_ids) {
265- // Initialize with 0 intersection.
266- path_intersection_length[encode_pair (id_i, id_j)] = 0 ;
263+ // Parallelize the nested loop for path pairs
264+ const size_t num_paths = actual_path_ids.size ();
265+ #pragma omp parallel for collapse(2) // / Both loops are flattened into one iteration space, so all i,j combinations distributed across threads
266+ for (size_t i = 0 ; i < num_paths; ++i) {
267+ for (size_t j = 0 ; j < num_paths; ++j) {
268+ // Initialize with 0 intersection.
269+ #pragma omp critical
270+ path_intersection_length[encode_pair (actual_path_ids[i], actual_path_ids[j])] = 0 ;
267271 }
268272 }
269273 }
@@ -272,11 +276,41 @@ int main_similarity(int argc, char** argv) {
272276 }
273277 }
274278
275- graph.for_each_handle (
276- [&](const handle_t & h) {
279+ std::unique_ptr<algorithms::progress_meter::ProgressMeter> progress_meter;
280+ if (show_progress) {
281+ progress_meter = std::make_unique<algorithms::progress_meter::ProgressMeter>(
282+ graph.get_node_count (), " [odgi::similarity] collecting path intersection lengths" );
283+ }
284+
285+ // Use limited thread-local storage to balance speed and memory usage
286+ // Use path_intersection_length as the first map to save one copy
287+ const uint64_t max_local_maps = std::min (3UL , num_threads);
288+ std::vector<ska::flat_hash_map<uint64_t , uint64_t >> thread_local_maps (max_local_maps - 1 );
289+ std::vector<std::mutex> map_mutexes (max_local_maps);
290+
291+ #pragma omp parallel
292+ {
293+ int thread_id = omp_get_thread_num ();
294+ int map_id = thread_id % max_local_maps;
295+
296+ // Use path_intersection_length as first map, thread_local_maps for others
297+ ska::flat_hash_map<uint64_t , uint64_t >* local_map;
298+ std::mutex* map_mutex;
299+
300+ if (map_id == 0 ) {
301+ local_map = &path_intersection_length;
302+ map_mutex = &map_mutexes[0 ];
303+ } else {
304+ local_map = &thread_local_maps[map_id - 1 ];
305+ map_mutex = &map_mutexes[map_id];
306+ }
307+
308+ #pragma omp for
309+ for (uint64_t node_id = 1 ; node_id <= graph.get_node_count (); ++node_id) {
310+ handle_t h = graph.get_handle (node_id);
277311 // Skip masked-out nodes
278312 if (!node_mask[graph.get_id (h) - 1 ]) {
279- return ;
313+ continue ;
280314 }
281315 ska::flat_hash_map<uint32_t , uint64_t > local_path_lengths;
282316 size_t l = graph.get_length (h);
@@ -286,22 +320,38 @@ int main_similarity(int argc, char** argv) {
286320 local_path_lengths[get_path_id (graph.get_path_handle_of_step (s))] += l;
287321 });
288322
289- #pragma omp critical (path_intersection_length)
290- for (auto & p : local_path_lengths) {
291- for (auto & q : local_path_lengths) {
292- path_intersection_length[encode_pair (p.first , q.first )] += std::min (p.second , q.second );
323+ // Update shared local map with mutex protection
324+ {
325+ std::lock_guard<std::mutex> lock (*map_mutex);
326+ for (auto & p : local_path_lengths) {
327+ for (auto & q : local_path_lengths) {
328+ (*local_map)[encode_pair (p.first , q.first )] += std::min (p.second , q.second );
329+ }
293330 }
294331 }
295332
296333 if (show_progress) {
297334 progress_meter->increment (1 );
298335 }
299- }, true );
336+ }
337+ }
338+
339+ // Merge thread-local maps into the main map
340+ // Skip index 0 since path_intersection_length is already used directly
341+ for (const auto & local_map : thread_local_maps) {
342+ for (const auto & pair : local_map) {
343+ path_intersection_length[pair.first ] += pair.second ;
344+ }
345+ }
300346
301347 if (show_progress) {
302348 progress_meter->finish ();
303349 }
304350
351+ if (show_progress) {
352+ std::cerr << " [odgi::similarity] Writing the output..." << std::endl;
353+ }
354+
305355 /* if (using_delim) {
306356 std::cout << "group.a" << "\t"
307357 << "group.b" << "\t"
@@ -335,6 +385,13 @@ int main_similarity(int argc, char** argv) {
335385 }
336386
337387 std::cout << std::endl;
388+
389+ // Use chunked buffering to balance speed and memory usage
390+ std::ostringstream output_buffer;
391+ output_buffer << std::fixed << std::setprecision (6 );
392+ const size_t buffer_chunk_size = 100000 ; // Lines per chunk
393+ size_t lines_written = 0 ;
394+
338395 for (auto & p : path_intersection_length) {
339396 uint32_t id_a, id_b;
340397 decode_pair (p.first , &id_a, &id_b);
@@ -347,27 +404,39 @@ int main_similarity(int argc, char** argv) {
347404 const double dice = 2.0 * ((double ) intersection / (double )(bp_count[id_a] + bp_count[id_b]));
348405 const double estimated_identity = 2.0 * jaccard / (1.0 + jaccard);
349406
350- std::cout << get_path_name (id_a) << " \t "
351- << get_path_name (id_b) << " \t "
352- << bp_count[id_a] << " \t "
353- << bp_count[id_b] << " \t "
354- << intersection << " \t " ;
407+ output_buffer << get_path_name (id_a) << " \t "
408+ << get_path_name (id_b) << " \t "
409+ << bp_count[id_a] << " \t "
410+ << bp_count[id_b] << " \t "
411+ << intersection << " \t " ;
355412
356413 if (emit_distances) {
357414 const double euclidian_distance = std::sqrt ((double )((bp_count[id_a] + bp_count[id_b] - intersection) - intersection));
358415 const uint64_t manhattan_distance = (bp_count[id_a] + bp_count[id_b] - intersection) - intersection;
359- std::cout << (1.0 - jaccard) << " \t "
360- << (1.0 - cosine) << " \t "
361- << (1.0 - dice) << " \t "
362- << (1.0 - estimated_identity) << " \t "
363- << euclidian_distance << " \t "
364- << manhattan_distance << std::endl ;
416+ output_buffer << (1.0 - jaccard) << " \t "
417+ << (1.0 - cosine) << " \t "
418+ << (1.0 - dice) << " \t "
419+ << (1.0 - estimated_identity) << " \t "
420+ << euclidian_distance << " \t "
421+ << manhattan_distance << " \n " ;
365422 } else {
366- std::cout << jaccard << " \t "
367- << cosine << " \t "
368- << dice << " \t "
369- << estimated_identity << std::endl ;
423+ output_buffer << jaccard << " \t "
424+ << cosine << " \t "
425+ << dice << " \t "
426+ << estimated_identity << " \n " ;
370427 }
428+
429+ // Flush buffer every chunk_size lines
430+ if (++lines_written % buffer_chunk_size == 0 ) {
431+ std::cout << output_buffer.str ();
432+ output_buffer.str (" " );
433+ output_buffer.clear ();
434+ }
435+ }
436+
437+ // Write remaining buffer
438+ if (!output_buffer.str ().empty ()) {
439+ std::cout << output_buffer.str ();
371440 }
372441
373442 return 0 ;
0 commit comments