|
| 1 | +from typing import Optional, Tuple |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import stk |
| 5 | +import torch |
| 6 | +from megablocks import ops |
| 7 | + |
| 8 | +from internlm.core.context import ParallelMode |
| 9 | +from internlm.core.context import global_context as gpc |
| 10 | +from internlm.moe.base_moe import BaseMoELayer |
| 11 | +from internlm.moe.megablock.megablock_moe import MegaBlockMoE |
| 12 | +from internlm.moe.megablock.mlp import MegaBlockGroupedFeedForward |
| 13 | +from internlm.moe.megablock.utils import promote_scalar |
| 14 | +from internlm.utils.registry import MODEL_INITIALIZER |
| 15 | + |
| 16 | + |
| 17 | +@MODEL_INITIALIZER.register_module(module_name="MegaBlock-D") |
| 18 | +class MegaBlockdMoE(MegaBlockMoE): |
| 19 | + """ |
| 20 | + Built on the paper and library Megablocks as described in |
| 21 | + https://arxiv.org/abs/2211.15841. This implementation is |
| 22 | + strictly equivalent to standard MoE with full capacity (no |
| 23 | + dropped tokens). It's faster since it formulates MoE operations |
| 24 | + in terms of block-sparse operations to accomodate imbalanced |
| 25 | + assignments of tokens to experts, whereas standard MoE either |
| 26 | + (1) drop tokens at the cost of reduced performance or (2) set |
| 27 | + capacity factor to number of experts and thus waste computation |
| 28 | + and memory on padding. |
| 29 | + """ |
| 30 | + |
| 31 | + def __init__( # pylint: disable=W0231 |
| 32 | + self, |
| 33 | + hidden_size: int, |
| 34 | + ep_group: Optional[torch.distributed.ProcessGroup], |
| 35 | + ep_size: int, |
| 36 | + num_experts: int, |
| 37 | + top_k: int = 1, |
| 38 | + parallel_mode: str = "tensor", |
| 39 | + device: Optional[torch.device] = None, |
| 40 | + dtype: Optional[torch.device] = None, |
| 41 | + multiple_of: int = 256, |
| 42 | + ) -> None: |
| 43 | + assert gpc.expert_parallel_size == 1, "do not support expert parallel" |
| 44 | + self.top_k = top_k |
| 45 | + self.num_experts = num_experts |
| 46 | + |
| 47 | + tp_size = gpc.get_world_size(ParallelMode.TENSOR) |
| 48 | + self.ffn_dim = multiple_of * ((int(hidden_size * gpc.config.model.mlp_ratio) + multiple_of - 1) // multiple_of) |
| 49 | + assert self.ffn_dim % tp_size == 0 |
| 50 | + if parallel_mode == "tensor": |
| 51 | + self.ffn_dim_per_row = self.ffn_dim // tp_size // ep_size |
| 52 | + else: |
| 53 | + self.ffn_dim_per_row = self.ffn_dim // ep_size |
| 54 | + BaseMoELayer.__init__( # pylint: disable=W0233 |
| 55 | + self, |
| 56 | + torch.nn.Linear(hidden_size, num_experts, bias=False), |
| 57 | + MegaBlockGroupedFeedForward( |
| 58 | + hidden_size, |
| 59 | + (self.ffn_dim // tp_size) * (num_experts // ep_size), |
| 60 | + parallel_mode, |
| 61 | + device, |
| 62 | + dtype, |
| 63 | + ), |
| 64 | + ep_group, |
| 65 | + ep_size, |
| 66 | + 1, |
| 67 | + ) |
| 68 | + |
| 69 | + # Calculate the number of bits needed to represent the expert indices |
| 70 | + # so that we can pass it to radix sort. |
| 71 | + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) |
| 72 | + self.blocking = 128 |
| 73 | + self.quantize_scatter_num_bits = -1 |
| 74 | + |
| 75 | + # Calculate the number of bits needed to represent the column indices |
| 76 | + # in the intermediate sparse matrix. |
| 77 | + max_column_index = (self.ffn_dim * (self.num_experts // ep_size)) // self.blocking |
| 78 | + self.transpose_sort_end_bit = max(int(np.ceil(np.log2(max_column_index))), 1) |
| 79 | + |
| 80 | + # re-init the number of experts in each device |
| 81 | + self.num_local_experts = num_experts // ep_size |
| 82 | + |
| 83 | + self.forward_fn = self._forward |
| 84 | + |
| 85 | + def sparse_transpose( |
| 86 | + self, size: int, row_indices, column_indices |
| 87 | + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 88 | + block_columns = size[1] // self.blocking |
| 89 | + |
| 90 | + # Sort row indices by column indices to get the transposed matrix's |
| 91 | + # column indices. |
| 92 | + # |
| 93 | + # NOTE: Our sort operation uses the same width indices as the input |
| 94 | + # values. To avoid overflow when we have large activation matrices |
| 95 | + # we cast to 32-bit before sorting. |
| 96 | + _, gather_indices = ops.sort(column_indices.int(), self.transpose_sort_end_bit) |
| 97 | + |
| 98 | + # There are a constant number of blocks in every row of the sparse |
| 99 | + # matrix. A blocks offset is: |
| 100 | + # |
| 101 | + # row_index * blocks_per_row + column_index % blocks_per_row |
| 102 | + # |
| 103 | + # Once we have the block offsets ordered for transposition we can |
| 104 | + # divide by blocks_per_row to get the transposed column indices. |
| 105 | + column_indices_t = row_indices.gather(0, gather_indices.long()) |
| 106 | + block_offsets_t = gather_indices.int() |
| 107 | + |
| 108 | + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) |
| 109 | + nnz_per_column = ops.histogram(column_indices, block_columns) |
| 110 | + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) |
| 111 | + offsets_t = torch.cat([zero, nnz_per_column]) |
| 112 | + return column_indices_t, offsets_t, block_offsets_t |
| 113 | + |
| 114 | + def topology(self, x: torch.Tensor, padded_bins: torch.Tensor) -> stk.Matrix: |
| 115 | + padded_tokens, _ = x.size() |
| 116 | + assert padded_tokens % self.blocking == 0 |
| 117 | + assert self.ffn_dim_per_row % self.blocking == 0 |
| 118 | + |
| 119 | + # Offsets for the sparse matrix. All rows have the |
| 120 | + # same number of nonzero blocks dictated by the |
| 121 | + # dimensionality of a single expert. |
| 122 | + block_rows = padded_tokens // self.blocking |
| 123 | + blocks_per_row = self.ffn_dim_per_row // self.blocking |
| 124 | + offsets = torch.arange( |
| 125 | + 0, |
| 126 | + block_rows * blocks_per_row + 1, |
| 127 | + blocks_per_row, |
| 128 | + dtype=torch.int32, |
| 129 | + device=x.device, |
| 130 | + ) |
| 131 | + |
| 132 | + # Indices for the sparse matrix. The indices for |
| 133 | + # the intermediate matrix are dynamic depending |
| 134 | + # on the mapping of tokens to experts. |
| 135 | + column_indices = ops.topology(padded_bins, self.blocking, block_rows, blocks_per_row) |
| 136 | + |
| 137 | + # TODO(tgale): This is unused. Remove the need for this in stk. |
| 138 | + # For now, use meta init to save the device memory. |
| 139 | + data = torch.empty( |
| 140 | + column_indices.numel(), |
| 141 | + self.blocking, |
| 142 | + self.blocking, |
| 143 | + dtype=x.dtype, |
| 144 | + device="meta", |
| 145 | + ) |
| 146 | + shape = (padded_tokens, self.ffn_dim_per_row * self.num_experts) |
| 147 | + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) |
| 148 | + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(shape, row_indices, column_indices) |
| 149 | + return stk.Matrix( |
| 150 | + shape, |
| 151 | + data, |
| 152 | + row_indices, |
| 153 | + column_indices, |
| 154 | + offsets, |
| 155 | + column_indices_t, |
| 156 | + offsets_t, |
| 157 | + block_offsets_t, |
| 158 | + ) |
| 159 | + |
| 160 | + def indices_and_padded_bins( |
| 161 | + self, selected_experts: torch.Tensor |
| 162 | + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 163 | + # Sort the expert ids to produce the scatter/gather |
| 164 | + # indices for the permutation. |
| 165 | + selected_experts = selected_experts.int() |
| 166 | + bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit) |
| 167 | + |
| 168 | + # Histogram the expert ids to identify the number of |
| 169 | + # tokens routed to each expert. |
| 170 | + tokens_per_expert = ops.histogram(selected_experts, self.num_experts) |
| 171 | + |
| 172 | + # Round the token counts up to the block size used in |
| 173 | + # the matrix muliplications. Caculate the starting |
| 174 | + # position of each bin. |
| 175 | + padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking) |
| 176 | + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) |
| 177 | + padded_bins = promote_scalar(padded_bins) |
| 178 | + |
| 179 | + # Calculate the bin bounds for the sorted tokens. |
| 180 | + bins = ops.inclusive_cumsum(tokens_per_expert, 0) |
| 181 | + bins = promote_scalar(bins) |
| 182 | + return indices, bin_ids, bins, padded_bins, tokens_per_expert |
| 183 | + |
| 184 | + def _forward(self, x, expert_weights, top_experts) -> torch.Tensor: |
| 185 | + with torch.no_grad(): |
| 186 | + (indices, bin_ids, bins, padded_bins, tokens_per_expert) = self.indices_and_padded_bins(top_experts) |
| 187 | + |
| 188 | + # Permute tokens and pad to prepare expert computation |
| 189 | + # (top_k * sequence_length + padding, model_dim) |
| 190 | + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k) |
| 191 | + |
| 192 | + # Create the sparse matrix topology |
| 193 | + with torch.no_grad(): |
| 194 | + topo = self.topology(x, padded_bins) |
| 195 | + |
| 196 | + # Perform the expert computation |
| 197 | + # First Dense x Dense -> Sparse for w1 and w3, |
| 198 | + # (top_k * sequence_length + padding, ffn_dim * n_experts) |
| 199 | + x = self.experts(x, topo=topo) |
| 200 | + |
| 201 | + # Permute back and remove padding |
| 202 | + # (top_k * sequence_length, model_dim) |
| 203 | + x = ops.padded_scatter( |
| 204 | + x, |
| 205 | + indices, |
| 206 | + bin_ids, |
| 207 | + expert_weights, |
| 208 | + bins, |
| 209 | + padded_bins, |
| 210 | + self.top_k, |
| 211 | + self.quantize_scatter_num_bits, |
| 212 | + ) |
| 213 | + |
| 214 | + return x, tokens_per_expert.flatten() |
0 commit comments