@@ -51,9 +51,18 @@ def add_argument():
51
51
type = int ,
52
52
help = '(moe) expert parallel world size' )
53
53
parser .add_argument ('--num-experts' ,
54
- default = 1 ,
55
54
type = int ,
56
- help = '(moe) number of total experts' )
55
+ nargs = '+' ,
56
+ default = [
57
+ 1 ,
58
+ ],
59
+ help = 'number of experts list, MoE related.' )
60
+ parser .add_argument (
61
+ '--mlp-type' ,
62
+ type = str ,
63
+ default = 'standard' ,
64
+ help =
65
+ 'Only applicable when num-experts > 1, accepts [standard, residual]' )
57
66
parser .add_argument ('--top-k' ,
58
67
default = 1 ,
59
68
type = int ,
@@ -168,9 +177,6 @@ def imshow(img):
168
177
169
178
args = add_argument ()
170
179
171
- if args .moe :
172
- deepspeed .utils .groups .initialize (ep_size = args .ep_world_size )
173
-
174
180
175
181
class Net (nn .Module ):
176
182
def __init__ (self ):
@@ -181,14 +187,21 @@ def __init__(self):
181
187
self .fc1 = nn .Linear (16 * 5 * 5 , 120 )
182
188
self .fc2 = nn .Linear (120 , 84 )
183
189
if args .moe :
184
- self .fc3 = nn .Linear (84 , 84 )
185
- self .fc3 = deepspeed .moe .layer .MoE (
186
- hidden_size = 84 ,
187
- expert = self .fc3 ,
188
- num_experts = args .num_experts ,
189
- k = args .top_k ,
190
- min_capacity = args .min_capacity ,
191
- noisy_gate_policy = args .noisy_gate_policy )
190
+ fc3 = nn .Linear (84 , 84 )
191
+ self .moe_layer_list = []
192
+ for n_e in args .num_experts :
193
+ # create moe layers based on the number of experts
194
+ self .moe_layer_list .append (
195
+ deepspeed .moe .layer .MoE (
196
+ hidden_size = 84 ,
197
+ expert = fc3 ,
198
+ num_experts = n_e ,
199
+ ep_size = args .ep_world_size ,
200
+ use_residual = args .mlp_type == 'residual' ,
201
+ k = args .top_k ,
202
+ min_capacity = args .min_capacity ,
203
+ noisy_gate_policy = args .noisy_gate_policy ))
204
+ self .moe_layer_list = nn .ModuleList (self .moe_layer_list )
192
205
self .fc4 = nn .Linear (84 , 10 )
193
206
else :
194
207
self .fc3 = nn .Linear (84 , 10 )
@@ -200,7 +213,8 @@ def forward(self, x):
200
213
x = F .relu (self .fc1 (x ))
201
214
x = F .relu (self .fc2 (x ))
202
215
if args .moe :
203
- x , _ , _ = self .fc3 (x )
216
+ for layer in self .moe_layer_list :
217
+ x , _ , _ = layer (x )
204
218
x = self .fc4 (x )
205
219
else :
206
220
x = self .fc3 (x )
@@ -213,7 +227,10 @@ def forward(self, x):
213
227
def create_moe_param_groups (model ):
214
228
from deepspeed .moe .utils import split_params_into_different_moe_groups_for_optimizer
215
229
216
- parameters = {'params' : model .parameters (), 'name' : 'parameters' }
230
+ parameters = {
231
+ 'params' : [p for p in model .parameters ()],
232
+ 'name' : 'parameters'
233
+ }
217
234
218
235
return split_params_into_different_moe_groups_for_optimizer (parameters )
219
236
0 commit comments