Skip to content

Commit

Permalink
feat(model): add rope_base interface (InternLM#512)
Browse files Browse the repository at this point in the history
  • Loading branch information
00INDEX authored Nov 23, 2023
1 parent 7776693 commit 0d3811c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
9 changes: 9 additions & 0 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class PackedFlashBaseLayer1D(nn.Module):
device (Optional[Union[str, torch.device]]): The device will be used.
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
use_flash_attn (bool): Whether use flash-attn. True by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
"""

def __init__(
Expand All @@ -75,6 +76,7 @@ def __init__(
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
rope_base: int = 10000,
):
super().__init__()
self.checkpoint = checkpoint
Expand All @@ -98,6 +100,7 @@ def __init__(
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
use_flash_attn=use_flash_attn,
rope_base=rope_base,
device=device,
dtype=dtype,
)
Expand Down Expand Up @@ -264,6 +267,7 @@ class PackedFlashInternLm1D(nn.Module):
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
"""

Expand Down Expand Up @@ -295,6 +299,7 @@ def __init__(
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
rope_base: int = 10000,
):
super().__init__()

Expand Down Expand Up @@ -344,6 +349,7 @@ def __init__(
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
rope_base=rope_base,
)
for lid in range(num_layers)
]
Expand Down Expand Up @@ -490,6 +496,7 @@ def build_model_with_cfg(
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
rope_base: int = 10000,
):
"""
Build model with config.
Expand Down Expand Up @@ -520,6 +527,7 @@ def build_model_with_cfg(
use_scaled_init (bool): Whether to use scaled init. True by default.
use_swiglu (bool): Whether to use swiglu. True by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
"""

Expand All @@ -545,6 +553,7 @@ def build_model_with_cfg(
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
rope_base=rope_base,
)

return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
7 changes: 6 additions & 1 deletion internlm/model/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class MHA(nn.Module):
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
use_flash_attn (bool): Whether to use flash-attn. True by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
"""

Expand All @@ -80,6 +81,7 @@ def __init__(
rotary_emb_dim: int = 0,
rotary_emb_scale_base: int = 0,
use_flash_attn: bool = True,
rope_base: int = 10000,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
Expand All @@ -100,13 +102,16 @@ def __init__(
if self.use_dynamic_ntk_rope:
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
self.rotary_emb_dim,
base=rope_base,
scale_base=rotary_emb_scale_base,
device=device,
max_position_embeddings=max_position_embeddings,
scaling_factor=1.0, # Currently do not support dynamic scaling.
)
else:
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
self.rotary_emb = RotaryEmbedding(
self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device
)

# notice here should change bias=True
self.Wqkv = ColumnParallelLinearTorch(
Expand Down

0 comments on commit 0d3811c

Please sign in to comment.