Skip to content

Commit def75dd

Browse files
authored
feat(moe): impl moe with megablock kernel (#76)
1 parent e9fcf55 commit def75dd

File tree

11 files changed

+1022
-9
lines changed

11 files changed

+1022
-9
lines changed

configs/7B_MoE4_sft.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@
149149
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
150150
num_experts=4,
151151
moe_use_residual=False,
152-
moe_type="GShard",
152+
moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D"
153153
)
154154
"""
155155
zero1 parallel (dict):
@@ -200,6 +200,7 @@
200200
)
201201

202202
# custom moe impl configs
203+
# GShard MoE config
203204
moe = dict(
204205
top_k=2,
205206
capacity_factor=1.0,
@@ -210,6 +211,14 @@
210211
use_rts=True,
211212
)
212213

214+
# MegaBlock MoE config
215+
# moe = dict(
216+
# top_k=2,
217+
# capacity_factor=1.0, # only used in MegaBlock(non-dmoe)
218+
# drop_tokens=True, # only used in MegaBlock(non-dmoe)
219+
# #parallel_mode="tensor", # only used in MegaBlock-D(dmoe), parallel_mode can be tensor or weight
220+
# )
221+
213222
model_type = "INTERNLM_MoE"
214223

215224
# metric_dtype can be "fp32" or other string

internlm/initialize/launch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from internlm.core.context import Config
1313
from internlm.core.context import global_context as gpc
1414
from internlm.core.context.process_group_initializer import ParallelMode
15+
from internlm.moe.megablock.utils import check_megablock_installed, check_stk_installed
1516
from internlm.monitor import initialize_light_monitor
1617
from internlm.utils.common import get_master_node
1718
from internlm.utils.gputest import warmup_process_group
@@ -314,6 +315,13 @@ def args_sanity_check():
314315
model._add_item("moe_use_residual", False)
315316
if "moe_type" not in model:
316317
model._add_item("moe_type", "GShard")
318+
# check dependency
319+
if gpc.config.model.moe_type == "MegaBlock":
320+
check_megablock_installed()
321+
if gpc.config.model.moe_type == "MegaBlock-D":
322+
check_megablock_installed()
323+
check_stk_installed()
324+
317325
# process the parallel config
318326
if "sequence_parallel" not in gpc.config.parallel:
319327
gpc.config.parallel._add_item("sequence_parallel", False)

internlm/moe/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,26 @@
11
from .gshard_moe import GShardMOELayer
22

33
__all__ = ["GShardMOELayer"]
4+
5+
try:
6+
from megablocks import ops # noqa # pylint: disable=W0611
7+
except ModuleNotFoundError:
8+
pass
9+
else:
10+
from internlm.moe.megablock.megablock_moe import ( # noqa # pylint: disable=W0611
11+
MegaBlockMoE,
12+
)
13+
14+
__all__ += "MegaBlockMoE"
15+
16+
try:
17+
import stk # noqa # pylint: disable=W0611
18+
from megablocks import ops # noqa # pylint: disable=W0611
19+
except ModuleNotFoundError:
20+
pass
21+
else:
22+
from internlm.moe.megablock.megablock_dmoe import ( # noqa # pylint: disable=W0611
23+
MegaBlockdMoE,
24+
)
25+
26+
__all__ += "MegaBlockdMoE"

internlm/moe/gshard_moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,9 +385,9 @@ class GShardMOELayer(BaseMoELayer):
385385

386386
def __init__(
387387
self,
388-
hidden_size,
388+
hidden_size: int,
389389
num_experts: int,
390-
ep_group,
390+
ep_group: Optional[torch.distributed.ProcessGroup],
391391
ep_size: int,
392392
top_k: int = 1,
393393
capacity_factor: float = 1.0,
@@ -396,8 +396,8 @@ def __init__(
396396
noisy_gate_policy: str = None,
397397
drop_tokens: bool = True,
398398
use_rts: bool = True,
399-
device=None,
400-
dtype=None,
399+
device: Optional[torch.device] = None,
400+
dtype: Optional[torch.device] = None,
401401
) -> None:
402402
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
403403
"Unsupported noisy_gate_policy: " + noisy_gate_policy

internlm/moe/megablock/__init__.py

Whitespace-only changes.
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
from typing import Optional, Tuple
2+
3+
import numpy as np
4+
import stk
5+
import torch
6+
from megablocks import ops
7+
8+
from internlm.core.context import ParallelMode
9+
from internlm.core.context import global_context as gpc
10+
from internlm.moe.base_moe import BaseMoELayer
11+
from internlm.moe.megablock.megablock_moe import MegaBlockMoE
12+
from internlm.moe.megablock.mlp import MegaBlockGroupedFeedForward
13+
from internlm.moe.megablock.utils import promote_scalar
14+
from internlm.utils.registry import MODEL_INITIALIZER
15+
16+
17+
@MODEL_INITIALIZER.register_module(module_name="MegaBlock-D")
18+
class MegaBlockdMoE(MegaBlockMoE):
19+
"""
20+
Built on the paper and library Megablocks as described in
21+
https://arxiv.org/abs/2211.15841. This implementation is
22+
strictly equivalent to standard MoE with full capacity (no
23+
dropped tokens). It's faster since it formulates MoE operations
24+
in terms of block-sparse operations to accomodate imbalanced
25+
assignments of tokens to experts, whereas standard MoE either
26+
(1) drop tokens at the cost of reduced performance or (2) set
27+
capacity factor to number of experts and thus waste computation
28+
and memory on padding.
29+
"""
30+
31+
def __init__( # pylint: disable=W0231
32+
self,
33+
hidden_size: int,
34+
ep_group: Optional[torch.distributed.ProcessGroup],
35+
ep_size: int,
36+
num_experts: int,
37+
top_k: int = 1,
38+
parallel_mode: str = "tensor",
39+
device: Optional[torch.device] = None,
40+
dtype: Optional[torch.device] = None,
41+
multiple_of: int = 256,
42+
) -> None:
43+
assert gpc.expert_parallel_size == 1, "do not support expert parallel"
44+
self.top_k = top_k
45+
self.num_experts = num_experts
46+
47+
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
48+
self.ffn_dim = multiple_of * ((int(hidden_size * gpc.config.model.mlp_ratio) + multiple_of - 1) // multiple_of)
49+
assert self.ffn_dim % tp_size == 0
50+
if parallel_mode == "tensor":
51+
self.ffn_dim_per_row = self.ffn_dim // tp_size // ep_size
52+
else:
53+
self.ffn_dim_per_row = self.ffn_dim // ep_size
54+
BaseMoELayer.__init__( # pylint: disable=W0233
55+
self,
56+
torch.nn.Linear(hidden_size, num_experts, bias=False),
57+
MegaBlockGroupedFeedForward(
58+
hidden_size,
59+
(self.ffn_dim // tp_size) * (num_experts // ep_size),
60+
parallel_mode,
61+
device,
62+
dtype,
63+
),
64+
ep_group,
65+
ep_size,
66+
1,
67+
)
68+
69+
# Calculate the number of bits needed to represent the expert indices
70+
# so that we can pass it to radix sort.
71+
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
72+
self.blocking = 128
73+
self.quantize_scatter_num_bits = -1
74+
75+
# Calculate the number of bits needed to represent the column indices
76+
# in the intermediate sparse matrix.
77+
max_column_index = (self.ffn_dim * (self.num_experts // ep_size)) // self.blocking
78+
self.transpose_sort_end_bit = max(int(np.ceil(np.log2(max_column_index))), 1)
79+
80+
# re-init the number of experts in each device
81+
self.num_local_experts = num_experts // ep_size
82+
83+
self.forward_fn = self._forward
84+
85+
def sparse_transpose(
86+
self, size: int, row_indices, column_indices
87+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
88+
block_columns = size[1] // self.blocking
89+
90+
# Sort row indices by column indices to get the transposed matrix's
91+
# column indices.
92+
#
93+
# NOTE: Our sort operation uses the same width indices as the input
94+
# values. To avoid overflow when we have large activation matrices
95+
# we cast to 32-bit before sorting.
96+
_, gather_indices = ops.sort(column_indices.int(), self.transpose_sort_end_bit)
97+
98+
# There are a constant number of blocks in every row of the sparse
99+
# matrix. A blocks offset is:
100+
#
101+
# row_index * blocks_per_row + column_index % blocks_per_row
102+
#
103+
# Once we have the block offsets ordered for transposition we can
104+
# divide by blocks_per_row to get the transposed column indices.
105+
column_indices_t = row_indices.gather(0, gather_indices.long())
106+
block_offsets_t = gather_indices.int()
107+
108+
zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
109+
nnz_per_column = ops.histogram(column_indices, block_columns)
110+
nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
111+
offsets_t = torch.cat([zero, nnz_per_column])
112+
return column_indices_t, offsets_t, block_offsets_t
113+
114+
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor) -> stk.Matrix:
115+
padded_tokens, _ = x.size()
116+
assert padded_tokens % self.blocking == 0
117+
assert self.ffn_dim_per_row % self.blocking == 0
118+
119+
# Offsets for the sparse matrix. All rows have the
120+
# same number of nonzero blocks dictated by the
121+
# dimensionality of a single expert.
122+
block_rows = padded_tokens // self.blocking
123+
blocks_per_row = self.ffn_dim_per_row // self.blocking
124+
offsets = torch.arange(
125+
0,
126+
block_rows * blocks_per_row + 1,
127+
blocks_per_row,
128+
dtype=torch.int32,
129+
device=x.device,
130+
)
131+
132+
# Indices for the sparse matrix. The indices for
133+
# the intermediate matrix are dynamic depending
134+
# on the mapping of tokens to experts.
135+
column_indices = ops.topology(padded_bins, self.blocking, block_rows, blocks_per_row)
136+
137+
# TODO(tgale): This is unused. Remove the need for this in stk.
138+
# For now, use meta init to save the device memory.
139+
data = torch.empty(
140+
column_indices.numel(),
141+
self.blocking,
142+
self.blocking,
143+
dtype=x.dtype,
144+
device="meta",
145+
)
146+
shape = (padded_tokens, self.ffn_dim_per_row * self.num_experts)
147+
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
148+
column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(shape, row_indices, column_indices)
149+
return stk.Matrix(
150+
shape,
151+
data,
152+
row_indices,
153+
column_indices,
154+
offsets,
155+
column_indices_t,
156+
offsets_t,
157+
block_offsets_t,
158+
)
159+
160+
def indices_and_padded_bins(
161+
self, selected_experts: torch.Tensor
162+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
163+
# Sort the expert ids to produce the scatter/gather
164+
# indices for the permutation.
165+
selected_experts = selected_experts.int()
166+
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
167+
168+
# Histogram the expert ids to identify the number of
169+
# tokens routed to each expert.
170+
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
171+
172+
# Round the token counts up to the block size used in
173+
# the matrix muliplications. Caculate the starting
174+
# position of each bin.
175+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking)
176+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
177+
padded_bins = promote_scalar(padded_bins)
178+
179+
# Calculate the bin bounds for the sorted tokens.
180+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
181+
bins = promote_scalar(bins)
182+
return indices, bin_ids, bins, padded_bins, tokens_per_expert
183+
184+
def _forward(self, x, expert_weights, top_experts) -> torch.Tensor:
185+
with torch.no_grad():
186+
(indices, bin_ids, bins, padded_bins, tokens_per_expert) = self.indices_and_padded_bins(top_experts)
187+
188+
# Permute tokens and pad to prepare expert computation
189+
# (top_k * sequence_length + padding, model_dim)
190+
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
191+
192+
# Create the sparse matrix topology
193+
with torch.no_grad():
194+
topo = self.topology(x, padded_bins)
195+
196+
# Perform the expert computation
197+
# First Dense x Dense -> Sparse for w1 and w3,
198+
# (top_k * sequence_length + padding, ffn_dim * n_experts)
199+
x = self.experts(x, topo=topo)
200+
201+
# Permute back and remove padding
202+
# (top_k * sequence_length, model_dim)
203+
x = ops.padded_scatter(
204+
x,
205+
indices,
206+
bin_ids,
207+
expert_weights,
208+
bins,
209+
padded_bins,
210+
self.top_k,
211+
self.quantize_scatter_num_bits,
212+
)
213+
214+
return x, tokens_per_expert.flatten()

0 commit comments

Comments
 (0)