Skip to content

Commit 8469e53

Browse files
committed
2025/11/06-09:30:19 (Linux sv1224 x86_64)
1 parent c1c0df7 commit 8469e53

16 files changed

+105
-50
lines changed

tests/test_backend_jax_statistics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import pytest
77
import torch_geometric
88

9-
pytest.importorskip('jax')
10-
pytest.importorskip('mace_jax')
9+
pytest.importorskip('jax', reason='JAX runtime is required for JAX backend tests.')
10+
pytest.importorskip('mace_jax', reason='MACE JAX is required for these tests.')
1111

1212
from mace_jax.data.utils import AtomicNumberTable as JaxAtomicNumberTable # noqa: E402
1313

tests/test_evaluate_mace.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from pathlib import Path
22

3+
import pytest
4+
5+
pytest.importorskip('mace', reason='MACE is required for MACE integration tests.')
6+
37
from equitrain import evaluate, get_args_parser_evaluate
48
from equitrain.utility_test import MaceWrapper
59
from equitrain.utility_test.mace_support import get_mace_model_path

tests/test_finetune_mace.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from pathlib import Path
22

3+
import pytest
34
import torch
45

6+
pytest.importorskip('mace', reason='MACE is required for MACE integration tests.')
7+
58
from equitrain import get_args_parser_train, train
69
from equitrain.checkpoint import load_checkpoint
710
from equitrain.finetune.delta_torch import DeltaFineTuneWrapper

tests/test_finetune_mace_jax.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,31 @@
88
from pathlib import Path
99
from unittest import mock
1010

11-
import jax
12-
import jax.numpy as jnp
13-
import jax.tree_util as jtu
14-
import jraph
1511
import numpy as np
1612
import pytest
1713
import torch
1814
import torch.nn.functional as F
1915
from ase import Atoms
20-
from flax import core as flax_core
21-
from flax import serialization, traverse_util
22-
from mace.data.atomic_data import AtomicData
23-
from mace.data.utils import config_from_atoms
24-
from mace.tools import torch_geometric
25-
from mace.tools.scripts_utils import extract_config_mace_model
26-
from mace_jax.cli import mace_torch2jax
27-
from mace_jax.data.utils import AtomicNumberTable as JaxAtomicNumberTable
28-
from mace_jax.data.utils import Configuration as JaxConfiguration
29-
from mace_jax.data.utils import graph_from_configuration
30-
from torch.serialization import add_safe_globals
16+
17+
pytest.importorskip('mace', reason='MACE is required for MACE JAX integration tests.')
18+
pytest.importorskip('mace_jax', reason='MACE JAX is required for these tests.')
19+
pytest.importorskip('jax', reason='JAX runtime is required for these tests.')
20+
21+
import jax # noqa: E402
22+
import jax.numpy as jnp # noqa: E402
23+
import jax.tree_util as jtu # noqa: E402
24+
import jraph # noqa: E402
25+
from flax import core as flax_core # noqa: E402
26+
from flax import serialization, traverse_util # noqa: E402
27+
from mace.data.atomic_data import AtomicData # noqa: E402
28+
from mace.data.utils import config_from_atoms # noqa: E402
29+
from mace.tools import torch_geometric # noqa: E402
30+
from mace.tools.scripts_utils import extract_config_mace_model # noqa: E402
31+
from mace_jax.cli import mace_torch2jax # noqa: E402
32+
from mace_jax.data.utils import AtomicNumberTable as JaxAtomicNumberTable # noqa: E402
33+
from mace_jax.data.utils import Configuration as JaxConfiguration # noqa: E402
34+
from mace_jax.data.utils import graph_from_configuration # noqa: E402
35+
from torch.serialization import add_safe_globals # noqa: E402
3136

3237
from equitrain import get_args_parser_train
3338
from equitrain import train as equitrain_train

tests/test_finetune_mace_readout.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import copy
22
from pathlib import Path
33

4+
import pytest
45
import torch
56

7+
pytest.importorskip('mace', reason='MACE is required for MACE integration tests.')
8+
69
from equitrain import get_args_parser_train, train
710
from equitrain.utility_test import MaceWrapper
811
from equitrain.utility_test.mace_support import get_mace_model_path

tests/test_jax_loss.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import numpy as np
77
import pytest
88

9-
pytest.importorskip('jax')
10-
import jax.numpy as jnp
11-
import jraph
9+
pytest.importorskip('jax', reason='JAX runtime is required for JAX backend tests.')
10+
11+
import jax.numpy as jnp # noqa: E402
12+
import jraph # noqa: E402
1213

1314
from equitrain.backends.jax_loss_fn import LossSettings, build_loss_fn
1415

tests/test_jax_model_equivalence.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,23 @@
22

33
import warnings
44

5-
import jax
65
import numpy as np
76
import pytest
87
import torch
9-
from flax import serialization
10-
from mace.data.atomic_data import AtomicData
11-
from mace.data.utils import config_from_atoms
12-
from mace.tools import torch_geometric
13-
from mace.tools.scripts_utils import extract_config_mace_model
14-
from mace_jax.cli import mace_torch2jax
15-
16-
from equitrain.utility_test.mace_support import (
8+
9+
pytest.importorskip('mace', reason='MACE is required for MACE JAX integration tests.')
10+
pytest.importorskip('mace_jax', reason='MACE JAX is required for these tests.')
11+
pytest.importorskip('jax', reason='JAX runtime is required for these tests.')
12+
13+
import jax # noqa: E402
14+
from flax import serialization # noqa: E402
15+
from mace.data.atomic_data import AtomicData # noqa: E402
16+
from mace.data.utils import config_from_atoms # noqa: E402
17+
from mace.tools import torch_geometric # noqa: E402
18+
from mace.tools.scripts_utils import extract_config_mace_model # noqa: E402
19+
from mace_jax.cli import mace_torch2jax # noqa: E402
20+
21+
from equitrain.utility_test.mace_support import ( # noqa: E402
1722
build_statistics,
1823
build_structures,
1924
create_model_args,

tests/test_predict_ani_atoms.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
"""
66

77
import numpy as np
8+
import pytest
89
import torch
910
from ase import Atoms
1011

12+
pytest.importorskip(
13+
'torchani', reason='TorchANI is required for ANI integration tests.'
14+
)
15+
1116
from equitrain import get_args_parser_predict, predict_atoms
1217
from equitrain.data.atomic import AtomicNumberTable
1318
from equitrain.utility_test import AniWrapper

tests/test_predict_mace.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from pathlib import Path
22

3+
import pytest
4+
5+
pytest.importorskip('mace', reason='MACE is required for MACE integration tests.')
6+
37
from equitrain import get_args_parser_predict, predict
48
from equitrain.utility_test import MaceWrapper
59
from equitrain.utility_test.mace_support import get_mace_model_path

tests/test_predict_mace_atoms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from pathlib import Path
22

33
import ase.io
4+
import pytest
5+
6+
pytest.importorskip('mace', reason='MACE is required for MACE integration tests.')
47

58
from equitrain import get_args_parser_predict, predict_atoms
69
from equitrain.backends.torch_utils import set_dtype

0 commit comments

Comments
 (0)