Skip to content

Commit

Permalink
Merge pull request astropy#16277 from eerovaher/coord-matching-tests
Browse files Browse the repository at this point in the history
Rewrite an overlong coordinate matching test function
  • Loading branch information
pllim authored Apr 5, 2024
2 parents aa4dfa7 + 4947a2f commit 0a06e9a
Showing 1 changed file with 146 additions and 104 deletions.
250 changes: 146 additions & 104 deletions astropy/coordinates/tests/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,34 @@
from numpy import testing as npt

from astropy import units as u
from astropy.coordinates import matching
from astropy.coordinates import (
ICRS,
Angle,
CartesianRepresentation,
Galactic,
SkyCoord,
match_coordinates_3d,
match_coordinates_sky,
search_around_3d,
search_around_sky,
)
from astropy.tests.helper import assert_quantity_allclose as assert_allclose
from astropy.utils import NumpyRNGContext
from astropy.utils.compat.optional_deps import HAS_SCIPY

"""
These are the tests for coordinate matching.
Note that this requires scipy.
Coordinate matching can involve caching, so it is best to recreate the
coordinate objects in every test instead of trying to reuse module-level
variables.
"""

if not HAS_SCIPY:
pytest.skip("Coordinate matching requires scipy", allow_module_level=True)

@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.")
def test_matching_function():
from astropy.coordinates import ICRS
from astropy.coordinates.matching import match_coordinates_3d

def test_matching_function():
# this only uses match_coordinates_3d because that's the actual implementation

cmatch = ICRS([4, 2.1] * u.degree, [0, 0] * u.degree)
Expand All @@ -37,11 +49,7 @@ def test_matching_function():
npt.assert_array_less(d3d.value, 0.02)


@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.")
def test_matching_function_3d_and_sky():
from astropy.coordinates import ICRS
from astropy.coordinates.matching import match_coordinates_3d, match_coordinates_sky

