Skip to content

Commit

Permalink
Introduce the scale enum flag in Embedding layer for LLM embedding. (#…
Browse files Browse the repository at this point in the history
…909)

The activation component should roughly have a magnitude of 1. Since the embedding tensor is
initialized with a scale of `1/sqrt(dim)`, the activation is multiplied by `sqrt(dim)` to
maintain the desired scale. e.g. Gemma [1]
[1] https://github.com/google-deepmind/gemma/blob/0d6ae857591248422127ca14c027909546362e6a/gemma/modules.py#L80

In addition, unsloth [2] discovered that `sqrt(dim)` needs to be computed in float32.
[2] Sec 3 in https://unsloth.ai/blog/gemma-bugs

TODO(axlearn-team): Use UNIT scale enum for AFM+. This will require re-sweeping
hyperparameters (e.g., learning rate).
  • Loading branch information
ds-hwang authored Jan 8, 2025
1 parent 6559036 commit 2d1fb29
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 5 deletions.
30 changes: 30 additions & 0 deletions axlearn/common/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,36 @@ def test_embed_attend(self, soft_cap_logits, is_training):
ref = soft_cap_logits * jnp.tanh(ref / soft_cap_logits)
assert_allclose(ref, actual_attends)

def test_embed_with_emb_scale(self):
seq_len = 5
vocab_size = 24
hidden_dim = 256

emb = TransformerTextEmbeddings.default_config().set(
name="embed",
dim=hidden_dim,
vocab_size=vocab_size,
)
emb.token_emb.set(scale=emb.token_emb.klass.Scale.UNIT)
layer = emb.instantiate(parent=None)

prng_key = jax.random.PRNGKey(1)
prng_key, init_key, data_key, fwd_key = jax.random.split(prng_key, num=4)
state = layer.initialize_parameters_recursively(init_key)

input_ids = jax.random.randint(data_key, shape=(3, seq_len), minval=1, maxval=vocab_size)
test_inputs = dict(inputs=input_ids)
outputs, _ = module.functional(
layer,
prng_key=fwd_key,
state=state,
inputs=dict(input_batch=test_inputs),
is_training=False,
)

assert_allclose(jnp.mean(outputs), 0.0, atol=0.05)
assert_allclose(jnp.std(outputs), 1.0, atol=0.05)


if __name__ == "__main__":
with utils.numeric_checks(True):
Expand Down
36 changes: 36 additions & 0 deletions axlearn/common/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Basic layers."""

import enum
import math
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union

Expand Down Expand Up @@ -808,6 +809,21 @@ class Embedding(BaseLayer):
Batched map for int in [0, <num_embeddings>) -> <dim> float vector.
"""

class Scale(enum.Enum):
"""Defines the scale method on embedding activations.
Available types:
1. **UNIT**: Scale the activation components to ~1.
The activation component should roughly have a magnitude of 1. Since the embedding tensor is
initialized with a scale of `1/√dim`, the activation is multiplied by `√dim` to
maintain the desired scale. e.g. Gemma [1]
[1]
https://github.com/google-deepmind/gemma/blob/0d6ae857591248422127ca14c027909546362e6a/gemma/modules.py#L80
"""

UNIT = "unit"

@config_class
class Config(BaseLayer.Config):
"""Configures Embedding."""
Expand All @@ -820,6 +836,8 @@ class Config(BaseLayer.Config):
embedding_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition output activation values.
output_partition_spec: Optional[tuple[Optional[str]]] = None
# Optional scaling of the embedding activations.
scale: Optional["Embedding.Scale"] = None

@classmethod
def default_config(cls):
Expand Down Expand Up @@ -859,9 +877,27 @@ def forward(self, x: Tensor) -> Tensor:
emb = self.parameters["weight"]
emb = maybe_shard(emb, cfg.embedding_partition_spec)
activation = emb[x]
activation = self._scale(activation)
activation = maybe_shard(activation, cfg.output_partition_spec)
return activation

def _scale(self, x: Tensor) -> Tensor:
"""Scale the activation if needed."""
cfg = self.config
if cfg.scale is None:
return x

# Unsloth [1] discovered that `sqrt(dim)` needs to be computed in float32.
# [1] Sec 3 in https://unsloth.ai/blog/gemma-bugs.html
x_dtype = x.dtype
x = x.astype(jnp.float32)
if cfg.scale == self.Scale.UNIT:
x = x * math.sqrt(x.shape[-1])
else:
raise ValueError(f"Unknown scale {cfg.scale}.")
x = x.astype(x_dtype)
return x

def attend(self, x: Tensor) -> Tensor:
"""Apply query array 'x' to the embedding weight array.
Expand Down
31 changes: 26 additions & 5 deletions axlearn/common/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,11 +1230,10 @@ def test_moving_average(self):

class EmbedTest(parameterized.TestCase):
@staticmethod
def build_embedder(dim, num_embeddings, rng):
cfg = Embedding.default_config()
cfg.dim = dim
cfg.num_embeddings = num_embeddings
cfg.name = "embed"
def build_embedder(dim, num_embeddings, rng, **kwargs):
cfg = Embedding.default_config().set(name="embed", dim=dim, num_embeddings=num_embeddings)
if kwargs:
cfg = cfg.set(**kwargs)
emb = cfg.instantiate(parent=None)
state = emb.initialize_parameters_recursively(rng)
return (emb, state)
Expand All @@ -1249,6 +1248,28 @@ def test_embed_lookup(self, seq_len, dim, num_embeddings, is_training):
)
np.testing.assert_array_equal(state["weight"][ixs], actual_embeds)

def test_embed_with_scale(self):
dim = 256
num_embeddings = 16
prng_key = jax.random.PRNGKey(123)
prng_key, input_key, fwd_key = jax.random.split(prng_key, num=3)
embedder, state = EmbedTest.build_embedder(
dim, num_embeddings, input_key, scale=Embedding.Scale.UNIT
)
batch, seq_len = 5, 8
ixs = jax.random.randint(input_key, minval=0, maxval=num_embeddings, shape=(batch, seq_len))

outputs, _ = F(
embedder,
inputs=(ixs,),
is_training=True,
state=state,
prng_key=fwd_key,
)

assert_allclose(jnp.mean(outputs), 0.0, atol=0.05)
assert_allclose(jnp.std(outputs), 1.0, atol=0.05)

@parameterized.parameters(itertools.product((5, 7), (2, 16), (10, 100), (True, False)))
def test_embed_attend(self, seq_len, dim, num_embeddings, is_training):
rng = jax.random.PRNGKey(1)
Expand Down

0 comments on commit 2d1fb29

Please sign in to comment.