|
9 | 9 |
|
10 | 10 | import abc
|
11 | 11 | import operator
|
12 |
| -from dataclasses import dataclass |
| 12 | +from dataclasses import asdict, dataclass |
13 | 13 | from enum import Enum, unique
|
14 | 14 | from typing import (
|
15 | 15 | Any,
|
@@ -549,7 +549,8 @@ def impl(self, rhs):
|
549 | 549 |
|
550 | 550 |
|
551 | 551 | class ModuleShardingPlan:
|
552 |
| - pass |
| 552 | + def _serialize(self) -> dict[str, Any]: |
| 553 | + raise NotImplementedError() |
553 | 554 |
|
554 | 555 |
|
555 | 556 | class CacheStatistics(abc.ABC):
|
@@ -772,6 +773,25 @@ def __str__(self) -> str:
|
772 | 773 | )
|
773 | 774 | return out
|
774 | 775 |
|
| 776 | + def _serialize(self) -> dict[str, Any]: |
| 777 | + sharding_plan_dict = {} |
| 778 | + for param_name, param_sharding in self.items(): |
| 779 | + sharding_plan_dict[param_name] = { |
| 780 | + "sharding_type": param_sharding.sharding_type, |
| 781 | + "compute_kernel": param_sharding.compute_kernel, |
| 782 | + "ranks": param_sharding.ranks, |
| 783 | + } |
| 784 | + if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec): |
| 785 | + shards = param_sharding.sharding_spec.shards |
| 786 | + if shards is not None: |
| 787 | + sharding_plan_dict[param_name]["shards"] = [] |
| 788 | + for shard in shards: |
| 789 | + shard_dict = asdict(shard) |
| 790 | + shard_dict["placement"] = str(shard_dict["placement"]) |
| 791 | + sharding_plan_dict[param_name]["shards"].append(shard_dict) |
| 792 | + |
| 793 | + return sharding_plan_dict |
| 794 | + |
775 | 795 |
|
776 | 796 | @dataclass
|
777 | 797 | class ShardingPlan:
|
@@ -805,6 +825,15 @@ def __str__(self) -> str:
|
805 | 825 | out += str(module_plan)
|
806 | 826 | return out
|
807 | 827 |
|
| 828 | + def _serialize(self) -> dict[str, Any]: |
| 829 | + sharding_plan_dict = { |
| 830 | + "plan": { |
| 831 | + module_path: module_plan._serialize() |
| 832 | + for module_path, module_plan in self.plan.items() |
| 833 | + } |
| 834 | + } |
| 835 | + return sharding_plan_dict |
| 836 | + |
808 | 837 |
|
809 | 838 | ShardedModuleContext = Multistreamable
|
810 | 839 |
|
@@ -1240,6 +1269,12 @@ class ObjectPoolShardingPlan(ModuleShardingPlan):
|
1240 | 1269 | sharding_type: ObjectPoolShardingType
|
1241 | 1270 | inference: bool = False
|
1242 | 1271 |
|
| 1272 | + def _serialize(self) -> dict[str, Any]: |
| 1273 | + return { |
| 1274 | + "sharding_type": self.sharding_type.name, |
| 1275 | + "inference": self.inference, |
| 1276 | + } |
| 1277 | + |
1243 | 1278 |
|
1244 | 1279 | @dataclass
|
1245 | 1280 | class ShardingBucketMetadata:
|
|
0 commit comments