diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 6586b11b..a8ef77bc 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -38,7 +38,10 @@ def _convert_cu_seqlens_for_qksplited(kwargs: Dict): def split_fused_wqkv_weight(wqkv, *args, **kwargs): # pylint: disable=W0613 q_dim = kwargs["q_dim"] kv_dim = kwargs["kv_dim"] - wq, wk, wv = torch.split(wqkv, [q_dim, kv_dim, kv_dim], dim=0) + split_size = [q_dim, kv_dim, kv_dim] + assert (q_dim + 2 * kv_dim) % wqkv.size(0) == 0 + divisor = (q_dim + 2 * kv_dim) // wqkv.size(0) + wq, wk, wv = torch.split(wqkv, [x // divisor for x in split_size], dim=0) return wq, wk, wv diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index b51f2b24..5547a9fb 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -934,10 +934,12 @@ def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Opt for mod in modules: inject_funcs[mod](_chunk, inject, interactive) - # reset parameters + # reset parameters and move model to device for _chunk in model: - if inject and reset_params: - _chunk.reset_parameters() + if inject: + if reset_params: + _chunk.reset_parameters() + _chunk.to(get_current_device()) # inject configs if inject: