Skip to content

Commit cf1ebfb

Browse files
authored
Merge pull request #167 from nasa/bug/calc_error_zero
Fix 0 case
2 parents a9edf31 + ba06ba9 commit cf1ebfb

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

src/progpy/data_models/lstm_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
585585
output_data.append(t_all)
586586

587587
model = keras.Model(inputs, outputs)
588-
model.compile(optimizer="rmsprop", loss="mse", metrics=["mae"]*len(outputs))
588+
model.compile(optimizer="rmsprop", loss="mse", metrics=[["mae"]]*len(outputs))
589589

590590
# Train model
591591
history = model.fit(

src/progpy/utils/calc_error.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def MSE(m, times: List[float], inputs: List[dict], outputs: List[dict], **kwargs
142142
If the model goes unstable before stability_tol is met and short_sim_penalty is None, then exception is raised
143143
Else if the model goes unstable before stability_tol is met and short_sim_penalty is not None- the penalty is added to the score
144144
Else, model goes unstable after stability_tol is met, the error calculated from data up to the instability is returned.
145-
short_sim_penalty (float, optional): penalty added for simulation becoming unstable before stability_tol, added for each % below tol. Default is 100
145+
short_sim_penalty (float, optional): penalty added for simulation becoming unstable before stability_tol, added for each % below tol. If set to None, operation will return an error if simulation becomes unstable before stability_tol. Default is 100
146146
147147
Returns:
148148
float: Total error
@@ -190,6 +190,8 @@ def MSE(m, times: List[float], inputs: List[dict], outputs: List[dict], **kwargs
190190
warn(f"Model unstable- NAN reached in simulation (t={t}) before cutoff threshold. "
191191
f"Cutoff threshold is {cutoffThreshold}, or roughly {stability_tol * 100}% of the data. Penalty added to score.")
192192
# Return value with Penalty added
193+
if counter == 0:
194+
return 100*short_sim_penalty
193195
return err_total/counter + (100-(t/cutoffThreshold)*100)*short_sim_penalty
194196
else:
195197
warn("Model unstable- NaN reached in simulation (t={})".format(t))

tests/test_calc_error.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,14 @@ def future_loading(t, x=None):
111111

112112
# With our current set parameters, our model goes unstable immediately
113113
with self.assertRaises(ValueError) as cm:
114-
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs, dt=1)
114+
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs, dt=1, short_sim_penalty=None)
115115
self.assertEqual(
116116
"Model unstable- NAN reached in simulation (t=0.0) before cutoff threshold. Cutoff threshold is 1900.0, or roughly 95.0% of the data",
117117
str(cm.exception)
118-
)
118+
)
119+
120+
# Shouldn't raise error for default case (i.e., short_sim_penalty is not None)
121+
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs, dt=1)
119122

120123
# Creating duplicate model to check if consistent results occur
121124
m1 = BatteryElectroChemEOD()
@@ -131,20 +134,26 @@ def future_loading(t, x=None):
131134

132135
# Checks to see if model goes unstable before default stability tolerance is met.
133136
with self.assertRaises(ValueError) as cm:
134-
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs, dt = 1)
137+
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs, dt=1, short_sim_penalty=None)
135138
self.assertEqual(
136139
"Model unstable- NAN reached in simulation (t=1800.0) before cutoff threshold. Cutoff threshold is 1900.0, or roughly 95.0% of the data",
137140
str(cm.exception)
138141
)
142+
143+
# Shouldn't happen for default case (short_sim_penalty is not none)
144+
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs, dt=1)
139145

140146
# Checks to see if m1 throws the same exception.
141147
with self.assertRaises(ValueError):
142-
m1.calc_error(m1_sim_results.times, m1_sim_results.inputs, m1_sim_results.outputs, dt = 1)
148+
m1.calc_error(m1_sim_results.times, m1_sim_results.inputs, m1_sim_results.outputs, dt=1, short_sim_penalty=None)
143149
self.assertEqual(
144150
"Model unstable- NAN reached in simulation (t=1800.0) before cutoff threshold. Cutoff threshold is 1900.0, or roughly 95.0% of the data",
145151
str(cm.exception)
146152
)
147153

154+
# Shouldn't for default case
155+
m1.calc_error(m1_sim_results.times, m1_sim_results.inputs, m1_sim_results.outputs, dt=1)
156+
148157
# Checks to see if stability_tolerance throws Warning rather than an Error when the model goes unstable after threshold
149158
with self.assertWarns(UserWarning) as cm:
150159
m.calc_error(simulated_results.times, simulated_results.inputs, simulated_results.outputs,

0 commit comments

Comments
 (0)