Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import jax
from ml_collections import config_flags
import tensorflow as tf

import time
import train


Expand All @@ -38,6 +38,8 @@
lock_config=True,
)

flags.DEFINE_string('workdir', None, 'Directory to store logs and checkpoints.')


def main(argv):
if len(argv) > 1:
Expand All @@ -55,9 +57,9 @@ def main(argv):
f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}'
)

start = time.time()
train.train_and_evaluate(FLAGS.config)

logging.info('Total training time: %.2f seconds', time.time() - start)

if __name__ == '__main__':
app.run(main)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanepunk why do you remove abseil app and the usage of config file?

47 changes: 25 additions & 22 deletions examples/vae/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,47 @@

"""VAE model definitions."""

from flax import linen as nn
from flax import nnx
from jax import random
import jax.numpy as jnp


class Encoder(nn.Module):
class Encoder(nnx.Module):
"""VAE Encoder."""

latents: int
def __init__(self, input_features: int, latents:int, *, rngs: nnx.Rngs):
self.linear_1 = nnx.Linear(input_features, 500, rngs=rngs)
self.mean_linear = nnx.Linear(500, latents, rngs=rngs)
self.logvar_linear = nnx.Linear(500, latents, rngs=rngs)

@nn.compact
def __call__(self, x):
x = nn.Dense(500, name='fc1')(x)
x = nn.relu(x)
mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
x = self.linear_1(x)
x = nnx.relu(x)
mean_x = self.mean_linear(x)
logvar_x = self.logvar_linear(x)
return mean_x, logvar_x


class Decoder(nn.Module):
class Decoder(nnx.Module):
"""VAE Decoder."""

@nn.compact
def __init__(self, latents: int, output_features:int, *, rngs: nnx.Rngs):
self.linear_1 = nnx.Linear(latents, 500, rngs=rngs)
self.linear_2 = nnx.Linear(500, output_features, rngs=rngs)

def __call__(self, z):
z = nn.Dense(500, name='fc1')(z)
z = nn.relu(z)
z = nn.Dense(784, name='fc2')(z)
z = self.linear_1(z)
z = nnx.relu(z)
z = self.linear_2(z)
return z


class VAE(nn.Module):
class VAE(nnx.Module):
"""Full VAE model."""

latents: int = 20

def setup(self):
self.encoder = Encoder(self.latents)
self.decoder = Decoder()
def __init__(self, input_features:int, latents: int, rngs: nnx.Rngs):
self.encoder = Encoder(input_features=input_features, latents=latents, rngs=rngs)
self.decoder = Decoder(latents=latents, output_features=input_features, rngs=rngs)

def __call__(self, x, z_rng):
mean, logvar = self.encoder(x)
Expand All @@ -60,7 +63,7 @@ def __call__(self, x, z_rng):
return recon_x, mean, logvar

def generate(self, z):
return nn.sigmoid(self.decoder(z))
return nnx.sigmoid(self.decoder(z))


def reparameterize(rng, mean, logvar):
Expand All @@ -69,5 +72,5 @@ def reparameterize(rng, mean, logvar):
return mean + eps * std


def model(latents):
return VAE(latents=latents)
def model(input_features: int, latents: int, rngs: nnx.Rngs):
return VAE(input_features=input_features, latents=latents, rngs=rngs)
10 changes: 6 additions & 4 deletions examples/vae/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
absl-py==1.4.0
flax==0.6.9
numpy==1.23.5
flax~=0.10
numpy>=1.26.4
optax==0.1.5
Pillow==10.2.0
tensorflow==2.12.0
tensorflow-datasets==4.9.2
tensorflow~=2.16.0
tensorflow-datasets==4.9.2
clu==0.0.12
ml-collections>=0.1.1
62 changes: 27 additions & 35 deletions examples/vae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
"""Training and evaluation logic."""

