Skip to content

Commit 63caac7

Browse files
danielsuoFlax Authors
authored andcommitted
[flax:examples:vae] Small linter fixes.
PiperOrigin-RevId: 817312045
1 parent c29b6d6 commit 63caac7

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

examples/vae/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@
3131

3232
FLAGS = flags.FLAGS
3333

34+
flags.DEFINE_string('workdir', None, 'Directory to store model data.')
3435
config_flags.DEFINE_config_file(
3536
'config',
3637
None,
3738
'File path to the training hyperparameter configuration.',
3839
lock_config=True,
3940
)
41+
flags.mark_flags_as_required(['config', 'workdir'])
4042

4143

4244
def main(argv):
@@ -56,7 +58,7 @@ def main(argv):
5658
f'process_count: {jax.process_count()}'
5759
)
5860

59-
train.train_and_evaluate(FLAGS.config)
61+
train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
6062

6163

6264
if __name__ == '__main__':

examples/vae/train.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Training and evaluation logic."""
15+
from typing import Any
1516

1617
from absl import logging
1718
from flax import linen as nn
@@ -24,6 +25,7 @@
2425
import jax.numpy as jnp
2526
import ml_collections
2627
import optax
28+
import tensorflow as tf
2729
import tensorflow_datasets as tfds
2830

2931

@@ -47,6 +49,7 @@ def compute_metrics(recon_x, x, mean, logvar):
4749

4850

4951
def train_step(state, batch, z_rng, latents):
52+
"""Train step."""
5053
def loss_fn(params):
5154
recon_x, mean, logvar = models.model(latents).apply(
5255
{'params': params}, batch, z_rng
@@ -62,6 +65,7 @@ def loss_fn(params):
6265

6366

6467
def eval_f(params, images, z, z_rng, latents):
68+
"""Evaluation function."""
6569
def eval_model(vae):
6670
recon_images, mean, logvar = vae(images, z_rng)
6771
comparison = jnp.concatenate([
@@ -77,8 +81,10 @@ def eval_model(vae):
7781
return nn.apply(eval_model, models.model(latents))({'params': params})
7882

7983

80-
def train_and_evaluate(config: ml_collections.ConfigDict):
84+
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
8185
"""Train and evaulate pipeline."""
86+
tf.io.gfile.makedirs(workdir)
87+
8288
rng = random.key(0)
8389
rng, key = random.split(rng)
8490

@@ -116,9 +122,11 @@ def train_and_evaluate(config: ml_collections.ConfigDict):
116122
state.params, test_ds, z, eval_rng, config.latents
117123
)
118124
vae_utils.save_image(
119-
comparison, f'results/reconstruction_{epoch}.png', nrow=8
125+
comparison, f'{workdir}/reconstruction_{epoch}.png', nrow=8
126+
)
127+
vae_utils.save_image(
128+
sample, f'{workdir}/sample_{epoch}.png', nrow=8
120129
)
121-
vae_utils.save_image(sample, f'results/sample_{epoch}.png', nrow=8)
122130

123131
print(
124132
'eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format(

0 commit comments

Comments
 (0)