|
8 | 8 | from pathlib import Path |
9 | 9 | from unittest import mock |
10 | 10 |
|
11 | | -import jax |
12 | | -import jax.numpy as jnp |
13 | | -import jax.tree_util as jtu |
14 | | -import jraph |
15 | 11 | import numpy as np |
16 | 12 | import pytest |
17 | 13 | import torch |
18 | 14 | import torch.nn.functional as F |
19 | 15 | from ase import Atoms |
20 | | -from flax import core as flax_core |
21 | | -from flax import serialization, traverse_util |
22 | | -from mace.data.atomic_data import AtomicData |
23 | | -from mace.data.utils import config_from_atoms |
24 | | -from mace.tools import torch_geometric |
25 | | -from mace.tools.scripts_utils import extract_config_mace_model |
26 | | -from mace_jax.cli import mace_torch2jax |
27 | | -from mace_jax.data.utils import AtomicNumberTable as JaxAtomicNumberTable |
28 | | -from mace_jax.data.utils import Configuration as JaxConfiguration |
29 | | -from mace_jax.data.utils import graph_from_configuration |
30 | | -from torch.serialization import add_safe_globals |
| 16 | + |
| 17 | +pytest.importorskip('mace', reason='MACE is required for MACE JAX integration tests.') |
| 18 | +pytest.importorskip('mace_jax', reason='MACE JAX is required for these tests.') |
| 19 | +pytest.importorskip('jax', reason='JAX runtime is required for these tests.') |
| 20 | + |
| 21 | +import jax # noqa: E402 |
| 22 | +import jax.numpy as jnp # noqa: E402 |
| 23 | +import jax.tree_util as jtu # noqa: E402 |
| 24 | +import jraph # noqa: E402 |
| 25 | +from flax import core as flax_core # noqa: E402 |
| 26 | +from flax import serialization, traverse_util # noqa: E402 |
| 27 | +from mace.data.atomic_data import AtomicData # noqa: E402 |
| 28 | +from mace.data.utils import config_from_atoms # noqa: E402 |
| 29 | +from mace.tools import torch_geometric # noqa: E402 |
| 30 | +from mace.tools.scripts_utils import extract_config_mace_model # noqa: E402 |
| 31 | +from mace_jax.cli import mace_torch2jax # noqa: E402 |
| 32 | +from mace_jax.data.utils import AtomicNumberTable as JaxAtomicNumberTable # noqa: E402 |
| 33 | +from mace_jax.data.utils import Configuration as JaxConfiguration # noqa: E402 |
| 34 | +from mace_jax.data.utils import graph_from_configuration # noqa: E402 |
| 35 | +from torch.serialization import add_safe_globals # noqa: E402 |
31 | 36 |
|
32 | 37 | from equitrain import get_args_parser_train |
33 | 38 | from equitrain import train as equitrain_train |
|
0 commit comments