from absl import logging
from flax import linen as nn
from flax import nnx
import input_pipeline
import models
import utils as vae_utils
from flax.training import train_state
import jax
from jax import random
import jax.numpy as jnp
Expand All @@ -34,7 +33,7 @@ def kl_divergence(mean, logvar):

@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
logits = nn.log_sigmoid(logits)
logits = nnx.log_sigmoid(logits)
return -jnp.sum(
labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits))
)
Expand All @@ -45,36 +44,34 @@ def compute_metrics(recon_x, x, mean, logvar):
kld_loss = kl_divergence(mean, logvar).mean()
return {'bce': bce_loss, 'kld': kld_loss, 'loss': bce_loss + kld_loss}


def train_step(state, batch, z_rng, latents):
def loss_fn(params):
recon_x, mean, logvar = models.model(latents).apply(
{'params': params}, batch, z_rng
)

@nnx.jit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use donate args to donate model and optimizer to reduce GPU memory usage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried adding donate_argnums to nnx.jit in the train_step, but was getting NaN loss and kl divergence.
what to do?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What to do about this?

def train_step(optimizer: nnx.Optimizer, model: nnx.Module, batch, z_rng, latents):
"""Single training step for the VAE model."""
def loss_fn(model):
recon_x, mean, logvar = model(batch, z_rng)
bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
kld_loss = kl_divergence(mean, logvar).mean()
loss = bce_loss + kld_loss
return loss

grads = jax.grad(loss_fn)(state.params)
return state.apply_gradients(grads=grads)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss


def eval_f(params, images, z, z_rng, latents):
def eval_model(vae):
recon_images, mean, logvar = vae(images, z_rng)
comparison = jnp.concatenate([
images[:8].reshape(-1, 28, 28, 1),
recon_images[:8].reshape(-1, 28, 28, 1),
])
@nnx.jit
def eval_f(model: nnx.Module, images, z, z_rng, latents):
"""Evaluation function for the VAE model."""
recon_images, mean, logvar = model(images, z_rng)
comparison = jnp.concatenate([
images[:8].reshape(-1, 28, 28, 1),
recon_images[:8].reshape(-1, 28, 28, 1),
])
generate_images = model.generate(z)
generate_images = generate_images.reshape(-1, 28, 28, 1)
metrics = compute_metrics(recon_images, images, mean, logvar)
return metrics, comparison, generate_images

generate_images = vae.generate(z)
generate_images = generate_images.reshape(-1, 28, 28, 1)
metrics = compute_metrics(recon_images, images, mean, logvar)
return metrics, comparison, generate_images

return nn.apply(eval_model, models.model(latents))({'params': params})


def train_and_evaluate(config: ml_collections.ConfigDict):
Expand All @@ -90,14 +87,9 @@ def train_and_evaluate(config: ml_collections.ConfigDict):
test_ds = input_pipeline.build_test_set(ds_builder)

logging.info('Initializing model.')
init_data = jnp.ones((config.batch_size, 784), jnp.float32)
params = models.model(config.latents).init(key, init_data, rng)['params']

state = train_state.TrainState.create(
apply_fn=models.model(config.latents).apply,
params=params,
tx=optax.adam(config.learning_rate),
)
rngs = nnx.Rngs(0)
model = models.model(784, config.latents, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adam(config.learning_rate))

rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, config.latents))
Expand All @@ -110,10 +102,10 @@ def train_and_evaluate(config: ml_collections.ConfigDict):
for _ in range(steps_per_epoch):
batch = next(train_ds)
rng, key = random.split(rng)
state = train_step(state, batch, key, config.latents)
loss_val = train_step(optimizer, model, batch, key, config.latents)

metrics, comparison, sample = eval_f(
state.params, test_ds, z, eval_rng, config.latents
model, test_ds, z, eval_rng, config.latents
)
vae_utils.save_image(
comparison, f'results/reconstruction_{epoch}.png', nrow=8
Expand Down