Skip to content

Commit 4011cc0

Browse files
committed
tidy elasticity, move to interactive cells (but required a bit more code to get the legend to update too)
1 parent 4c8fb96 commit 4011cc0

File tree

7 files changed

+189
-133
lines changed

7 files changed

+189
-133
lines changed

ml_peg/analysis/bulk_crystal/elasticity/analyse_elasticity.py

Lines changed: 13 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
build_table,
1313
plot_density_scatter,
1414
)
15-
from ml_peg.analysis.utils.utils import load_metrics_config, mae
15+
from ml_peg.analysis.utils.utils import (
16+
build_density_inputs,
17+
load_metrics_config,
18+
mae,
19+
)
1620
from ml_peg.app import APP_ROOT
1721
from ml_peg.calcs import CALCS_ROOT
1822
from ml_peg.models.get_models import get_model_names
@@ -31,26 +35,6 @@
3135
G_COLUMN = "G_vrh"
3236

3337

34-
def _load_results(model_name: str) -> pd.DataFrame | None:
35-
"""
36-
Load benchmark results for a model, if available.
37-
38-
Parameters
39-
----------
40-
model_name
41-
Name of the model whose results should be read.
42-
43-
Returns
44-
-------
45-
pd.DataFrame | None
46-
Results dataframe if present, otherwise ``None``.
47-
"""
48-
results_path = CALC_PATH / model_name / "moduli_results.csv"
49-
if not results_path.exists():
50-
return None
51-
return pd.read_csv(results_path)
52-
53-
5438
def _filter_results(df: pd.DataFrame, model_name: str) -> tuple[pd.DataFrame, int]:
5539
"""
5640
Filter outlier predictions and return remaining data with exclusion count.
@@ -85,22 +69,11 @@ def _collect_model_data() -> dict[str, dict[str, Any]]:
8569
"""
8670
stats: dict[str, dict[str, Any]] = {}
8771
for model_name in MODELS:
88-
df = _load_results(model_name)
89-
if df is None:
90-
stats[model_name] = {
91-
"bulk": None,
92-
"shear": None,
93-
"excluded": None,
94-
}
95-
continue
72+
results_path = CALC_PATH / model_name / "moduli_results.csv"
73+
df = pd.read_csv(results_path)
74+
9675
filtered, excluded = _filter_results(df, model_name)
97-
if filtered.empty:
98-
stats[model_name] = {
99-
"bulk": None,
100-
"shear": None,
101-
"excluded": excluded,
102-
}
103-
continue
76+
10477
stats[model_name] = {
10578
"bulk": {
10679
"ref": filtered[f"{K_COLUMN}_DFT"].tolist(),
@@ -115,44 +88,6 @@ def _collect_model_data() -> dict[str, dict[str, Any]]:
11588
return stats
11689

11790

118-
def _density_inputs(property_key: str, model_stats: dict[str, dict[str, Any]]) -> dict:
119-
"""
120-
Prepare mapping for density scatter decorator.
121-
122-
Parameters
123-
----------
124-
property_key
125-
Property key to extract (``"bulk"`` or ``"shear"``).
126-
model_stats
127-
Aggregated statistics for each model.
128-
129-
Returns
130-
-------
131-
dict
132-
Mapping of model name to density-scatter inputs.
133-
"""
134-
inputs = {}
135-
for model_name in MODELS:
136-
prop = model_stats.get(model_name, {}).get(property_key)
137-
excluded = model_stats.get(model_name, {}).get("excluded")
138-
if not prop:
139-
inputs[model_name] = {
140-
"ref": [],
141-
"pred": [],
142-
"meta": {"excluded": excluded},
143-
}
144-
continue
145-
ref_vals = prop["ref"]
146-
pred_vals = prop["pred"]
147-
inputs[model_name] = {
148-
"ref": ref_vals,
149-
"pred": pred_vals,
150-
"mae": mae(ref_vals, pred_vals) if ref_vals else None,
151-
"meta": {"excluded": excluded},
152-
}
153-
return inputs
154-
155-
15691
@pytest.fixture
15792
def elasticity_stats() -> dict[str, dict[str, Any]]:
15893
"""
@@ -185,9 +120,6 @@ def bulk_mae(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, float | N
185120
results: dict[str, float | None] = {}
186121
for model_name in MODELS:
187122
prop = elasticity_stats.get(model_name, {}).get("bulk")
188-
if not prop:
189-
results[model_name] = None
190-
continue
191123
results[model_name] = mae(prop["ref"], prop["pred"])
192124
return results
193125

@@ -210,17 +142,14 @@ def shear_mae(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, float |
210142
results: dict[str, float | None] = {}
211143
for model_name in MODELS:
212144
prop = elasticity_stats.get(model_name, {}).get("shear")
213-
if not prop:
214-
results[model_name] = None
215-
continue
216145
results[model_name] = mae(prop["ref"], prop["pred"])
217146
return results
218147

219148

220149
@pytest.fixture
221150
@plot_density_scatter(
222151
filename=OUT_PATH / "figure_bulk_density.json",
223-
title="Bulk modulus density",
152+
title="Bulk modulus density plot",
224153
x_label="Reference bulk modulus / GPa",
225154
y_label="Predicted bulk modulus / GPa",
226155
)
@@ -238,13 +167,13 @@ def bulk_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict]
238167
dict[str, dict]
239168
Mapping of model name to density-scatter data.
240169
"""
241-
return _density_inputs("bulk", elasticity_stats)
170+
return build_density_inputs(MODELS, elasticity_stats, "bulk", mae_fn=mae)
242171

243172

244173
@pytest.fixture
245174
@plot_density_scatter(
246175
filename=OUT_PATH / "figure_shear_density.json",
247-
title="Shear modulus density",
176+
title="Shear modulus density plot",
248177
x_label="Reference shear modulus / GPa",
249178
y_label="Predicted shear modulus / GPa",
250179
)
@@ -262,7 +191,7 @@ def shear_density(elasticity_stats: dict[str, dict[str, Any]]) -> dict[str, dict
262191
dict[str, dict]
263192
Mapping of model name to density-scatter data.
264193
"""
265-
return _density_inputs("shear", elasticity_stats)
194+
return build_density_inputs(MODELS, elasticity_stats, "shear", mae_fn=mae)
266195

267196

268197
@pytest.fixture

ml_peg/analysis/utils/decorators.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,12 @@ def plot_density_scatter(
261261
seed: int = 0,
262262
) -> Callable:
263263
"""
264-
Plot density-coloured parity scatter with model selector.
264+
Plot density-coloured parity scatter with legend-based model toggling.
265265
266266
The decorated function must return a mapping of model name to a dictionary with
267267
``ref`` and ``pred`` arrays (and optional ``mae``). Each model is rendered as a
268268
scatter trace with marker colours indicating local data density.
269-
Only one model is shown at a time using a dropdown selector.
269+
Only one model is shown at a time; use the legend to toggle models.
270270
271271
Parameters
272272
----------
@@ -486,41 +486,21 @@ def _downsample(
486486
)
487487
)
488488

489-
buttons = []
490-
n_models = len(model_names)
491-
for idx, model in enumerate(model_names):
492-
visible = [False] * n_models + [True]
493-
visible[idx] = True
494-
buttons.append(
495-
{
496-
"label": model,
497-
"method": "update",
498-
"args": [
499-
{"visible": visible},
500-
{
501-
"title": f"{title} - {model}" if title else None,
502-
"annotations": [annotations[idx]],
503-
},
504-
],
505-
}
506-
)
489+
# Store all annotations and model order in layout meta so consumers
490+
# can swap annotation text when filtering per-model on the frontend.
491+
layout_meta = {
492+
"annotations": annotations,
493+
"models": model_names,
494+
}
507495

508496
fig.update_layout(
509-
title={"text": f"{title} - {model_names[0]}" if title else None},
497+
title={"text": title} if title else None,
510498
xaxis={"title": {"text": x_label}},
511499
yaxis={"title": {"text": y_label}},
512-
updatemenus=[
513-
{
514-
"buttons": buttons,
515-
"direction": "down",
516-
"showactive": True,
517-
"x": 0.0,
518-
"xanchor": "left",
519-
"y": 1.15,
520-
"yanchor": "top",
521-
}
522-
],
523500
annotations=[annotations[0]],
501+
meta=layout_meta,
502+
showlegend=True,
503+
legend_title_text="Model",
524504
)
525505

526506
Path(filename).parent.mkdir(parents=True, exist_ok=True)

ml_peg/analysis/utils/utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from collections.abc import Callable
66
from pathlib import Path
7+
from typing import Any
78

89
from matplotlib import cm
910
from matplotlib.colors import Colormap
@@ -114,6 +115,54 @@ def rmse(ref: list, prediction: list) -> float:
114115
return mean_squared_error(ref, prediction)
115116

116117

118+
def build_density_inputs(
119+
models: list[str],
120+
model_stats: dict[str, dict[str, Any]],
121+
property_key: str,
122+
*,
123+
mae_fn: Callable[[list, list], float] | None = None,
124+
) -> dict[str, dict[str, Any]]:
125+
"""
126+
Prepare a model->data mapping for density scatter plots.
127+
128+
Parameters
129+
----------
130+
models
131+
Ordered list of model names to include.
132+
model_stats
133+
Mapping of model -> {"<property_key>": {"ref": [...], "pred": [...]},
134+
"excluded": int}.
135+
property_key
136+
Key to extract from ``model_stats`` for each model (e.g. ``"bulk"`` or
137+
``"shear"``).
138+
mae_fn
139+
Optional callable to compute MAE. Defaults to :func:`mae` when None.
140+
141+
Returns
142+
-------
143+
dict[str, dict[str, Any]]
144+
Mapping ready for ``plot_density_scatter``.
145+
"""
146+
mae_fn = mae if mae_fn is None else mae_fn
147+
inputs: dict[str, dict[str, Any]] = {}
148+
149+
for model_name in models:
150+
stats = model_stats.get(model_name, {})
151+
prop = stats.get(property_key)
152+
excluded = stats.get("excluded")
153+
154+
ref_vals = prop.get("ref", [])
155+
pred_vals = prop.get("pred", [])
156+
inputs[model_name] = {
157+
"ref": ref_vals,
158+
"pred": pred_vals,
159+
"mae": mae_fn(ref_vals, pred_vals) if ref_vals else None,
160+
"meta": {"excluded": excluded} if excluded is not None else {},
161+
}
162+
163+
return inputs
164+
165+
117166
def calc_metric_scores(
118167
metrics_data: list[MetricRow],
119168
thresholds: Thresholds | None = None,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
title: Bulk Crystal Systems
1+
title: Bulk Crystals
22
description: Bulk crystal properties, including elastic moduli, phonons, and lattice constants.

ml_peg/app/bulk_crystal/elasticity/app_elasticity.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77

88
from ml_peg.app import APP_ROOT
99
from ml_peg.app.base_app import BaseApp
10-
from ml_peg.app.utils.build_callbacks import plot_from_table_column
11-
from ml_peg.app.utils.load import read_plot
10+
from ml_peg.app.utils.build_callbacks import plot_from_table_cell
11+
from ml_peg.app.utils.load import read_density_plot_for_model
12+
from ml_peg.models.get_models import get_model_names
13+
from ml_peg.models.models import current_models
1214

15+
MODELS = get_model_names(current_models)
1316
BENCHMARK_NAME = "Elasticity"
1417
DOCS_URL = "https://ddmms.github.io/ml-peg/user_guide/benchmarks/bulk.html#elasticity"
1518
DATA_PATH = APP_ROOT / "data" / "bulk_crystal" / "elasticity"
@@ -20,22 +23,26 @@ class ElasticityApp(BaseApp):
2023

2124
def register_callbacks(self) -> None:
2225
"""Register callbacks to app."""
23-
bulk_plot = read_plot(
24-
DATA_PATH / "figure_bulk_density.json",
25-
id=f"{BENCHMARK_NAME}-bulk-figure",
26-
)
27-
shear_plot = read_plot(
28-
DATA_PATH / "figure_shear_density.json",
29-
id=f"{BENCHMARK_NAME}-shear-figure",
30-
)
26+
density_plots = {
27+
model: {
28+
"Bulk modulus MAE": read_density_plot_for_model(
29+
filename=DATA_PATH / "figure_bulk_density.json",
30+
model=model,
31+
id=f"{BENCHMARK_NAME}-{model}-bulk-figure",
32+
),
33+
"Shear modulus MAE": read_density_plot_for_model(
34+
filename=DATA_PATH / "figure_shear_density.json",
35+
model=model,
36+
id=f"{BENCHMARK_NAME}-{model}-shear-figure",
37+
),
38+
}
39+
for model in MODELS
40+
}
3141

32-
plot_from_table_column(
42+
plot_from_table_cell(
3343
table_id=self.table_id,
3444
plot_id=f"{BENCHMARK_NAME}-figure-placeholder",
35-
column_to_plot={
36-
"Bulk modulus MAE": bulk_plot,
37-
"Shear modulus MAE": shear_plot,
38-
},
45+
cell_to_plot=density_plots,
3946
)
4047

4148

0 commit comments

Comments
 (0)