Skip to content

Commit 175af22

Browse files
Refactor thread-local storage management to optimize memory usage and improve performance in similarity computation
1 parent ab378b0 commit 175af22

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/subcommand/similarity_main.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -283,16 +283,27 @@ int main_similarity(int argc, char** argv) {
283283
}
284284

285285
// Use limited thread-local storage to balance speed and memory usage
286-
const uint64_t max_local_maps = std::min(2UL, num_threads);
287-
std::vector<ska::flat_hash_map<uint64_t, uint64_t>> thread_local_maps(max_local_maps);
286+
// Use path_intersection_length as the first map to save one copy
287+
const uint64_t max_local_maps = std::min(3UL, num_threads);
288+
std::vector<ska::flat_hash_map<uint64_t, uint64_t>> thread_local_maps(max_local_maps - 1);
288289
std::vector<std::mutex> map_mutexes(max_local_maps);
289290

290291
#pragma omp parallel
291292
{
292293
int thread_id = omp_get_thread_num();
293294
int map_id = thread_id % max_local_maps;
294-
auto& local_map = thread_local_maps[map_id];
295-
auto& map_mutex = map_mutexes[map_id];
295+
296+
// Use path_intersection_length as first map, thread_local_maps for others
297+
ska::flat_hash_map<uint64_t, uint64_t>* local_map;
298+
std::mutex* map_mutex;
299+
300+
if (map_id == 0) {
301+
local_map = &path_intersection_length;
302+
map_mutex = &map_mutexes[0];
303+
} else {
304+
local_map = &thread_local_maps[map_id - 1];
305+
map_mutex = &map_mutexes[map_id];
306+
}
296307

297308
#pragma omp for
298309
for (uint64_t node_id = 1; node_id <= graph.get_node_count(); ++node_id) {
@@ -311,10 +322,10 @@ int main_similarity(int argc, char** argv) {
311322

312323
// Update shared local map with mutex protection
313324
{
314-
std::lock_guard<std::mutex> lock(map_mutex);
325+
std::lock_guard<std::mutex> lock(*map_mutex);
315326
for (auto& p : local_path_lengths) {
316327
for (auto& q : local_path_lengths) {
317-
local_map[encode_pair(p.first, q.first)] += std::min(p.second, q.second);
328+
(*local_map)[encode_pair(p.first, q.first)] += std::min(p.second, q.second);
318329
}
319330
}
320331
}
@@ -326,6 +337,7 @@ int main_similarity(int argc, char** argv) {
326337
}
327338

328339
// Merge thread-local maps into the main map
340+
// Skip index 0 since path_intersection_length is already used directly
329341
for (const auto& local_map : thread_local_maps) {
330342
for (const auto& pair : local_map) {
331343
path_intersection_length[pair.first] += pair.second;

0 commit comments

Comments
 (0)