|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +from pathlib import Path |
| 6 | + |
5 | 7 | from ase import units |
6 | 8 | from ase.io import read |
7 | 9 | import numpy as np |
8 | 10 | from numpy.typing import NDArray |
9 | 11 | import pytest |
10 | 12 |
|
11 | 13 | from ml_peg.analysis.utils.decorators import build_table, plot_parity |
| 14 | +from ml_peg.analysis.utils.utils import load_metrics_config |
12 | 15 | from ml_peg.app import APP_ROOT |
13 | 16 | from ml_peg.calcs import CALCS_ROOT |
14 | 17 | from ml_peg.models.get_models import get_model_names |
|
18 | 21 | CALC_PATH = CALCS_ROOT / "molecular" / "GMTKN55" / "outputs" |
19 | 22 | OUT_PATH = APP_ROOT / "data" / "molecular" / "GMTKN55" |
20 | 23 |
|
21 | | -DEFAULT_WEIGHTS = { |
22 | | - "Small systems": 0, |
23 | | - "Large systems": 0, |
24 | | - "Barrier heights": 0, |
25 | | - "Intramolecular NCIs": 0, |
26 | | - "Intermolecular NCIs": 0, |
27 | | - "WTMAD": 1, |
28 | | -} |
29 | | - |
30 | | -DEFAULT_THRESHOLDS = { |
31 | | - "Small systems": (0.5, 50.0), |
32 | | - "Large systems": (0.5, 50.0), |
33 | | - "Barrier heights": (0.5, 50.0), |
34 | | - "Intramolecular NCIs": (0.5, 50.0), |
35 | | - "Intermolecular NCIs": (0.5, 50.0), |
36 | | - "WTMAD": (0.5, 50.0), |
37 | | -} |
| 24 | +METRICS_CONFIG_PATH = Path(__file__).with_name("metrics.yml") |
| 25 | +DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, DEFAULT_WEIGHTS = load_metrics_config( |
| 26 | + METRICS_CONFIG_PATH |
| 27 | +) |
38 | 28 |
|
39 | 29 | # Unit conversion |
40 | 30 | EV_TO_KCAL_PER_MOL = units.mol / units.kcal |
@@ -290,15 +280,7 @@ def weighted_error(subset_errors: dict[str, dict[str, float]]) -> dict[str, floa |
290 | 280 | @pytest.fixture |
291 | 281 | @build_table( |
292 | 282 | filename=OUT_PATH / "gmtkn55_metrics_table.json", |
293 | | - metric_tooltips={ |
294 | | - "Model": "Name of the model", |
295 | | - "Small systems": "Weighted Mean Absolute Deviation (kcal/mol)", |
296 | | - "Large systems": "Weighted Mean Absolute Deviation (kcal/mol)", |
297 | | - "Barrier heights": "Weighted Mean Absolute Deviation (kcal/mol)", |
298 | | - "Intramolecular NCIs": "Weighted Mean Absolute Deviation (kcal/mol)", |
299 | | - "Intermolecular NCIs": "Weighted Mean Absolute Deviation (kcal/mol)", |
300 | | - "WTMAD": "Total Weighted Mean Absolute Deviation (kcal/mol)", |
301 | | - }, |
| 283 | + metric_tooltips=DEFAULT_TOOLTIPS, |
302 | 284 | thresholds=DEFAULT_THRESHOLDS, |
303 | 285 | weights=DEFAULT_WEIGHTS, |
304 | 286 | ) |
|
0 commit comments