File tree Expand file tree Collapse file tree 1 file changed +12
-12
lines changed
Expand file tree Collapse file tree 1 file changed +12
-12
lines changed Original file line number Diff line number Diff line change @@ -285,7 +285,18 @@ def __init__(self, config: InferenceBenchmarkConfig):
285285 "*.layers.*.feed_forward.down_proj" : RowwiseParallel (use_local_output = True ),
286286 }
287287
288- if not self .config .disable_moe_replacement :
288+ if self .config .disable_moe_replacement :
289+ tp_plan .update (
290+ {
291+ # HF MoE
292+ "*.layers.*.feed_forward.shared_expert.gate_proj" : ColwiseParallel (use_local_output = False ),
293+ "*.layers.*.feed_forward.shared_expert.up_proj" : ColwiseParallel (use_local_output = False ),
294+ "*.layers.*.feed_forward.shared_expert.down_proj" : RowwiseParallel (use_local_output = True ),
295+ # TODO:Need to write ParallelStyle for HF's grouped_mm implementation.
296+ }
297+ )
298+
299+ else :
289300 tp_plan .update (
290301 {
291302 # Custom MoE
@@ -306,17 +317,6 @@ def __init__(self, config: InferenceBenchmarkConfig):
306317 }
307318 )
308319
309- else :
310- tp_plan .update (
311- {
312- # HF MoE
313- "*.layers.*.feed_forward.shared_expert.gate_proj" : ColwiseParallel (use_local_output = False ),
314- "*.layers.*.feed_forward.shared_expert.up_proj" : ColwiseParallel (use_local_output = False ),
315- "*.layers.*.feed_forward.shared_expert.down_proj" : RowwiseParallel (use_local_output = True ),
316- # TODO:Need to write ParallelStyle for HF's grouped_mm implementation.
317- }
318- )
319-
320320 if mesh :
321321 model = parallelize_module (model , mesh , tp_plan )
322322
You can’t perform that action at this time.
0 commit comments