Skip to content

Commit 0a06e9a

Browse files
authored
Merge pull request astropy#16277 from eerovaher/coord-matching-tests
Rewrite an overlong coordinate matching test function
2 parents aa4dfa7 + 4947a2f commit 0a06e9a

File tree

1 file changed

+146
-104
lines changed

1 file changed

+146
-104
lines changed

astropy/coordinates/tests/test_matching.py

Lines changed: 146 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,34 @@
55
from numpy import testing as npt
66

77
from astropy import units as u
8-
from astropy.coordinates import matching
8+
from astropy.coordinates import (
9+
ICRS,
10+
Angle,
11+
CartesianRepresentation,
12+
Galactic,
13+
SkyCoord,
14+
match_coordinates_3d,
15+
match_coordinates_sky,
16+
search_around_3d,
17+
search_around_sky,
18+
)
919
from astropy.tests.helper import assert_quantity_allclose as assert_allclose
20+
from astropy.utils import NumpyRNGContext
1021
from astropy.utils.compat.optional_deps import HAS_SCIPY
1122

1223
"""
1324
These are the tests for coordinate matching.
1425
15-
Note that this requires scipy.
26+
Coordinate matching can involve caching, so it is best to recreate the
27+
coordinate objects in every test instead of trying to reuse module-level
28+
variables.
1629
"""
1730

31+
if not HAS_SCIPY:
32+
pytest.skip("Coordinate matching requires scipy", allow_module_level=True)
1833

19-
@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.")
20-
def test_matching_function():
21-
from astropy.coordinates import ICRS
22-
from astropy.coordinates.matching import match_coordinates_3d
2334

35+
def test_matching_function():
2436
# this only uses match_coordinates_3d because that's the actual implementation
2537

2638
cmatch = ICRS([4, 2.1] * u.degree, [0, 0] * u.degree)
@@ -37,11 +49,7 @@ def test_matching_function():
3749
npt.assert_array_less(d3d.value, 0.02)
3850

3951

