Skip to content

Commit 790d7d7

Browse files
committed
Simplify if statement
1 parent a7b18cb commit 790d7d7

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

benchmarks/python/benchmark_inference.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,18 @@ def __init__(self, config: InferenceBenchmarkConfig):
285285
"*.layers.*.feed_forward.down_proj": RowwiseParallel(use_local_output=True),
286286
}
287287

288-
if not self.config.disable_moe_replacement:
288+
if self.config.disable_moe_replacement:
289+
tp_plan.update(
290+
{
291+
# HF MoE
292+
"*.layers.*.feed_forward.shared_expert.gate_proj": ColwiseParallel(use_local_output=False),
293+
"*.layers.*.feed_forward.shared_expert.up_proj": ColwiseParallel(use_local_output=False),
294+
"*.layers.*.feed_forward.shared_expert.down_proj": RowwiseParallel(use_local_output=True),
295+
# TODO:Need to write ParallelStyle for HF's grouped_mm implementation.
296+
}
297+
)
298+
299+
else:
289300
tp_plan.update(
290301
{
291302
# Custom MoE
@@ -306,17 +317,6 @@ def __init__(self, config: InferenceBenchmarkConfig):
306317
}
307318
)
308319

309-
else:
310-
tp_plan.update(
311-
{
312-
# HF MoE
313-
"*.layers.*.feed_forward.shared_expert.gate_proj": ColwiseParallel(use_local_output=False),
314-
"*.layers.*.feed_forward.shared_expert.up_proj": ColwiseParallel(use_local_output=False),
315-
"*.layers.*.feed_forward.shared_expert.down_proj": RowwiseParallel(use_local_output=True),
316-
# TODO:Need to write ParallelStyle for HF's grouped_mm implementation.
317-
}
318-
)
319-
320320
if mesh:
321321
model = parallelize_module(model, mesh, tp_plan)
322322

0 commit comments

Comments
 (0)