@@ -333,26 +333,30 @@ def MARGE(confile):
333333 model_evaluate = None
334334 if "lossfunc" in conf .keys ():
335335 # Format: path/to/module.py function_name
336- lossfunc = conf ["lossfunc" ].split () # [path/to/module.py, function_name]
337- if lossfunc [0 ] == 'mse' :
338- lossfunc = keras .losses .MeanSquaredError
336+ # or string identifier
337+ lossfunc = conf ["lossfunc" ].split () # [path/to/module.py, function_name] or [string identifier]
338+ # Handle string identifiers:
339+ if lossfunc [0 ] in ['mse' , 'mean_squared_error' ]:
340+ lossfunc = keras .losses .MeanSquaredError ()
339341 lossfunc .__name__ = 'mse'
340- elif lossfunc [0 ] == 'mae' :
341- lossfunc = keras .losses .MeanAbsoluteError
342+ elif lossfunc [0 ] in [ 'mae' , 'mean_absolute_error' ] :
343+ lossfunc = keras .losses .MeanAbsoluteError ()
342344 lossfunc .__name__ = 'mae'
343- elif lossfunc [0 ] == 'mape' :
344- lossfunc = keras .losses .MeanAbsolutePercentageError
345+ elif lossfunc [0 ] in [ 'mape' , 'mean_absolute_percent_error' , 'mean_absolute_percentage_error' ] :
346+ lossfunc = keras .losses .MeanAbsolutePercentageError ()
345347 lossfunc .__name__ = 'mape'
346- elif lossfunc [0 ] == 'msle' :
347- lossfunc = keras .losses .MeanSquaredLogarithmicError
348+ elif lossfunc [0 ] in [ 'msle' , 'mean_squared_logarithmic_error' , 'mean_squared_log_error' ] :
349+ lossfunc = keras .losses .MeanSquaredLogarithmicError ()
348350 lossfunc .__name__ = 'msle'
349- elif lossfunc [0 ] in ['maxmse' , 'm3se' , 'mslse' , 'mse_per_ax' , 'maxse' ]:
351+ # Handle the custom loss functions in lib/losses.py
352+ elif lossfunc [0 ] in ['maxmse' , 'm3se' , 'mslse' , 'mse_per_ax' , 'maxse' , 'smape' , 'maxsape' ]:
350353 lossname = lossfunc [0 ]
351354 lossfunc = getattr (losses , lossfunc [0 ])
352355 lossfunc .__name__ = lossname
353356 elif lossfunc [0 ] in ['heteroscedastic' , 'heteroscedastic_loss' ]:
354357 lossfunc = functools .partial (losses .heteroscedastic_loss , D = np .product (oshape ), N = batch_size )
355358 lossfunc .__name__ = 'heteroscedastic_loss'
359+ # Handle user-supplied custom loss functions
356360 else :
357361 if lossfunc [0 ][- 3 :] == '.py' :
358362 lossfunc [0 ] = lossfunc [0 ][:- 3 ] # path/to/module
0 commit comments