Skip to content

Commit cbed07d

Browse files
ds-hwangchanglan
authored andcommitted
Add streaming/decoding API support to embeddings.
GitOrigin-RevId: 4572855cc635b40ea1637948f39d8374d02bf86e
1 parent 2126f88 commit cbed07d

File tree

2 files changed

+60
-24
lines changed

2 files changed

+60
-24
lines changed

axlearn/common/decoder.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
infer_initial_time_step,
3838
sample_decode,
3939
)
40-
from axlearn.common.embedding import TransformerTextEmbeddings
40+
from axlearn.common.embedding import BaseEmbedding, TransformerTextEmbeddings
4141
from axlearn.common.layers import Dropout, LayerNorm, set_dropout_rate_recursively
4242
from axlearn.common.logit_modifiers import LogitsToLogitsFn
4343
from axlearn.common.module import (
@@ -446,7 +446,7 @@ class Config(BaseLayer.Config):
446446
# explicitly.
447447
dropout_rate: float = 0.0
448448
# Vector from input ids table.
449-
emb: TransformerTextEmbeddings.Config = TransformerTextEmbeddings.default_config()
449+
emb: BaseEmbedding.Config = TransformerTextEmbeddings.default_config()
450450
# Transformer model trunk.
451451
transformer: BaseStackedTransformerLayer.Config = StackedTransformerLayer.default_config()
452452
# Layer norm applied to transformer output.
@@ -519,25 +519,25 @@ def _forward_for_mode(
519519
emb_batch = {**input_batch}
520520
emb_batch["inputs"] = emb_batch["input_ids"]
521521

522-
x = self.emb(input_batch=emb_batch)
523-
524522
if mode == ForwardMode.FORWARD:
525-
transformer_state, x = (
526-
None,
527-
self.transformer(
528-
x,
529-
self_attention_logit_biases=self_attention_logit_biases,
530-
target_segment_ids=input_segment_ids,
531-
target_positions=positions,
532-
cross_attention_data=cross_attention_data,
533-
cross_attention_logit_biases=cross_attention_logit_biases,
534-
),
523+
x = self.emb(input_batch=emb_batch)
524+
x = self.transformer(
525+
x,
526+
self_attention_logit_biases=self_attention_logit_biases,
527+
target_segment_ids=input_segment_ids,
528+
target_positions=positions,
529+
cross_attention_data=cross_attention_data,
530+
cross_attention_logit_biases=cross_attention_logit_biases,
535531
)
532+
cached_states = None
536533
elif mode == ForwardMode.INIT_STATES:
537534
assert cached_states is not None
538535
if input_segment_ids is not None:
539536
raise ValueError("input_segment_ids is not supported in INIT_STATES.")
540-
transformer_state, x = self.transformer.init_states(
537+
cached_states["emb"], x = self.emb.extend_step(
538+
cached_states=cached_states["emb"], input_batch=emb_batch
539+
)
540+
cached_states["transformer_state"], x = self.transformer.init_states(
541541
time_step=cached_states["transformer_state"],
542542
data=x,
543543
self_attention_logit_biases=self_attention_logit_biases,
@@ -548,7 +548,10 @@ def _forward_for_mode(
548548
assert cached_states is not None
549549
if input_segment_ids is not None:
550550
raise ValueError("input_segment_ids is not supported in EXTEND_STEP.")
551-
transformer_state, x = self.transformer.extend_step(
551+
cached_states["emb"], x = self.emb.extend_step(
552+
cached_states=cached_states["emb"], input_batch=emb_batch
553+
)
554+
cached_states["transformer_state"], x = self.transformer.extend_step(
552555
cached_states=cached_states["transformer_state"],
553556
data=x,
554557
self_attention_logit_biases=self_attention_logit_biases,
@@ -588,7 +591,7 @@ def _forward_for_mode(
588591
logits = self._output_logits_modifier(logits)
589592
logits = with_sharding_constraint(logits, PartitionSpec(*self.config.logits_partition_spec))
590593
# TODO(markblee): Rename to just "transformer". "transformer_state" is a bit redundant.
591-
return dict(transformer_state=transformer_state), dict(logits=logits, hidden_states=x)
594+
return cached_states, dict(logits=logits, hidden_states=x)
592595

593596
def forward(
594597
self,
@@ -647,12 +650,14 @@ def init_states(
647650
) -> NestedTensor:
648651
"""See `BaseDecoder.init_states` for details."""
649652
cfg: Decoder.Config = self.config
650-
init_state, _ = self.transformer.init_states(
653+
emb = self.emb.init_states(batch_size=batch_size, dtype=dtype)
654+
transformer_state, _ = self.transformer.init_states(
651655
time_step=None,
652656
data=TensorSpec([batch_size, max_sequence_length, cfg.dim], dtype=dtype),
653657
)
654658
return dict(
655-
transformer_state=init_state,
659+
emb=emb,
660+
transformer_state=transformer_state,
656661
input_ids=jnp.full(
657662
(batch_size, max_sequence_length), cfg.pad_token_id, dtype=jnp.int32
658663
),
@@ -677,13 +682,14 @@ def prefill_states(
677682
See `BaseDecoder.prefill_states` for details.
678683
"""
679684
validate_contains_paths(input_batch, paths=["input_ids"])
680-
input_ids = input_batch["input_ids"]
685+
input_ids: Tensor = input_batch["input_ids"]
681686
input_segment_ids = input_batch.get("input_segment_ids", None)
682687
positions = input_batch.get("positions", None)
683688

689+
emb = self.emb.init_states(batch_size=input_ids.shape[0], dtype=self.dtype())
684690
states, outputs = self._forward_for_mode(
685691
mode=ForwardMode.INIT_STATES,
686-
cached_states=dict(transformer_state=time_step),
692+
cached_states=dict(emb=emb, transformer_state=time_step),
687693
input_batch=input_batch,
688694
# TODO(markblee): Consider supporting packed inputs for more efficient prefilling.
689695
self_attention_logit_biases=self.compute_attention_logit_biases(
@@ -748,14 +754,13 @@ def extend_step(
748754
cached_states=cached_states,
749755
**kwargs,
750756
)
751-
updated_states = dict(
757+
updated_states.update(
752758
input_ids=updated_inputs,
753759
# There are some non-greedy DFS/BFS and sliding attention algorithms that
754760
# recursively search through potentials.
755761
# They backtrace to some anchor time step after exploring for t steps.
756762
# This requires tracking time_step separately from the attention time_step.
757763
time_step=cached_states["time_step"] + 1,
758-
**updated_states,
759764
)
760765
return updated_states, outputs
761766

axlearn/common/embedding.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from axlearn.common.base_layer import BaseLayer
1212
from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class
1313
from axlearn.common.layers import Dropout, Embedding
14-
from axlearn.common.module import Module, Tensor, child_context
14+
from axlearn.common.module import Module, Tensor, child_context, nowrap
1515
from axlearn.common.utils import Nested, validate_contains_paths
1616

1717

@@ -50,6 +50,37 @@ def attend(self, x: Tensor):
5050
"""
5151
raise NotImplementedError(type(self))
5252

53+
@nowrap
54+
def init_states(self, *, batch_size: int, dtype: jnp.dtype) -> Nested[Tensor]:
55+
"""Initializes state for streaming decode.
56+
57+
Args:
58+
batch_size: Batch size.
59+
dtype: dtype for the decoding cache.
60+
61+
Returns:
62+
A nested dict of initial states. Returns empty dict by default.
63+
"""
64+
del batch_size, dtype
65+
return dict()
66+
67+
def extend_step(
68+
self, *, cached_states: Nested[Tensor], input_batch: Nested[Tensor]
69+
) -> tuple[Nested[Tensor], Tensor]:
70+
"""Extends one step for streaming decode.
71+
72+
TODO(dhwang2): prefill uses `extend_step`, which has a performance issue. This will be
73+
resolved after #2057.
74+
75+
Args:
76+
cached_states: Cached states from previous step or init_state.
77+
input_batch: Input batch for current step.
78+
79+
Returns:
80+
A tuple of (updated_states, embeddings).
81+
"""
82+
return cached_states, self.forward(input_batch)
83+
5384

5485
class TransformerTextEmbeddings(BaseEmbedding):
5586
"""Textual embeddings from token id, position and token type embeddings."""

0 commit comments

Comments
 (0)