Skip to content

Commit 86f0641

Browse files
committed
[CausalLM] Avoid Race condition on evict experts
Avoid Race Condition on eviction experts **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghak PARK <[email protected]>
1 parent 6bdbf2f commit 86f0641

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

Applications/CausalLM/layers/ernie_moe_layer.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,10 +458,14 @@ void ErnieMoELayer::incremental_forwarding(nntrainer::RunLayerContext &context,
458458

459459
// Evict experts
460460
#pragma omp parallel
461-
while (loaded_expert_deque.size() > 16) {
461+
while (true) {
462462
int target_idx;
463463
{
464464
std::lock_guard<std::mutex> lock(cache_mutex);
465+
466+
if (loaded_expert_deque.size() > 16)
467+
break;
468+
465469
target_idx = loaded_expert_deque.front();
466470
loaded_expert_deque.pop_front();
467471
iteration_map.erase(target_idx);

Applications/CausalLM/layers/gpt_oss_moe_layer_cached.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,13 @@ void CachedSlimGptOssMoELayer::incremental_forwarding(
378378

379379
// Evict experts
380380
#pragma omp parallel
381-
while (loaded_expert_deque.size() > 16) {
381+
while (true) {
382382
int target_idx;
383383
{
384384
std::lock_guard<std::mutex> lock(cache_mutex);
385+
if (loaded_expert_deque.size() > 16)
386+
break;
387+
385388
target_idx = loaded_expert_deque.front();
386389
loaded_expert_deque.pop_front();
387390
iteration_map.erase(target_idx);

Applications/CausalLM/layers/qwen_moe_layer_cached.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,10 +438,11 @@ void CachedSlimMoELayer::incremental_forwarding(
438438

439439
// Evict experts
440440
#pragma omp parallel
441-
while (loaded_expert_deque.size() > 32) {
441+
while (true) {
442442
int target_idx;
443443
{
444444
std::lock_guard<std::mutex> lock(cache_mutex);
445+
if (loaded_expert_deque.size() > 16) break;
445446
target_idx = loaded_expert_deque.front();
446447
loaded_expert_deque.pop_front();
447448
iteration_map.erase(target_idx);

0 commit comments

Comments
 (0)