Skip to content

Commit 305b57f

Browse files
authored
Improve type hints in io.ase (#4556)
* type hinting in io/ase * type hinting in io/ase
1 parent 9808204 commit 305b57f

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

src/pymatgen/io/ase.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import warnings
99
from copy import deepcopy
1010
from importlib.metadata import PackageNotFoundError
11-
from typing import TYPE_CHECKING, TypeVar
11+
from typing import TYPE_CHECKING, Literal, TypeVar, overload
1212

1313
import numpy as np
1414
from monty.json import MontyDecoder, MSONable, jsanitize
@@ -48,8 +48,8 @@ def __init__(self, *args, **kwargs):
4848
__email__ = "[email protected]"
4949
__date__ = "Mar 8, 2012"
5050

51-
StructT = TypeVar("StructT", bound=IStructure | IMolecule | Structure | Molecule)
52-
MolT = TypeVar("MolT", bound=IMolecule)
51+
IMoleculeT = TypeVar("IMoleculeT", bound=IMolecule)
52+
StructOrMolT = TypeVar("StructOrMolT", bound=Structure | Molecule)
5353

5454

5555
class MSONAtoms(Atoms, MSONable):
@@ -86,6 +86,30 @@ def from_dict(cls, dct: dict[str, Any]) -> Self:
8686
class AseAtomsAdaptor:
8787
"""Adaptor serves as a bridge between ASE Atoms and pymatgen objects."""
8888

89+
@overload
90+
@staticmethod
91+
def get_atoms(
92+
structure: SiteCollection,
93+
msonable: Literal[True] = ...,
94+
**kwargs: Any,
95+
) -> MSONAtoms: ...
96+
97+
@overload
98+
@staticmethod
99+
def get_atoms(
100+
structure: SiteCollection,
101+
msonable: Literal[False],
102+
**kwargs: Any,
103+
) -> Atoms: ...
104+
105+
@overload
106+
@staticmethod
107+
def get_atoms(
108+
structure: SiteCollection,
109+
msonable: bool = True,
110+
**kwargs: Any,
111+
) -> MSONAtoms | Atoms: ...
112+
89113
@staticmethod
90114
def get_atoms(
91115
structure: SiteCollection,
@@ -239,9 +263,9 @@ def get_atoms(
239263
@staticmethod
240264
def get_structure(
241265
atoms: Atoms,
242-
cls=Structure,
266+
cls: type[StructOrMolT] = Structure,
243267
**cls_kwargs,
244-
) -> Structure | Molecule:
268+
) -> StructOrMolT:
245269
"""Get pymatgen structure from ASE Atoms.
246270
247271
Args:
@@ -392,7 +416,7 @@ def get_structure(
392416
return structure
393417

394418
@staticmethod
395-
def get_molecule(atoms: Atoms, cls: type[MolT] = Molecule, **cls_kwargs) -> Molecule | IMolecule: # type:ignore[assignment]
419+
def get_molecule(atoms: Atoms, cls: type[IMoleculeT] = Molecule, **cls_kwargs) -> IMoleculeT:
396420
"""Get pymatgen molecule from ASE Atoms.
397421
398422
Args:
@@ -401,7 +425,7 @@ def get_molecule(atoms: Atoms, cls: type[MolT] = Molecule, **cls_kwargs) -> Mole
401425
**cls_kwargs: Any additional kwargs to pass to the cls constructor
402426
403427
Returns:
404-
(I)Molecule: Equivalent pymatgen (I)Molecule
428+
MolT: Equivalent pymatgen (I)Molecule
405429
"""
406430
molecule = AseAtomsAdaptor.get_structure(atoms, cls=cls, **cls_kwargs)
407431

0 commit comments

Comments
 (0)