From 2d1fb295801e954d72400a4c69e066254ff0c84d Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Wed, 8 Jan 2025 13:56:27 -0800 Subject: [PATCH] Introduce the scale enum flag in Embedding layer for LLM embedding. (#909) The activation component should roughly have a magnitude of 1. Since the embedding tensor is initialized with a scale of `1/sqrt(dim)`, the activation is multiplied by `sqrt(dim)` to maintain the desired scale. e.g. Gemma [1] [1] https://github.com/google-deepmind/gemma/blob/0d6ae857591248422127ca14c027909546362e6a/gemma/modules.py#L80 In addition, unsloth [2] discovered that `sqrt(dim)` needs to be computed in float32. [2] Sec 3 in https://unsloth.ai/blog/gemma-bugs TODO(axlearn-team): Use UNIT scale enum for AFM+. This will require re-sweeping hyperparameters (e.g., learning rate). --- axlearn/common/embedding_test.py | 30 ++++++++++++++++++++++++++ axlearn/common/layers.py | 36 ++++++++++++++++++++++++++++++++ axlearn/common/layers_test.py | 31 ++++++++++++++++++++++----- 3 files changed, 92 insertions(+), 5 deletions(-) diff --git a/axlearn/common/embedding_test.py b/axlearn/common/embedding_test.py index 880afd130..15637c9c6 100644 --- a/axlearn/common/embedding_test.py +++ b/axlearn/common/embedding_test.py @@ -111,6 +111,36 @@ def test_embed_attend(self, soft_cap_logits, is_training): ref = soft_cap_logits * jnp.tanh(ref / soft_cap_logits) assert_allclose(ref, actual_attends) + def test_embed_with_emb_scale(self): + seq_len = 5 + vocab_size = 24 + hidden_dim = 256 + + emb = TransformerTextEmbeddings.default_config().set( + name="embed", + dim=hidden_dim, + vocab_size=vocab_size, + ) + emb.token_emb.set(scale=emb.token_emb.klass.Scale.UNIT) + layer = emb.instantiate(parent=None) + + prng_key = jax.random.PRNGKey(1) + prng_key, init_key, data_key, fwd_key = jax.random.split(prng_key, num=4) + state = layer.initialize_parameters_recursively(init_key) + + input_ids = jax.random.randint(data_key, shape=(3, seq_len), minval=1, maxval=vocab_size) + test_inputs = dict(inputs=input_ids) + outputs, _ = module.functional( + layer, + prng_key=fwd_key, + state=state, + inputs=dict(input_batch=test_inputs), + is_training=False, + ) + + assert_allclose(jnp.mean(outputs), 0.0, atol=0.05) + assert_allclose(jnp.std(outputs), 1.0, atol=0.05) + if __name__ == "__main__": with utils.numeric_checks(True): diff --git a/axlearn/common/layers.py b/axlearn/common/layers.py index 5c95fee0a..3cbb4de0c 100644 --- a/axlearn/common/layers.py +++ b/axlearn/common/layers.py @@ -18,6 +18,7 @@ """Basic layers.""" import enum +import math from collections.abc import Sequence from typing import Any, Callable, Optional, Union @@ -808,6 +809,21 @@ class Embedding(BaseLayer): Batched map for int in [0, ) -> float vector. """ + class Scale(enum.Enum): + """Defines the scale method on embedding activations. + + Available types: + 1. **UNIT**: Scale the activation components to ~1. + + The activation component should roughly have a magnitude of 1. Since the embedding tensor is + initialized with a scale of `1/√dim`, the activation is multiplied by `√dim` to + maintain the desired scale. e.g. Gemma [1] + [1] + https://github.com/google-deepmind/gemma/blob/0d6ae857591248422127ca14c027909546362e6a/gemma/modules.py#L80 + """ + + UNIT = "unit" + @config_class class Config(BaseLayer.Config): """Configures Embedding.""" @@ -820,6 +836,8 @@ class Config(BaseLayer.Config): embedding_partition_spec: Optional[tuple[Optional[str]]] = None # If not None, how to partition output activation values. output_partition_spec: Optional[tuple[Optional[str]]] = None + # Optional scaling of the embedding activations. + scale: Optional["Embedding.Scale"] = None @classmethod def default_config(cls): @@ -859,9 +877,27 @@ def forward(self, x: Tensor) -> Tensor: emb = self.parameters["weight"] emb = maybe_shard(emb, cfg.embedding_partition_spec) activation = emb[x] + activation = self._scale(activation) activation = maybe_shard(activation, cfg.output_partition_spec) return activation + def _scale(self, x: Tensor) -> Tensor: + """Scale the activation if needed.""" + cfg = self.config + if cfg.scale is None: + return x + + # Unsloth [1] discovered that `sqrt(dim)` needs to be computed in float32. + # [1] Sec 3 in https://unsloth.ai/blog/gemma-bugs.html + x_dtype = x.dtype + x = x.astype(jnp.float32) + if cfg.scale == self.Scale.UNIT: + x = x * math.sqrt(x.shape[-1]) + else: + raise ValueError(f"Unknown scale {cfg.scale}.") + x = x.astype(x_dtype) + return x + def attend(self, x: Tensor) -> Tensor: """Apply query array 'x' to the embedding weight array. diff --git a/axlearn/common/layers_test.py b/axlearn/common/layers_test.py index 18e600851..14cbc1279 100644 --- a/axlearn/common/layers_test.py +++ b/axlearn/common/layers_test.py @@ -1230,11 +1230,10 @@ def test_moving_average(self): class EmbedTest(parameterized.TestCase): @staticmethod - def build_embedder(dim, num_embeddings, rng): - cfg = Embedding.default_config() - cfg.dim = dim - cfg.num_embeddings = num_embeddings - cfg.name = "embed" + def build_embedder(dim, num_embeddings, rng, **kwargs): + cfg = Embedding.default_config().set(name="embed", dim=dim, num_embeddings=num_embeddings) + if kwargs: + cfg = cfg.set(**kwargs) emb = cfg.instantiate(parent=None) state = emb.initialize_parameters_recursively(rng) return (emb, state) @@ -1249,6 +1248,28 @@ def test_embed_lookup(self, seq_len, dim, num_embeddings, is_training): ) np.testing.assert_array_equal(state["weight"][ixs], actual_embeds) + def test_embed_with_scale(self): + dim = 256 + num_embeddings = 16 + prng_key = jax.random.PRNGKey(123) + prng_key, input_key, fwd_key = jax.random.split(prng_key, num=3) + embedder, state = EmbedTest.build_embedder( + dim, num_embeddings, input_key, scale=Embedding.Scale.UNIT + ) + batch, seq_len = 5, 8 + ixs = jax.random.randint(input_key, minval=0, maxval=num_embeddings, shape=(batch, seq_len)) + + outputs, _ = F( + embedder, + inputs=(ixs,), + is_training=True, + state=state, + prng_key=fwd_key, + ) + + assert_allclose(jnp.mean(outputs), 0.0, atol=0.05) + assert_allclose(jnp.std(outputs), 1.0, atol=0.05) + @parameterized.parameters(itertools.product((5, 7), (2, 16), (10, 100), (True, False))) def test_embed_attend(self, seq_len, dim, num_embeddings, is_training): rng = jax.random.PRNGKey(1)