Skip to content

Commit e7ee5a4

Browse files
umak1106jbarnoud
andauthored
Type hints for lib module (#3729)
* Type hints for mdamath.py * fix errors in mdamath.py * changed input from NDarray to arraylike * Added type annotations to init.py Added type annotations to init to avoid mypy from raising errors when other modules are type checked . * Allowing mypy to type check lib module * Update changes in init.py * Update __init__.py * type hints for pkdtree * Fix all errros in pkdtree.py * Chage npt.NDArray to np.ndarray * Update pkdtree.py * Update pkdtree.py * Update pkdtree.py * Update NeighborSearch.py * Update NeighborSearch.py * Update pkdtree.py Co-authored-by: Jonathan Barnoud <[email protected]>
1 parent 0788165 commit e7ee5a4

File tree

5 files changed

+62
-37
lines changed

5 files changed

+62
-37
lines changed

mypy.ini

-8
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ ignore_errors = True
1818
[mypy-MDAnalysis.core.*]
1919
ignore_errors = True
2020

21-
[mypy-MDAnalysis.lib.*]
22-
ignore_errors = True
23-
2421
[mypy-MDAnalysis.selections.*]
2522
ignore_errors = True
2623

@@ -47,8 +44,3 @@ ignore_errors = True
4744

4845
[mypy-MDAnalysis.version]
4946
ignore_errors = True
50-
51-
[mypy-MDAnalysis.*]
52-
ignore_errors = True
53-
54-

package/MDAnalysis/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@
179179
_CONVERTERS: Dict = {}
180180
# Registry of TopologyAttributes
181181
_TOPOLOGY_ATTRS: Dict = {} # {attrname: cls}
182-
_TOPOLOGY_TRANSPLANTS: Dict = {}
183-
# {name: [attrname, method, transplant class]}
182+
_TOPOLOGY_TRANSPLANTS: Dict = {} # {name: [attrname, method, transplant class]}
184183
_TOPOLOGY_ATTRNAMES: Dict = {} # {lower case name w/o _ : name}
185184

186185

package/MDAnalysis/lib/NeighborSearch.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
import numpy as np
3232
from MDAnalysis.lib.distances import capped_distance
3333
from MDAnalysis.lib.util import unique_int_1d
34+
from MDAnalysis.core.groups import AtomGroup, SegmentGroup, ResidueGroup
35+
import numpy.typing as npt
36+
from typing import Optional, Union, List
3437

3538

3639
class AtomNeighborSearch(object):
@@ -41,7 +44,8 @@ class AtomNeighborSearch(object):
4144
:class:`~MDAnalysis.lib.distances.capped_distance`.
4245
"""
4346

44-
def __init__(self, atom_group, box=None):
47+
def __init__(self, atom_group: AtomGroup,
48+
box: Optional[npt.ArrayLike] = None) -> None:
4549
"""
4650
4751
Parameters
@@ -58,7 +62,10 @@ def __init__(self, atom_group, box=None):
5862
self._u = atom_group.universe
5963
self._box = box
6064

61-
def search(self, atoms, radius, level='A'):
65+
def search(self, atoms: AtomGroup,
66+
radius: float,
67+
level: str = 'A'
68+
) -> Optional[Union[AtomGroup, ResidueGroup, SegmentGroup]]:
6269
"""
6370
Return all atoms/residues/segments that are within *radius* of the
6471
atoms in *atoms*.
@@ -102,7 +109,10 @@ def search(self, atoms, radius, level='A'):
102109
unique_idx = unique_int_1d(np.asarray(pairs[:, 1], dtype=np.intp))
103110
return self._index2level(unique_idx, level)
104111

105-
def _index2level(self, indices, level):
112+
def _index2level(self,
113+
indices: List[int],
114+
level: str
115+
) -> Union[AtomGroup, ResidueGroup, SegmentGroup]:
106116
"""Convert list of atom_indices in a AtomGroup to either the
107117
Atoms or segments/residues containing these atoms.
108118

package/MDAnalysis/lib/mdamath.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,13 @@
6363
from . import util
6464
from ._cutil import (make_whole, find_fragments, _sarrus_det_single,
6565
_sarrus_det_multiple)
66+
import numpy.typing as npt
67+
from typing import Union
6668

6769
# geometric functions
6870

6971

70-
def norm(v):
72+
def norm(v: npt.ArrayLike) -> float:
7173
r"""Calculate the norm of a vector v.
7274
7375
.. math:: v = \sqrt{\mathbf{v}\cdot\mathbf{v}}
@@ -90,7 +92,8 @@ def norm(v):
9092
return np.sqrt(np.dot(v, v))
9193

9294

93-
def normal(vec1, vec2):
95+
# typing: numpy
96+
def normal(vec1: npt.ArrayLike, vec2: npt.ArrayLike) -> np.ndarray:
9497
r"""Returns the unit vector normal to two vectors.
9598
9699
.. math::
@@ -110,7 +113,8 @@ def normal(vec1, vec2):
110113
return normal / n
111114

112115

113-
def pdot(a, b):
116+
# typing: numpy
117+
def pdot(a: npt.ArrayLike, b: npt.ArrayLike) -> np.ndarray:
114118
"""Pairwise dot product.
115119
116120
``a`` must be the same shape as ``b``.
@@ -127,7 +131,8 @@ def pdot(a, b):
127131
return np.einsum('ij,ij->i', a, b)
128132

129133

130-
def pnorm(a):
134+
# typing: numpy
135+
def pnorm(a: npt.ArrayLike) -> np.ndarray:
131136
"""Euclidean norm of each vector in a matrix
132137
133138
Parameters
@@ -141,7 +146,7 @@ def pnorm(a):
141146
return pdot(a, a)**0.5
142147

143148

144-
def angle(a, b):
149+
def angle(a: npt.ArrayLike, b: npt.ArrayLike) -> float:
145150
"""Returns the angle between two vectors in radians
146151
147152
.. versionchanged:: 0.11.0
@@ -156,7 +161,7 @@ def angle(a, b):
156161
return np.arccos(x)
157162

158163

159-
def stp(vec1, vec2, vec3):
164+
def stp(vec1: npt.ArrayLike, vec2: npt.ArrayLike, vec3: npt.ArrayLike) -> float:
160165
r"""Takes the scalar triple product of three vectors.
161166
162167
Returns the volume *V* of the parallel epiped spanned by the three
@@ -172,7 +177,7 @@ def stp(vec1, vec2, vec3):
172177
return np.dot(vec3, np.cross(vec1, vec2))
173178

174179

175-
def dihedral(ab, bc, cd):
180+
def dihedral(ab: npt.ArrayLike, bc: npt.ArrayLike, cd: npt.ArrayLike) -> float:
176181
r"""Returns the dihedral angle in radians between vectors connecting A,B,C,D.
177182
178183
The dihedral measures the rotation around bc::
@@ -194,7 +199,8 @@ def dihedral(ab, bc, cd):
194199
return (x if stp(ab, bc, cd) <= 0.0 else -x)
195200

196201

197-
def sarrus_det(matrix):
202+
# typing: numpy
203+
def sarrus_det(matrix: np.ndarray) -> Union[float, np.ndarray]:
198204
"""Computes the determinant of a 3x3 matrix according to the
199205
`rule of Sarrus`_.
200206
@@ -239,7 +245,8 @@ def sarrus_det(matrix):
239245
return _sarrus_det_multiple(m.reshape((-1, 3, 3))).reshape(shape[:-2])
240246

241247

242-
def triclinic_box(x, y, z):
248+
# typing: numpy
249+
def triclinic_box(x: npt.ArrayLike, y: npt.ArrayLike, z: npt.ArrayLike) -> np.ndarray:
243250
"""Convert the three triclinic box vectors to
244251
``[lx, ly, lz, alpha, beta, gamma]``.
245252
@@ -301,7 +308,8 @@ def triclinic_box(x, y, z):
301308
return np.zeros(6, dtype=np.float32)
302309

303310

304-
def triclinic_vectors(dimensions, dtype=np.float32):
311+
# typing: numpy
312+
def triclinic_vectors(dimensions: npt.ArrayLike, dtype: npt.DTypeLike = np.float32) -> np.ndarray:
305313
"""Convert ``[lx, ly, lz, alpha, beta, gamma]`` to a triclinic matrix
306314
representation.
307315
@@ -399,7 +407,7 @@ def triclinic_vectors(dimensions, dtype=np.float32):
399407
return box_matrix
400408

401409

402-
def box_volume(dimensions):
410+
def box_volume(dimensions: npt.ArrayLike) -> float:
403411
"""Return the volume of the unitcell described by `dimensions`.
404412
405413
The volume is computed as the product of the box matrix trace, with the

package/MDAnalysis/lib/pkdtree.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
from .util import unique_rows
3838

3939
from MDAnalysis.lib.distances import apply_PBC
40+
import numpy.typing as npt
41+
from typing import Optional, ClassVar
4042

4143
__all__ = [
4244
'PeriodicKDTree'
@@ -61,7 +63,8 @@ class PeriodicKDTree(object):
6163
:func:`MDAnalysis.lib.distances.undo_augment` function.
6264
6365
"""
64-
def __init__(self, box=None, leafsize=10):
66+
67+
def __init__(self, box: npt.ArrayLike = None, leafsize: int = 10) -> None:
6568
"""
6669
6770
Parameters
@@ -82,7 +85,7 @@ def __init__(self, box=None, leafsize=10):
8285
self.dim = 3 # 3D systems
8386
self.box = box
8487
self._built = False
85-
self.cutoff = None
88+
self.cutoff: Optional[float] = None
8689

8790
@property
8891
def pbc(self):
@@ -95,7 +98,7 @@ def pbc(self):
9598
"""
9699
return self.box is not None
97100

98-
def set_coords(self, coords, cutoff=None):
101+
def set_coords(self, coords: npt.ArrayLike, cutoff: Optional[float] = None) -> None:
99102
"""Constructs KDTree from the coordinates
100103
101104
Wrapping of coordinates to the primary unit cell is enforced
@@ -126,23 +129,24 @@ def set_coords(self, coords, cutoff=None):
126129
MDAnalysis.lib.distances.augment_coordinates
127130
128131
"""
129-
# If no cutoff distance is provided but PBC aware
130-
if self.pbc and (cutoff is None):
131-
raise RuntimeError('Provide a cutoff distance'
132-
' with tree.set_coords(...)')
133132

134133
# set coords dtype to float32
135134
# augment coordinates will work only with float32
136135
coords = np.asarray(coords, dtype=np.float32)
137136

137+
# If no cutoff distance is provided but PBC aware
138138
if self.pbc:
139139
self.cutoff = cutoff
140+
if cutoff is None:
141+
raise RuntimeError('Provide a cutoff distance'
142+
' with tree.set_coords(...)')
143+
140144
# Bring the coordinates in the central cell
141145
self.coords = apply_PBC(coords, self.box)
142146
# generate duplicate images
143147
self.aug, self.mapping = augment_coordinates(self.coords,
144148
self.box,
145-
self.cutoff)
149+
cutoff)
146150
# Images + coords
147151
self.all_coords = np.concatenate([self.coords, self.aug])
148152
self.ckdt = cKDTree(self.all_coords, leafsize=self.leafsize)
@@ -155,7 +159,8 @@ def set_coords(self, coords, cutoff=None):
155159
self.ckdt = cKDTree(self.coords, self.leafsize)
156160
self._built = True
157161

