From aee457c43f68fcbe2aa5b406e71448b7911c5b8e Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 18 Nov 2024 14:01:06 +0800 Subject: [PATCH] fix(mha.py): fix evaluation argu key err (#370) --- internlm/model/modules/mha.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index a8ef77bc1..42418a212 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -211,7 +211,7 @@ def _training(self, x, **kwargs): # self attention kwargs = _convert_cu_seqlens_for_qksplited(kwargs) - if gpc.config.data.use_packed_dataset is False: + if gpc.config.data.use_packed_dataset is False or self.training is False: kwargs.pop("max_seqlen_q", None) kwargs.pop("max_seqlen_k", None) context = self.inner_attn(q, k, v, **kwargs) @@ -529,7 +529,7 @@ def _training(self, x, **kwargs): kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) - if gpc.config.data.use_packed_dataset is False: + if gpc.config.data.use_packed_dataset is False or self.training is False: kwargs.pop("max_seqlen_q", None) kwargs.pop("max_seqlen_k", None)