Skip to content

Commit 41545ce

Browse files
authored
Fix(mha,linear): fix norm_head and mha inference (#234)
Co-authored-by: shidongxing <shidongxing@>
1 parent c355767 commit 41545ce

File tree

5 files changed

+32
-23
lines changed

5 files changed

+32
-23
lines changed

internlm/core/parallel/comm/__init__.py

Whitespace-only changes.

internlm/core/parallel/comm/tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def grad_output_hook(
233233
if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1:
234234
return grad_output, DUMMY_HANDLE_CONST
235235

236-
return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1)
236+
return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST
237237

238238
def output_hook(
239239
self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
@@ -244,7 +244,7 @@ def output_hook(
244244
if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1:
245245
return output, DUMMY_HANDLE_CONST
246246

247-
return _gather(output, parallel_mode=self._parallel_mode, dim=-1)
247+
return _gather(output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST
248248

249249

250250
class HeadSequenceParallelCommunicator(SequenceParallelCommunicator):
@@ -274,7 +274,7 @@ def grad_output_hook(
274274
if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1:
275275
return grad_output, DUMMY_HANDLE_CONST
276276

277-
return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1)
277+
return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST
278278

279279
# rewrite ouput communication hook
280280
def output_hook(
@@ -286,7 +286,7 @@ def output_hook(
286286
if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1:
287287
return output, DUMMY_HANDLE_CONST
288288

289-
return _gather(output, parallel_mode=self._parallel_mode, dim=-1)
289+
return _gather(output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST
290290

291291

292292
class MoESequenceParallelCommunicator:

internlm/model/modules/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def forward(self, input): # pylint: disable=W0622
492492

493493
return fused_dense_func(
494494
input,
495-
self.weight,
495+
weight,
496496
communicator=self._communicator,
497497
module=self,
498498
bias=self.bias,

internlm/model/modules/mha.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -184,18 +184,22 @@ def _convert_unpacked_qkv_to_packed(
184184
max_seqlen_q = attention_mask.shape[-1]
185185
max_seqlen_k = attention_mask.shape[-1]
186186

187-
q_packed = q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1])
188-
kv_packed = kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)).view(
189-
-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]
187+
q_packed = (
188+
q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0)
189+
)
190+
kv_packed = (
191+
kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1))
192+
.view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1])
193+
.unsqueeze(0)
190194
)
191195

192196
return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
193197

194198
def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613
195199
assert inference_params is not None, "inference_params is required for inference"
196200
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
197-
attention_mask = inference_params.get("attention_mask", None)
198-
sequence_len_offset = inference_params.get("sequence_len_offset", 0)
201+
attention_mask = inference_params.attention_mask
202+
sequence_len_offset = inference_params.sequence_len_offset
199203
batch_size = x.shape[0]
200204

201205
# wqkv, output: q, kv
@@ -230,21 +234,21 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613
230234
q = self.rotary_emb(
231235
q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved
232236
)
233-
k = kv[:, :, 0].squeueze(2)
237+
k = kv[:, :, 0].squeeze(2)
234238
self.rotary_emb(
235239
k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True
236240
) # in-place is important
237241
else:
238242
if self.rotary_emb_dim > 0:
239243
q = self.rotary_emb(q, offsets=0, cache_type="query", interleaved=self.interleaved)
240-
k = kv[:, :, 0].squeueze(2)
244+
k = kv[:, :, 0].squeeze(2)
241245
self.rotary_emb(
242246
k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True
243247
) # in-place is important
244248
else:
245249
assert self.rotary_emb_dim > 0, "You should use rotary_emb."
246250

247-
k, v = kv[:, :, 0].squeueze(2), kv[:, :, 1].squeueze(2)
251+
k, v = kv[:, :, 0].squeeze(2), kv[:, :, 1].squeeze(2)
248252

249253
if attention_mask is None:
250254
q = self.rotary_emb(q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved)
@@ -474,27 +478,31 @@ def _convert_unpacked_qkv_to_packed(
474478
max_seqlen_q = attention_mask.shape[-1]
475479
max_seqlen_k = attention_mask.shape[-1]
476480

477-
q_packed = q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1])
478-
kv_packed = kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)).view(
479-
-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]
481+
q_packed = (
482+
q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0)
483+
)
484+
kv_packed = (
485+
kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1))
486+
.view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1])
487+
.unsqueeze(0)
480488
)
481489

482490
return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
483491

484492
def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613
485493
assert inference_params is not None, "inference_params is required for inference"
486494
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
487-
attention_mask = inference_params.get("attention_mask", None)
488-
sequence_len_offset = inference_params.get("sequence_len_offset", 0)
489-
window_size = inference_params.get("window_size", None)
495+
attention_mask = inference_params.attention_mask
496+
sequence_len_offset = inference_params.sequence_len_offset
497+
window_size = inference_params.window_size
490498

491499
batch_size = x.shape[0]
492500

493501
# wqkv, output: q, k, v
494502
if self.enable_qkv_fusion:
495503
qkv = self.wqkv(x)
496504
qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=self.q_per_kv + 2, d=self.head_dim)
497-
q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :].unsqueeze(-2), qkv[..., -1, :].unsqueeze(-2))
505+
q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :])
498506
q = rearrange(q, "b s h gs d -> b s (h gs) d")
499507
else:
500508
q, k, v = self.wq(x), self.wk(x), self.wv(x)

internlm/utils/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,15 @@ def __kv_checker(num_args: int):
6262
# kv: [batch, seqlen, 3, n_head, headdim]
6363
return len(args[2].shape) == 5
6464

65-
def __cu_seqlens_checker(num_args: int, check_idx: int):
65+
def __cu_seqlens_checker(args, check_idx: int):
66+
num_args = len(args)
6667
if num_args < (check_idx + 1):
6768
if check_idx == 2:
6869
return "cu_seqlens" in kwargs and kwargs["cu_seqlens"] is not None
6970
else:
7071
return "cu_seqlens_q" in kwargs and kwargs["cu_seqlens_q"] is not None
7172
else:
72-
return isinstance(num_args[check_idx], torch.Tensor)
73+
return isinstance(args[check_idx], torch.Tensor)
7374

7475
if __qkv_checker(len(args)):
7576
# qkv packed, and we should check cu_seqlens with index 2
@@ -81,7 +82,7 @@ def __cu_seqlens_checker(num_args: int, check_idx: int):
8182
# qkv splited, and we should check cu_seqlens with index 4
8283
qkv_pack_type = int(QKVPackType.QKVSPLITED)
8384

84-
with_cu_seqlens = __cu_seqlens_checker(len(args), qkv_pack_type)
85+
with_cu_seqlens = __cu_seqlens_checker(args, qkv_pack_type)
8586

8687
return str(qkv_pack_type), str(with_cu_seqlens)
8788

0 commit comments

Comments
 (0)