diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index bc2fc6102..fcdb239f4 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -17,6 +17,8 @@ import copy from typing import List, Optional +from torch.distributed.device_mesh import init_device_mesh + import deepspeed import json import logging @@ -205,12 +207,17 @@ def add_fsdp_args(parser): '--fsdp_sharding_strategy', default='zero2', # TODO(Mddct): pipeline and model parallel (3-D parallelism) - choices=['no_shard', 'model', 'zero2', 'zero3'], + choices=[ + 'no_shard', 'model', 'zero2', 'zero3', 'hsdp_zero3', 'hsdp_zero2' + ], help='Sharding strategy for FSDP. Choose from the following options:\n' ' - "no_shard": Equivalent to DistributedDataParallel (DDP).\n' ' - "model": WENET_ENC_DEC strategy, equivalent to DeepSpeed zero1.\n' ' - "zero2": SHARD_GRAD_OP strategy, equivalent to DeepSpeed zero2.\n' ' - "zero3": FULL_SHARD strategy, equivalent to DeepSpeed zero3.\n' + ' - "hsdp_model": one host sharded and replica across hosts.\n' + ' - "hsdp_zero2": one host sharded and replica across hosts.\n' + ' - "hsdp_zero3": one host sharded and replica across hosts.\n' 'For more information, refer to the FSDP API documentation.') return parser @@ -221,9 +228,13 @@ def init_distributed(args): rank = int(os.environ.get('RANK', 0)) logging.info('training on multiple gpus, this gpu {}'.format(local_rank) + ', rank {}, world_size {}'.format(rank, world_size)) - if args.train_engine in ["torch_ddp", "torch_fsdp"]: + if args.train_engine == "torch_ddp": torch.cuda.set_device(local_rank) dist.init_process_group(args.dist_backend) + elif args.train_engint == "torch_fsdp": + # use mesh in wrap_cuda_model + torch.cuda.set_device(local_rank) + pass elif args.train_engine == "deepspeed": deepspeed.init_distributed(dist_backend=args.dist_backend) else: @@ -396,7 +407,11 @@ def wrap_cuda_model(args, model, configs=None): device = None # Init device later pass # Init DeepSpeed later elif args.train_engine == 'torch_fsdp': + device = torch.device("cuda") assert configs is not None + # 2x2 device mesh, set up NCCL communicatitor automatically + device_mesh = init_device_mesh( + "cuda", (world_size // local_world_size, local_world_size)) mixed_precision_dtype = { 'fp32': torch.float32, "fp16": torch.float16, @@ -408,6 +423,9 @@ def wrap_cuda_model(args, model, configs=None): 'zero2': ShardingStrategy.SHARD_GRAD_OP, 'zero3': ShardingStrategy.FULL_SHARD, 'no_shard': ShardingStrategy.NO_SHARD, + 'hsdp_model': ShardingStrategy._HYBRID_SHARD_ZERO2, + 'hsdp_zero2': ShardingStrategy._HYBRID_SHARD_ZERO2, + 'hsdp_zero3': ShardingStrategy.HYBRID_SHARD, }[args.fsdp_sharding_strategy] wrap_policy = wenet_fsdp_wrap_policy(mode=args.fsdp_sharding_strategy) layer_types = check_gradient_checkpoint(model) @@ -428,9 +446,9 @@ def wrap_cuda_model(args, model, configs=None): # init_distributed is called (torch.cuda.set_device), # we should set device_id, see FSDP api device_id=torch.cuda.current_device(), + device_mesh=device_mesh, ) apply_fsdp_checkpointing(model, layer_types) - device = torch.device("cuda") else: logging.error("not supported engine: {}".format(args.train_engine)) if args.train_engine in ["torch_fsdp", "torch_ddp"]: