44
55import json
66
7+ import matplotlib .pyplot as plt
78import numpy as np
89import pandas as pd
910import pytest
1011
11- from ml_peg .analysis .utils .decorators import build_table , plot_periodic_table
12+ from ml_peg .analysis .utils .decorators import (
13+ PERIODIC_TABLE_COLS ,
14+ PERIODIC_TABLE_POSITIONS ,
15+ PERIODIC_TABLE_ROWS ,
16+ build_table ,
17+ render_periodic_table_grid ,
18+ )
1219from ml_peg .app import APP_ROOT
1320from ml_peg .calcs import CALCS_ROOT
1421from ml_peg .models .get_models import get_model_names
2229
2330DIATOMICS_THRESHOLDS = {
2431 "Force flips" : (1.0 , 5.0 ),
25- "Tortuosity" : (1.0 , 5.0 ),
2632 "Energy minima" : (1.0 , 5.0 ),
2733 "Energy inflections" : (1.0 , 5.0 ),
28- "Spearman's coefficient (E: repulsion)" : (- 1.0 , 1.0 ),
29- "Spearman's coefficient (F: descending)" : (- 1.0 , 1.0 ),
30- "Spearman's coefficient (E: attraction)" : (1.0 , - 1.0 ),
31- "Spearman's coefficient (F: ascending)" : (1.0 , - 1.0 ),
34+ "ρ(E, repulsion)" : (- 1.0 , 1.0 ),
35+ "ρ(E, attraction)" : (1.0 , - 1.0 ),
3236}
3337
3438
@@ -146,51 +150,15 @@ def compute_pair_metrics(
146150 mask = fs_sign != 0
147151 f_flip = int (np .sum (np .diff (fs_sign [mask ]) != 0 )) if mask .any () else 0
148152
149- fdiff = np .diff (fs )
150- fdiff [np .abs (fdiff ) < 1e-3 ] = 0.0
151- fdiff_sign = np .sign (fdiff )
152- mask = fdiff_sign != 0
153- fjump = 0.0
154- if mask .any ():
155- diff = fdiff [mask ]
156- diff_sign = fdiff_sign [mask ]
157- flips = np .diff (diff_sign ) != 0
158- if flips .any ():
159- fjump = float (
160- np .abs (diff [:- 1 ][flips ]).sum () + np .abs (diff [1 :][flips ]).sum ()
161- )
162-
163- ediff = np .diff (es )
164- ediff [np .abs (ediff ) < 1e-3 ] = 0.0
165- ediff_sign = np .sign (ediff )
166- mask = ediff_sign != 0
167- ejump = 0.0
168- ediff_flip_times = 0
169- if mask .any ():
170- diff = ediff [mask ]
171- diff_sign = ediff_sign [mask ]
172- flips = np .diff (diff_sign ) != 0
173- ediff_flip_times = int (np .sum (flips ))
174- if flips .any ():
175- ejump = float (
176- np .abs (diff [:- 1 ][flips ]).sum () + np .abs (diff [1 :][flips ]).sum ()
177- )
178-
179- conservation_deviation = float (np .mean (np .abs (fs + de_dr )))
180- energy_total_variation = float (np .sum (np .abs (np .diff (es ))))
181-
182153 well_depth = float (es .min ())
183154
184155 spearman_repulsion = np .nan
185156 spearman_attraction = np .nan
186- spearman_force_desc = np .nan
187- spearman_force_asc = np .nan
188157
189158 try :
190159 from scipy import stats
191160
192161 imine = int (np .argmin (es ))
193- iminf = int (np .argmin (fs ))
194162 if rs [imine :].size > 1 :
195163 spearman_repulsion = float (
196164 stats .spearmanr (rs [imine :], es [imine :]).statistic
@@ -199,36 +167,15 @@ def compute_pair_metrics(
199167 spearman_attraction = float (
200168 stats .spearmanr (rs [:imine ], es [:imine ]).statistic
201169 )
202- if rs [iminf :].size > 1 :
203- spearman_force_desc = float (
204- stats .spearmanr (rs [iminf :], fs [iminf :]).statistic
205- )
206- if rs [:iminf ].size > 1 :
207- spearman_force_asc = float (
208- stats .spearmanr (rs [:iminf ], fs [:iminf ]).statistic
209- )
210170 except Exception :
211171 pass
212172
213- tortuosity = 0.0
214- denominator = abs (es [0 ] - es .min ()) + (es [- 1 ] - es .min ())
215- if denominator > 0 :
216- tortuosity = float (energy_total_variation / denominator )
217-
218173 metrics = {
219174 "Force flips" : float (f_flip ),
220- "Tortuosity" : tortuosity ,
221175 "Energy minima" : float (minima ),
222176 "Energy inflections" : float (inflections ),
223- "Spearman's coefficient (E: repulsion)" : float (spearman_repulsion ),
224- "Spearman's coefficient (F: descending)" : float (spearman_force_desc ),
225- "Spearman's coefficient (E: attraction)" : float (spearman_attraction ),
226- "Spearman's coefficient (F: ascending)" : float (spearman_force_asc ),
227- "Energy diff flips" : float (ediff_flip_times ),
228- "Energy jump" : float (ejump ),
229- "Force jump" : float (fjump ),
230- "Conservation deviation" : float (conservation_deviation ),
231- "Energy total variation" : float (energy_total_variation ),
177+ "ρ(E, repulsion)" : float (spearman_repulsion ),
178+ "ρ(E, attraction)" : float (spearman_attraction ),
232179 }
233180 return metrics , well_depth
234181
@@ -296,13 +243,10 @@ def score_diatomics(
296243 """
297244 ideal_targets = {
298245 "Force flips" : 1.0 ,
299- "Tortuosity" : 1.0 ,
300246 "Energy minima" : 1.0 ,
301247 "Energy inflections" : 1.0 ,
302- "Spearman's coefficient (E: repulsion)" : - 1.0 ,
303- "Spearman's coefficient (F: descending)" : - 1.0 ,
304- "Spearman's coefficient (E: attraction)" : 1.0 ,
305- "Spearman's coefficient (F: ascending)" : 1.0 ,
248+ "ρ(E, repulsion)" : - 1.0 ,
249+ "ρ(E, attraction)" : 1.0 ,
306250 }
307251
308252 metrics_df ["Score" ] = 0.0
@@ -353,43 +297,195 @@ def write_curve_data(model_name: str, df: pd.DataFrame) -> None:
353297 json .dump (payload , fh )
354298
355299
356- def write_periodic_table_figures (
357- model_name : str , well_depths : dict [str , float ]
300+ def write_periodic_table_assets (
301+ model_name : str ,
302+ df : pd .DataFrame ,
303+ well_depths : dict [str , float ],
358304) -> None :
359305 """
360- Create periodic-table figure JSON for a model .
306+ Create periodic-table overview and element-focused plots for app consumption .
361307
362308 Parameters
363309 ----------
364310 model_name
365311 Name of the model being processed.
312+ df
313+ Dataframe containing curve samples for the model.
366314 well_depths
367315 Mapping of element symbol to homonuclear well depth.
368316 """
317+ if df .empty :
318+ return
369319
370- @plot_periodic_table (
371- title = f"{ model_name } homonuclear well depths" ,
372- colorbar_title = "Well depth / eV" ,
373- filename = str (PERIODIC_TABLE_PATH / f"{ model_name } .json" ),
374- colorscale = "Viridis" ,
375- )
376- def generate_plot (values : dict [str , float ]) -> dict [str , float ]:
320+ model_dir = PERIODIC_TABLE_PATH / model_name
321+ elements_dir = model_dir / "elements"
322+ model_dir .mkdir (parents = True , exist_ok = True )
323+ elements_dir .mkdir (parents = True , exist_ok = True )
324+
325+ def _plot_overview (ax , element : str ) -> bool :
377326 """
378- Identity helper to leverage the periodic-table decorator .
327+ Render the homonuclear curve for a single element into ``ax`` .
379328
380329 Parameters
381330 ----------
382- values
383- Mapping of element symbol to well depth.
331+ ax
332+ Matplotlib axes to draw on.
333+ element
334+ Chemical symbol identifying the homonuclear pair.
384335
385336 Returns
386337 -------
387- dict[str, float]
388- The unchanged mapping of element well depths .
338+ bool
339+ ``True`` if the element had data and was plotted, else ``False`` .
389340 """
390- return values
341+ pair_label = f"{ element } -{ element } "
342+ pair_df = (
343+ df [df ["pair" ] == pair_label ]
344+ .sort_values ("distance" )
345+ .drop_duplicates ("distance" )
346+ )
347+ if pair_df .empty :
348+ return False
349+
350+ x = pair_df ["distance" ].to_numpy ()
351+ y = pair_df ["energy" ].to_numpy ()
352+ y_shifted = y - y [- 1 ]
353+
354+ ax .plot (x , y_shifted , linewidth = 1 , color = "tab:blue" , zorder = 1 )
355+ ax .axhline (0 , color = "lightgray" , linewidth = 0.6 , zorder = 0 )
356+ ax .set_facecolor ("white" )
357+ ax .set_xlim (0.0 , 6.0 )
358+ ax .set_ylim (- 20.0 , 20.0 )
359+ ax .set_xticks ([0 , 2 , 4 , 6 ])
360+ ax .set_yticks ([- 20 , - 10 , 0 , 10 , 20 ])
361+ ax .tick_params (labelsize = 7 , length = 2 , pad = 1 )
362+
363+ depth = well_depths .get (element )
364+ label = f"{ element } \n { depth :.2f} eV" if depth is not None else element
365+ ax .text (
366+ 0.02 ,
367+ 0.95 ,
368+ label ,
369+ transform = ax .transAxes ,
370+ ha = "left" ,
371+ va = "top" ,
372+ fontsize = 8 ,
373+ fontweight = "bold" ,
374+ )
375+ return True
376+
377+ render_periodic_table_grid (
378+ title = f"Homonuclear diatomic curves: { model_name } " ,
379+ filename_stem = model_dir / "overview" ,
380+ plot_cell = _plot_overview ,
381+ figsize = (36 , 20 ),
382+ formats = ("svg" ,),
383+ suptitle_kwargs = {"fontsize" : 28 , "fontweight" : "bold" },
384+ )
385+
386+ manifest : dict [str , dict [str , str ] | str ] = {
387+ "overview" : "overview.svg" ,
388+ "elements" : {},
389+ }
390+
391+ available_elements = sorted (
392+ e
393+ for e in (set (df ["element_1" ].tolist ()) | set (df ["element_2" ].tolist ()))
394+ if isinstance (e , str ) and e
395+ )
396+ for element in available_elements :
397+ rel_path = f"elements/{ element } .png"
398+ output_path = elements_dir / f"{ element } .png"
399+ if _render_element_focus (df , element , output_path ):
400+ manifest ["elements" ][element ] = rel_path
391401
392- generate_plot (well_depths )
402+ manifest_path = model_dir / "manifest.json"
403+ manifest_path .write_text (json .dumps (manifest , indent = 2 ))
404+
405+
406+ def _render_element_focus (df : pd .DataFrame , selected_element : str , output_path ) -> bool :
407+ """
408+ Render heteronuclear overview for a selected element.
409+
410+ Parameters
411+ ----------
412+ df
413+ Dataframe containing pair data.
414+ selected_element
415+ Element to highlight in the periodic table.
416+ output_path
417+ File path to save the PNG figure.
418+
419+ Returns
420+ -------
421+ bool
422+ ``True`` if any data was rendered for the element.
423+ """
424+ pair_groups : dict [str , pd .DataFrame ] = {}
425+ for pair , df_pair in df .groupby ("pair" ):
426+ try :
427+ element1 , element2 = pair .split ("-" )
428+ except ValueError :
429+ continue
430+ if selected_element not in {element1 , element2 }:
431+ continue
432+ other = element2 if element1 == selected_element else element1
433+ pair_groups [other ] = df_pair .sort_values ("distance" ).drop_duplicates ("distance" )
434+
435+ if not pair_groups :
436+ return False
437+
438+ fig , axes = plt .subplots (
439+ PERIODIC_TABLE_ROWS ,
440+ PERIODIC_TABLE_COLS ,
441+ figsize = (30 , 15 ),
442+ constrained_layout = True ,
443+ )
444+ axes = axes .reshape (PERIODIC_TABLE_ROWS , PERIODIC_TABLE_COLS )
445+ for ax in axes .ravel ():
446+ ax .axis ("off" )
447+
448+ has_data = False
449+ for element , (row , col ) in PERIODIC_TABLE_POSITIONS .items ():
450+ pair_df = pair_groups .get (element )
451+ if pair_df is None :
452+ continue
453+ x = pair_df ["distance" ].to_numpy ()
454+ y = pair_df ["energy" ].to_numpy ()
455+ shifted = y - y [- 1 ]
456+
457+ ax = axes [row , col ]
458+ ax .axis ("on" )
459+ ax .set_facecolor ("white" )
460+ ax .plot (x , shifted , linewidth = 1 , color = "tab:blue" , zorder = 1 )
461+ ax .axhline (0 , color = "lightgray" , linewidth = 0.6 , zorder = 0 )
462+ ax .set_xlim (0.0 , 6.0 )
463+ ax .set_ylim (- 20.0 , 20.0 )
464+ ax .set_xticks ([0 , 2 , 4 , 6 ])
465+ ax .set_yticks ([- 20 , - 10 , 0 , 10 , 20 ])
466+ ax .tick_params (labelsize = 7 , length = 2 , pad = 1 )
467+ ax .set_title (
468+ f"{ selected_element } -{ element } , shift: { float (y [- 1 ]):.4f} " ,
469+ fontsize = 8 ,
470+ )
471+ if element == selected_element :
472+ for spine in ax .spines .values ():
473+ spine .set_edgecolor ("crimson" )
474+ spine .set_linewidth (2 )
475+ has_data = True
476+
477+ if not has_data :
478+ plt .close (fig )
479+ return False
480+
481+ fig .suptitle (
482+ f"Diatomics involving { selected_element } " ,
483+ fontsize = 22 ,
484+ fontweight = "bold" ,
485+ )
486+ fig .savefig (output_path , format = "png" , dpi = 200 )
487+ plt .close (fig )
488+ return True
393489
394490
395491def collect_metrics (
@@ -428,7 +524,7 @@ def collect_metrics(
428524 rows .append (row )
429525
430526 write_curve_data (model_name , df )
431- write_periodic_table_figures (model_name , well_depths )
527+ write_periodic_table_assets (model_name , df , well_depths )
432528 model_well_depths [model_name ] = well_depths
433529
434530 if not rows :
@@ -521,21 +617,14 @@ def diatomics_well_depths(
521617 metric_tooltips = {
522618 "Model" : "Name of the model" ,
523619 "Force flips" : "Mean count of force-direction changes per pair (ideal 1)" ,
524- "Tortuosity" : "Energy curve tortuosity (lower is smoother)" ,
525620 "Energy minima" : "Average number of energy minima per pair" ,
526621 "Energy inflections" : "Average number of energy inflection points per pair" ,
527- "Spearman's coefficient (E: repulsion)" : (
622+ "ρ(E, repulsion)" : (
528623 "Spearman correlation for energy in repulsive regime (ideal -1)"
529624 ),
530- "Spearman's coefficient (F: descending)" : (
531- "Spearman correlation for force in descending regime (ideal -1)"
532- ),
533- "Spearman's coefficient (E: attraction)" : (
625+ "ρ(E, attraction)" : (
534626 "Spearman correlation for energy in attractive regime (ideal +1)"
535627 ),
536- "Spearman's coefficient (F: ascending)" : (
537- "Spearman correlation for force in ascending regime (ideal +1)"
538- ),
539628 "Score" : "Aggregate deviation from physical targets (lower is better)" ,
540629 "Rank" : "Model ranking based on score (lower is better)" ,
541630 },
0 commit comments