Skip to content

Commit 549fef0

Browse files
committed
clean up app and generalise building the dropdowns for the ptable
1 parent c5146db commit 549fef0

File tree

4 files changed

+481
-247
lines changed

4 files changed

+481
-247
lines changed

ml_peg/analysis/molecular/Diatomics/analyse_Diatomics.py

Lines changed: 186 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@
44

55
import json
66

7+
import matplotlib.pyplot as plt
78
import numpy as np
89
import pandas as pd
910
import pytest
1011

11-
from ml_peg.analysis.utils.decorators import build_table, plot_periodic_table
12+
from ml_peg.analysis.utils.decorators import (
13+
PERIODIC_TABLE_COLS,
14+
PERIODIC_TABLE_POSITIONS,
15+
PERIODIC_TABLE_ROWS,
16+
build_table,
17+
render_periodic_table_grid,
18+
)
1219
from ml_peg.app import APP_ROOT
1320
from ml_peg.calcs import CALCS_ROOT
1421
from ml_peg.models.get_models import get_model_names
@@ -22,13 +29,10 @@
2229

2330
DIATOMICS_THRESHOLDS = {
2431
"Force flips": (1.0, 5.0),
25-
"Tortuosity": (1.0, 5.0),
2632
"Energy minima": (1.0, 5.0),
2733
"Energy inflections": (1.0, 5.0),
28-
"Spearman's coefficient (E: repulsion)": (-1.0, 1.0),
29-
"Spearman's coefficient (F: descending)": (-1.0, 1.0),
30-
"Spearman's coefficient (E: attraction)": (1.0, -1.0),
31-
"Spearman's coefficient (F: ascending)": (1.0, -1.0),
34+
"ρ(E, repulsion)": (-1.0, 1.0),
35+
"ρ(E, attraction)": (1.0, -1.0),
3236
}
3337

3438

