1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414"""Training and evaluation logic."""
15+ from typing import Any
1516
1617from absl import logging
1718from flax import linen as nn
2425import jax .numpy as jnp
2526import ml_collections
2627import optax
28+ import tensorflow as tf
2729import tensorflow_datasets as tfds
2830
2931
@@ -47,6 +49,7 @@ def compute_metrics(recon_x, x, mean, logvar):
4749
4850
4951def train_step (state , batch , z_rng , latents ):
52+ """Train step."""
5053 def loss_fn (params ):
5154 recon_x , mean , logvar = models .model (latents ).apply (
5255 {'params' : params }, batch , z_rng
@@ -62,6 +65,7 @@ def loss_fn(params):
6265
6366
6467def eval_f (params , images , z , z_rng , latents ):
68+ """Evaluation function."""
6569 def eval_model (vae ):
6670 recon_images , mean , logvar = vae (images , z_rng )
6771 comparison = jnp .concatenate ([
@@ -77,8 +81,10 @@ def eval_model(vae):
7781 return nn .apply (eval_model , models .model (latents ))({'params' : params })
7882
7983
80- def train_and_evaluate (config : ml_collections .ConfigDict ):
84+ def train_and_evaluate (config : ml_collections .ConfigDict , workdir : str ):
8185 """Train and evaulate pipeline."""
86+ tf .io .gfile .makedirs (workdir )
87+
8288 rng = random .key (0 )
8389 rng , key = random .split (rng )
8490
@@ -116,9 +122,11 @@ def train_and_evaluate(config: ml_collections.ConfigDict):
116122 state .params , test_ds , z , eval_rng , config .latents
117123 )
118124 vae_utils .save_image (
119- comparison , f'results/reconstruction_{ epoch } .png' , nrow = 8
125+ comparison , f'{ workdir } /reconstruction_{ epoch } .png' , nrow = 8
126+ )
127+ vae_utils .save_image (
128+ sample , f'{ workdir } /sample_{ epoch } .png' , nrow = 8
120129 )
121- vae_utils .save_image (sample , f'results/sample_{ epoch } .png' , nrow = 8 )
122130
123131 print (
124132 'eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}' .format (
0 commit comments