Skip to content

Commit db79bae

Browse files
Dasha Yefymenkofacebook-github-bot
Dasha Yefymenko
authored andcommitted
Serialize ModuleShardingPlan to JSON-serializable dict (#3020)
Summary: Serialize ModuleShardingPlan and its implementation classes to JSON-serializable dicts. Reviewed By: seanx92 Differential Revision: D75708433
1 parent a5f8103 commit db79bae

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

torchrec/distributed/types.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import abc
1111
import operator
12-
from dataclasses import dataclass
12+
from dataclasses import asdict, dataclass
1313
from enum import Enum, unique
1414
from typing import (
1515
Any,
@@ -549,7 +549,8 @@ def impl(self, rhs):
549549

550550

551551
class ModuleShardingPlan:
552-
pass
552+
def _serialize(self) -> dict[str, Any]:
553+
raise NotImplementedError()
553554

554555

555556
class CacheStatistics(abc.ABC):
@@ -772,6 +773,25 @@ def __str__(self) -> str:
772773
)
773774
return out
774775

776+
def _serialize(self) -> dict[str, Any]:
777+
sharding_plan_dict = {}
778+
for param_name, param_sharding in self.items():
779+
sharding_plan_dict[param_name] = {
780+
"sharding_type": param_sharding.sharding_type,
781+
"compute_kernel": param_sharding.compute_kernel,
782+
"ranks": param_sharding.ranks,
783+
}
784+
if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec):
785+
shards = param_sharding.sharding_spec.shards
786+
if shards is not None:
787+
sharding_plan_dict[param_name]["shards"] = []
788+
for shard in shards:
789+
shard_dict = asdict(shard)
790+
shard_dict["placement"] = str(shard_dict["placement"])
791+
sharding_plan_dict[param_name]["shards"].append(shard_dict)
792+
793+
return sharding_plan_dict
794+
775795

776796
@dataclass
777797
class ShardingPlan:
@@ -805,6 +825,15 @@ def __str__(self) -> str:
805825
out += str(module_plan)
806826
return out
807827

828+
def _serialize(self) -> dict[str, Any]:
829+
sharding_plan_dict = {
830+
"plan": {
831+
module_path: module_plan._serialize()
832+
for module_path, module_plan in self.plan.items()
833+
}
834+
}
835+
return sharding_plan_dict
836+
808837

809838
ShardedModuleContext = Multistreamable
810839

@@ -1240,6 +1269,12 @@ class ObjectPoolShardingPlan(ModuleShardingPlan):
12401269
sharding_type: ObjectPoolShardingType
12411270
inference: bool = False
12421271

1272+
def _serialize(self) -> dict[str, Any]:
1273+
return {
1274+
"sharding_type": self.sharding_type.name,
1275+
"inference": self.inference,
1276+
}
1277+
12431278

12441279
@dataclass
12451280
class ShardingBucketMetadata:

0 commit comments

Comments
 (0)