Skip to content

Commit 6ddc039

Browse files
committed
[Flux] Fix parallel dimension names (#1199)
## Context - Currently we did not enable CP for flux model training, so we need to remove the - For FSDP/ HSDP, the correct dimension name should be ("dp_replicate", "dp_shard") or ("dp_shard",)
1 parent a831e63 commit 6ddc039

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

torchtitan/experiments/flux/parallelize_flux.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,11 @@ def parallelize_flux(
2828
if job_config.activation_checkpoint.mode != "none":
2929
apply_ac(model, job_config.activation_checkpoint)
3030

31-
if (
32-
parallel_dims.dp_shard_enabled or parallel_dims.dp_replicate_enabled
33-
): # apply FSDP or HSDP
31+
if parallel_dims.dp_shard_enabled: # apply FSDP or HSDP
3432
if parallel_dims.dp_replicate_enabled:
35-
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
33+
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
3634
else:
37-
dp_mesh_dim_names = ("dp_shard_cp",)
35+
dp_mesh_dim_names = ("dp_shard",)
3836

3937
apply_fsdp(
4038
model,
@@ -122,13 +120,11 @@ def parallelize_encoders(
122120
parallel_dims: ParallelDims,
123121
job_config: JobConfig,
124122
):
125-
if (
126-
parallel_dims.dp_shard_enabled or parallel_dims.dp_replicate_enabled
127-
): # apply FSDP or HSDP
123+
if parallel_dims.dp_shard_enabled: # apply FSDP or HSDP
128124
if parallel_dims.dp_replicate_enabled:
129-
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
125+
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
130126
else:
131-
dp_mesh_dim_names = ("dp_shard_cp",)
127+
dp_mesh_dim_names = ("dp_shard",)
132128

133129
mp_policy = MixedPrecisionPolicy(
134130
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],

0 commit comments

Comments
 (0)