You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I have been trying to implement a naive transformer using Jax and Flax. I have managed to implement something that looks like it works. The problem is, when I'm trying to train my model, after 2-3 epochs it starts predicting the same token - the, no matter what. I have checked for exploding gradient, not updating parameters and all transformer blocks pointing to the same instance, but none of that seem to be the case. The dataset that I'm using is WikiText-2-raw
@nnx.jit
def train_step(model : Transformer, state : optax.OptState, inputs : jax.Array, targets : jax.Array) :
def loss_fn(model : Transformer, inputs : jax.Array, targets : jax.Array) :
logits = model(inputs)
batch, seq_len, _ = logits.shape
targets = jnp.reshape(targets, (batch*seq_len))
logits = jnp.reshape(logits, (batch*seq_len, vocab_size))
loss_fn = optax.softmax_cross_entropy_with_integer_labels
loss = loss_fn(logits, targets)
return jnp.mean(loss)
loss, grad = nnx.value_and_grad(loss_fn)(model, inputs, targets)
updates, state = optimizer.update(grad, state)
params = optax.apply_updates(nnx.state(model, nnx.Param), updates)
nnx.update(model, params)
return loss, state
@nnx.jit
def valid_step(model : Transformer, inputs : jax.Array, targets : jax.Array) :
def loss_fn(model : Transformer, inputs : jax.Array, targets : jax.Array) :
logits = model(inputs)
batch, seq_len, _ = logits.shape
targets = jnp.reshape(targets, (batch*seq_len))
logits = jnp.reshape(logits, (batch*seq_len, vocab_size))
loss_fn = optax.softmax_cross_entropy_with_integer_labels
loss = loss_fn(logits, targets)
return jnp.mean(loss)
loss = loss_fn(model, inputs, targets)
return loss
num_epochs = 20
epoch_loss = []
epoch_val = []
state = init_state
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
for epoch in range(num_epochs) :
train_loss = []
model.train()
for input, target in data_generator(trainset, batch//2, mngr.next()) :
loss, state = train_step(model=model, state=state, inputs=input, targets=target)
train_loss.append(float(loss))
epoch_loss = jnp.mean(jnp.array(train_loss))
print(f'Training Loss = {epoch_loss} at epoch {epoch+1}')
input_text = "Before the morning of 1 September had passed , reports coming in to US 2nd Division headquarters made it clear that North Koreans had penetrated to the north @-@ south Changnyong @-@ Yongsan road and cut the division in two ; the 38th and 23d Infantry Regiments with the bulk of the division artillery in the north were separated from the division headquarters and the 9th Infantry Regiment in the south . Keiser decided that this situation made it advisable to control and direct the divided division as two special forces . Accordingly , he placed the division artillery commander , Brigadier General Loyal M. Haynes , in command of the northern group ."
generate_text(model=model, tokenizer=tokenizer, seq_len=seq_len, input_text=input_text)
valid_loss = []
for input, target in data_generator(valset, 1, mngr.next()) :
loss = valid_step(model=model, inputs=input, targets=target)
valid_loss.append(float(loss))
epoch_loss = jnp.mean(jnp.array(valid_loss))
print(f'Validation Loss = {epoch_loss} at epoch {epoch+1}')
ax1.plot(train_loss, label = f'Epoch {epoch+1}')
ax2.plot(valid_loss, label = f'Epoch {epoch+1}')
ax1.set_title('Training loss')
ax1.set_ylabel('Loss')
ax1.legend()
ax2.set_title('Validation loss')
ax2.set_ylabel('Loss')
ax2.legend()
plt.tight_layout()
plt.show()
Now the problem is the model seem to converge to predicting the same token no matter the input (it looks like it is the most likely token in the dataset, because it's - the). I have been trying to figure out why's that happening, but cannot seem to do it. I have attached my Jupyter Notebook with the whole code that has some other interesting debugging data, like visualizing some matrix on update parameters during training. It seems like model is updating just fine, but I cannot comprehend why the training converges to this degenerative behavior. Any help is greatly appreciated 🙏 All cells after saving and restoring the model were my debugging cells
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
The link to ipynb file - Notebook
Hi, I have been trying to implement a naive transformer using Jax and Flax. I have managed to implement something that looks like it works. The problem is, when I'm trying to train my model, after 2-3 epochs it starts predicting the same token - the, no matter what. I have checked for exploding gradient, not updating parameters and all transformer blocks pointing to the same instance, but none of that seem to be the case. The dataset that I'm using is WikiText-2-raw
Here are my functions regarding training :
Now the problem is the model seem to converge to predicting the same token no matter the input (it looks like it is the most likely token in the dataset, because it's - the). I have been trying to figure out why's that happening, but cannot seem to do it. I have attached my Jupyter Notebook with the whole code that has some other interesting debugging data, like visualizing some matrix on update parameters during training. It seems like model is updating just fine, but I cannot comprehend why the training converges to this degenerative behavior. Any help is greatly appreciated 🙏 All cells after saving and restoring the model were my debugging cells
Beta Was this translation helpful? Give feedback.
All reactions