Skip to content

[cp][flex_attention] integration test trial #1228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: gh/XilunWu/20/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
import os
from collections.abc import Generator, Iterable
from datetime import timedelta
from typing import Optional

import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
from torch import distributed as dist
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental._attention import _FlexAttentionSharder
from torch.nn.attention import SDPBackend
from torch.nn.attention.flex_attention import BlockMask

from torchtitan.models.attention import ScaledDotProductAttention
from torchtitan.tools.logging import logger
Expand Down Expand Up @@ -156,22 +159,35 @@ def create_context_parallel_ctx(
cp_seq_dims: list[int],
cp_no_restore_buffers: set[torch.Tensor],
cp_rotate_method: str,
sharder: Optional[_FlexAttentionSharder] = None,
):
try:
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
from torch.distributed.tensor.experimental._attention import (
_DispatchMode,
_set_dispatch_mode,
set_rotate_method,
)
except ImportError:
print(
f"PyTorch version {torch.__version__} does not include the experimental "
"Context Parallel API. Please update to a newer version."
)

set_rotate_method(cp_rotate_method)
"""
_set_dispatch_mode("torch_dispatch")
assert (
torch.distributed.tensor.experimental._attention._dispatch_mode
== _DispatchMode.TORCH_DISPATCH
)
"""
return context_parallel(
cp_mesh,
buffers=cp_buffers,
buffer_seq_dims=cp_seq_dims,
no_restore_buffers=cp_no_restore_buffers,
sharder=sharder,
)


Expand All @@ -192,8 +208,9 @@ def context(cp_context: Generator[None, None, None] | None = None):
if cp_context is not None:
if SDPBackend.MATH in ScaledDotProductAttention.backends:
ScaledDotProductAttention.backends.remove(SDPBackend.MATH)
# TODO: add logic for flex-attention
assert (
ScaledDotProductAttention.backends
ScaledDotProductAttention.backends or True
), "No valid SDPA backends with CP."
stack.enter_context(cp_context)

Expand Down
3 changes: 3 additions & 0 deletions torchtitan/experiments/llama4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
rope_theta=500000,
num_experts=16,
interleave_moe_layer_step=1,
use_flex_attn=True,
attn_mask_type="block_causal",
# attn_mask_type="causal",
),
"17bx128e": TransformerModelArgs(
dim=5120,
Expand Down
5 changes: 3 additions & 2 deletions torchtitan/experiments/llama4/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TransformerModelArgs(BaseModelArgs):
interleave_moe_layer_step: int = 2
# token-choice
top_k: int = 1
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
use_grouped_mm: bool = False # grouped mm or for-loop for the experts computation
load_balance_coeff: float | None = 1e-3

def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
Expand All @@ -74,12 +74,13 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
"FlexAttention is not compatible with selective AC yet. "
"See https://github.com/pytorch/pytorch/issues/147879"
)

"""
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
raise ValueError(
"FlexAttention is not compatible with CP yet. "
"We are still working on this."
)
"""

def get_nparams_and_flops(
self, model: nn.Module, seq_len: int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ batch_size = 8
seq_len = 2048
max_norm = 1.0 # grad norm clipping
steps = 10
compile = false
compile = true
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[parallelism]
Expand Down
14 changes: 11 additions & 3 deletions torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = false
enable_tensorboard = true
save_tb_folder = "tb"

[model]
Expand All @@ -27,23 +27,31 @@ eps = 1e-15

[lr_scheduler]
warmup_steps = 600
# warmup_steps = 20
lr_min = 0.1

[training]
batch_size = 8
# batch_size = 8
batch_size = 4
seq_len = 8192
# seq_len = 16384
# seq_len = 32768
# seq_len = 65536
max_norm = 1.0 # grad norm clipping
steps = 3000
# steps = 100
compile = false
# compile = true
dataset = "c4"
deterministic = true

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 8
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
context_parallel_degree = 1
context_parallel_degree = 4

[checkpoint]
enable_checkpoint = false
Expand Down
11 changes: 11 additions & 0 deletions torchtitan/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@
multiple_of=1024,
rope_theta=500000,
),
"8B_flex_attn": TransformerModelArgs(
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=1024,
rope_theta=500000,
use_flex_attn=True,
attn_mask_type="block_causal",
),
"70B": TransformerModelArgs(
dim=8192,
n_layers=80,
Expand Down
6 changes: 0 additions & 6 deletions torchtitan/models/llama3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,6 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
"See https://github.com/pytorch/pytorch/issues/147879"
)

if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
raise ValueError(
"FlexAttention is not compatible with CP yet. "
"We are still working on this."
)

def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
nparams = sum(p.numel() for p in model.parameters())
nparams_embedding = sum(
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ batch_size = 8
seq_len = 2048
max_norm = 1.0 # grad norm clipping
steps = 10
compile = false
compile = true
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[parallelism]
Expand Down
20 changes: 14 additions & 6 deletions torchtitan/models/llama3/train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ description = "Llama 3 8B training"
[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100
# profile_freq = 100
profile_freq = 10
enable_memory_snapshot = true
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 10
enable_tensorboard = true
# enable_tensorboard = false
save_tb_folder = "tb"

[model]
Expand All @@ -27,22 +31,25 @@ lr = 3e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 200 # lr scheduler warm up
# warmup_steps = 200 # lr scheduler warm up
warmup_steps = 600

[training]
batch_size = 1
batch_size = 4
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 1000
# steps = 1000
steps = 20
compile = false
dataset = "c4"
deterministic = false

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
context_parallel_degree = 1
context_parallel_degree = 4

[checkpoint]
enable_checkpoint = false
Expand All @@ -53,7 +60,8 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
# mode = "selective" # ["none", "selective", "full"]
mode = "full"
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy

[float8]
Expand Down
25 changes: 24 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
from typing import Any, Generator, Iterable, Optional

import torch
from torch.distributed.elastic.multiprocessing.errors import record

import torchtitan.components.ft as ft
import torchtitan.protocols.train_spec as train_spec_module
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.experimental._attention import (
_FlexAttentionSequentialSharder,
)

from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.metrics import (
build_metrics_processor,
Expand Down Expand Up @@ -132,7 +136,9 @@ def __init__(self, job_config: JobConfig):

# build model (using meta init)
model_cls = self.train_spec.cls
# NOTE (xilunwu): need to store model_args.use_flex_attn for train_step
model_args = self.train_spec.config[job_config.model.flavor]
self.model_args = model_args
# set the model args from training job configs
model_args.update_from_config(job_config, tokenizer)

Expand Down Expand Up @@ -323,13 +329,30 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
inputs = input_dict["input"]

# TODO: move this into `create_context_parallel_ctx`
# init block_mask for flex_attention
block_mask = None
if self.model_args.use_flex_attn:
from torchtitan.models.attention import FlexAttention

mask_mod = FlexAttention._get_causal_mask_mod()
batch_dimension = 1
seq_len = inputs.shape[1]
block_mask = FlexAttention.compiled_create_block_mask(
mask_mod, batch_dimension, None, seq_len, seq_len
)

optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=world_mesh["cp"],
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={inputs, labels},
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
sharder=_FlexAttentionSequentialSharder(
mesh=world_mesh["cp"], block_mask=block_mask
),
)
if parallel_dims.cp_enabled
else None
Expand Down
Loading