Skip to content

Commit

Permalink
Input batch sharding strategy BATCH
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 11, 2024
1 parent c20387c commit 03052e4
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 31 deletions.
19 changes: 15 additions & 4 deletions axlearn/common/evaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from axlearn.common.module import Module, OutputCollection
from axlearn.common.module import functional as F
from axlearn.common.utils import (
DataPartitionType,
NestedPartitionSpec,
NestedTensor,
Tensor,
Expand Down Expand Up @@ -81,6 +82,11 @@ class Config(Module.Config):
# evalers, not setting prefix will show the accuracies on the same plot for comparison
# across evalers.
prefix: Optional[str] = None
# Subset of mesh axis names over which the leaves of the input batch are sharded.
batch_axis_names: Union[str, Sequence[str]] = "data"
# The input partition:
# Options: FULL (default), BATCH, REPLICATED
input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL

def __init__(
self,
Expand Down Expand Up @@ -188,11 +194,11 @@ def _pjit(self, fn: Callable) -> Callable:
in_shardings=(
self._model_param_partition_specs, # model_params.
None, # replicated_inputs (e.g., prng_key).
utils.input_partition_spec(), # per_example_inputs.
utils.data_partition_type_to_spec(partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names), # per_example_inputs.
),
out_shardings=dict(
replicated=None,
per_example=utils.input_partition_spec(),
per_example=utils.data_partition_type_to_spec( partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names),
),
)

Expand Down Expand Up @@ -574,6 +580,11 @@ class Config(Module.Config):
metric_calculator: BaseMetricCalculator.Config = ModelSummaryAccumulator.default_config()
# If not None, writes input batches and `metric_calculator` forward outputs.
output_writer: Optional[BaseOutputWriter.Config] = None
# Subset of mesh axis names over which the leaves of the input batch are sharded.
batch_axis_names: Union[str, Sequence[str]] = "data"
# The input partition:
# Options: FULL (default), BATCH, REPLICATED
input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL

def __init__(
self,
Expand All @@ -595,7 +606,7 @@ def __init__(
self._add_child("input", maybe_set_config(cfg.input, is_training=False))
self._add_child(
"metric_calculator",
cfg.metric_calculator.set(eval_dtype=cfg.eval_dtype),
cfg.metric_calculator.set(eval_dtype=cfg.eval_dtype, batch_axis_names=cfg.batch_axis_names, input_partition_type=cfg.input_partition_type),
model=model,
model_param_partition_specs=model_param_partition_specs,
)
Expand Down Expand Up @@ -691,7 +702,7 @@ def eval_step(

with jax.profiler.StepTraceAnnotation(cfg.name, step_num=step):
with jax.profiler.TraceAnnotation(f"{cfg.name}.forward"):
global_input_batch = utils.host_to_global_device_array(input_batch)
global_input_batch = utils.host_to_global_device_array(input_batch, partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names)
forward_outputs = self.metric_calculator.forward(
global_input_batch,
model_params=model_params,
Expand Down
9 changes: 6 additions & 3 deletions axlearn/common/gda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
class GDATest(TestCase):
@parameterized.parameters(
itertools.product(
((1, 1), (8, 1), (4, 2)), # mesh_shape
((1, 1), (8, 1), (4, 2), (16, 4)), # mesh_shape
(1, 16), # per_host_batch_size
(DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition
(DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType.BATCH), # data_partition
)
)
def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition):
Expand All @@ -41,13 +41,16 @@ def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition
if not is_supported_mesh_shape(mesh_shape):
return
devices = mesh_utils.create_device_mesh(mesh_shape)
if data_partition == DataPartitionType.FULL:
if data_partition == DataPartitionType.FULL or data_partition == DataPartitionType.BATCH:
global_batch_size = per_host_batch_size * jax.process_count()
else:
assert data_partition == DataPartitionType.REPLICATED
global_batch_size = per_host_batch_size
if data_partition == DataPartitionType.FULL and global_batch_size < jax.device_count():
return
# first axis is assumed to be batch axis
if data_partition == DataPartitionType.BATCH and global_batch_size % mesh_shape[0] == 0:
return
per_host_input_batch = dict(x=jnp.zeros((per_host_batch_size, 8), dtype=jnp.float32))
with jax.sharding.Mesh(devices, ("data", "model")):
global_input_batch = host_to_global_device_array(
Expand Down
6 changes: 2 additions & 4 deletions axlearn/common/host_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
host_to_global_device_array,
)


def is_supported(
platform: str,
mesh_shape: tuple[int, int],
Expand All @@ -37,16 +36,15 @@ def is_supported(
)
)


class HostArrayTest(TestCase):
@parameterized.parameters(
filter(
lambda params: is_supported(*params),
itertools.product(
("cpu", "tpu"), # platform,
((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape
((1, 1), (4, 1), (2, 2), (8, 1), (4, 2), (16, 4)), # mesh_shape
(1, 16), # global_batch_size
(DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition
(DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType,BATCH), # data_partition
),
)
)
Expand Down
9 changes: 7 additions & 2 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from axlearn.common.summary_writer import BaseWriter, SummaryWriter
from axlearn.common.update_transformation import ForwardOutputs
from axlearn.common.utils import (
DataPartitionType,
HybridMeshShape,
MeshShape,
Nested,
Expand Down Expand Up @@ -200,6 +201,10 @@ class Config(Module.Config):
# The provided config should instantiate to a thunk that returns the context manager.
context_manager: Optional[ConfigOr[Callable[[], ContextManager]]] = None

# The input partition:
# Options: FULL (default), BATCH, REPLICATED
input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL

def __init__(
self,
cfg: Config,
Expand Down Expand Up @@ -343,7 +348,7 @@ def trainer_state_partition_specs(self):

def _train_step_input_partition_specs(self):
# By default, each input tensor is fully partitioned along the batch axis.
return utils.input_partition_spec()
return utils.data_partition_type_to_spec(self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names)

def model_params_for_eval(self):
state = self.trainer_state
Expand Down Expand Up @@ -568,7 +573,7 @@ def run(
self._step = self._step + 1
self.vlog(3, "Start step %s", self.step)
output = self._run_step(
utils.host_to_global_device_array(input_batch),
utils.host_to_global_device_array(input_batch, partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names),
force_run_evals=(
force_run_eval_sets_at_max_step if self.step >= cfg.max_step else None
),
Expand Down
12 changes: 9 additions & 3 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,14 +591,17 @@ class DataPartitionType(Enum):
FULL = "full"
# Data are fully replicated across all devices.
REPLICATED = "replicated"
# Data are partitioned across batch axis only.
BATCH = "batch"


def data_partition_type_to_spec(partition: DataPartitionType) -> PartitionSpec:
def data_partition_type_to_spec(partition: DataPartitionType, * , batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp")) -> PartitionSpec:
"""Returns a PartitionSpec for the given partition type."""
if partition == DataPartitionType.FULL:
return input_partition_spec()
elif partition == DataPartitionType.REPLICATED:
return None
elif partition == DataPartitionType.BATCH:
return PartitionSpec(batch_axis_names)
else:
raise NotImplementedError(f"Unsupported partition: {partition}")

Expand All @@ -607,6 +610,7 @@ def host_to_global_device_array(
host_arrays: Nested[Union[np.ndarray, Tensor]],
*,
partition: DataPartitionType = DataPartitionType.FULL,
batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp"),
) -> NestedTensor:
"""Converts the given host device arrays to global device arrays.
Expand All @@ -625,7 +629,7 @@ def host_to_global_device_array(
NotImplementedError: if the given `partition` type is not supported.
"""
mesh = thread_resources.env.physical_mesh
partition_spec = data_partition_type_to_spec(partition)
partition_spec = data_partition_type_to_spec(partition, batch_axis_names=batch_axis_names)
partition_specs = complete_partition_spec_tree(
jax.tree_util.tree_structure(host_arrays), partition_spec
)
Expand All @@ -636,6 +640,8 @@ def make_gda(x, partition_spec):
global_shape = (x.shape[0] * process_count, *x.shape[1:])
elif partition == DataPartitionType.REPLICATED:
global_shape = (x.shape[0], *x.shape[1:])
elif partition == DataPartitionType.BATCH:
global_shape = (x.shape[0] * process_count, *x.shape[1:])
else:
raise NotImplementedError(f"Unsupported partition: {partition}")
return jax.make_array_from_process_local_data(
Expand Down
26 changes: 26 additions & 0 deletions axlearn/common/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.utils import (
DataPartitionType,
PHYSICAL_TO_LOGICAL_DISPATCH_KEY,
HybridMeshShape,
MeshShape,
Expand Down Expand Up @@ -1701,6 +1702,31 @@ def test_length(self):
class HostToGlobalArrayTest(TestCase):
"""Tests host_to_global_device_array."""

@pytest.mark.neuron
def test_partition_batch(self):
"""Test a case where each process produces a slice."""
device_count = jax.device_count()
process_count = jax.process_count()
print(f"{device_count=}, {process_count=}")
assert device_count > 1

global_shape = (device_count // 2, 1)
assert global_shape[0] % process_count == 0
per_feed_size = global_shape[0] // process_count
feed_index = jax.process_index()

with jax.sharding.Mesh(np.array(jax.devices()).reshape(device_count // 2, 2), ("x", "y")):
start = feed_index * per_feed_size
local_x = jnp.arange(start, start + per_feed_size)[:, None]

# Construct global array.
global_x = host_to_global_device_array(local_x, partition=DataPartitionType.BATCH, batch_axis_names="x")

# Compare against expected.
expected = jnp.arange(global_shape[0])[:, None]
self.assertEqual(jnp.mean(expected), jnp.mean(global_x))
self.assertNestedEqual(expected, replicate_to_local_data(global_x))

@pytest.mark.tpu
def test_partition_full(self):
"""Test a case where each process produces a slice."""
Expand Down
33 changes: 19 additions & 14 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from axlearn.common.param_init import PARAM_REGEXP_WEIGHT, DefaultInitializer, WeightInitializer
from axlearn.common.summary_writer import BaseWriter
from axlearn.common.trainer import MeshShape, SpmdTrainer
from axlearn.common.utils import HybridMeshShape, Nested, get_data_dir
from axlearn.common.utils import DataPartitionType, HybridMeshShape, Nested, get_data_dir
from axlearn.experiments.text.common import DataMixtureComponent, tfds_text_source
from axlearn.experiments.trainer_config_utils import TrainerConfigFn

Expand Down Expand Up @@ -640,6 +640,7 @@ def get_trainer_config_fn(
mesh_shape: Union[MeshShape, HybridMeshShape],
mesh_axis_names: Sequence[str] = MESH_AXIS_NAMES,
mesh_rules: Optional[Sequence[tuple[str, Optional[Union[MeshShape, HybridMeshShape]]]]] = None,
input_partition_type: Optional[DataPartitionType] = None,
eval_every_n_steps: int = 5000,
eval_batch_size: Optional[int] = None,
keep_every_n_steps: int = 50_000,
Expand Down Expand Up @@ -689,9 +690,26 @@ def config_fn() -> InstantiableConfig:
pad_example_fn=input_tf_data.default_pad_example_fn,
),
)
if input_partition_type:
cfg.input_partition_type = input_partition_type
if len(mesh_axis_names) != len(mesh_shape):
raise ValueError(
f"Number of mesh axis names ({mesh_axis_names}) "
f"must match number of mesh dims ({mesh_shape})."
)
cfg.mesh_axis_names = mesh_axis_names
cfg.mesh_shape = mesh_shape
# Set batch sharding spec to exclude the "model" axis (assumed for tensor-parallelism) and
# "pipeline" axis (for pipeline parallelism).
cfg.batch_axis_names = tuple(
el for el in mesh_axis_names if el not in ("model", "pipeline")
)
cfg.mesh_rules = mesh_rules
cfg.evalers = {}
for name, evaler_cfg in evalers.items():
evaler_cfg.input.batcher.set(global_batch_size=eval_batch_size or train_batch_size)
evaler_cfg.set(input_partition_type=input_partition_type)
evaler_cfg.set(batch_axis_names=cfg.batch_axis_names)
evaler_cfg.set(
eval_policy=config_for_function(eval_every_n_steps_policy).set(
n=eval_every_n_steps,
Expand All @@ -708,19 +726,6 @@ def config_fn() -> InstantiableConfig:
cfg.checkpointer.keep_last_n = 3
cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100)
cfg.summary_writer.max_queue = 1000
if len(mesh_axis_names) != len(mesh_shape):
raise ValueError(
f"Number of mesh axis names ({mesh_axis_names}) "
f"must match number of mesh dims ({mesh_shape})."
)
cfg.mesh_axis_names = mesh_axis_names
cfg.mesh_shape = mesh_shape
# Set batch sharding spec to exclude the "model" axis (assumed for tensor-parallelism) and
# "pipeline" axis (for pipeline parallelism).
cfg.batch_axis_names = tuple(
el for el in mesh_axis_names if el not in ("model", "pipeline")
)
cfg.mesh_rules = mesh_rules
# Maybe load state.
if init_state_builder:
cfg.init_state_builder = init_state_builder
Expand Down
3 changes: 2 additions & 1 deletion axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
MeshShapeModifier,
RematSpecModifier,
)
from axlearn.common.utils import extended_checkpoint_policies
from axlearn.common.utils import DataPartitionType, extended_checkpoint_policies
from axlearn.experiments.text.gpt.common import (
STEP_DTYPE,
SourceBuilder,
Expand Down Expand Up @@ -423,6 +423,7 @@ def get_trainer_kwargs(
raise NotImplementedError(f"Unknown model size {model_size}.")
model_kwargs = trainer_kwargs.pop("model_kwargs")
model_kwargs.setdefault("vocab_size", vocab_size)
trainer_kwargs["input_partition_type"] = None if backend != "neuron" else DataPartitionType.BATCH
trainer_kwargs["model_cfg"] = model_config(**model_kwargs)
trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config(
max_step=trainer_kwargs["max_step"],
Expand Down

0 comments on commit 03052e4

Please sign in to comment.