@@ -146,51 +150,15 @@ def compute_pair_metrics(
146150
mask = fs_sign != 0
147151
f_flip = int(np.sum(np.diff(fs_sign[mask]) != 0)) if mask.any() else 0
148152

149-
fdiff = np.diff(fs)
150-
fdiff[np.abs(fdiff) < 1e-3] = 0.0
151-
fdiff_sign = np.sign(fdiff)
152-
mask = fdiff_sign != 0
153-
fjump = 0.0
154-
if mask.any():
155-
diff = fdiff[mask]
156-
diff_sign = fdiff_sign[mask]
157-
flips = np.diff(diff_sign) != 0
158-
if flips.any():
159-
fjump = float(
160-
np.abs(diff[:-1][flips]).sum() + np.abs(diff[1:][flips]).sum()
161-
)
162-
163-
ediff = np.diff(es)
164-
ediff[np.abs(ediff) < 1e-3] = 0.0
165-
ediff_sign = np.sign(ediff)
166-
mask = ediff_sign != 0
167-
ejump = 0.0
168-
ediff_flip_times = 0
169-
if mask.any():
170-
diff = ediff[mask]
171-
diff_sign = ediff_sign[mask]
172-
flips = np.diff(diff_sign) != 0
173-
ediff_flip_times = int(np.sum(flips))
174-
if flips.any():
175-
ejump = float(
176-
np.abs(diff[:-1][flips]).sum() + np.abs(diff[1:][flips]).sum()
177-
)
178-
179-
conservation_deviation = float(np.mean(np.abs(fs + de_dr)))
180-
energy_total_variation = float(np.sum(np.abs(np.diff(es))))
181-
182153
well_depth = float(es.min())
183154

184155
spearman_repulsion = np.nan
185156
spearman_attraction = np.nan
186-
spearman_force_desc = np.nan
187-
spearman_force_asc = np.nan
188157

189158
try:
190159
from scipy import stats
191160

192161
imine = int(np.argmin(es))
193-
iminf = int(np.argmin(fs))
194162
if rs[imine:].size > 1:
195163
spearman_repulsion = float(
196164
stats.spearmanr(rs[imine:], es[imine:]).statistic
@@ -199,36 +167,15 @@ def compute_pair_metrics(
199167
spearman_attraction = float(
200168
stats.spearmanr(rs[:imine], es[:imine]).statistic
201169
)
202-
if rs[iminf:].size > 1:
203-
spearman_force_desc = float(
204-
stats.spearmanr(rs[iminf:], fs[iminf:]).statistic
205-
)
206-
if rs[:iminf].size > 1:
207-
spearman_force_asc = float(
208-
stats.spearmanr(rs[:iminf], fs[:iminf]).statistic
209-
)
210170
except Exception:
211171
pass
212172

213-
tortuosity = 0.0
214-
denominator = abs(es[0] - es.min()) + (es[-1] - es.min())
215-
if denominator > 0:
216-
tortuosity = float(energy_total_variation / denominator)
217-
218173
metrics = {
219174
"Force flips": float(f_flip),
220-
"Tortuosity": tortuosity,
221175
"Energy minima": float(minima),
222176
"Energy inflections": float(inflections),
223-
"Spearman's coefficient (E: repulsion)": float(spearman_repulsion),
224-
"Spearman's coefficient (F: descending)": float(spearman_force_desc),
225-
"Spearman's coefficient (E: attraction)": float(spearman_attraction),
226-
"Spearman's coefficient (F: ascending)": float(spearman_force_asc),
227-
"Energy diff flips": float(ediff_flip_times),
228-
"Energy jump": float(ejump),
229-
"Force jump": float(fjump),
230-
"Conservation deviation": float(conservation_deviation),
231-
"Energy total variation": float(energy_total_variation),
177+
"ρ(E, repulsion)": float(spearman_repulsion),
178+
"ρ(E, attraction)": float(spearman_attraction),
232179
}
233180
return metrics, well_depth
234181

@@ -296,13 +243,10 @@ def score_diatomics(
296243
"""
297244
ideal_targets = {
298245
"Force flips": 1.0,
299-
"Tortuosity": 1.0,
300246
"Energy minima": 1.0,
301247
"Energy inflections": 1.0,
302-
"Spearman's coefficient (E: repulsion)": -1.0,
303-
"Spearman's coefficient (F: descending)": -1.0,
304-
"Spearman's coefficient (E: attraction)": 1.0,
305-
"Spearman's coefficient (F: ascending)": 1.0,
248+
"ρ(E, repulsion)": -1.0,
249+
"ρ(E, attraction)": 1.0,
306250
}
307251

308252
metrics_df["Score"] = 0.0
@@ -353,43 +297,195 @@ def write_curve_data(model_name: str, df: pd.DataFrame) -> None:
353297
json.dump(payload, fh)
354298

355299

356-
def write_periodic_table_figures(
357-
model_name: str, well_depths: dict[str, float]
300+
def write_periodic_table_assets(
301+
model_name: str,
302+
df: pd.DataFrame,
303+
well_depths: dict[str, float],
358304
) -> None:
359305
"""
360-
Create periodic-table figure JSON for a model.
306+
Create periodic-table overview and element-focused plots for app consumption.
361307
362308
Parameters
363309
----------
364310
model_name
365311
Name of the model being processed.
312+
df
313+
Dataframe containing curve samples for the model.
366314
well_depths
367315
Mapping of element symbol to homonuclear well depth.
368316
"""
317+
if df.empty:
318+
return
369319

370-
@plot_periodic_table(
371-
title=f"{model_name} homonuclear well depths",
372-
colorbar_title="Well depth / eV",
373-
filename=str(PERIODIC_TABLE_PATH / f"{model_name}.json"),
374-
colorscale="Viridis",
375-
)
376-
def generate_plot(values: dict[str, float]) -> dict[str, float]:
320+
model_dir = PERIODIC_TABLE_PATH / model_name
321+
elements_dir = model_dir / "elements"
322+
model_dir.mkdir(parents=True, exist_ok=True)
323+
elements_dir.mkdir(parents=True, exist_ok=True)
324+
325+
def _plot_overview(ax, element: str) -> bool:
377326
"""
378-
Identity helper to leverage the periodic-table decorator.
327+
Render the homonuclear curve for a single element into ``ax``.
379328
380329
Parameters
381330
----------
382-
values
383-
Mapping of element symbol to well depth.
331+
ax
332+
Matplotlib axes to draw on.
333+
element
334+
Chemical symbol identifying the homonuclear pair.
384335
385336
Returns
386337
-------
387-
dict[str, float]
388-
The unchanged mapping of element well depths.
338+
bool
339+
``True`` if the element had data and was plotted, else ``False``.
389340
"""
390-
return values
341+
pair_label = f"{element}-{element}"
342+
pair_df = (
343+
df[df["pair"] == pair_label]
344+
.sort_values("distance")
345+
.drop_duplicates("distance")
346+
)
347+
if pair_df.empty:
348+
return False
349+
350+
x = pair_df["distance"].to_numpy()
351+
y = pair_df["energy"].to_numpy()
352+
y_shifted = y - y[-1]
353+
354+
ax.plot(x, y_shifted, linewidth=1, color="tab:blue", zorder=1)
355+
ax.axhline(0, color="lightgray", linewidth=0.6, zorder=0)
356+
ax.set_facecolor("white")
357+
ax.set_xlim(0.0, 6.0)
358+
ax.set_ylim(-20.0, 20.0)
359+
ax.set_xticks([0, 2, 4, 6])
360+
ax.set_yticks([-20, -10, 0, 10, 20])
361+
ax.tick_params(labelsize=7, length=2, pad=1)
362+
363+
depth = well_depths.get(element)
364+
label = f"{element}\n{depth:.2f} eV" if depth is not None else element
365+
ax.text(
366+
0.02,
367+
0.95,
368+
label,
369+
transform=ax.transAxes,
370+
ha="left",
371+
va="top",
372+
fontsize=8,
373+
fontweight="bold",
374+
)
375+
return True
376+
377+
render_periodic_table_grid(
378+
title=f"Homonuclear diatomic curves: {model_name}",
379+
filename_stem=model_dir / "overview",
380+
plot_cell=_plot_overview,
381+
figsize=(36, 20),
382+
formats=("svg",),
383+
suptitle_kwargs={"fontsize": 28, "fontweight": "bold"},
384+
)
385+
386+
manifest: dict[str, dict[str, str] | str] = {
387+
"overview": "overview.svg",
388+
"elements": {},
389+
}
390+
391+
available_elements = sorted(
392+
e
393+
for e in (set(df["element_1"].tolist()) | set(df["element_2"].tolist()))
394+
if isinstance(e, str) and e
395+
)
396+
for element in available_elements:
397+
rel_path = f"elements/{element}.png"
398+
output_path = elements_dir / f"{element}.png"
399+
if _render_element_focus(df, element, output_path):
400+
manifest["elements"][element] = rel_path
391401

392-
generate_plot(well_depths)
402+
manifest_path = model_dir / "manifest.json"
403+
manifest_path.write_text(json.dumps(manifest, indent=2))
404+
405+
406+
def _render_element_focus(df: pd.DataFrame, selected_element: str, output_path) -> bool:
407+
"""
408+
Render heteronuclear overview for a selected element.
409+
410+
Parameters
411+
----------
412+
df
413+
Dataframe containing pair data.
414+
selected_element
415+
Element to highlight in the periodic table.
416+
output_path
417+
File path to save the PNG figure.
418+
419+
Returns
420+
-------
421+
bool
422+
``True`` if any data was rendered for the element.
423+
"""
424+
pair_groups: dict[str, pd.DataFrame] = {}
425+
for pair, df_pair in df.groupby("pair"):
426+
try:
427+
element1, element2 = pair.split("-")
428+
except ValueError:
429+
continue
430+
if selected_element not in {element1, element2}:
431+
continue
432+
other = element2 if element1 == selected_element else element1
433+
pair_groups[other] = df_pair.sort_values("distance").drop_duplicates("distance")
434+
435+
if not pair_groups:
436+
return False
437+
438+
fig, axes = plt.subplots(
439+
PERIODIC_TABLE_ROWS,
440+
PERIODIC_TABLE_COLS,
441+
figsize=(30, 15),
442+
constrained_layout=True,
443+
)
444+
axes = axes.reshape(PERIODIC_TABLE_ROWS, PERIODIC_TABLE_COLS)
445+
for ax in axes.ravel():
446+
ax.axis("off")
447+
448+
has_data = False
449+
for element, (row, col) in PERIODIC_TABLE_POSITIONS.items():
450+
pair_df = pair_groups.get(element)
451+
if pair_df is None:
452+
continue
453+
x = pair_df["distance"].to_numpy()
454+
y = pair_df["energy"].to_numpy()
455+
shifted = y - y[-1]
456+
457+
ax = axes[row, col]
458+
ax.axis("on")
459+
ax.set_facecolor("white")
460+
ax.plot(x, shifted, linewidth=1, color="tab:blue", zorder=1)
461+
ax.axhline(0, color="lightgray", linewidth=0.6, zorder=0)
462+
ax.set_xlim(0.0, 6.0)
463+
ax.set_ylim(-20.0, 20.0)
464+
ax.set_xticks([0, 2, 4, 6])
465+
ax.set_yticks([-20, -10, 0, 10, 20])
466+
ax.tick_params(labelsize=7, length=2, pad=1)
467+
ax.set_title(
468+
f"{selected_element}-{element}, shift: {float(y[-1]):.4f}",
469+
fontsize=8,
470+
)
471+
if element == selected_element:
472+
for spine in ax.spines.values():
473+
spine.set_edgecolor("crimson")
474+
spine.set_linewidth(2)
475+
has_data = True
476+
477+
if not has_data:
478+
plt.close(fig)
479+
return False
480+
481+
fig.suptitle(
482+
f"Diatomics involving {selected_element}",
483+
fontsize=22,
484+
fontweight="bold",
485+
)
486+
fig.savefig(output_path, format="png", dpi=200)
487+
plt.close(fig)
488+
return True
393489

394490

395491
def collect_metrics(
@@ -428,7 +524,7 @@ def collect_metrics(
428524
rows.append(row)
429525

430526
write_curve_data(model_name, df)
431-
write_periodic_table_figures(model_name, well_depths)
527+
write_periodic_table_assets(model_name, df, well_depths)
432528
model_well_depths[model_name] = well_depths
433529

434530
if not rows:
@@ -521,21 +617,14 @@ def diatomics_well_depths(
521617
metric_tooltips={
522618
"Model": "Name of the model",
523619
"Force flips": "Mean count of force-direction changes per pair (ideal 1)",
524-
"Tortuosity": "Energy curve tortuosity (lower is smoother)",
525620
"Energy minima": "Average number of energy minima per pair",
526621
"Energy inflections": "Average number of energy inflection points per pair",
527-
"Spearman's coefficient (E: repulsion)": (
622+
"ρ(E, repulsion)": (
528623
"Spearman correlation for energy in repulsive regime (ideal -1)"
529624
),
530-
"Spearman's coefficient (F: descending)": (
531-
"Spearman correlation for force in descending regime (ideal -1)"
532-
),
533-
"Spearman's coefficient (E: attraction)": (
625+
"ρ(E, attraction)": (
534626
"Spearman correlation for energy in attractive regime (ideal +1)"
535627
),
536-
"Spearman's coefficient (F: ascending)": (
537-
"Spearman correlation for force in ascending regime (ideal +1)"
538-
),
539628
"Score": "Aggregate deviation from physical targets (lower is better)",
540629
"Rank": "Model ranking based on score (lower is better)",
541630
},

0 commit comments

Comments
 (0)