Skip to content

Commit bea999b

Browse files
rohitc33changlan
authored andcommitted
Support shard_map for data parallel support in Softserve
GitOrigin-RevId: fbf0e9509aec5db2d006a24c8cf4cf9548e4cf36
1 parent 8c52c7f commit bea999b

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

axlearn/common/attention.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,8 @@ class Config(BaseQKVLinear.Config):
10231023
num_kv_heads: Required[int] = REQUIRED
10241024
# The layer used to project.
10251025
layer: MultiheadInputLinear.Config = MultiheadInputLinear.default_config()
1026+
# Optional partition spec for query, key, value output activations.
1027+
output_partition_spec: Optional[PartitionSpec] = None
10261028

10271029
def __init__(self, cfg: Config, *, parent: Module):
10281030
super().__init__(cfg, parent=parent)
@@ -1073,17 +1075,18 @@ def forward(
10731075
q_proj, k_proj, v_proj = jnp.split(
10741076
proj, [cfg.num_heads, cfg.num_heads + cfg.num_kv_heads], axis=-2
10751077
)
1076-
# This sharding hint is needed since compiler sometimes will generate large allgather
1077-
# before the split and then slice, which is not the ideal compilation. Ensure sharding
1078-
# after the split to ensure allgather is inserted after the split.
1079-
axis_names = thread_resources.env.physical_mesh.axis_names
1080-
batch_axes = tuple(x for x in axis_names if x in ("data", "fsdp")) or None
1081-
spec = PartitionSpec(
1082-
batch_axes,
1083-
"seq" if "seq" in axis_names else None,
1084-
"model" if "model" in axis_names else None,
1085-
None,
1086-
)
1078+
if (spec := cfg.output_partition_spec) is None:
1079+
# This sharding hint is needed since compiler sometimes will generate large allgather
1080+
# before the split and then slice, which is not the ideal compilation. Ensure sharding
1081+
# after the split to ensure allgather is inserted after the split.
1082+
axis_names = thread_resources.env.physical_mesh.axis_names
1083+
batch_axes = tuple(x for x in axis_names if x in ("data", "fsdp")) or None
1084+
spec = PartitionSpec(
1085+
batch_axes,
1086+
"seq" if "seq" in axis_names else None,
1087+
"model" if "model" in axis_names else None,
1088+
None,
1089+
)
10871090
q_proj = with_sharding_constraint(q_proj, spec)
10881091
k_proj = with_sharding_constraint(k_proj, spec)
10891092
v_proj = with_sharding_constraint(v_proj, spec)

0 commit comments

Comments
 (0)