-
|
Hello Flax community, Basically: class BareTransformerLayer(nn.Module):
qkv_dim: int # Can be smaller than input dim (e.g., 512)
mlp_dim: int
"Inner mlp dim. By default, set to qkv_dim * 4."
num_heads: int # Number of attention heads (e.g., 2)
decode: bool # Whether in decoding mode
dropout_rate: float = 0.1
dtype: jnp.dtype = jnp.bfloat16
@nn.compact
def __call__(self, x, mask=None, dropout_rng: Array | None = None):
input_dim = x.shape[-1] # Automatically infer input dimension (e.g., 5120)
# LayerNorm + Self-Attention (with residual)
h = nn.LayerNorm(dtype=self.dtype)(x)
h = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
qkv_features=self.qkv_dim, # Projects to qkv_dim internally
dropout_rate=self.dropout_rate,
dtype=self.dtype,
decode=self.decode,
deterministic=False,
)(h, mask=mask, dropout_rng=dropout_rng)
x = x + h # Residual connection (shape preserved)
# LayerNorm + MLP (with residual)
h = nn.LayerNorm(dtype=self.dtype)(x)
h = nn.Dense(self.mlp_dim * 4, dtype=self.dtype)(h) # Expand
h = nn.relu(h)
h = nn.Dense(input_dim, dtype=self.dtype)(h) # Project back to input_dim
h = nn.Dropout(rate=self.dropout_rate)(h, deterministic=False)
return x + h # Residual connection (shape preserved)
class BareTransformerBlock(nn.Module):
"A class to stack transformer layers"
num_heads: int
qkv_dim: int
num_layers: int
mlp_dim: int
decode: bool
dropout_rate: float = 0.1
dtype: jnp.dtype = jnp.bfloat16
@nn.compact # nn.compact to be able to infer input_dim without complications.
def __call__(self, x, mask=None):
input_dim = x.shape[-1] # Infer input dimension (e.g., 5120)
# Stack multiple transformer layers
for _ in range(self.num_layers):
x = BareTransformerLayer(
qkv_dim=self.qkv_dim,
num_heads=self.num_heads,
decode=self.decode,
dropout_rate=self.dropout_rate,
mlp_dim=self.mlp_dim,
dtype=self.dtype,
)(x, mask=mask, dropout_rng=self.make_rng("dropout"))
# Final projection (optional, if output dim differs from input_dim)
x = nn.Dense(input_dim, dtype=self.dtype)(x) # Preserve dim by default
return nn.LayerNorm(dtype=self.dtype)(x)And an autoregressive decoder that calls them - pretty standard but in this particular module - it's a continuous output transformer that uses discrete outputs as a stop head; basically that's all you need to know. class DecoderTransformerAction(nn.Module):
"A transformer block that uses discrete tokens to mark stop; and returns the tokens."
num_heads: int
num_layers: int
qkv_dim: int
mlp_dim: int
num_actions: int
vocab_size: int
out_features_per_action: int
stop_token: int # When this token is sampled, generation stops.
dropout_rate: float = 0.1
dtype: jnp.dtype = jnp.bfloat16
def setup(self):
self.transformer = BareTransformerBlock(
num_heads=self.num_heads,
qkv_dim=self.qkv_dim,
mlp_dim=self.mlp_dim,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
num_layers=self.num_layers,
decode=True,
)
self.norm = nn.LayerNorm(dtype=self.dtype)
# Removed discrete token embedding since input is continuous.
self.token_pred = nn.Dense(self.vocab_size, dtype=self.dtype)
def _causal_mask(self, seq_len):
return jnp.tril(jnp.ones((1, 1, seq_len, seq_len), dtype=bool), k=1)
def __call__(self, inp: jax.Array):
# Expect inp shape: (batch_size, embed_features) as a continuous query.
batch_size = inp.shape[:2]
# Initialize state by tiling the continuous input added to a learned start embedding.
initial_token = inp # + self.start_embedding #leave start embedding out because it will complicate training.
state = jnp.tile(initial_token[:, :, None, :], (1, 1, self.num_actions, 1))
# Preallocate discrete tokens and continuous representations.
discrete_tokens = jnp.zeros((*batch_size, self.num_actions), dtype=jnp.int32)
cont_reps = jnp.zeros(
(*batch_size, self.num_actions, self.mlp_dim),
dtype=self.dtype,
)
rng = self.make_rng("action")
finished = jnp.zeros((*batch_size,), dtype=bool)
def body_fn(i, carry):
state, rng, discrete_tokens, finished, stop_token = carry
full_seq = state
# Use negative dimension for the sequence dimension.
seq_len = full_seq.shape[-2]
# Create a valid positions mask using the sequence length.
valid_mask = jnp.arange(seq_len) <= i
valid_outer = jnp.outer(valid_mask, valid_mask)[None, None, :, :]
full_causal = self._causal_mask(seq_len)
combined_mask = full_causal & valid_outer
x = self.transformer(full_seq, mask=combined_mask)
x = self.norm(x)
# Use negative indexing to select the i-th token from the sequence dimension.
current_logits = self.token_pred(x)[..., i, :]
rng, subkey = jax.random.split(rng)
sampled_token = jax.random.categorical(subkey, current_logits, axis=-1)
discrete_tokens = discrete_tokens.at[..., i].set(sampled_token)
# state is a filtered sequence for stop token.
state = jax.lax.cond(
i < seq_len - 1,
lambda s: s.at[..., i + 1, :].set(x[..., i, :]),
lambda s: s,
state,
)
finished = jnp.where(finished, finished, sampled_token == stop_token)
return (state, rng, discrete_tokens, finished, stop_token)
init_carry = (state, rng, discrete_tokens, finished, self.stop_token)
@jax.jit # again, a workaround.
def jitted_func(init_carry, num_actions):
state, _rng, discrete_tokens, _finished, _stop_token = jax.lax.fori_loop(
0, num_actions, body_fn, init_carry
)
return state, discrete_tokens
state, discrete_tokens = jitted_func(init_carry, self.num_actions) After which I'm getting this error: So that begs the question - and how do I instantiate the So, it's not a variable that should be stored in Where and how do I store it, then? I've tried about 10 different approaches and possible bugfixes and none of them worked; LLMs are also clueless, and I couldn't actually find/understand caching details of transformer examples in "examples/". How could I solve this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
Found the solution. The solution was to modify TrainState and add cache there; then propagate it through the model. Pretty simple; though... Will it be reset every time? I have no clue how to reset that cache; or what to reset it to. If anybody knows, please lmk. |
Beta Was this translation helpful? Give feedback.
Found the solution. The solution was to modify TrainState and add cache there; then propagate it through the model. Pretty simple; though...
Will it be reset every time? I have no clue how to reset that cache; or what to reset it to.
If anybody knows, please lmk.