Skip to content

Commit 497d046

Browse files
Merge pull request #614 from pangenome/odgi_similarity_parallelized
`odgi similarity`: a bit more parallelization for deep graphs
2 parents 892c702 + 175af22 commit 497d046

File tree

1 file changed

+103
-34
lines changed

1 file changed

+103
-34
lines changed

src/subcommand/similarity_main.cpp

Lines changed: 103 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include "split.hpp"
55
#include <omp.h>
66
#include "utils.hpp"
7+
#include <sstream>
8+
#include <iomanip>
79

810
namespace odgi {
911

@@ -225,16 +227,13 @@ int main_similarity(int argc, char** argv) {
225227
}
226228
path_length += graph.get_length(h);
227229
});
230+
231+
uint32_t path_id = get_path_id(p);
228232
#pragma omp critical (bp_count)
229-
bp_count[get_path_id(p)] += path_length;
233+
bp_count[path_id] += path_length;
230234
}
231235

232236
const bool show_progress = args::get(progress);
233-
std::unique_ptr<algorithms::progress_meter::ProgressMeter> progress_meter;
234-
if (show_progress) {
235-
progress_meter = std::make_unique<algorithms::progress_meter::ProgressMeter>(
236-
graph.get_node_count(), "[odgi::similarity] collecting path intersection lengths");
237-
}
238237

239238
ska::flat_hash_map<uint64_t, uint64_t> path_intersection_length;
240239

@@ -245,9 +244,11 @@ int main_similarity(int argc, char** argv) {
245244
}
246245
if (using_delim) {
247246
const uint32_t num_groups = path_groups.size();
247+
#pragma omp parallel for collapse(2)
248248
for (uint32_t i = 0; i < num_groups; ++i) {
249249
for (uint32_t j = 0; j < num_groups; ++j) {
250250
// Initialize with 0 intersection. Will be updated later if intersection > 0.
251+
#pragma omp critical
251252
path_intersection_length[encode_pair(i, j)] = 0;
252253
}
253254
}
@@ -259,11 +260,14 @@ int main_similarity(int argc, char** argv) {
259260
actual_path_ids.push_back((uint32_t)as_integer(p));
260261
});
261262

262-
// Iterate through all actual path integer IDs collected earlier
263-
for (const uint32_t id_i : actual_path_ids) {
264-
for (const uint32_t id_j : actual_path_ids) {
265-
// Initialize with 0 intersection.
266-
path_intersection_length[encode_pair(id_i, id_j)] = 0;
263+
// Parallelize the nested loop for path pairs
264+
const size_t num_paths = actual_path_ids.size();
265+
#pragma omp parallel for collapse(2) // / Both loops are flattened into one iteration space, so all i,j combinations distributed across threads
266+
for (size_t i = 0; i < num_paths; ++i) {
267+
for (size_t j = 0; j < num_paths; ++j) {
268+
// Initialize with 0 intersection.
269+
#pragma omp critical
270+
path_intersection_length[encode_pair(actual_path_ids[i], actual_path_ids[j])] = 0;
267271
}
268272
}
269273
}
@@ -272,11 +276,41 @@ int main_similarity(int argc, char** argv) {
272276
}
273277
}
274278

