Skip to content

Commit

Permalink
Merge pull request #167 from nasa/bug/calc_error_zero
Browse files Browse the repository at this point in the history
Fix 0 case
  • Loading branch information
teubert authored Nov 5, 2024
2 parents a9edf31 + ba06ba9 commit cf1ebfb
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/progpy/data_models/lstm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
output_data.append(t_all)

model = keras.Model(inputs, outputs)
model.compile(optimizer="rmsprop", loss="mse", metrics=["mae"]*len(outputs))
model.compile(optimizer="rmsprop", loss="mse", metrics=[["mae"]]*len(outputs))

# Train model
history = model.fit(
Expand Down
4 changes: 3 additions & 1 deletion src/progpy/utils/calc_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def MSE(m, times: List[float], inputs: List[dict], outputs: List[dict], **kwargs
If the model goes unstable before stability_tol is met and short_sim_penalty is None, then exception is raised
Else if the model goes unstable before stability_tol is met and short_sim_penalty is not None- the penalty is added to the score
Else, model goes unstable after stability_tol is met, the error calculated from data up to the instability is returned.
short_sim_penalty (float, optional): penalty added for simulation becoming unstable before stability_tol, added for each % below tol. Default is 100
short_sim_penalty (float, optional): penalty added for simulation becoming unstable before stability_tol, added for each % below tol. If set to None, operation will return an error if simulation becomes unstable before stability_tol. Default is 100
Returns:
float: Total error
Expand Down Expand Up @@ -190,6 +190,8 @@ def MSE(m, times: List[float], inputs: List[dict], outputs: List[dict], **kwargs
warn(f"Model unstable- NAN reached in simulation (t={t}) before cutoff threshold. "
f"Cutoff threshold is {cutoffThreshold}, or roughly {stability_tol * 100}% of the data. Penalty added to score.")
# Return value with Penalty added
if counter == 0:
return 100*short_sim_penalty
return err_total/counter + (100-(t/cutoffThreshold)*100)*short_sim_penalty
else:
warn("Model unstable- NaN reached in simulation (t={})".format(t))
Expand Down
17 changes: 13 additions & 4 deletions tests/test_calc_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,14 @@ def future_loading(t, x=None):

# With our current set parameters, our model goes unstable immediately
with self.assertRaises(ValueError) as cm:
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs, dt=1)
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs, dt=1, short_sim_penalty=None)
self.assertEqual(
"Model unstable- NAN reached in simulation (t=0.0) before cutoff threshold. Cutoff threshold is 1900.0, or roughly 95.0% of the data",
str(cm.exception)
)
)

# Shouldn't raise error for default case (i.e., short_sim_penalty is not None)
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs, dt=1)

# Creating duplicate model to check if consistent results occur
m1 = BatteryElectroChemEOD()
Expand All @@ -131,20 +134,26 @@ def future_loading(t, x=None):

# Checks to see if model goes unstable before default stability tolerance is met.
with self.assertRaises(ValueError) as cm:
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs, dt = 1)
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs, dt=1, short_sim_penalty=None)
self.assertEqual(
"Model unstable- NAN reached in simulation (t=1800.0) before cutoff threshold. Cutoff threshold is 1900.0, or roughly 95.0% of the data",
str(cm.exception)
)

# Shouldn't happen for default case (short_sim_penalty is not none)
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs, dt=1)

# Checks to see if m1 throws the same exception.
with self.assertRaises(ValueError):
m1.calc_error(m1_sim_results.times, m1_sim_results.inputs, m1_sim_results.outputs, dt = 1)
m1.calc_error(m1_sim_results.times, m1_sim_results.inputs, m1_sim_results.outputs, dt=1, short_sim_penalty=None)
self.assertEqual(
"Model unstable- NAN reached in simulation (t=1800.0) before cutoff threshold. Cutoff threshold is 1900.0, or roughly 95.0% of the data",
str(cm.exception)
)

# Shouldn't for default case
m1.calc_error(m1_sim_results.times, m1_sim_results.inputs, m1_sim_results.outputs, dt=1)

# Checks to see if stability_tolerance throws Warning rather than an Error when the model goes unstable after threshold
with self.assertWarns(UserWarning) as cm:
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs,
Expand Down

0 comments on commit cf1ebfb

Please sign in to comment.