Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Piz/hybridmesh #66

Open
wants to merge 1 commit into
base: flash_attention
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""PyTorch LLaMA model."""

import math
import os
import warnings
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -61,6 +62,7 @@

_CONFIG_FOR_DOC = "LlamaConfig"

NUM_SLICE=int(os.getenv('NUM_SLICE', 1))

def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
Expand Down Expand Up @@ -387,7 +389,11 @@ def forward(
# Integrated with PyTorch/XLA Pallas Flash Attention:
from torch_xla.experimental.custom_kernel import flash_attention
query_states /= math.sqrt(self.head_dim)
attn_output = flash_attention(query_states, key_states, value_states, causal=True, partition_spec=('fsdp', 'tensor', None, None))
if NUM_SLICE == 1:
attn_output = flash_attention(query_states, key_states, value_states, causal=True, partition_spec=('fsdp', 'tensor', None, None))
else:
attn_output = flash_attention(query_states, key_states, value_states, causal=True, partition_spec=(('dcn', 'fsdp'), None, None, None))


if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
Expand Down
49 changes: 45 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
set_seed,
speed_metrics,
)
import torch_xla.distributed.parallel_loader as pl
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import (
ADAPTER_CONFIG_NAME,
Expand Down Expand Up @@ -264,6 +265,7 @@ def _get_fsdp_ckpt_kwargs():

logger = logging.get_logger(__name__)

NUM_SLICE=int(os.getenv('NUM_SLICE', 1))

# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
Expand Down Expand Up @@ -381,6 +383,7 @@ def __init__(
args = TrainingArguments(output_dir=output_dir)
self.args = args
# Seed must be set before instantiating the model when using model
set_seed(self.args.seed)
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
self.hp_name = None
self.deepspeed = None
Expand Down Expand Up @@ -679,6 +682,18 @@ def __init__(
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
num_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
if NUM_SLICE==1:
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
# To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
mesh = xs.Mesh(device_ids, mesh_shape, ('fsdp', 'tensor'))
xs.set_global_mesh(mesh)
else:
dcn_axis = NUM_SLICE
ici_mesh_shape = (1, num_devices // dcn_axis, 1)
dcn_mesh_shape = (dcn_axis, 1, 1)
mesh = xs.HybridMesh(ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape, axis_names=('dcn', 'fsdp', 'tensor'))
xs.set_global_mesh(mesh)

def _activate_neftune(self, model):
r"""
Expand Down Expand Up @@ -877,6 +892,24 @@ def get_train_dataloader(self) -> DataLoader:
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor


if is_torch_xla_available():
torch_dataloader = DataLoader(train_dataset, **dataloader_params)
device = xm.xla_device()
if NUM_SLICE==1:
mp_device_loader = pl.MpDeviceLoader(
torch_dataloader,
device,
input_sharding=xs.ShardingSpec(xs.get_global_mesh(), ("fsdp", None)),
)
else:
mp_device_loader = pl.MpDeviceLoader(
torch_dataloader,
device,
input_sharding=xs.ShardingSpec(xs.get_global_mesh(), (("dcn", "fsdp"), None)),
)
return mp_device_loader

return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))

def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
Expand Down Expand Up @@ -1681,7 +1714,6 @@ def _wrap_model(self, model, training=True, dataloader=None):
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
fsdp_kwargs = self.args.xla_fsdp_config
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
if model.config.use_cache:
logger.warning_once(
Expand Down Expand Up @@ -1709,7 +1741,11 @@ def shard_output(output, mesh):

if real_output is None:
raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
xs.mark_sharding(real_output, mesh, ("fsdp", None, None))

if NUM_SLICE==1:
xs.mark_sharding(real_output, mesh, ("fsdp", None, None))
else:
xs.mark_sharding(real_output, mesh, (("dcn", "fsdp"), None, None))

self.model = model = FSDPv2(
model,
Expand All @@ -1718,10 +1754,12 @@ def shard_output(output, mesh):
auto_wrapper_callable=auto_wrapper_callable,
)
else:
fsdp_kwargs = self.args.xla_fsdp_config
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
reshard_after_forward=False,
**fsdp_kwargs,
)

Expand Down Expand Up @@ -1854,6 +1892,7 @@ def train(
# Disable progress bars when uploading models during checkpoints to avoid polluting stdout
hf_hub_utils.disable_progress_bars()
return inner_training_loop(
batch_size=self._train_batch_size,
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
Expand All @@ -1863,6 +1902,7 @@ def train(
hf_hub_utils.enable_progress_bars()
else:
return inner_training_loop(
batch_size=self._train_batch_size,
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
Expand Down Expand Up @@ -1892,8 +1932,8 @@ def _inner_training_loop(
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
if self.is_fsdp_xla_v2_enabled:
train_dataloader = tpu_spmd_dataloader(train_dataloader)
# if self.is_fsdp_xla_v2_enabled:
# train_dataloader = tpu_spmd_dataloader(train_dataloader)

# Setting up training control variables:
# number of training epochs: num_train_epochs
Expand Down Expand Up @@ -4454,3 +4494,4 @@ def _fsdp_qlora_plugin_updates(self):
fsdp_plugin.set_mixed_precision(
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
)

Loading