cmatch = ICRS([4, 2.1] * u.degree, [0, 0] * u.degree, distance=[1, 5] * u.kpc)
ccatalog = ICRS(
[1, 2, 3, 4] * u.degree, [0, 0, 0, 0] * u.degree, distance=[1, 1, 1, 5] * u.kpc
Expand All @@ -64,16 +72,13 @@ def test_matching_function_3d_and_sky():
@pytest.mark.parametrize(
"functocheck, args, defaultkdtname, bothsaved",
[
(matching.match_coordinates_3d, [], "kdtree_3d", False),
(matching.match_coordinates_sky, [], "kdtree_sky", False),
(matching.search_around_3d, [1 * u.kpc], "kdtree_3d", True),
(matching.search_around_sky, [1 * u.deg], "kdtree_sky", False),
(match_coordinates_3d, [], "kdtree_3d", False),
(match_coordinates_sky, [], "kdtree_sky", False),
(search_around_3d, [1 * u.kpc], "kdtree_3d", True),
(search_around_sky, [1 * u.deg], "kdtree_sky", False),
],
)
@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.")
def test_kdtree_storage(functocheck, args, defaultkdtname, bothsaved):
from astropy.coordinates import ICRS

def make_scs():
cmatch = ICRS([4, 2.1] * u.degree, [0, 0] * u.degree, distance=[1, 2] * u.kpc)
ccatalog = ICRS(
Expand Down Expand Up @@ -119,12 +124,7 @@ def make_scs():
assert "KD" in e.value.args[0]


@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy.")
def test_matching_method():
from astropy.coordinates import ICRS, SkyCoord
from astropy.coordinates.matching import match_coordinates_3d, match_coordinates_sky
from astropy.utils import NumpyRNGContext

with NumpyRNGContext(987654321):
cmatch = ICRS(
np.random.rand(20) * 360.0 * u.degree,
Expand Down Expand Up @@ -153,98 +153,150 @@ def test_matching_method():
assert len(idx1) == len(d2d1) == len(d3d1) == 20


@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
def test_search_around():
from astropy.coordinates import ICRS, SkyCoord
from astropy.coordinates.matching import search_around_3d, search_around_sky

coo1 = ICRS([4, 2.1] * u.degree, [0, 0] * u.degree, distance=[1, 5] * u.kpc)
coo2 = ICRS(
[1, 2, 3, 4] * u.degree, [0, 0, 0, 0] * u.degree, distance=[1, 1, 1, 5] * u.kpc
)

idx1_1deg, idx2_1deg, d2d_1deg, d3d_1deg = search_around_sky(
coo1, coo2, 1.01 * u.deg
)
idx1_0p05deg, idx2_0p05deg, d2d_0p05deg, d3d_0p05deg = search_around_sky(
coo1, coo2, 0.05 * u.deg
@pytest.mark.parametrize(
"search_limit,expected_idx1,expected_idx2,expected_d2d,expected_d3d",
[
pytest.param(
1.01 * u.deg,
[0, 0, 1, 1],
[2, 3, 1, 2],
[1, 0, 0.1, 0.9] * u.deg,
[0.01745307, 4.0, 4.0000019, 4.00015421] * u.kpc,
id="1.01_deg",
),
pytest.param(0.05 * u.deg, [0], [3], [0] * u.deg, [4] * u.kpc, id="0.05_deg"),
],
)
def test_search_around_sky(
search_limit, expected_idx1, expected_idx2, expected_d2d, expected_d3d
):
idx1, idx2, d2d, d3d = search_around_sky(
ICRS([4, 2.1] * u.deg, [0, 0] * u.deg, distance=[1, 5] * u.kpc),
ICRS([1, 2, 3, 4] * u.deg, [0, 0, 0, 0] * u.deg, distance=[1, 1, 1, 5] * u.kpc),
search_limit,
)
npt.assert_array_equal(idx1, expected_idx1)
npt.assert_array_equal(idx2, expected_idx2)
assert_allclose(d2d, expected_d2d)
assert_allclose(d3d, expected_d3d)

assert list(zip(idx1_1deg, idx2_1deg)) == [(0, 2), (0, 3), (1, 1), (1, 2)]
assert_allclose(d2d_1deg[0], 1.0 * u.deg, atol=1e-14 * u.deg, rtol=0)
assert_allclose(d2d_1deg, [1, 0, 0.1, 0.9] * u.deg)

assert list(zip(idx1_0p05deg, idx2_0p05deg)) == [(0, 3)]

idx1_1kpc, idx2_1kpc, d2d_1kpc, d3d_1kpc = search_around_3d(coo1, coo2, 1 * u.kpc)
idx1_sm, idx2_sm, d2d_sm, d3d_sm = search_around_3d(coo1, coo2, 0.05 * u.kpc)
@pytest.mark.parametrize(
"search_limit,expected_idx1,expected_idx2,expected_d2d,expected_d3d",
[
pytest.param(
1 * u.kpc,
[0, 0, 0, 1],
[0, 1, 2, 3],
[3, 2, 1, 1.9] * u.deg,
[0.0523539, 0.03490481, 0.01745307, 0.16579868] * u.kpc,
id="1_kpc",
),
pytest.param(
0.05 * u.kpc,
[0, 0],
[1, 2],
[2, 1] * u.deg,
[0.03490481, 0.01745307] * u.kpc,
id="0.05_kpc",
),
],
)
def test_search_around_3d(
search_limit, expected_idx1, expected_idx2, expected_d2d, expected_d3d
):
idx1, idx2, d2d, d3d = search_around_3d(
ICRS([4, 2.1] * u.deg, [0, 0] * u.deg, distance=[1, 5] * u.kpc),
ICRS([1, 2, 3, 4] * u.deg, [0, 0, 0, 0] * u.deg, distance=[1, 1, 1, 5] * u.kpc),
search_limit,
)
npt.assert_array_equal(idx1, expected_idx1)
npt.assert_array_equal(idx2, expected_idx2)
assert_allclose(d2d, expected_d2d)
assert_allclose(d3d, expected_d3d)

assert list(zip(idx1_1kpc, idx2_1kpc)) == [(0, 0), (0, 1), (0, 2), (1, 3)]
assert list(zip(idx1_sm, idx2_sm)) == [(0, 1), (0, 2)]
assert_allclose(d2d_sm, [2, 1] * u.deg)

@pytest.mark.parametrize(
"function,search_limit",
[
pytest.param(func, limit, id=func.__name__)
for func, limit in ([search_around_3d, 1 * u.m], [search_around_sky, 1 * u.deg])
],
)
def test_search_around_no_matches(function, search_limit):
# Test for the non-matches, #4877
coo1 = ICRS([4.1, 2.1] * u.degree, [0, 0] * u.degree, distance=[1, 5] * u.kpc)
idx1, idx2, d2d, d3d = search_around_sky(coo1, coo2, 1 * u.arcsec)
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
assert idx1.dtype == idx2.dtype == int
assert d2d.unit == u.deg
assert d3d.unit == u.kpc
idx1, idx2, d2d, d3d = search_around_3d(coo1, coo2, 1 * u.m)
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
assert idx1.dtype == idx2.dtype == int
idx1, idx2, d2d, d3d = function(
ICRS([41, 21] * u.deg, [0, 0] * u.deg, distance=[1, 5] * u.kpc),
ICRS([1, 2] * u.deg, [0, 0] * u.deg, distance=[1, 1] * u.kpc),
search_limit,
)
assert idx1.size == 0
assert idx2.size == 0
assert d2d.size == 0
assert d3d.size == 0
assert idx1.dtype == int
assert idx2.dtype == int
assert d2d.unit == u.deg
assert d3d.unit == u.kpc


@pytest.mark.parametrize(
"function,search_limit",
[
pytest.param(func, limit, id=func.__name__)
for func, limit in ([search_around_3d, 1 * u.m], [search_around_sky, 1 * u.deg])
],
)
@pytest.mark.parametrize(
"sources,catalog",
[
pytest.param(
ICRS(ra=[] * u.deg, dec=[] * u.deg, distance=[] * u.kpc),
ICRS([1] * u.deg, [0] * u.deg, distance=[1] * u.kpc),
id="empty_sources",
),
pytest.param(
ICRS([1] * u.deg, [0] * u.deg, distance=[1] * u.kpc),
ICRS(ra=[] * u.deg, dec=[] * u.deg, distance=[] * u.kpc),
id="empty_catalog",
),
pytest.param(
ICRS(ra=[] * u.deg, dec=[] * u.deg, distance=[] * u.kpc),
ICRS(ra=[] * u.deg, dec=[] * u.deg, distance=[] * u.kpc),
id="empty_both",
),
],
)
def test_search_around_empty_input(sources, catalog, function, search_limit):
# Test when one or both of the coordinate arrays is empty, #4875
empty = ICRS(ra=[] * u.degree, dec=[] * u.degree, distance=[] * u.kpc)
idx1, idx2, d2d, d3d = search_around_sky(empty, coo2, 1 * u.arcsec)
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
assert idx1.dtype == idx2.dtype == int
assert d2d.unit == u.deg
assert d3d.unit == u.kpc
idx1, idx2, d2d, d3d = search_around_sky(coo1, empty, 1 * u.arcsec)
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
assert idx1.dtype == idx2.dtype == int
assert d2d.unit == u.deg
assert d3d.unit == u.kpc
empty = ICRS(ra=[] * u.degree, dec=[] * u.degree, distance=[] * u.kpc)
idx1, idx2, d2d, d3d = search_around_sky(empty, empty[:], 1 * u.arcsec)
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
assert idx1.dtype == idx2.dtype == int
assert d2d.unit == u.deg
assert d3d.unit == u.kpc
idx1, idx2, d2d, d3d = search_around_3d(empty, coo2, 1 * u.m)
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
assert idx1.dtype == idx2.dtype == int
assert d2d.unit == u.deg
assert d3d.unit == u.kpc
idx1, idx2, d2d, d3d = search_around_3d(coo1, empty, 1 * u.m)
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
assert idx1.dtype == idx2.dtype == int
assert d2d.unit == u.deg
assert d3d.unit == u.kpc
idx1, idx2, d2d, d3d = search_around_3d(empty, empty[:], 1 * u.m)
assert idx1.size == idx2.size == d2d.size == d3d.size == 0
assert idx1.dtype == idx2.dtype == int
idx1, idx2, d2d, d3d = function(sources, catalog, search_limit)
assert idx1.size == 0
assert idx2.size == 0
assert d2d.size == 0
assert d3d.size == 0
assert idx1.dtype == int
assert idx2.dtype == int
assert d2d.unit == u.deg
assert d3d.unit == u.kpc


@pytest.mark.parametrize(
"function,search_limit",
[
pytest.param(func, limit, id=func.__name__)
for func, limit in ([search_around_3d, 1 * u.m], [search_around_sky, 1 * u.deg])
],
)
def test_search_around_no_dist_input_output_units(function, search_limit):
# Test that input without distance units results in a
# 'dimensionless_unscaled' unit
cempty = SkyCoord(ra=[], dec=[], unit=u.deg)
idx1, idx2, d2d, d3d = search_around_3d(cempty, cempty[:], 1 * u.m)
assert d2d.unit == u.deg
assert d3d.unit == u.dimensionless_unscaled
idx1, idx2, d2d, d3d = search_around_sky(cempty, cempty[:], 1 * u.m)
empty_sc = SkyCoord([], [], unit=u.deg)
idx1, idx2, d2d, d3d = function(empty_sc, empty_sc[:], search_limit)
assert d2d.unit == u.deg
assert d3d.unit == u.dimensionless_unscaled


@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
def test_search_around_scalar():
from astropy.coordinates import Angle, SkyCoord

cat = SkyCoord([1, 2, 3], [-30, 45, 8], unit="deg")
target = SkyCoord("1.1 -30.1", unit="deg")

Expand All @@ -260,10 +312,7 @@ def test_search_around_scalar():
assert "search_around_3d" in str(excinfo.value)


@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
def test_match_catalog_empty():
from astropy.coordinates import SkyCoord

sc1 = SkyCoord(1, 2, unit="deg")
cat0 = SkyCoord([], [], unit="deg")
cat1 = SkyCoord([1.1], [2.1], unit="deg")
Expand All @@ -290,11 +339,8 @@ def test_match_catalog_empty():
assert "catalog" in str(excinfo.value)


@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
@pytest.mark.filterwarnings(r"ignore:invalid value encountered in.*:RuntimeWarning")
def test_match_catalog_nan():
from astropy.coordinates import Galactic, SkyCoord

sc1 = SkyCoord(1, 2, unit="deg")
sc_with_nans = SkyCoord(1, np.nan, unit="deg")

Expand Down Expand Up @@ -324,11 +370,7 @@ def test_match_catalog_nan():
assert "Matching coordinates cannot contain" in str(excinfo.value)


@pytest.mark.skipif(not HAS_SCIPY, reason="Requires scipy")
def test_match_catalog_nounit():
from astropy.coordinates import ICRS, CartesianRepresentation
from astropy.coordinates.matching import match_coordinates_sky

i1 = ICRS([[1], [2], [3]], representation_type=CartesianRepresentation)
i2 = ICRS([[1], [2], [4, 5]], representation_type=CartesianRepresentation)
i, sep, sep3d = match_coordinates_sky(i1, i2)
Expand Down

0 comments on commit 0a06e9a

Please sign in to comment.