Skip to content

Commit 72d22cf

Browse files
authored
Merge pull request #3 from exosports/lossfix
Fixed using custom loss functions
2 parents 6b37a04 + 18e1bee commit 72d22cf

File tree

3 files changed

+27
-12
lines changed

3 files changed

+27
-12
lines changed

MARGE.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -333,26 +333,30 @@ def MARGE(confile):
333333
model_evaluate = None
334334
if "lossfunc" in conf.keys():
335335
# Format: path/to/module.py function_name
336-
lossfunc = conf["lossfunc"].split() # [path/to/module.py, function_name]
337-
if lossfunc[0] == 'mse':
338-
lossfunc = keras.losses.MeanSquaredError
336+
# or string identifier
337+
lossfunc = conf["lossfunc"].split() # [path/to/module.py, function_name] or [string identifier]
338+
# Handle string identifiers:
339+
if lossfunc[0] in ['mse', 'mean_squared_error']:
340+
lossfunc = keras.losses.MeanSquaredError()
339341
lossfunc.__name__ = 'mse'
340-
elif lossfunc[0] == 'mae':
341-
lossfunc = keras.losses.MeanAbsoluteError
342+
elif lossfunc[0] in ['mae', 'mean_absolute_error']:
343+
lossfunc = keras.losses.MeanAbsoluteError()
342344
lossfunc.__name__ = 'mae'
343-
elif lossfunc[0] == 'mape':
344-
lossfunc = keras.losses.MeanAbsolutePercentageError
345+
elif lossfunc[0] in ['mape', 'mean_absolute_percent_error', 'mean_absolute_percentage_error']:
346+
lossfunc = keras.losses.MeanAbsolutePercentageError()
345347
lossfunc.__name__ = 'mape'
346-
elif lossfunc[0] == 'msle':
347-
lossfunc = keras.losses.MeanSquaredLogarithmicError
348+
elif lossfunc[0] in ['msle', 'mean_squared_logarithmic_error', 'mean_squared_log_error']:
349+
lossfunc = keras.losses.MeanSquaredLogarithmicError()
348350
lossfunc.__name__ = 'msle'
349-
elif lossfunc[0] in ['maxmse', 'm3se', 'mslse', 'mse_per_ax', 'maxse']:
351+
# Handle the custom loss functions in lib/losses.py
352+
elif lossfunc[0] in ['maxmse', 'm3se', 'mslse', 'mse_per_ax', 'maxse', 'smape', 'maxsape']:
350353
lossname = lossfunc[0]
351354
lossfunc = getattr(losses, lossfunc[0])
352355
lossfunc.__name__ = lossname
353356
elif lossfunc[0] in ['heteroscedastic', 'heteroscedastic_loss']:
354357
lossfunc = functools.partial(losses.heteroscedastic_loss, D=np.product(oshape), N=batch_size)
355358
lossfunc.__name__ = 'heteroscedastic_loss'
359+
# Handle user-supplied custom loss functions
356360
else:
357361
if lossfunc[0][-3:] == '.py':
358362
lossfunc[0] = lossfunc[0][:-3] # path/to/module

lib/NN.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def __init__(self, ftrain_TFR, fvalid_TFR, ftest_TFR,
185185
self.stop_file = stop_file
186186

187187
if lossfunc is None:
188-
lossfunc = keras.losses.MeanSquaredError
188+
lossfunc = keras.losses.MeanSquaredError()
189189
lossfunc.__name__ = 'mse'
190190
#else:
191191
# self.lossfunc = lossfunc
@@ -382,7 +382,7 @@ def __init__(self, ftrain_TFR, fvalid_TFR, ftest_TFR,
382382

383383
# Compile model
384384
self.model.compile(optimizer=Adam(learning_rate=self.lengthscale, amsgrad=True),
385-
loss=lossfunc())
385+
loss=lossfunc)
386386
if self.verb:
387387
self.model.summary(print_fn=logger.info)
388388

lib/loss/losses.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,14 @@ def heteroscedastic_loss(true, pred, D, N):
8888

8989
return K.mean(quad + log_det, 0)
9090

91+
def smape(y_true, y_pred):
92+
epsilon = 1e-8 # Small constant to prevent division by zero
93+
numerator = tf.abs(y_pred - y_true)
94+
denominator = tf.abs(y_true) + tf.abs(y_pred) + epsilon
95+
return K.mean(numerator / denominator) * 200
96+
97+
def maxsape(y_true, y_pred):
98+
epsilon = 1e-8 # Small constant to prevent division by zero
99+
numerator = tf.abs(y_pred - y_true)
100+
denominator = tf.abs(y_true) + tf.abs(y_pred) + epsilon
101+
return K.max(numerator / denominator) * 200

0 commit comments

Comments
 (0)