Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@

FLAGS = flags.FLAGS

flags.DEFINE_string('workdir', None, 'Directory to store model data.')
config_flags.DEFINE_config_file(
'config',
None,
'File path to the training hyperparameter configuration.',
lock_config=True,
)
flags.mark_flags_as_required(['config', 'workdir'])


def main(argv):
Expand All @@ -56,7 +58,7 @@ def main(argv):
f'process_count: {jax.process_count()}'
)

train.train_and_evaluate(FLAGS.config)
train.train_and_evaluate(FLAGS.config, FLAGS.workdir)


if __name__ == '__main__':
Expand Down
14 changes: 11 additions & 3 deletions examples/vae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training and evaluation logic."""
from typing import Any

from absl import logging
from flax import linen as nn
Expand All @@ -24,6 +25,7 @@
import jax.numpy as jnp
import ml_collections
import optax
import tensorflow as tf
import tensorflow_datasets as tfds


Expand All @@ -47,6 +49,7 @@ def compute_metrics(recon_x, x, mean, logvar):


def train_step(state, batch, z_rng, latents):
"""Train step."""
def loss_fn(params):
recon_x, mean, logvar = models.model(latents).apply(
{'params': params}, batch, z_rng
Expand All @@ -62,6 +65,7 @@ def loss_fn(params):


def eval_f(params, images, z, z_rng, latents):
"""Evaluation function."""
def eval_model(vae):
recon_images, mean, logvar = vae(images, z_rng)
comparison = jnp.concatenate([
Expand All @@ -77,8 +81,10 @@ def eval_model(vae):
return nn.apply(eval_model, models.model(latents))({'params': params})


def train_and_evaluate(config: ml_collections.ConfigDict):
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
"""Train and evaulate pipeline."""
tf.io.gfile.makedirs(workdir)

rng = random.key(0)
rng, key = random.split(rng)

Expand Down Expand Up @@ -116,9 +122,11 @@ def train_and_evaluate(config: ml_collections.ConfigDict):
state.params, test_ds, z, eval_rng, config.latents
)
vae_utils.save_image(
comparison, f'results/reconstruction_{epoch}.png', nrow=8
comparison, f'{workdir}/reconstruction_{epoch}.png', nrow=8
)
vae_utils.save_image(
sample, f'{workdir}/sample_{epoch}.png', nrow=8
)
vae_utils.save_image(sample, f'results/sample_{epoch}.png', nrow=8)

print(
'eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format(
Expand Down
Loading