3737 infer_initial_time_step ,
3838 sample_decode ,
3939)
40- from axlearn .common .embedding import TransformerTextEmbeddings
40+ from axlearn .common .embedding import BaseEmbedding , TransformerTextEmbeddings
4141from axlearn .common .layers import Dropout , LayerNorm , set_dropout_rate_recursively
4242from axlearn .common .logit_modifiers import LogitsToLogitsFn
4343from axlearn .common .module import (
@@ -446,7 +446,7 @@ class Config(BaseLayer.Config):
446446 # explicitly.
447447 dropout_rate : float = 0.0
448448 # Vector from input ids table.
449- emb : TransformerTextEmbeddings .Config = TransformerTextEmbeddings .default_config ()
449+ emb : BaseEmbedding .Config = TransformerTextEmbeddings .default_config ()
450450 # Transformer model trunk.
451451 transformer : BaseStackedTransformerLayer .Config = StackedTransformerLayer .default_config ()
452452 # Layer norm applied to transformer output.
@@ -519,25 +519,25 @@ def _forward_for_mode(
519519 emb_batch = {** input_batch }
520520 emb_batch ["inputs" ] = emb_batch ["input_ids" ]
521521
522- x = self .emb (input_batch = emb_batch )
523-
524522 if mode == ForwardMode .FORWARD :
525- transformer_state , x = (
526- None ,
527- self .transformer (
528- x ,
529- self_attention_logit_biases = self_attention_logit_biases ,
530- target_segment_ids = input_segment_ids ,
531- target_positions = positions ,
532- cross_attention_data = cross_attention_data ,
533- cross_attention_logit_biases = cross_attention_logit_biases ,
534- ),
523+ x = self .emb (input_batch = emb_batch )
524+ x = self .transformer (
525+ x ,
526+ self_attention_logit_biases = self_attention_logit_biases ,
527+ target_segment_ids = input_segment_ids ,
528+ target_positions = positions ,
529+ cross_attention_data = cross_attention_data ,
530+ cross_attention_logit_biases = cross_attention_logit_biases ,
535531 )
532+ cached_states = None
536533 elif mode == ForwardMode .INIT_STATES :
537534 assert cached_states is not None
538535 if input_segment_ids is not None :
539536 raise ValueError ("input_segment_ids is not supported in INIT_STATES." )
540- transformer_state , x = self .transformer .init_states (
537+ cached_states ["emb" ], x = self .emb .extend_step (
538+ cached_states = cached_states ["emb" ], input_batch = emb_batch
539+ )
540+ cached_states ["transformer_state" ], x = self .transformer .init_states (
541541 time_step = cached_states ["transformer_state" ],
542542 data = x ,
543543 self_attention_logit_biases = self_attention_logit_biases ,
@@ -548,7 +548,10 @@ def _forward_for_mode(
548548 assert cached_states is not None
549549 if input_segment_ids is not None :
550550 raise ValueError ("input_segment_ids is not supported in EXTEND_STEP." )
551- transformer_state , x = self .transformer .extend_step (
551+ cached_states ["emb" ], x = self .emb .extend_step (
552+ cached_states = cached_states ["emb" ], input_batch = emb_batch
553+ )
554+ cached_states ["transformer_state" ], x = self .transformer .extend_step (
552555 cached_states = cached_states ["transformer_state" ],
553556 data = x ,
554557 self_attention_logit_biases = self_attention_logit_biases ,
@@ -588,7 +591,7 @@ def _forward_for_mode(
588591 logits = self ._output_logits_modifier (logits )
589592 logits = with_sharding_constraint (logits , PartitionSpec (* self .config .logits_partition_spec ))
590593 # TODO(markblee): Rename to just "transformer". "transformer_state" is a bit redundant.
591- return dict ( transformer_state = transformer_state ) , dict (logits = logits , hidden_states = x )
594+ return cached_states , dict (logits = logits , hidden_states = x )
592595
593596 def forward (
594597 self ,
@@ -647,12 +650,14 @@ def init_states(
647650 ) -> NestedTensor :
648651 """See `BaseDecoder.init_states` for details."""
649652 cfg : Decoder .Config = self .config
650- init_state , _ = self .transformer .init_states (
653+ emb = self .emb .init_states (batch_size = batch_size , dtype = dtype )
654+ transformer_state , _ = self .transformer .init_states (
651655 time_step = None ,
652656 data = TensorSpec ([batch_size , max_sequence_length , cfg .dim ], dtype = dtype ),
653657 )
654658 return dict (
655- transformer_state = init_state ,
659+ emb = emb ,
660+ transformer_state = transformer_state ,
656661 input_ids = jnp .full (
657662 (batch_size , max_sequence_length ), cfg .pad_token_id , dtype = jnp .int32
658663 ),
@@ -677,13 +682,14 @@ def prefill_states(
677682 See `BaseDecoder.prefill_states` for details.
678683 """
679684 validate_contains_paths (input_batch , paths = ["input_ids" ])
680- input_ids = input_batch ["input_ids" ]
685+ input_ids : Tensor = input_batch ["input_ids" ]
681686 input_segment_ids = input_batch .get ("input_segment_ids" , None )
682687 positions = input_batch .get ("positions" , None )
683688
689+ emb = self .emb .init_states (batch_size = input_ids .shape [0 ], dtype = self .dtype ())
684690 states , outputs = self ._forward_for_mode (
685691 mode = ForwardMode .INIT_STATES ,
686- cached_states = dict (transformer_state = time_step ),
692+ cached_states = dict (emb = emb , transformer_state = time_step ),
687693 input_batch = input_batch ,
688694 # TODO(markblee): Consider supporting packed inputs for more efficient prefilling.
689695 self_attention_logit_biases = self .compute_attention_logit_biases (
@@ -748,14 +754,13 @@ def extend_step(
748754 cached_states = cached_states ,
749755 ** kwargs ,
750756 )
751- updated_states = dict (
757+ updated_states . update (
752758 input_ids = updated_inputs ,
753759 # There are some non-greedy DFS/BFS and sliding attention algorithms that
754760 # recursively search through potentials.
755761 # They backtrace to some anchor time step after exploring for t steps.
756762 # This requires tracking time_step separately from the attention time_step.
757763 time_step = cached_states ["time_step" ] + 1 ,
758- ** updated_states ,
759764 )
760765 return updated_states , outputs
761766
0 commit comments