@@ -66,62 +66,74 @@ def create_virtual_table_local_metadata(
66
66
local_metadata : ShardMetadata ,
67
67
param : Union [torch .Tensor , PartiallyMaterializedTensor ],
68
68
my_rank : int ,
69
+ offset : Optional [int ] = None ,
70
+ weight_count_per_rank : Optional [List [int ]] = None ,
69
71
) -> None :
70
- local_metadata .shard_sizes = list (param .size ()) # pyre-ignore
71
- local_metadata .shard_offsets = [0 for _ in range (len (param .size ()))] # pyre-ignore
72
+ if offset is None :
73
+ offset = (
74
+ my_rank
75
+ if weight_count_per_rank is None
76
+ else sum (weight_count_per_rank [:my_rank ])
77
+ )
78
+ local_metadata .shard_sizes = list (param .size ()) # pyre-ignore[6]
79
+ local_metadata .shard_offsets = [
80
+ offset if dim == 0 else 0 for dim in range (len (param .size ())) # pyre-ignore[6]
81
+ ]
72
82
73
83
74
84
def create_virtual_table_global_metadata (
75
85
metadata : ShardedTensorMetadata ,
76
86
my_rank : int ,
77
87
param : Union [torch .Tensor , PartiallyMaterializedTensor ],
88
+ weight_count_per_rank : Optional [List [int ]],
89
+ use_param_size_as_rows : bool ,
78
90
) -> None :
79
91
# update tensor properties from local tensor properties, this should be universal for all ranks
80
92
metadata .tensor_properties .dtype = param .dtype
81
93
metadata .tensor_properties .requires_grad = param .requires_grad
82
94
83
- # manually craft metadata, faking the metadata in a way that all other rank only has 0 row
84
- # NOTE this currently only works for row-wise sharding
85
- fake_total_rows = param .size ()[0 ] # pyre-ignore
86
- metadata .size = torch .Size (
87
- [
88
- fake_total_rows if dim == 0 else param .size (dim )
89
- for dim in range (len (param .size ())) # pyre-ignore
90
- ]
91
- )
95
+ offset = 0
92
96
93
97
for rank , shard_metadata in enumerate (metadata .shards_metadata ):
98
+ if use_param_size_as_rows : # respect the param size and treat it as rows
99
+ curr_rank_rows = param .size ()[0 ] # pyre-ignore[16]
100
+ else :
101
+ curr_rank_rows = (
102
+ weight_count_per_rank [rank ] if weight_count_per_rank is not None else 1
103
+ )
94
104
if rank < my_rank :
95
- shard_metadata .shard_sizes = [ # pyre-ignore
96
- 0 if dim == 0 else param .size (dim )
97
- # pyre-ignore
98
- for dim in range (len (param .size ()))
105
+ shard_metadata .shard_sizes = [
106
+ curr_rank_rows if dim == 0 else param .size (dim )
107
+ for dim in range (len (param .size ())) # pyre-ignore[6]
99
108
]
100
109
shard_metadata .shard_offsets = [
101
- 0 for _ in range (len (param .size ())) # pyre-ignore
110
+ offset if dim == 0 else 0 for dim in range (len (param .size ())) # pyre-ignore[6]
102
111
]
103
112
elif rank == my_rank :
104
- create_virtual_table_local_metadata (shard_metadata , param , my_rank )
113
+ curr_rank_rows = param .size ()[0 ] # pyre-ignore[16]
114
+ create_virtual_table_local_metadata (shard_metadata , param , my_rank , offset )
105
115
else :
106
- # pyre-ignore
107
116
shard_metadata .shard_sizes = [
108
- 0 if dim == 0 else param .size (dim )
109
- # pyre-ignore
110
- for dim in range (len (param .size ()))
117
+ curr_rank_rows if dim == 0 else param .size (dim )
118
+ for dim in range (len (param .size ())) # pyre-ignore[6]
111
119
]
112
- # pyre-ignore
113
120
shard_metadata .shard_offsets = [
114
- param .size (0 ) if dim == 0 else 0
115
- # pyre-ignore
116
- for dim in range (len (param .size ()))
121
+ offset if dim == 0 else 0 for dim in range (len (param .size ())) # pyre-ignore[6]
117
122
]
123
+ offset += curr_rank_rows
124
+
125
+ metadata .size = torch .Size (
126
+ [offset if dim == 0 else param .size (dim ) for dim in range (len (param .size ()))] # pyre-ignore[6]
127
+ )
118
128
119
129
120
130
def create_virtual_sharded_tensors (
121
131
embedding_tables : List [ShardedEmbeddingTable ],
122
132
params : Union [List [torch .Tensor ], List [PartiallyMaterializedTensor ]],
123
133
pg : Optional [dist .ProcessGroup ] = None ,
124
134
prefix : str = "" ,
135
+ table_name_to_weight_count_per_rank : Optional [Dict [str , List [int ]]] = None ,
136
+ use_param_size_as_rows : bool = False ,
125
137
) -> List [ShardedTensor ]:
126
138
"""
127
139
Create virtual sharded tensors for the given embedding tables and parameters.
@@ -139,19 +151,32 @@ def create_virtual_sharded_tensors(
139
151
def get_key_from_embedding_table (embedding_table : ShardedEmbeddingTable ) -> str :
140
152
return prefix + f"{ embedding_table .name } "
141
153
154
+ def get_weight_count_per_rank (table_name : str ) -> Optional [List [int ]]:
155
+ return (
156
+ table_name_to_weight_count_per_rank .get (table_name , None )
157
+ if table_name_to_weight_count_per_rank
158
+ and table_name in table_name_to_weight_count_per_rank .keys ()
159
+ else None
160
+ )
161
+
142
162
my_rank = dist .get_rank ()
143
163
for embedding_table , param in zip (embedding_tables , params ):
144
164
key = get_key_from_embedding_table (embedding_table )
145
165
assert embedding_table .use_virtual_table
146
166
147
167
assert embedding_table .global_metadata is not None
148
168
global_metadata = copy .deepcopy (embedding_table .global_metadata )
149
- create_virtual_table_global_metadata (global_metadata , my_rank , param )
169
+ weight_count_per_rank = get_weight_count_per_rank (embedding_table .name )
170
+ create_virtual_table_global_metadata (
171
+ global_metadata ,
172
+ my_rank ,
173
+ param ,
174
+ weight_count_per_rank ,
175
+ use_param_size_as_rows ,
176
+ )
150
177
key_to_global_metadata [key ] = global_metadata
151
178
152
- assert embedding_table .local_metadata is not None
153
- local_metadata = copy .deepcopy (embedding_table .local_metadata )
154
- create_virtual_table_local_metadata (local_metadata , param , my_rank )
179
+ local_metadata = copy .deepcopy (global_metadata .shards_metadata [my_rank ])
155
180
156
181
key_to_local_shards [key ].append (Shard (param , local_metadata )) # pyre-ignore
157
182
0 commit comments