Skip to content

Commit be5f2e1

Browse files
aporialiaofacebook-github-bot
authored andcommitted
DistributedModelParallel resharding Interface (#2945)
Summary: Finally! DMP interface for resharding, most of the changes here are to enable proper testing of DMP. ## Main changes: ### 1. DMP reshard API: * which calls the underlying sharder for sharded module to reshard ### 2. Proper Testing: * A multi-rank test which generates a full Model and utilizes DMP interface. Currently only tests TW. * This test is called from `test_dynamic_sharding.py` -> `test_model_parallel.py` -> `test_sharding.py`, which follows the same structure as current DMP unit tests * This is how the test tests for correctness: ``` 1. Generate global model and inputs 2. Create 2 identical local models based on global model 3. Use planner to generate sharding plan for local model 4. Based on planner output, generate a second, different sharding plan 5. Shard both local models 1 and 2 through DMP with plan 1 and 2 respectively 6. Reshard (dynamic sharding API) model 1 with plan 2 7. Generate predictions for local models and compare them to global model prediction. Expect to be the same. ``` * This tests for `optimzier` being correctly saved in resharding * The test is setup with other variables to-be-set once more functionalities are enabled with dynamic sharding, e.g. `variable_batch_size` etc. ### 3. Helper functions for testing * `get_sharding_constructor_from_type` to enable setting sharding_type for each unit test. * `compare_model_pred_one_step` only used for debugging to get more information on whether models are identical after resharding/running initial step * `compare_model_weights` also for debugging ### 3. Small refactoring in `update_shards` call. Differential Revision: D73049934
1 parent 60442e6 commit be5f2e1

File tree

7 files changed

+854
-47
lines changed

7 files changed

+854
-47
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,15 +1531,9 @@ def update_shards(
15311531
current_state = self.state_dict()
15321532
# TODO: Save Optimizers
15331533

1534-
saved_weights = {}
15351534
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1536-
for i, lookup in enumerate(self._lookups):
1537-
for attribute, tbe_module in lookup.named_modules():
1538-
if type(tbe_module) is DenseTableBatchedEmbeddingBagsCodegen:
1539-
saved_weights[str(i) + "." + attribute] = tbe_module.weights.cpu()
1540-
# Note: lookup.purge should delete tbe_module and weights
1541-
# del tbe_module.weights
1542-
# del tbe_module
1535+
# TODO: Ensure lookup tensors are actually being deleted
1536+
for _, lookup in enumerate(self._lookups):
15431537
# pyre-ignore
15441538
lookup.purge()
15451539

torchrec/distributed/model_parallel.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torchrec.distributed.types import (
3636
EnumerableShardingSpec,
3737
ModuleSharder,
38+
ParameterSharding,
3839
ShardedModule,
3940
ShardingEnv,
4041
ShardingEnv2D,
@@ -612,6 +613,84 @@ def _reset_parameters(module: nn.Module) -> None:
612613
if hasattr(m, "reset_parameters"):
613614
m.reset_parameters()
614615

616+
def reshard(
617+
self,
618+
sharded_module_fqn: str,
619+
changed_shard_to_params: Dict[str, ParameterSharding],
620+
) -> None:
621+
"""
622+
Reshards an already-sharded module in the DMP given a set of ParameterShardings to change placements.
623+
624+
This method allows you to dynamically change the sharding strategy for a specific module
625+
without recreating the entire DMP. It's particularly useful for:
626+
1. Adapting to changing requirements during training
627+
2. Implementing progressive sharding strategies
628+
3. Rebalancing load across devices
629+
4. A/B Testing different sharding plans
630+
631+
Args:
632+
path_to_sharded_module (str): The path to the sharded module in the DMP.
633+
For example, "sparse.ebc".
634+
changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping
635+
parameter names to their new ParameterSharding configurations. Includes
636+
only the shards that needs to be moved.
637+
638+
Example:
639+
```
640+
# Original sharding plan might have table sharded across 2 GPUs
641+
original_plan = {
642+
"table_0': ParameterSharding(
643+
sharding_type="table_wise",
644+
ranks=[0, 1, 2, 3],
645+
sharding_spec=EnumerableShardingSpec(...)
646+
)
647+
}
648+
649+
# New sharding plan to shard across 4 GPUs
650+
new_plan = {
651+
"weight": ParameterSharding(
652+
sharding_type="table_wise",
653+
ranks=[0, 1, 2, 3],
654+
sharding_spec=EnumerableShardingSpec(...)
655+
)
656+
}
657+
658+
# Helper function for only selecting the delta between original and new plan
659+
changed_sharding_params = output_sharding_plan_delta(new_plan)
660+
661+
# Reshard the module and redistribute the tensors
662+
model.reshard("embedding_module", changed_sharding_params)
663+
```
664+
665+
Notes:
666+
- The sharder for the module must implement a `reshard` method
667+
- Resharding involves redistributing tensor data across devices, which can be expensive
668+
- After resharding, the optimizer state is maintained for the module
669+
- The sharding plan is updated to reflect the new configuration
670+
"""
671+
steps = sharded_module_fqn.split(".")
672+
sharded_module = self.module
673+
for s in steps:
674+
sharded_module = getattr(sharded_module, s)
675+
676+
assert isinstance(sharded_module, ShardedModule)
677+
assert changed_shard_to_params is not None
678+
sharder_key = sharded_module.unsharded_module_type
679+
sharder = self._sharder_map[sharder_key]
680+
assert hasattr(
681+
sharder, "reshard"
682+
), "reshard is not implemented for this sharder"
683+
sharded_module = sharder.reshard( # pyre-ignore
684+
sharded_module,
685+
changed_shard_to_params,
686+
self._env,
687+
self.device,
688+
)
689+
690+
self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module)
691+
self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan
692+
return sharded_module
693+
615694

616695
class DMPCollection(DistributedModelParallel):
617696
"""

torchrec/distributed/sharding/dynamic_sharding.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77

88
# pyre-strict
99

10+
import copy
1011
from typing import Any, Callable, Dict, List, Tuple
1112

1213
import torch
1314
import torch.distributed as dist
1415
import torch.nn.functional as F
1516
from torch.distributed._shard.sharded_tensor import Shard
1617
from torchrec.distributed.types import (
18+
EmbeddingModuleShardingPlan,
1719
ParameterSharding,
1820
ShardedModule,
1921
ShardedTensor,
@@ -364,3 +366,29 @@ def pad_tensor_to_max_dims(
364366
mode="constant",
365367
value=0,
366368
)
369+
370+
371+
# Utils
372+
def output_sharding_plan_delta(
373+
old_plan: EmbeddingModuleShardingPlan, new_plan: EmbeddingModuleShardingPlan
374+
) -> EmbeddingModuleShardingPlan:
375+
"""
376+
Compute and return a new sharding plan that is the delta
377+
between new and old embedding module plans. Assumes that the old and new plan
378+
have the same number of parameters/tables.
379+
380+
This is useful for Dynamic Sharding since Resharding API takes in only the
381+
ParameterSharding or shards that needs to be moved.
382+
"""
383+
assert len(old_plan) == len(new_plan)
384+
return_plan = copy.deepcopy(new_plan)
385+
for table_name, old_param in old_plan.items():
386+
if table_name not in return_plan:
387+
raise ValueError(f"Table {table_name} not found in new plan")
388+
new_param = return_plan[table_name]
389+
old_ranks = old_param.ranks
390+
new_ranks = new_param.ranks
391+
if old_ranks == new_ranks:
392+
del return_plan[table_name]
393+
394+
return return_plan

torchrec/distributed/sharding_plan.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,20 @@ def _get_parameter_sharding(
410410
]
411411

412412

413+
def get_sharding_constructor_from_type(
414+
sharding_type: ShardingType,
415+
) -> Callable[..., ParameterShardingGenerator]:
416+
sharding_type_to_constructor = {
417+
ShardingType.TABLE_WISE: table_wise,
418+
ShardingType.ROW_WISE: row_wise,
419+
ShardingType.COLUMN_WISE: column_wise,
420+
ShardingType.TABLE_ROW_WISE: table_row_wise,
421+
ShardingType.GRID_SHARD: grid_shard,
422+
ShardingType.DATA_PARALLEL: data_parallel,
423+
}
424+
return sharding_type_to_constructor[sharding_type]
425+
426+
413427
def data_parallel() -> ParameterShardingGenerator:
414428
"""
415429
Returns a generator of ParameterShardingPlan for `ShardingType::DATA_PARALLEL` for construct_module_sharding_plan.

torchrec/distributed/test_utils/test_model_parallel.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torchrec.distributed.test_utils.test_model import TestSparseNN, TestSparseNNBase
2323
from torchrec.distributed.test_utils.test_sharding import (
2424
create_test_sharder,
25+
dynamic_sharding_test,
2526
SharderType,
2627
sharding_single_rank_test,
2728
)
@@ -186,6 +187,78 @@ def _test_sharding(
186187
lengths_dtype=lengths_dtype,
187188
)
188189

190+
def _test_dynamic_sharding(
191+
self,
192+
sharders: List[ModuleSharder[nn.Module]],
193+
backend: str = "gloo",
194+
world_size: int = 2,
195+
local_size: Optional[int] = None,
196+
world_size_2D: Optional[int] = None,
197+
node_group_size: Optional[int] = None,
198+
model_class: Type[TestSparseNNBase] = TestSparseNN,
199+
qcomms_config: Optional[QCommsConfig] = None,
200+
apply_optimizer_in_backward_config: Optional[
201+
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
202+
] = None,
203+
variable_batch_size: bool = False,
204+
variable_batch_per_feature: bool = False,
205+
has_weighted_tables: bool = True,
206+
global_constant_batch: bool = False,
207+
pooling: PoolingType = PoolingType.SUM,
208+
data_type: DataType = DataType.FP32,
209+
use_inter_host_allreduce: bool = False,
210+
allow_zero_batch_size: bool = False,
211+
custom_all_reduce: bool = False,
212+
use_offsets: bool = False,
213+
indices_dtype: torch.dtype = torch.int64,
214+
offsets_dtype: torch.dtype = torch.int64,
215+
lengths_dtype: torch.dtype = torch.int64,
216+
sharding_type: ShardingType = None, # pyre-ignore
217+
random_seed: int = 0,
218+
) -> None:
219+
"""
220+
Tests the reshard API with dynamic_sharding_test, which creates 2 identical models
221+
one of which is resharded, and then compares the predictions of the 2 models.
222+
"""
223+
self._build_tables_and_groups(data_type=data_type)
224+
constraints = {}
225+
if sharding_type is not None:
226+
for table in self.tables:
227+
name = table.name
228+
# Default sharding type constraints
229+
constraints[name] = ParameterConstraints(
230+
sharding_types=[sharding_type.value],
231+
)
232+
233+
self._run_multi_process_test(
234+
callable=dynamic_sharding_test,
235+
world_size=world_size,
236+
local_size=local_size,
237+
world_size_2D=world_size_2D,
238+
node_group_size=node_group_size,
239+
model_class=model_class,
240+
tables=self.tables if pooling == PoolingType.SUM else self.mean_tables,
241+
weighted_tables=self.weighted_tables if has_weighted_tables else None,
242+
embedding_groups=self.embedding_groups,
243+
sharders=sharders,
244+
backend=backend,
245+
optim=EmbOptimType.EXACT_SGD,
246+
constraints=constraints,
247+
qcomms_config=qcomms_config,
248+
variable_batch_size=variable_batch_size,
249+
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
250+
variable_batch_per_feature=variable_batch_per_feature,
251+
global_constant_batch=global_constant_batch,
252+
use_inter_host_allreduce=use_inter_host_allreduce,
253+
allow_zero_batch_size=allow_zero_batch_size,
254+
custom_all_reduce=custom_all_reduce,
255+
use_offsets=use_offsets,
256+
indices_dtype=indices_dtype,
257+
offsets_dtype=offsets_dtype,
258+
lengths_dtype=lengths_dtype,
259+
random_seed=random_seed,
260+
)
261+
189262

190263
@skip_if_asan_class
191264
class ModelParallelBase(ModelParallelTestShared):

0 commit comments

Comments
 (0)