Skip to content

DistributedModelParallel resharding Interface #2945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,15 +1531,9 @@ def update_shards(
current_state = self.state_dict()
# TODO: Save Optimizers

saved_weights = {}
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
for i, lookup in enumerate(self._lookups):
for attribute, tbe_module in lookup.named_modules():
if type(tbe_module) is DenseTableBatchedEmbeddingBagsCodegen:
saved_weights[str(i) + "." + attribute] = tbe_module.weights.cpu()
# Note: lookup.purge should delete tbe_module and weights
# del tbe_module.weights
# del tbe_module
# TODO: Ensure lookup tensors are actually being deleted
for _, lookup in enumerate(self._lookups):
# pyre-ignore
lookup.purge()

Expand Down Expand Up @@ -1603,6 +1597,12 @@ def update_shards(
for embedding_configs in self.sharding_type_to_sharding_infos.values()
]

# Reset input dists
self._has_uninitialized_input_dist = True
self._input_dists: List[nn.Module] = []
self._features_order: List[int] = []
self._feature_splits: List[int] = []

self._create_lookups()
self._update_output_dist()

Expand Down
79 changes: 79 additions & 0 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from torchrec.distributed.types import (
EnumerableShardingSpec,
ModuleSharder,
ParameterSharding,
ShardedModule,
ShardingEnv,
ShardingEnv2D,
Expand Down Expand Up @@ -612,6 +613,84 @@ def _reset_parameters(module: nn.Module) -> None:
if hasattr(m, "reset_parameters"):
m.reset_parameters()

def reshard(
self,
sharded_module_fqn: str,
changed_shard_to_params: Dict[str, ParameterSharding],
) -> None:
"""
Reshards an already-sharded module in the DMP given a set of ParameterShardings to change placements.

This method allows you to dynamically change the sharding strategy for a specific module
without recreating the entire DMP. It's particularly useful for:
1. Adapting to changing requirements during training
2. Implementing progressive sharding strategies
3. Rebalancing load across devices
4. A/B Testing different sharding plans

Args:
path_to_sharded_module (str): The path to the sharded module in the DMP.
For example, "sparse.ebc".
changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping
parameter names to their new ParameterSharding configurations. Includes
only the shards that needs to be moved.

Example:
```
# Original sharding plan might have table sharded across 2 GPUs
original_plan = {
"table_0': ParameterSharding(
sharding_type="table_wise",
ranks=[0, 1, 2, 3],
sharding_spec=EnumerableShardingSpec(...)
)
}

# New sharding plan to shard across 4 GPUs
new_plan = {
"weight": ParameterSharding(
sharding_type="table_wise",
ranks=[0, 1, 2, 3],
sharding_spec=EnumerableShardingSpec(...)
)
}

# Helper function for only selecting the delta between original and new plan
changed_sharding_params = output_sharding_plan_delta(new_plan)

# Reshard the module and redistribute the tensors
model.reshard("embedding_module", changed_sharding_params)
```

Notes:
- The sharder for the module must implement a `reshard` method
- Resharding involves redistributing tensor data across devices, which can be expensive
- After resharding, the optimizer state is maintained for the module
- The sharding plan is updated to reflect the new configuration
"""
steps = sharded_module_fqn.split(".")
sharded_module = self.module
for s in steps:
sharded_module = getattr(sharded_module, s)

assert isinstance(sharded_module, ShardedModule)
assert changed_shard_to_params is not None
sharder_key = sharded_module.unsharded_module_type
sharder = self._sharder_map[sharder_key]
assert hasattr(
sharder, "reshard"
), "reshard is not implemented for this sharder"
sharded_module = sharder.reshard( # pyre-ignore
sharded_module,
changed_shard_to_params,
self._env,
self.device,
)

self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module)
self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan
return sharded_module


class DMPCollection(DistributedModelParallel):
"""
Expand Down
24 changes: 24 additions & 0 deletions torchrec/distributed/sharding/dynamic_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

# pyre-strict

import copy
from typing import Any, Callable, Dict, List, Tuple

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed._shard.sharded_tensor import Shard
from torchrec.distributed.types import (
EmbeddingModuleShardingPlan,
ParameterSharding,
ShardedModule,
ShardedTensor,
Expand Down Expand Up @@ -364,3 +366,25 @@ def pad_tensor_to_max_dims(
mode="constant",
value=0,
)


# Utils
def output_sharding_plan_delta(
old_plan: EmbeddingModuleShardingPlan, new_plan: EmbeddingModuleShardingPlan
) -> EmbeddingModuleShardingPlan:
"""
Compute and return a new sharding plan that is the delta
between new and old embedding module plans. Assumes that the old and new plan
have the same number of parameters/tables.

