2020import argparse
2121import sys
2222
23- import tensorflow as tf
23+ import tensorflow as tf # pylint: disable=g-bad-import-order
2424
2525from official .mnist import dataset
2626from official .utils .arg_parsers import parsers
2727from official .utils .logging import hooks_helper
2828
2929LEARNING_RATE = 1e-4
3030
31+
3132class Model (tf .keras .Model ):
3233 """Model to recognize digits in the MNIST dataset.
3334
@@ -145,31 +146,36 @@ def model_fn(features, labels, mode, params):
145146
146147
147148def validate_batch_size_for_multi_gpu (batch_size ):
148- """For multi-gpu, batch-size must be a multiple of the number of
149- available GPUs.
149+ """For multi-gpu, batch-size must be a multiple of the number of GPUs.
150150
151151 Note that this should eventually be handled by replicate_model_fn
152152 directly. Multi-GPU support is currently experimental, however,
153153 so doing the work here until that feature is in place.
154+
155+ Args:
156+ batch_size: the number of examples processed in each training batch.
157+
158+ Raises:
159+ ValueError: if no GPUs are found, or selected batch_size is invalid.
154160 """
155- from tensorflow .python .client import device_lib
161+ from tensorflow .python .client import device_lib # pylint: disable=g-import-not-at-top
156162
157163 local_device_protos = device_lib .list_local_devices ()
158164 num_gpus = sum ([1 for d in local_device_protos if d .device_type == 'GPU' ])
159165 if not num_gpus :
160166 raise ValueError ('Multi-GPU mode was specified, but no GPUs '
161- 'were found. To use CPU, run without --multi_gpu.' )
167+ 'were found. To use CPU, run without --multi_gpu.' )
162168
163169 remainder = batch_size % num_gpus
164170 if remainder :
165171 err = ('When running with multiple GPUs, batch size '
166- 'must be a multiple of the number of available GPUs. '
167- 'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
168- ).format (num_gpus , batch_size , batch_size - remainder )
172+ 'must be a multiple of the number of available GPUs. '
173+ 'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
174+ ).format (num_gpus , batch_size , batch_size - remainder )
169175 raise ValueError (err )
170176
171177
172- def main (unused_argv ):
178+ def main (_ ):
173179 model_function = model_fn
174180
175181 if FLAGS .multi_gpu :
@@ -195,6 +201,8 @@ def main(unused_argv):
195201
196202 # Set up training and evaluation input functions.
197203 def train_input_fn ():
204+ """Prepare data for training."""
205+
198206 # When choosing shuffle buffer sizes, larger sizes result in better
199207 # randomness, while smaller sizes use less memory. MNIST is a small
200208 # enough dataset that we can easily shuffle the full epoch.
@@ -215,7 +223,7 @@ def eval_input_fn():
215223 FLAGS .hooks , batch_size = FLAGS .batch_size )
216224
217225 # Train and evaluate model.
218- for n in range (FLAGS .train_epochs // FLAGS .epochs_between_evals ):
226+ for _ in range (FLAGS .train_epochs // FLAGS .epochs_between_evals ):
219227 mnist_classifier .train (input_fn = train_input_fn , hooks = train_hooks )
220228 eval_results = mnist_classifier .evaluate (input_fn = eval_input_fn )
221229 print ('\n Evaluation results:\n \t %s\n ' % eval_results )
@@ -231,10 +239,11 @@ def eval_input_fn():
231239
232240class MNISTArgParser (argparse .ArgumentParser ):
233241 """Argument parser for running MNIST model."""
242+
234243 def __init__ (self ):
235244 super (MNISTArgParser , self ).__init__ (parents = [
236- parsers .BaseParser (),
237- parsers .ImageModelParser ()])
245+ parsers .BaseParser (),
246+ parsers .ImageModelParser ()])
238247
239248 self .add_argument (
240249 '--export_dir' ,
0 commit comments