@@ -16,8 +16,13 @@ def rmse_metric_function(predictions, labels, **kwargs):
1616 float: average rmse of the time predictions.
1717 """
1818 seq_mask = kwargs .get ('seq_mask' )
19- pred = predictions [PredOutputIndex .TimePredIndex ][seq_mask ]
20- label = labels [PredOutputIndex .TimePredIndex ][seq_mask ]
19+ if seq_mask is None or len (seq_mask ) == 0 :
20+ # If mask is empty or None, use all predictions
21+ pred = predictions [PredOutputIndex .TimePredIndex ]
22+ label = labels [PredOutputIndex .TimePredIndex ]
23+ else :
24+ pred = predictions [PredOutputIndex .TimePredIndex ][seq_mask ]
25+ label = labels [PredOutputIndex .TimePredIndex ][seq_mask ]
2126
2227 pred = np .reshape (pred , [- 1 ])
2328 label = np .reshape (label , [- 1 ])
@@ -36,8 +41,13 @@ def acc_metric_function(predictions, labels, **kwargs):
3641 float: accuracy ratio of the type predictions.
3742 """
3843 seq_mask = kwargs .get ('seq_mask' )
39- pred = predictions [PredOutputIndex .TypePredIndex ][seq_mask ]
40- label = labels [PredOutputIndex .TypePredIndex ][seq_mask ]
44+ if seq_mask is None or len (seq_mask ) == 0 :
45+ # If mask is empty or None, use all predictions
46+ pred = predictions [PredOutputIndex .TypePredIndex ]
47+ label = labels [PredOutputIndex .TypePredIndex ]
48+ else :
49+ pred = predictions [PredOutputIndex .TypePredIndex ][seq_mask ]
50+ label = labels [PredOutputIndex .TypePredIndex ][seq_mask ]
4151 pred = np .reshape (pred , [- 1 ])
4252 label = np .reshape (label , [- 1 ])
4353 return np .mean (pred == label )
0 commit comments