From 88d74979c27bb40888168140d7cfc66c67bd100d Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 1 Dec 2025 13:38:13 -0800 Subject: [PATCH] [flax:examples:vae] Small linter fixes. PiperOrigin-RevId: 838913868 --- examples/vae/main.py | 4 +++- examples/vae/train.py | 14 +++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/examples/vae/main.py b/examples/vae/main.py index 537ec08d6..20ed9dba1 100644 --- a/examples/vae/main.py +++ b/examples/vae/main.py @@ -31,12 +31,14 @@ FLAGS = flags.FLAGS +flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( 'config', None, 'File path to the training hyperparameter configuration.', lock_config=True, ) +flags.mark_flags_as_required(['config', 'workdir']) def main(argv): @@ -56,7 +58,7 @@ def main(argv): f'process_count: {jax.process_count()}' ) - train.train_and_evaluate(FLAGS.config) + train.train_and_evaluate(FLAGS.config, FLAGS.workdir) if __name__ == '__main__': diff --git a/examples/vae/train.py b/examples/vae/train.py index 84f1b582a..2d7056bac 100644 --- a/examples/vae/train.py +++ b/examples/vae/train.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Training and evaluation logic.""" +from typing import Any from absl import logging from flax import linen as nn @@ -24,6 +25,7 @@ import jax.numpy as jnp import ml_collections import optax +import tensorflow as tf import tensorflow_datasets as tfds @@ -47,6 +49,7 @@ def compute_metrics(recon_x, x, mean, logvar): def train_step(state, batch, z_rng, latents): + """Train step.""" def loss_fn(params): recon_x, mean, logvar = models.model(latents).apply( {'params': params}, batch, z_rng @@ -62,6 +65,7 @@ def loss_fn(params): def eval_f(params, images, z, z_rng, latents): + """Evaluation function.""" def eval_model(vae): recon_images, mean, logvar = vae(images, z_rng) comparison = jnp.concatenate([ @@ -77,8 +81,10 @@ def eval_model(vae): return nn.apply(eval_model, models.model(latents))({'params': params}) -def train_and_evaluate(config: ml_collections.ConfigDict): +def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Train and evaulate pipeline.""" + tf.io.gfile.makedirs(workdir) + rng = random.key(0) rng, key = random.split(rng) @@ -116,9 +122,11 @@ def train_and_evaluate(config: ml_collections.ConfigDict): state.params, test_ds, z, eval_rng, config.latents ) vae_utils.save_image( - comparison, f'results/reconstruction_{epoch}.png', nrow=8 + comparison, f'{workdir}/reconstruction_{epoch}.png', nrow=8 + ) + vae_utils.save_image( + sample, f'{workdir}/sample_{epoch}.png', nrow=8 ) - vae_utils.save_image(sample, f'results/sample_{epoch}.png', nrow=8) print( 'eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format(