@@ -33,7 +33,7 @@ def testDefaults(self):
3333
3434 self .assertEqual (slim .get_variables (), [])
3535 self .assertEqual (deploy_config .caching_device (), None )
36- self .assertDeviceEqual (deploy_config .clone_device (0 ), '' )
36+ self .assertDeviceEqual (deploy_config .clone_device (0 ), 'GPU:0 ' )
3737 self .assertEqual (deploy_config .clone_scope (0 ), '' )
3838 self .assertDeviceEqual (deploy_config .optimizer_device (), 'CPU:0' )
3939 self .assertDeviceEqual (deploy_config .inputs_device (), 'CPU:0' )
@@ -65,7 +65,7 @@ def testPS(self):
6565 deploy_config = model_deploy .DeploymentConfig (num_clones = 1 , num_ps_tasks = 1 )
6666
6767 self .assertDeviceEqual (deploy_config .clone_device (0 ),
68- '/job:worker' )
68+ '/job:worker/device:GPU:0 ' )
6969 self .assertEqual (deploy_config .clone_scope (0 ), '' )
7070 self .assertDeviceEqual (deploy_config .optimizer_device (),
7171 '/job:worker/device:CPU:0' )
@@ -105,7 +105,7 @@ def testReplicasPS(self):
105105 num_ps_tasks = 2 )
106106
107107 self .assertDeviceEqual (deploy_config .clone_device (0 ),
108- '/job:worker' )
108+ '/job:worker/device:GPU:0 ' )
109109 self .assertEqual (deploy_config .clone_scope (0 ), '' )
110110 self .assertDeviceEqual (deploy_config .optimizer_device (),
111111 '/job:worker/device:CPU:0' )
@@ -201,7 +201,7 @@ def testCreateLogisticClassifier(self):
201201 self .assertEqual (clone .outputs .op .name ,
202202 'LogisticClassifier/fully_connected/Sigmoid' )
203203 self .assertEqual (clone .scope , '' )
204- self .assertDeviceEqual (clone .device , '' )
204+ self .assertDeviceEqual (clone .device , 'GPU:0 ' )
205205 self .assertEqual (len (slim .losses .get_losses ()), 1 )
206206 update_ops = tf .get_collection (tf .GraphKeys .UPDATE_OPS )
207207 self .assertEqual (update_ops , [])
@@ -227,7 +227,7 @@ def testCreateSingleclone(self):
227227 self .assertEqual (clone .outputs .op .name ,
228228 'BatchNormClassifier/fully_connected/Sigmoid' )
229229 self .assertEqual (clone .scope , '' )
230- self .assertDeviceEqual (clone .device , '' )
230+ self .assertDeviceEqual (clone .device , 'GPU:0 ' )
231231 self .assertEqual (len (slim .losses .get_losses ()), 1 )
232232 update_ops = tf .get_collection (tf .GraphKeys .UPDATE_OPS )
233233 self .assertEqual (len (update_ops ), 2 )
@@ -278,7 +278,7 @@ def testCreateOnecloneWithPS(self):
278278 clone = clones [0 ]
279279 self .assertEqual (clone .outputs .op .name ,
280280 'BatchNormClassifier/fully_connected/Sigmoid' )
281- self .assertDeviceEqual (clone .device , '/job:worker' )
281+ self .assertDeviceEqual (clone .device , '/job:worker/device:GPU:0 ' )
282282 self .assertEqual (clone .scope , '' )
283283 self .assertEqual (len (slim .get_variables ()), 5 )
284284 for v in slim .get_variables ():
@@ -350,7 +350,7 @@ def testCreateLogisticClassifier(self):
350350 self .assertEqual (len (grads_and_vars ), len (tf .trainable_variables ()))
351351 self .assertEqual (total_loss .op .name , 'total_loss' )
352352 for g , v in grads_and_vars :
353- self .assertDeviceEqual (g .device , '' )
353+ self .assertDeviceEqual (g .device , 'GPU:0 ' )
354354 self .assertDeviceEqual (v .device , 'CPU:0' )
355355
356356 def testCreateSingleclone (self ):
@@ -376,7 +376,7 @@ def testCreateSingleclone(self):
376376 self .assertEqual (len (grads_and_vars ), len (tf .trainable_variables ()))
377377 self .assertEqual (total_loss .op .name , 'total_loss' )
378378 for g , v in grads_and_vars :
379- self .assertDeviceEqual (g .device , '' )
379+ self .assertDeviceEqual (g .device , 'GPU:0 ' )
380380 self .assertDeviceEqual (v .device , 'CPU:0' )
381381
382382 def testCreateMulticlone (self ):
@@ -458,7 +458,7 @@ def testCreateOnecloneWithPS(self):
458458 self .assertEqual (len (grads_and_vars ), len (tf .trainable_variables ()))
459459 self .assertEqual (total_loss .op .name , 'total_loss' )
460460 for g , v in grads_and_vars :
461- self .assertDeviceEqual (g .device , '/job:worker' )
461+ self .assertDeviceEqual (g .device , '/job:worker/device:GPU:0 ' )
462462 self .assertDeviceEqual (v .device , '/job:ps/task:0/CPU:0' )
463463
464464
@@ -515,7 +515,7 @@ def testLocalTrainOp(self):
515515 for _ in range (10 ):
516516 sess .run (model .train_op )
517517 final_loss = sess .run (model .total_loss )
518- self .assertLess (final_loss , initial_loss / 10 .0 )
518+ self .assertLess (final_loss , initial_loss / 5 .0 )
519519
520520 final_mean , final_variance = sess .run ([moving_mean ,
521521 moving_variance ])
0 commit comments