Skip to content

Commit 26ab9ed

Browse files
aporialiaofacebook-github-bot
authored andcommitted
DistributedModelParallel resharding Interface (pytorch#2945)
Summary: Finally! DMP interface for resharding, most of the changes here are to enable proper testing of DMP. ## Main changes: ### 1. DMP reshard API: * which calls the underlying sharder for sharded module to reshard ### 2. Proper Testing: * A multi-rank test which generates a full Model and utilizes DMP interface. Currently only tests TW. * This test is called from `test_dynamic_sharding.py` -> `test_model_parallel.py` -> `test_sharding.py`, which follows the same structure as current DMP unit tests * This is how the test tests for correctness: ``` 1. Generate global model and inputs 2. Create 2 identical local models based on global model 3. Use planner to generate sharding plan for local model 4. Based on planner output, generate a second, different sharding plan 5. Shard both local models 1 and 2 through DMP with plan 1 and 2 respectively 6. Reshard (dynamic sharding API) model 1 with plan 2 7. Generate predictions for local models and compare them to global model prediction. Expect to be the same. ``` * This tests for `optimzier` being correctly saved in resharding * The test is setup with other variables to-be-set once more functionalities are enabled with dynamic sharding, e.g. `variable_batch_size` etc. ### 3. Helper functions for testing * `get_sharding_constructor_from_type` to enable setting sharding_type for each unit test. * `compare_model_pred_one_step` only used for debugging to get more information on whether models are identical after resharding/running initial step * `compare_model_weights` also for debugging ### 3. Small refactoring in `update_shards` call. Differential Revision: D73049934
1 parent cc7f1d0 commit 26ab9ed

File tree

7 files changed

+776
-47
lines changed

7 files changed

+776
-47
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,15 +1531,9 @@ def update_shards(
15311531
current_state = self.state_dict()
15321532
# TODO: Save Optimizers
15331533

1534-
saved_weights = {}
15351534
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1536-
for i, lookup in enumerate(self._lookups):
1537-
for attribute, tbe_module in lookup.named_modules():
1538-
if type(tbe_module) is DenseTableBatchedEmbeddingBagsCodegen:
1539-
saved_weights[str(i) + "." + attribute] = tbe_module.weights.cpu()
1540-
# Note: lookup.purge should delete tbe_module and weights
1541-
# del tbe_module.weights
1542-
# del tbe_module
1535+
# TODO: Ensure lookup tensors are actually being deleted
1536+
for _, lookup in enumerate(self._lookups):
15431537
# pyre-ignore
15441538
lookup.purge()
15451539

torchrec/distributed/model_parallel.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torchrec.distributed.types import (
3636
EnumerableShardingSpec,
3737
ModuleSharder,
38+
ParameterSharding,
3839
ShardedModule,
3940
ShardingEnv,
4041
ShardingEnv2D,
@@ -612,6 +613,85 @@ def _reset_parameters(module: nn.Module) -> None:
612613
if hasattr(m, "reset_parameters"):
613614
m.reset_parameters()
614615

616+
def reshard(
617+
self,
618+
path_to_sharded_module: str,
619+
changed_shard_to_params: Dict[str, ParameterSharding],
620+
) -> None:
621+
"""
622+
Reshards a module in the DMP. This is useful when the sharding plan for a module
623+
changes during training.
624+
625+
This method allows you to dynamically change the sharding strategy for a specific module
626+
without recreating the entire DMP. It's particularly useful for:
627+
1. Adapting to changing requirements during training
628+
2. Implementing progressive sharding strategies
629+
3. Rebalancing load across devices
630+
4. A/B Testing different sharding plans
631+
632+
Args:
633+
path_to_sharded_module (str): The path to the sharded module in the DMP.
634+
For example, "sparse.ebc".
635+
changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping
636+
parameter names to their new ParameterSharding configurations. Includes
637+
only the shards that needs to be moved.
638+
639+
Example:
640+
```
641+
# Original sharding plan might have table sharded across 2 GPUs
642+
original_plan = {
643+
"table_0': ParameterSharding(
644+
sharding_type="table_wise",
645+
ranks=[0, 1, 2, 3],
646+
sharding_spec=EnumerableShardingSpec(...)
647+
)
648+
}
649+
650+
# New sharding plan to shard across 4 GPUs
651+
new_plan = {
652+
"weight": ParameterSharding(
653+
sharding_type="table_wise",
654+
ranks=[0, 1, 2, 3],
655+
sharding_spec=EnumerableShardingSpec(...)
656+
)
657+
}
658+
659+
# Helper function for only selecting the delta between original and new plan
660+
changed_sharding_params = output_sharding_plan_delta(new_plan)
661+
662+
# Reshard the module and redistribute the tensors
663+
model.reshard("embedding_module", changed_sharding_params)
664+
```
665+
666+
Notes:
667+
- The sharder for the module must implement a `reshard` method
668+
- Resharding involves redistributing tensor data across devices, which can be expensive
669+
- After resharding, the optimizer state is maintained for the module
670+
- The sharding plan is updated to reflect the new configuration
671+
"""
672+
steps = path_to_sharded_module.split(".")
673+
sharded_module = self.module
674+
for s in steps:
675+
sharded_module = getattr(sharded_module, s)
676+
677+
assert isinstance(sharded_module, ShardedModule)
678+
assert changed_shard_to_params is not None
679+
sharder_key = sharded_module.unsharded_module_type
680+
sharder = self._sharder_map[sharder_key]
681+
assert hasattr(
682+
sharder, "reshard"
683+
), "reshard is not implemented for this sharder"
684+
sharded_module = sharder.reshard( # pyre-ignore
685+
sharded_module,
686+
changed_shard_to_params,
687+
self._env,
688+
self.device,
689+
)
690+
691+
self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module)
692+
self._plan.plan[path_to_sharded_module] = sharded_module.module_sharding_plan
693+
return sharded_module
694+
615695

616696
class DMPCollection(DistributedModelParallel):
617697
"""

torchrec/distributed/sharding/dynamic_sharding.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77

88
# pyre-strict
99

10+
import copy
1011
from typing import Any, Callable, Dict, List, Tuple
1112

1213
import torch
1314
import torch.distributed as dist
1415
import torch.nn.functional as F
1516
from torch.distributed._shard.sharded_tensor import Shard
1617
from torchrec.distributed.types import (
18+
EmbeddingModuleShardingPlan,
1719
ParameterSharding,
1820
ShardedModule,
1921
ShardedTensor,
@@ -364,3 +366,29 @@ def pad_tensor_to_max_dims(
364366
mode="constant",
365367
value=0,
366368
)
369+
370+
371+
# Utils
372+
def output_sharding_plan_delta(
373+
old_plan: EmbeddingModuleShardingPlan, new_plan: EmbeddingModuleShardingPlan
374+
) -> EmbeddingModuleShardingPlan:
375+
"""
376+
Helper Function for outputting a new sharding plan that is the delta
377+
between two embedding module plans. Assumes that the old and new plan
378+
have the same number of items.
379+
380+
This is useful for Dynamic Sharding since Resharding API takes in only the
381+
ParameterSharding or shards that needs to be moved.
382+
"""
383+
assert len(old_plan) == len(new_plan)
384+
return_plan = copy.deepcopy(new_plan)
385+
for shard_name, old_param in old_plan.items():
386+
if shard_name not in return_plan:
387+
raise ValueError(f"Shard {shard_name} not found in new plan")
388+
new_param = return_plan[shard_name]
389+
old_ranks = old_param.ranks
390+
new_ranks = new_param.ranks
391+
if old_ranks == new_ranks:
392+
del return_plan[shard_name]
393+
394+
return return_plan

torchrec/distributed/sharding_plan.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,20 @@ def _get_parameter_sharding(
410410
]
411411

412412

413+
def get_sharding_constructor_from_type(
414+
sharding_type: ShardingType,
415+
) -> Callable[..., ParameterShardingGenerator]:
416+
sharding_type_to_constructor = {
417+
ShardingType.TABLE_WISE: table_wise,
418+
ShardingType.ROW_WISE: row_wise,
419+
ShardingType.COLUMN_WISE: column_wise,
420+
ShardingType.TABLE_ROW_WISE: table_row_wise,
421+
ShardingType.GRID_SHARD: grid_shard,
422+
ShardingType.DATA_PARALLEL: data_parallel,
423+
}
424+
return sharding_type_to_constructor[sharding_type]
425+
426+
413427
def data_parallel() -> ParameterShardingGenerator:
414428
"""
415429
Returns a generator of ParameterShardingPlan for `ShardingType::DATA_PARALLEL` for construct_module_sharding_plan.

torchrec/distributed/test_utils/test_model_parallel.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torchrec.distributed.test_utils.test_model import TestSparseNN, TestSparseNNBase
2323
from torchrec.distributed.test_utils.test_sharding import (
2424
create_test_sharder,
25+
dynamic_sharding_test,
2526
SharderType,
2627
sharding_single_rank_test,
2728
)
@@ -186,6 +187,78 @@ def _test_sharding(
186187
lengths_dtype=lengths_dtype,
187188
)
188189

190+
def _test_dynamic_sharding(
191+
self,
192+
sharders: List[ModuleSharder[nn.Module]],
193+
backend: str = "gloo",
194+
world_size: int = 2,
195+
local_size: Optional[int] = None,
196+
world_size_2D: Optional[int] = None,
197+
node_group_size: Optional[int] = None,
198+
model_class: Type[TestSparseNNBase] = TestSparseNN,
199+
qcomms_config: Optional[QCommsConfig] = None,
200+
apply_optimizer_in_backward_config: Optional[
201+
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
202+
] = None,
203+
variable_batch_size: bool = False,
204+
variable_batch_per_feature: bool = False,
205+
has_weighted_tables: bool = True,
206+
global_constant_batch: bool = False,
207+
pooling: PoolingType = PoolingType.SUM,
208+
data_type: DataType = DataType.FP32,
209+
use_inter_host_allreduce: bool = False,
210+
allow_zero_batch_size: bool = False,
211+
custom_all_reduce: bool = False,
212+
use_offsets: bool = False,
213+
indices_dtype: torch.dtype = torch.int64,
214+
offsets_dtype: torch.dtype = torch.int64,
215+
lengths_dtype: torch.dtype = torch.int64,
216+
sharding_type: ShardingType = None, # pyre-ignore
217+
random_seed: int = 0,
218+
) -> None:
219+
"""
220+
Tests the reshard API with dynamic_sharding_test, which creates 2 identical models
221+
one of which is resharded, and then compares the predictions of the 2 models.
222+
"""
223+
self._build_tables_and_groups(data_type=data_type)
224+
constraints = {}
225+
if sharding_type is not None:
226+
for table in self.tables:
227+
name = table.name
228+
# Default sharding type constraints
229+
constraints[name] = ParameterConstraints(
230+
sharding_types=[sharding_type.value],
231+
)
232+
233+
self._run_multi_process_test(
234+
callable=dynamic_sharding_test,
235+
world_size=world_size,
236+
local_size=local_size,
237+
world_size_2D=world_size_2D,
238+
node_group_size=node_group_size,
239+
model_class=model_class,
240+
tables=self.tables if pooling == PoolingType.SUM else self.mean_tables,
241+
weighted_tables=self.weighted_tables if has_weighted_tables else None,
242+
embedding_groups=self.embedding_groups,
243+
sharders=sharders,
244+
backend=backend,
245+
optim=EmbOptimType.EXACT_SGD,
246+
constraints=constraints,
247+
qcomms_config=qcomms_config,
248+
variable_batch_size=variable_batch_size,
249+
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
250+
variable_batch_per_feature=variable_batch_per_feature,
251+
global_constant_batch=global_constant_batch,
252+
use_inter_host_allreduce=use_inter_host_allreduce,
253+
allow_zero_batch_size=allow_zero_batch_size,
254+
custom_all_reduce=custom_all_reduce,
255+
use_offsets=use_offsets,
256+
indices_dtype=indices_dtype,
257+
offsets_dtype=offsets_dtype,
258+
lengths_dtype=lengths_dtype,
259+
random_seed=random_seed,
260+
)
261+
189262

190263
@skip_if_asan_class
191264
class ModelParallelBase(ModelParallelTestShared):

0 commit comments

Comments
 (0)