275-
graph.for_each_handle(
276-
[&](const handle_t& h) {
279+
std::unique_ptr<algorithms::progress_meter::ProgressMeter> progress_meter;
280+
if (show_progress) {
281+
progress_meter = std::make_unique<algorithms::progress_meter::ProgressMeter>(
282+
graph.get_node_count(), "[odgi::similarity] collecting path intersection lengths");
283+
}
284+
285+
// Use limited thread-local storage to balance speed and memory usage
286+
// Use path_intersection_length as the first map to save one copy
287+
const uint64_t max_local_maps = std::min(3UL, num_threads);
288+
std::vector<ska::flat_hash_map<uint64_t, uint64_t>> thread_local_maps(max_local_maps - 1);
289+
std::vector<std::mutex> map_mutexes(max_local_maps);
290+
291+
#pragma omp parallel
292+
{
293+
int thread_id = omp_get_thread_num();
294+
int map_id = thread_id % max_local_maps;
295+
296+
// Use path_intersection_length as first map, thread_local_maps for others
297+
ska::flat_hash_map<uint64_t, uint64_t>* local_map;
298+
std::mutex* map_mutex;
299+
300+
if (map_id == 0) {
301+
local_map = &path_intersection_length;
302+
map_mutex = &map_mutexes[0];
303+
} else {
304+
local_map = &thread_local_maps[map_id - 1];
305+
map_mutex = &map_mutexes[map_id];
306+
}
307+
308+
#pragma omp for
309+
for (uint64_t node_id = 1; node_id <= graph.get_node_count(); ++node_id) {
310+
handle_t h = graph.get_handle(node_id);
277311
// Skip masked-out nodes
278312
if (!node_mask[graph.get_id(h) - 1]) {
279-
return;
313+
continue;
280314
}
281315
ska::flat_hash_map<uint32_t, uint64_t> local_path_lengths;
282316
size_t l = graph.get_length(h);
@@ -286,22 +320,38 @@ int main_similarity(int argc, char** argv) {
286320
local_path_lengths[get_path_id(graph.get_path_handle_of_step(s))] += l;
287321
});
288322

289-
#pragma omp critical (path_intersection_length)
290-
for (auto& p : local_path_lengths) {
291-
for (auto& q : local_path_lengths) {
292-
path_intersection_length[encode_pair(p.first, q.first)] += std::min(p.second, q.second);
323+
// Update shared local map with mutex protection
324+
{
325+
std::lock_guard<std::mutex> lock(*map_mutex);
326+
for (auto& p : local_path_lengths) {
327+
for (auto& q : local_path_lengths) {
328+
(*local_map)[encode_pair(p.first, q.first)] += std::min(p.second, q.second);
329+
}
293330
}
294331
}
295332

296333
if (show_progress) {
297334
progress_meter->increment(1);
298335
}
299-
}, true);
336+
}
337+
}
338+
339+
// Merge thread-local maps into the main map
340+
// Skip index 0 since path_intersection_length is already used directly
341+
for (const auto& local_map : thread_local_maps) {
342+
for (const auto& pair : local_map) {
343+
path_intersection_length[pair.first] += pair.second;
344+
}
345+
}
300346

301347
if (show_progress) {
302348
progress_meter->finish();
303349
}
304350

351+
if (show_progress) {
352+
std::cerr << "[odgi::similarity] Writing the output..." << std::endl;
353+
}
354+
305355
/*if (using_delim) {
306356
std::cout << "group.a" << "\t"
307357
<< "group.b" << "\t"
@@ -335,6 +385,13 @@ int main_similarity(int argc, char** argv) {
335385
}
336386

337387
std::cout << std::endl;
388+
389+
// Use chunked buffering to balance speed and memory usage
390+
std::ostringstream output_buffer;
391+
output_buffer << std::fixed << std::setprecision(6);
392+
const size_t buffer_chunk_size = 100000; // Lines per chunk
393+
size_t lines_written = 0;
394+
338395
for (auto& p : path_intersection_length) {
339396
uint32_t id_a, id_b;
340397
decode_pair(p.first, &id_a, &id_b);
@@ -347,27 +404,39 @@ int main_similarity(int argc, char** argv) {
347404
const double dice = 2.0 * ((double) intersection / (double)(bp_count[id_a] + bp_count[id_b]));
348405
const double estimated_identity = 2.0 * jaccard / (1.0 + jaccard);
349406

350-
std::cout << get_path_name(id_a) << "\t"
351-
<< get_path_name(id_b) << "\t"
352-
<< bp_count[id_a] << "\t"
353-
<< bp_count[id_b] << "\t"
354-
<< intersection << "\t";
407+
output_buffer << get_path_name(id_a) << "\t"
408+
<< get_path_name(id_b) << "\t"
409+
<< bp_count[id_a] << "\t"
410+
<< bp_count[id_b] << "\t"
411+
<< intersection << "\t";
355412

356413
if (emit_distances) {
357414
const double euclidian_distance = std::sqrt((double)((bp_count[id_a] + bp_count[id_b] - intersection) - intersection));
358415
const uint64_t manhattan_distance = (bp_count[id_a] + bp_count[id_b] - intersection) - intersection;
359-
std::cout << (1.0 - jaccard) << "\t"
360-
<< (1.0 - cosine) << "\t"
361-
<< (1.0 - dice) << "\t"
362-
<< (1.0 - estimated_identity) << "\t"
363-
<< euclidian_distance << "\t"
364-
<< manhattan_distance << std::endl;
416+
output_buffer << (1.0 - jaccard) << "\t"
417+
<< (1.0 - cosine) << "\t"
418+
<< (1.0 - dice) << "\t"
419+
<< (1.0 - estimated_identity) << "\t"
420+
<< euclidian_distance << "\t"
421+
<< manhattan_distance << "\n";
365422
} else {
366-
std::cout << jaccard << "\t"
367-
<< cosine << "\t"
368-
<< dice << "\t"
369-
<< estimated_identity << std::endl;
423+
output_buffer << jaccard << "\t"
424+
<< cosine << "\t"
425+
<< dice << "\t"
426+
<< estimated_identity << "\n";
370427
}
428+
429+
// Flush buffer every chunk_size lines
430+
if (++lines_written % buffer_chunk_size == 0) {
431+
std::cout << output_buffer.str();
432+
output_buffer.str("");
433+
output_buffer.clear();
434+
}
435+
}
436+
437+
// Write remaining buffer
438+
if (!output_buffer.str().empty()) {
439+
std::cout << output_buffer.str();
371440
}
372441

373442
return 0;

0 commit comments

Comments
 (0)