@@ -1023,6 +1023,8 @@ class Config(BaseQKVLinear.Config):
10231023 num_kv_heads : Required [int ] = REQUIRED
10241024 # The layer used to project.
10251025 layer : MultiheadInputLinear .Config = MultiheadInputLinear .default_config ()
1026+ # Optional partition spec for query, key, value output activations.
1027+ output_partition_spec : Optional [PartitionSpec ] = None
10261028
10271029 def __init__ (self , cfg : Config , * , parent : Module ):
10281030 super ().__init__ (cfg , parent = parent )
@@ -1073,17 +1075,18 @@ def forward(
10731075 q_proj , k_proj , v_proj = jnp .split (
10741076 proj , [cfg .num_heads , cfg .num_heads + cfg .num_kv_heads ], axis = - 2
10751077 )
1076- # This sharding hint is needed since compiler sometimes will generate large allgather
1077- # before the split and then slice, which is not the ideal compilation. Ensure sharding
1078- # after the split to ensure allgather is inserted after the split.
1079- axis_names = thread_resources .env .physical_mesh .axis_names
1080- batch_axes = tuple (x for x in axis_names if x in ("data" , "fsdp" )) or None
1081- spec = PartitionSpec (
1082- batch_axes ,
1083- "seq" if "seq" in axis_names else None ,
1084- "model" if "model" in axis_names else None ,
1085- None ,
1086- )
1078+ if (spec := cfg .output_partition_spec ) is None :
1079+ # This sharding hint is needed since compiler sometimes will generate large allgather
1080+ # before the split and then slice, which is not the ideal compilation. Ensure sharding
1081+ # after the split to ensure allgather is inserted after the split.
1082+ axis_names = thread_resources .env .physical_mesh .axis_names
1083+ batch_axes = tuple (x for x in axis_names if x in ("data" , "fsdp" )) or None
1084+ spec = PartitionSpec (
1085+ batch_axes ,
1086+ "seq" if "seq" in axis_names else None ,
1087+ "model" if "model" in axis_names else None ,
1088+ None ,
1089+ )
10871090 q_proj = with_sharding_constraint (q_proj , spec )
10881091 k_proj = with_sharding_constraint (k_proj , spec )
10891092 v_proj = with_sharding_constraint (v_proj , spec )
0 commit comments