Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train_engine] fsdp support mesh #2512

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import copy
from typing import List, Optional

from torch.distributed.device_mesh import init_device_mesh

import deepspeed
import json
import logging
Expand Down Expand Up @@ -205,12 +207,17 @@ def add_fsdp_args(parser):
'--fsdp_sharding_strategy',
default='zero2',
# TODO(Mddct): pipeline and model parallel (3-D parallelism)
choices=['no_shard', 'model', 'zero2', 'zero3'],
choices=[
'no_shard', 'model', 'zero2', 'zero3', 'hsdp_zero3', 'hsdp_zero2'
],
help='Sharding strategy for FSDP. Choose from the following options:\n'
' - "no_shard": Equivalent to DistributedDataParallel (DDP).\n'
' - "model": WENET_ENC_DEC strategy, equivalent to DeepSpeed zero1.\n'
' - "zero2": SHARD_GRAD_OP strategy, equivalent to DeepSpeed zero2.\n'
' - "zero3": FULL_SHARD strategy, equivalent to DeepSpeed zero3.\n'
' - "hsdp_model": one host sharded and replica across hosts.\n'
' - "hsdp_zero2": one host sharded and replica across hosts.\n'
' - "hsdp_zero3": one host sharded and replica across hosts.\n'
'For more information, refer to the FSDP API documentation.')
return parser

Expand All @@ -221,9 +228,13 @@ def init_distributed(args):
rank = int(os.environ.get('RANK', 0))
logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
', rank {}, world_size {}'.format(rank, world_size))
if args.train_engine in ["torch_ddp", "torch_fsdp"]:
if args.train_engine == "torch_ddp":
torch.cuda.set_device(local_rank)
dist.init_process_group(args.dist_backend)
elif args.train_engint == "torch_fsdp":
# use mesh in wrap_cuda_model
torch.cuda.set_device(local_rank)
pass
elif args.train_engine == "deepspeed":
deepspeed.init_distributed(dist_backend=args.dist_backend)
else:
Expand Down Expand Up @@ -396,7 +407,11 @@ def wrap_cuda_model(args, model, configs=None):
device = None # Init device later
pass # Init DeepSpeed later
elif args.train_engine == 'torch_fsdp':
device = torch.device("cuda")
assert configs is not None
# 2x2 device mesh, set up NCCL communicatitor automatically
device_mesh = init_device_mesh(
"cuda", (world_size // local_world_size, local_world_size))
mixed_precision_dtype = {
'fp32': torch.float32,
"fp16": torch.float16,
Expand All @@ -408,6 +423,9 @@ def wrap_cuda_model(args, model, configs=None):
'zero2': ShardingStrategy.SHARD_GRAD_OP,
'zero3': ShardingStrategy.FULL_SHARD,
'no_shard': ShardingStrategy.NO_SHARD,
'hsdp_model': ShardingStrategy._HYBRID_SHARD_ZERO2,
'hsdp_zero2': ShardingStrategy._HYBRID_SHARD_ZERO2,
'hsdp_zero3': ShardingStrategy.HYBRID_SHARD,
}[args.fsdp_sharding_strategy]
wrap_policy = wenet_fsdp_wrap_policy(mode=args.fsdp_sharding_strategy)
layer_types = check_gradient_checkpoint(model)
Expand All @@ -428,9 +446,9 @@ def wrap_cuda_model(args, model, configs=None):
# init_distributed is called (torch.cuda.set_device),
# we should set device_id, see FSDP api
device_id=torch.cuda.current_device(),
device_mesh=device_mesh,
)
apply_fsdp_checkpointing(model, layer_types)
device = torch.device("cuda")
else:
logging.error("not supported engine: {}".format(args.train_engine))
if args.train_engine in ["torch_fsdp", "torch_ddp"]:
Expand Down
Loading