This is useful for Dynamic Sharding since Resharding API takes in only the
ParameterSharding or shards that needs to be moved.
"""
assert len(old_plan) == len(new_plan)
return EmbeddingModuleShardingPlan(
{
k: copy.deepcopy(v)
for k, v in new_plan.items()
if v.ranks != old_plan[k].ranks
}
)
14 changes: 14 additions & 0 deletions torchrec/distributed/sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,20 @@ def _get_parameter_sharding(
]


def get_sharding_constructor_from_type(
sharding_type: ShardingType,
) -> Callable[..., ParameterShardingGenerator]:
sharding_type_to_constructor = {
ShardingType.TABLE_WISE: table_wise,
ShardingType.ROW_WISE: row_wise,
ShardingType.COLUMN_WISE: column_wise,
ShardingType.TABLE_ROW_WISE: table_row_wise,
ShardingType.GRID_SHARD: grid_shard,
ShardingType.DATA_PARALLEL: data_parallel,
}
return sharding_type_to_constructor[sharding_type]


def data_parallel() -> ParameterShardingGenerator:
"""
Returns a generator of ParameterShardingPlan for `ShardingType::DATA_PARALLEL` for construct_module_sharding_plan.
Expand Down
73 changes: 73 additions & 0 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchrec.distributed.test_utils.test_model import TestSparseNN, TestSparseNNBase
from torchrec.distributed.test_utils.test_sharding import (
create_test_sharder,
dynamic_sharding_test,
SharderType,
sharding_single_rank_test,
)
Expand Down Expand Up @@ -190,6 +191,78 @@ def _test_sharding(
lengths_dtype=lengths_dtype,
)

def _test_dynamic_sharding(
self,
sharders: List[ModuleSharder[nn.Module]],
backend: str = "gloo",
world_size: int = 2,
local_size: Optional[int] = None,
world_size_2D: Optional[int] = None,
node_group_size: Optional[int] = None,
model_class: Type[TestSparseNNBase] = TestSparseNN,
qcomms_config: Optional[QCommsConfig] = None,
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
] = None,
variable_batch_size: bool = False,
variable_batch_per_feature: bool = False,
has_weighted_tables: bool = True,
global_constant_batch: bool = False,
pooling: PoolingType = PoolingType.SUM,
data_type: DataType = DataType.FP32,
use_inter_host_allreduce: bool = False,
allow_zero_batch_size: bool = False,
custom_all_reduce: bool = False,
use_offsets: bool = False,
indices_dtype: torch.dtype = torch.int64,
offsets_dtype: torch.dtype = torch.int64,
lengths_dtype: torch.dtype = torch.int64,
sharding_type: ShardingType = None, # pyre-ignore
random_seed: int = 0,
) -> None:
"""
Tests the reshard API with dynamic_sharding_test, which creates 2 identical models
one of which is resharded, and then compares the predictions of the 2 models.
"""
self._build_tables_and_groups(data_type=data_type)
constraints = {}
if sharding_type is not None:
for table in self.tables:
name = table.name
# Default sharding type constraints
constraints[name] = ParameterConstraints(
sharding_types=[sharding_type.value],
)

self._run_multi_process_test(
callable=dynamic_sharding_test,
world_size=world_size,
local_size=local_size,
world_size_2D=world_size_2D,
node_group_size=node_group_size,
model_class=model_class,
tables=self.tables if pooling == PoolingType.SUM else self.mean_tables,
weighted_tables=self.weighted_tables if has_weighted_tables else None,
embedding_groups=self.embedding_groups,
sharders=sharders,
backend=backend,
optim=EmbOptimType.EXACT_SGD,
constraints=constraints,
qcomms_config=qcomms_config,
variable_batch_size=variable_batch_size,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=global_constant_batch,
use_inter_host_allreduce=use_inter_host_allreduce,
allow_zero_batch_size=allow_zero_batch_size,
custom_all_reduce=custom_all_reduce,
use_offsets=use_offsets,
indices_dtype=indices_dtype,
offsets_dtype=offsets_dtype,
lengths_dtype=lengths_dtype,
random_seed=random_seed,
)


@skip_if_asan_class
class ModelParallelBase(ModelParallelTestShared):
Expand Down
Loading
Loading