@@ -284,10 +284,13 @@ def model_fn(features, labels, mode, params):
284284 train_op = optimizer .minimize (combined_cost , global_step = global_step )
285285
286286 # Computations to be executed on CPU, outside of the main TPU queues.
287- def eval_metrics_host_call_fn (policy_output , value_output , pi_tensor ,
288- value_tensor , policy_cost , value_cost ,
289- l2_cost , combined_cost , step ,
290- est_mode = tf .estimator .ModeKeys .TRAIN ):
287+ def eval_metrics_host_call_fn (
288+ features ,
289+ policy_output , value_output ,
290+ pi_tensor , value_tensor ,
291+ policy_cost , value_cost ,
292+ l2_cost , combined_cost ,
293+ step , est_mode = tf .estimator .ModeKeys .TRAIN ):
291294 policy_entropy = - tf .reduce_mean (tf .reduce_sum (
292295 policy_output * tf .log (policy_output ), axis = 1 ))
293296 # pi_tensor is one_hot when generated from sgfs (for supervised learning)
@@ -306,6 +309,8 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor,
306309
307310 value_cost_normalized = value_cost / params ['value_cost_weight' ]
308311 avg_value_observed = tf .reduce_mean (value_tensor )
312+ avg_stones_black = tf .reduce_mean (tf .reduce_sum (features [:,:,:,1 ], [1 ,2 ]))
313+ avg_stones_white = tf .reduce_mean (tf .reduce_sum (features [:,:,:,0 ], [1 ,2 ]))
309314
310315 with tf .variable_scope ('metrics' ):
311316 metric_ops = {
@@ -315,13 +320,17 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor,
315320 'l2_cost' : tf .metrics .mean (l2_cost ),
316321 'policy_entropy' : tf .metrics .mean (policy_entropy ),
317322 'combined_cost' : tf .metrics .mean (combined_cost ),
318- 'avg_value_observed' : tf .metrics .mean (avg_value_observed ),
319323 'policy_accuracy_top_1' : tf .metrics .mean (policy_output_in_top1 ),
320324 'policy_accuracy_top_3' : tf .metrics .mean (policy_output_in_top3 ),
321325 'policy_top_1_confidence' : tf .metrics .mean (policy_top_1_confidence ),
326+ 'value_confidence' : tf .metrics .mean (tf .abs (value_output )),
327+
328+ # Metrics about input data
322329 'policy_target_top_1_confidence' : tf .metrics .mean (
323330 policy_target_top_1_confidence ),
324- 'value_confidence' : tf .metrics .mean (tf .abs (value_output )),
331+ 'avg_value_observed' : tf .metrics .mean (avg_value_observed ),
332+ 'avg_stones_black' : tf .metrics .mean (avg_stones_black ),
333+ 'avg_stones_white' : tf .metrics .mean (avg_stones_white ),
325334 }
326335
327336 if est_mode == tf .estimator .ModeKeys .EVAL :
@@ -349,6 +358,7 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor,
349358 return summary .all_summary_ops () + [cond_reset_op ]
350359
351360 metric_args = [
361+ features ,
352362 policy_output ,
353363 value_output ,
354364 labels ['pi_tensor' ],
0 commit comments