@@ -88,9 +88,20 @@ def predicted_data(self) -> FlattenedStorage:
8888 def plot (self ):
8989 """
9090 Plots correlation and (training) error histograms.
91+
92+ See :class:`.PotentialPlots`.
9193 """
9294 return PotentialPlots (self .training_data , self .predicted_data )
9395
96+ @property
97+ def metrics (self ):
98+ """
99+ Calculate error (training) metrics.
100+
101+ See :class:`.PotentialMetrics`.
102+ """
103+ return PotentialMetrics (self .training_data , self .predicted_data )
104+
94105 @abc .abstractmethod
95106 def get_lammps_potential (self ) -> pd .DataFrame :
96107 """
@@ -103,6 +114,7 @@ def get_lammps_potential(self) -> pd.DataFrame:
103114 pass
104115
105116
117+
106118class PotentialPlots :
107119 def __init__ (self , training_data , predicted_data ):
108120 self ._training_data = training_data
@@ -291,3 +303,51 @@ def force_angle_histogram(
291303 "Angular Deviation of Force [" + ["rad" , "deg" ][angle_in_degrees ] + "]"
292304 )
293305 plt .ylabel ("Count" )
306+
307+
308+ class PotentialMetrics :
309+ """
310+ Calculates various error metrics on training and test data.
311+ """
312+
313+ __slots__ = ("_true_data" , "_predicted_data" )
314+
315+ def __init__ (self , true_data : TrainingStorage , predicted_data : FlattenedStorage ):
316+ self ._true_data = true_data
317+ self ._predicted_data = predicted_data
318+
319+ def _rmse (self , a , b ):
320+ return np .sqrt (np .mean ((a - b )** 2 ))
321+
322+ def _mae (self , a , b ):
323+ return np .mean (np .abs (a - b ))
324+
325+ @property
326+ def energy_rmse (self ):
327+ N = self ._true_data ["length" ]
328+ return self ._rmse (
329+ self ._true_data ["energy" ]/ N ,
330+ self ._predicted_data ["energy" ]/ N
331+ )
332+
333+ @property
334+ def energy_mae (self ):
335+ N = self ._true_data ["length" ]
336+ return self ._mae (
337+ self ._true_data ["energy" ]/ N ,
338+ self ._predicted_data ["energy" ]/ N
339+ )
340+
341+ @property
342+ def force_rmse (self ):
343+ return self ._rmse (
344+ np .linalg .norm (self ._true_data ["forces" ], axis = - 1 ),
345+ np .linalg .norm (self ._predicted_data ["forces" ], axis = - 1 )
346+ )
347+
348+ @property
349+ def force_mae (self ):
350+ return self ._mae (
351+ np .linalg .norm (self ._true_data ["forces" ], axis = - 1 ),
352+ np .linalg .norm (self ._predicted_data ["forces" ], axis = - 1 )
353+ )
0 commit comments