40-
@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.")
4152
def test_matching_function_3d_and_sky():
42-
from astropy.coordinates import ICRS
43-
from astropy.coordinates.matching import match_coordinates_3d, match_coordinates_sky
44-
4553
cmatch = ICRS([4, 2.1] * u.degree, [0, 0] * u.degree, distance=[1, 5] * u.kpc)
4654
ccatalog = ICRS(
4755
[1, 2, 3, 4] * u.degree, [0, 0, 0, 0] * u.degree, distance=[1, 1, 1, 5] * u.kpc
@@ -64,16 +72,13 @@ def test_matching_function_3d_and_sky():
6472
@pytest.mark.parametrize(
6573
"functocheck, args, defaultkdtname, bothsaved",
6674
[
67-
(matching.match_coordinates_3d, [], "kdtree_3d", False),
68-
(matching.match_coordinates_sky, [], "kdtree_sky", False),
69-
(matching.search_around_3d, [1 * u.kpc], "kdtree_3d", True),
70-
(matching.search_around_sky, [1 * u.deg], "kdtree_sky", False),
75+
(match_coordinates_3d, [], "kdtree_3d", False),
76+
(match_coordinates_sky, [], "kdtree_sky", False),
77+
(search_around_3d, [1 * u.kpc], "kdtree_3d", True),
78+
(search_around_sky, [1 * u.deg], "kdtree_sky", False),
7179
],
7280
)
73-
@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.")
7481
def test_kdtree_storage(functocheck, args, defaultkdtname, bothsaved):
75-
from astropy.coordinates import ICRS
76-
7782
def make_scs():
7883
cmatch = ICRS([4, 2.1] * u.degree, [0, 0] * u.degree, distance=[1, 2] * u.kpc)
7984
ccatalog = ICRS(
@@ -119,12 +124,7 @@ def make_scs():
119124
assert "KD" in e.value.args[0]
120125

121126

122-
@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.")
123127
def test_matching_method():
124-
from astropy.coordinates import ICRS, SkyCoord
125-
from astropy.coordinates.matching import match_coordinates_3d, match_coordinates_sky
126-
from astropy.utils import NumpyRNGContext
127-
128128
with NumpyRNGContext(987654321):
129129
cmatch = ICRS(
130130
np.random.rand(20) * 360.0 * u.degree,
@@ -153,98 +153,150 @@ def test_matching_method():
153153
assert len(idx1) == len(d2d1) == len(d3d1) == 20
154154

155155

156-
@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
157-
def test_search_around():
158-
from astropy.coordinates import ICRS, SkyCoord
159-
from astropy.coordinates.matching import search_around_3d, search_around_sky
160-
161-
coo1 = ICRS([4, 2.1] * u.degree, [0, 0] * u.degree, distance=[1, 5] * u.kpc)
162-
coo2 = ICRS(
163-
[1, 2, 3, 4] * u.degree, [0, 0, 0, 0] * u.degree, distance=[1, 1, 1, 5] * u.kpc
164-
)
165-
166-
idx1_1deg, idx2_1deg, d2d_1deg, d3d_1deg = search_around_sky(
167-
coo1, coo2, 1.01 * u.deg
168-
)
169-
idx1_0p05deg, idx2_0p05deg, d2d_0p05deg, d3d_0p05deg = search_around_sky(
170-
coo1, coo2, 0.05 * u.deg
156+
@pytest.mark.parametrize(
157+
"search_limit,expected_idx1,expected_idx2,expected_d2d,expected_d3d",
158+
[
159+
pytest.param(
160+
1.01 * u.deg,
161+
[0, 0, 1, 1],
162+
[2, 3, 1, 2],
163+
[1, 0, 0.1, 0.9] * u.deg,
164+
[0.01745307, 4.0, 4.0000019, 4.00015421] * u.kpc,
165+
id="1.01_deg",
166+
),
167+
pytest.param(0.05 * u.deg, [0], [3], [0] * u.deg, [4] * u.kpc, id="0.05_deg"),
168+
],
169+
)
170+
def test_search_around_sky(
171+
search_limit, expected_idx1, expected_idx2, expected_d2d, expected_d3d
172+
):
173+
idx1, idx2, d2d, d3d = search_around_sky(
174+
ICRS([4, 2.1] * u.deg, [0, 0] * u.deg, distance=[1, 5] * u.kpc),
175+
ICRS([1, 2, 3, 4] * u.deg, [0, 0, 0, 0] * u.deg, distance=[1, 1, 1, 5] * u.kpc),
176+
search_limit,
171177
)
178+
npt.assert_array_equal(idx1, expected_idx1)
179+
npt.assert_array_equal(idx2, expected_idx2)
180+
assert_allclose(d2d, expected_d2d)
181+
assert_allclose(d3d, expected_d3d)
172182

173-
assert list(zip(idx1_1deg, idx2_1deg)) == [(0, 2), (0, 3), (1, 1), (1, 2)]
174-
assert_allclose(d2d_1deg[0], 1.0 * u.deg, atol=1e-14 * u.deg, rtol=0)
175-
assert_allclose(d2d_1deg, [1, 0, 0.1, 0.9] * u.deg)
176183

177-
assert list(zip(idx1_0p05deg, idx2_0p05deg)) == [(0, 3)]
178-
179-
idx1_1kpc, idx2_1kpc, d2d_1kpc, d3d_1kpc = search_around_3d(coo1, coo2, 1 * u.kpc)
180-
idx1_sm, idx2_sm, d2d_sm, d3d_sm = search_around_3d(coo1, coo2, 0.05 * u.kpc)
184+
@pytest.mark.parametrize(
185+
"search_limit,expected_idx1,expected_idx2,expected_d2d,expected_d3d",
186+
[
187+
pytest.param(
188+
1 * u.kpc,
189+
[0, 0, 0, 1],
190+
[0, 1, 2, 3],
191+
[3, 2, 1, 1.9] * u.deg,
192+
[0.0523539, 0.03490481, 0.01745307, 0.16579868] * u.kpc,
193+
id="1_kpc",
194+
),
195+
pytest.param(
196+
0.05 * u.kpc,
197+
[0, 0],
198+
[1, 2],
199+
[2, 1] * u.deg,
200+
[0.03490481, 0.01745307] * u.kpc,
201+
id="0.05_kpc",
202+
),
203+
],
204+
)
205+
def test_search_around_3d(
206+
search_limit, expected_idx1, expected_idx2, expected_d2d, expected_d3d
207+
):
208+
idx1, idx2, d2d, d3d = search_around_3d(
209+
ICRS([4, 2.1] * u.deg, [0, 0] * u.deg, distance=[1, 5] * u.kpc),
210+
ICRS([1, 2, 3, 4] * u.deg, [0, 0, 0, 0] * u.deg, distance=[1, 1, 1, 5] * u.kpc),
211+
search_limit,
212+
)
213+
npt.assert_array_equal(idx1, expected_idx1)
214+
npt.assert_array_equal(idx2, expected_idx2)
215+
assert_allclose(d2d, expected_d2d)
216+
assert_allclose(d3d, expected_d3d)
181217

182-
assert list(zip(idx1_1kpc, idx2_1kpc)) == [(0, 0), (0, 1), (0, 2), (1, 3)]
183-
assert list(zip(idx1_sm, idx2_sm)) == [(0, 1), (0, 2)]
184-
assert_allclose(d2d_sm, [2, 1] * u.deg)
185218

219+
@pytest.mark.parametrize(
220+
"function,search_limit",
221+
[
222+
pytest.param(func, limit, id=func.__name__)
223+
for func, limit in ([search_around_3d, 1 * u.m], [search_around_sky, 1 * u.deg])
224+
],
225+
)
226+
def test_search_around_no_matches(function, search_limit):
186227
# Test for the non-matches, #4877
187-
coo1 = ICRS([4.1, 2.1] * u.degree, [0, 0] * u.degree, distance=[1, 5] * u.kpc)
188-
idx1, idx2, d2d, d3d = search_around_sky(coo1, coo2, 1 * u.arcsec)
189-
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
190-
assert idx1.dtype == idx2.dtype == int
191-
assert d2d.unit == u.deg
192-
assert d3d.unit == u.kpc
193-
idx1, idx2, d2d, d3d = search_around_3d(coo1, coo2, 1 * u.m)
194-
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
195-
assert idx1.dtype == idx2.dtype == int
228+
idx1, idx2, d2d, d3d = function(
229+
ICRS([41, 21] * u.deg, [0, 0] * u.deg, distance=[1, 5] * u.kpc),
230+
ICRS([1, 2] * u.deg, [0, 0] * u.deg, distance=[1, 1] * u.kpc),
231+
search_limit,
232+
)
233+
assert idx1.size == 0
234+
assert idx2.size == 0
235+
assert d2d.size == 0
236+
assert d3d.size == 0
237+
assert idx1.dtype == int
238+
assert idx2.dtype == int
196239
assert d2d.unit == u.deg
197240
assert d3d.unit == u.kpc
198241

242+
243+
@pytest.mark.parametrize(
244+
"function,search_limit",
245+
[
246+
pytest.param(func, limit, id=func.__name__)
247+
for func, limit in ([search_around_3d, 1 * u.m], [search_around_sky, 1 * u.deg])
248+
],
249+
)
250+
@pytest.mark.parametrize(
251+
"sources,catalog",
252+
[
253+
pytest.param(
254+
ICRS(ra=[] * u.deg, dec=[] * u.deg, distance=[] * u.kpc),
255+
ICRS([1] * u.deg, [0] * u.deg, distance=[1] * u.kpc),
256+
id="empty_sources",
257+
),
258+
pytest.param(
259+
ICRS([1] * u.deg, [0] * u.deg, distance=[1] * u.kpc),
260+
ICRS(ra=[] * u.deg, dec=[] * u.deg, distance=[] * u.kpc),
261+
id="empty_catalog",
262+
),
263+
pytest.param(
264+
ICRS(ra=[] * u.deg, dec=[] * u.deg, distance=[] * u.kpc),
265+
ICRS(ra=[] * u.deg, dec=[] * u.deg, distance=[] * u.kpc),
266+
id="empty_both",
267+
),
268+
],
269+
)
270+
def test_search_around_empty_input(sources, catalog, function, search_limit):
199271
# Test when one or both of the coordinate arrays is empty, #4875
200-
empty = ICRS(ra=[] * u.degree, dec=[] * u.degree, distance=[] * u.kpc)
201-
idx1, idx2, d2d, d3d = search_around_sky(empty, coo2, 1 * u.arcsec)
202-
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
203-
assert idx1.dtype == idx2.dtype == int
204-
assert d2d.unit == u.deg
205-
assert d3d.unit == u.kpc
206-
idx1, idx2, d2d, d3d = search_around_sky(coo1, empty, 1 * u.arcsec)
207-
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
208-
assert idx1.dtype == idx2.dtype == int
209-
assert d2d.unit == u.deg
210-
assert d3d.unit == u.kpc
211-
empty = ICRS(ra=[] * u.degree, dec=[] * u.degree, distance=[] * u.kpc)
212-
idx1, idx2, d2d, d3d = search_around_sky(empty, empty[:], 1 * u.arcsec)
213-
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
214-
assert idx1.dtype == idx2.dtype == int
215-
assert d2d.unit == u.deg
216-
assert d3d.unit == u.kpc
217-
idx1, idx2, d2d, d3d = search_around_3d(empty, coo2, 1 * u.m)
218-
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
219-
assert idx1.dtype == idx2.dtype == int
220-
assert d2d.unit == u.deg
221-
assert d3d.unit == u.kpc
222-
idx1, idx2, d2d, d3d = search_around_3d(coo1, empty, 1 * u.m)
223-
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
224-
assert idx1.dtype == idx2.dtype == int
225-
assert d2d.unit == u.deg
226-
assert d3d.unit == u.kpc
227-
idx1, idx2, d2d, d3d = search_around_3d(empty, empty[:], 1 * u.m)
228-
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
229-
assert idx1.dtype == idx2.dtype == int
272+
idx1, idx2, d2d, d3d = function(sources, catalog, search_limit)
273+
assert idx1.size == 0
274+
assert idx2.size == 0
275+
assert d2d.size == 0
276+
assert d3d.size == 0
277+
assert idx1.dtype == int
278+
assert idx2.dtype == int
230279
assert d2d.unit == u.deg
231280
assert d3d.unit == u.kpc
232281

282+
283+
@pytest.mark.parametrize(
284+
"function,search_limit",
285+
[
286+
pytest.param(func, limit, id=func.__name__)
287+
for func, limit in ([search_around_3d, 1 * u.m], [search_around_sky, 1 * u.deg])
288+
],
289+
)
290+
def test_search_around_no_dist_input_output_units(function, search_limit):
233291
# Test that input without distance units results in a
234292
# 'dimensionless_unscaled' unit
235-
cempty = SkyCoord(ra=[], dec=[], unit=u.deg)
236-
idx1, idx2, d2d, d3d = search_around_3d(cempty, cempty[:], 1 * u.m)
237-
assert d2d.unit == u.deg
238-
assert d3d.unit == u.dimensionless_unscaled
239-
idx1, idx2, d2d, d3d = search_around_sky(cempty, cempty[:], 1 * u.m)
293+
empty_sc = SkyCoord([], [], unit=u.deg)
294+
idx1, idx2, d2d, d3d = function(empty_sc, empty_sc[:], search_limit)
240295
assert d2d.unit == u.deg
241296
assert d3d.unit == u.dimensionless_unscaled
242297

243298

244-
@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
245299
def test_search_around_scalar():
246-
from astropy.coordinates import Angle, SkyCoord
247-
248300
cat = SkyCoord([1, 2, 3], [-30, 45, 8], unit="deg")
249301
target = SkyCoord("1.1 -30.1", unit="deg")
250302

@@ -260,10 +312,7 @@ def test_search_around_scalar():
260312
assert "search_around_3d" in str(excinfo.value)
261313

262314

263-
@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
264315
def test_match_catalog_empty():
265-
from astropy.coordinates import SkyCoord
266-
267316
sc1 = SkyCoord(1, 2, unit="deg")
268317
cat0 = SkyCoord([], [], unit="deg")
269318
cat1 = SkyCoord([1.1], [2.1], unit="deg")
@@ -290,11 +339,8 @@ def test_match_catalog_empty():
290339
assert "catalog" in str(excinfo.value)
291340

292341

293-
@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
294342
@pytest.mark.filterwarnings(r"ignore:invalid value encountered in.*:RuntimeWarning")
295343
def test_match_catalog_nan():
296-
from astropy.coordinates import Galactic, SkyCoord
297-
298344
sc1 = SkyCoord(1, 2, unit="deg")
299345
sc_with_nans = SkyCoord(1, np.nan, unit="deg")
300346

@@ -324,11 +370,7 @@ def test_match_catalog_nan():
324370
assert "Matching coordinates cannot contain" in str(excinfo.value)
325371

326372

327-
@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
328373
def test_match_catalog_nounit():
329-
from astropy.coordinates import ICRS, CartesianRepresentation
330-
from astropy.coordinates.matching import match_coordinates_sky
331-
332374
i1 = ICRS([[1], [2], [3]], representation_type=CartesianRepresentation)
333375
i2 = ICRS([[1], [2], [4, 5]], representation_type=CartesianRepresentation)
334376
i, sep, sep3d = match_coordinates_sky(i1, i2)

0 commit comments

Comments
 (0)