158-
def search(self, centers, radius):
162+
# typing: numpy
163+
def search(self, centers: npt.ArrayLike, radius: float) -> np.ndarray:
159164
"""Search all points within radius from centers and their periodic images.
160165
161166
All the centers coordinates are wrapped around the central cell
@@ -179,6 +184,9 @@ def search(self, centers, radius):
179184

180185
# Sanity check
181186
if self.pbc:
187+
if self.cutoff is None:
188+
raise ValueError(
189+
"Cutoff needs to be provided when working with PBC.")
182190
if self.cutoff < radius:
183191
raise RuntimeError('Set cutoff greater or equal to the radius.')
184192
# Bring all query points to the central cell
@@ -202,17 +210,19 @@ def search(self, centers, radius):
202210
self._indices = np.asarray(unique_int_1d(self._indices))
203211
return self._indices
204212

205-
def get_indices(self):
213+
# typing: numpy
214+
def get_indices(self) -> np.ndarray:
206215
"""Return the neighbors from the last query.
207216
208217
Returns
209218
------
210-
indices : list
219+
indices : NDArray
211220
neighbors for the last query points and search radius
212221
"""
213222
return self._indices
214223

215-
def search_pairs(self, radius):
224+
# typing: numpy
225+
def search_pairs(self, radius: float) -> np.ndarray:
216226
"""Search all the pairs within a specified radius
217227
218228
Parameters
@@ -229,6 +239,9 @@ def search_pairs(self, radius):
229239
raise RuntimeError(' Unbuilt Tree. Run tree.set_coords(...)')
230240

231241
if self.pbc:
242+
if self.cutoff is None:
243+
raise ValueError(
244+
"Cutoff needs to be provided when working with PBC.")
232245
if self.cutoff < radius:
233246
raise RuntimeError('Set cutoff greater or equal to the radius.')
234247

@@ -245,7 +258,7 @@ def search_pairs(self, radius):
245258
pairs = unique_rows(pairs)
246259
return pairs
247260

248-
def search_tree(self, centers, radius):
261+
def search_tree(self, centers: npt.ArrayLike, radius: float) -> np.ndarray:
249262
"""
250263
Searches all the pairs within `radius` between `centers`
251264
and ``coords``
@@ -285,6 +298,9 @@ class initialization
285298

286299
# Sanity check
287300
if self.pbc:
301+
if self.cutoff is None:
302+
raise ValueError(
303+
"Cutoff needs to be provided when working with PBC.")
288304
if self.cutoff < radius:
289305
raise RuntimeError('Set cutoff greater or equal to the radius.')
290306
# Bring all query points to the central cell

0 commit comments

Comments
 (0)