37
37
from .util import unique_rows
38
38
39
39
from MDAnalysis .lib .distances import apply_PBC
40
+ import numpy .typing as npt
41
+ from typing import Optional , ClassVar
40
42
41
43
__all__ = [
42
44
'PeriodicKDTree'
@@ -61,7 +63,8 @@ class PeriodicKDTree(object):
61
63
:func:`MDAnalysis.lib.distances.undo_augment` function.
62
64
63
65
"""
64
- def __init__ (self , box = None , leafsize = 10 ):
66
+
67
+ def __init__ (self , box : npt .ArrayLike = None , leafsize : int = 10 ) -> None :
65
68
"""
66
69
67
70
Parameters
@@ -82,7 +85,7 @@ def __init__(self, box=None, leafsize=10):
82
85
self .dim = 3 # 3D systems
83
86
self .box = box
84
87
self ._built = False
85
- self .cutoff = None
88
+ self .cutoff : Optional [ float ] = None
86
89
87
90
@property
88
91
def pbc (self ):
@@ -95,7 +98,7 @@ def pbc(self):
95
98
"""
96
99
return self .box is not None
97
100
98
- def set_coords (self , coords , cutoff = None ):
101
+ def set_coords (self , coords : npt . ArrayLike , cutoff : Optional [ float ] = None ) -> None :
99
102
"""Constructs KDTree from the coordinates
100
103
101
104
Wrapping of coordinates to the primary unit cell is enforced
@@ -126,23 +129,24 @@ def set_coords(self, coords, cutoff=None):
126
129
MDAnalysis.lib.distances.augment_coordinates
127
130
128
131
"""
129
- # If no cutoff distance is provided but PBC aware
130
- if self .pbc and (cutoff is None ):
131
- raise RuntimeError ('Provide a cutoff distance'
132
- ' with tree.set_coords(...)' )
133
132
134
133
# set coords dtype to float32
135
134
# augment coordinates will work only with float32
136
135
coords = np .asarray (coords , dtype = np .float32 )
137
136
137
+ # If no cutoff distance is provided but PBC aware
138
138
if self .pbc :
139
139
self .cutoff = cutoff
140
+ if cutoff is None :
141
+ raise RuntimeError ('Provide a cutoff distance'
142
+ ' with tree.set_coords(...)' )
143
+
140
144
# Bring the coordinates in the central cell
141
145
self .coords = apply_PBC (coords , self .box )
142
146
# generate duplicate images
143
147
self .aug , self .mapping = augment_coordinates (self .coords ,
144
148
self .box ,
145
- self . cutoff )
149
+ cutoff )
146
150
# Images + coords
147
151
self .all_coords = np .concatenate ([self .coords , self .aug ])
148
152
self .ckdt = cKDTree (self .all_coords , leafsize = self .leafsize )
@@ -155,7 +159,8 @@ def set_coords(self, coords, cutoff=None):
155
159
self .ckdt = cKDTree (self .coords , self .leafsize )
156
160
self ._built = True
157
161
158
- def search (self , centers , radius ):
162
+ # typing: numpy
163
+ def search (self , centers : npt .ArrayLike , radius : float ) -> np .ndarray :
159
164
"""Search all points within radius from centers and their periodic images.
160
165
161
166
All the centers coordinates are wrapped around the central cell
@@ -179,6 +184,9 @@ def search(self, centers, radius):
179
184
180
185
# Sanity check
181
186
if self .pbc :
187
+ if self .cutoff is None :
188
+ raise ValueError (
189
+ "Cutoff needs to be provided when working with PBC." )
182
190
if self .cutoff < radius :
183
191
raise RuntimeError ('Set cutoff greater or equal to the radius.' )
184
192
# Bring all query points to the central cell
@@ -202,17 +210,19 @@ def search(self, centers, radius):
202
210
self ._indices = np .asarray (unique_int_1d (self ._indices ))
203
211
return self ._indices
204
212
205
- def get_indices (self ):
213
+ # typing: numpy
214
+ def get_indices (self ) -> np .ndarray :
206
215
"""Return the neighbors from the last query.
207
216
208
217
Returns
209
218
------
210
- indices : list
219
+ indices : NDArray
211
220
neighbors for the last query points and search radius
212
221
"""
213
222
return self ._indices
214
223
215
- def search_pairs (self , radius ):
224
+ # typing: numpy
225
+ def search_pairs (self , radius : float ) -> np .ndarray :
216
226
"""Search all the pairs within a specified radius
217
227
218
228
Parameters
@@ -229,6 +239,9 @@ def search_pairs(self, radius):
229
239
raise RuntimeError (' Unbuilt Tree. Run tree.set_coords(...)' )
230
240
231
241
if self .pbc :
242
+ if self .cutoff is None :
243
+ raise ValueError (
244
+ "Cutoff needs to be provided when working with PBC." )
232
245
if self .cutoff < radius :
233
246
raise RuntimeError ('Set cutoff greater or equal to the radius.' )
234
247
@@ -245,7 +258,7 @@ def search_pairs(self, radius):
245
258
pairs = unique_rows (pairs )
246
259
return pairs
247
260
248
- def search_tree (self , centers , radius ) :
261
+ def search_tree (self , centers : npt . ArrayLike , radius : float ) -> np . ndarray :
249
262
"""
250
263
Searches all the pairs within `radius` between `centers`
251
264
and ``coords``
@@ -285,6 +298,9 @@ class initialization
285
298
286
299
# Sanity check
287
300
if self .pbc :
301
+ if self .cutoff is None :
302
+ raise ValueError (
303
+ "Cutoff needs to be provided when working with PBC." )
288
304
if self .cutoff < radius :
289
305
raise RuntimeError ('Set cutoff greater or equal to the radius.' )
290
306
# Bring all query points to the central cell
0 commit comments