Skip to content

Commit a609e60

Browse files
authored
Merge pull request #2345 from astrofrog/fix-1d-wcs
Fix world coordinates for 1D WCS
2 parents ff81f84 + 06c6c50 commit a609e60

16 files changed

+119
-57
lines changed

glue/core/component.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import pandas as pd
55

66
from glue.core.coordinate_helpers import dependent_axes, pixel2world_single_axis
7-
from glue.utils import (shape_to_string, coerce_numeric,
8-
broadcast_to, categorical_ndarray)
7+
from glue.utils import shape_to_string, coerce_numeric, categorical_ndarray
98

109
try:
1110
import dask.array as da
@@ -330,7 +329,7 @@ def _calculate(self, view=None):
330329
world_coords = world_coords[tuple(final_slice)]
331330

332331
# We then broadcast the final array back to what it should be
333-
world_coords = broadcast_to(world_coords, tuple(final_shape))
332+
world_coords = np.broadcast_to(world_coords, tuple(final_shape))
334333

335334
# We apply the view if we weren't able to optimize before
336335
if optimize_view:

glue/core/component_link.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
world2pixel_single_axis)
1212
from glue.core.subset import InequalitySubsetState
1313
from glue.core.util import join_component_view
14-
from glue.utils import unbroadcast, broadcast_to
14+
from glue.utils import unbroadcast
1515
from glue.logger import logger
1616

1717
__all__ = ['ComponentLink', 'BinaryComponentLink', 'CoordinateComponentLink']
@@ -198,7 +198,7 @@ def compute(self, data, view=None):
198198
result.shape = args[0].shape
199199

200200
# Finally we broadcast the final result to desired shape
201-
result = broadcast_to(result, original_shape)
201+
result = np.broadcast_to(result, original_shape)
202202

203203
return result
204204

@@ -386,7 +386,7 @@ def using(self, *args):
386386
args2[f] = a
387387
for i in range(self.ndim):
388388
if args2[i] is None:
389-
args2[i] = broadcast_to(default[self.ndim - 1 - i], args[0].shape)
389+
args2[i] = np.broadcast_to(default[self.ndim - 1 - i], args[0].shape)
390390
args2 = tuple(args2)
391391

392392
if self.pixel2world:
@@ -487,7 +487,7 @@ def compute(self, data, view=None):
487487
if original_shape is None:
488488
return result
489489
else:
490-
return broadcast_to(result, original_shape)
490+
return np.broadcast_to(result, original_shape)
491491

492492
def __gluestate__(self, context):
493493
left = context.id(self._left)

glue/core/coordinate_helpers.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from astropy.wcs import WCS
33

4-
from glue.utils import unbroadcast, broadcast_to
4+
from glue.utils import unbroadcast
55
from glue.core.coordinates import LegacyCoordinates
66

77

@@ -53,9 +53,20 @@ def pixel2world_single_axis(wcs, *pixel, world_axis=None):
5353
pixel_new.append(p.flat[0])
5454
pixel = np.broadcast_arrays(*pixel_new)
5555

56-
result = wcs.pixel_to_world_values(*pixel)
56+
# In the case of 1D WCS, there is an astropy issue which prevents us from
57+
# passing arbitrary shapes - see https://github.com/astropy/astropy/issues/12154
58+
# Therefore, we ravel the values and reshape afterwards
5759

58-
return broadcast_to(result[world_axis], original_shape)
60+
if len(pixel) == 1 and pixel[0].ndim > 1:
61+
pixel_shape = pixel[0].shape
62+
result = wcs.pixel_to_world_values(pixel[0].ravel())
63+
result = result.reshape(pixel_shape)
64+
else:
65+
result = wcs.pixel_to_world_values(*pixel)
66+
if len(pixel) > 1:
67+
result = result[world_axis]
68+
69+
return np.broadcast_to(result, original_shape)
5970

6071

6172
def world2pixel_single_axis(wcs, *world, pixel_axis=None):
@@ -99,9 +110,20 @@ def world2pixel_single_axis(wcs, *world, pixel_axis=None):
99110
world_new.append(w.flat[0])
100111
world = np.broadcast_arrays(*world_new)
101112

102-
result = wcs.world_to_pixel_values(*world)
113+
# In the case of 1D WCS, there is an astropy issue which prevents us from
114+
# passing arbitrary shapes - see https://github.com/astropy/astropy/issues/12154
115+
# Therefore, we ravel the values and reshape afterwards
116+
117+
if len(world) == 1 and world[0].ndim > 1:
118+
world_shape = world[0].shape
119+
result = wcs.world_to_pixel_values(world[0].ravel())
120+
result = result.reshape(world_shape)
121+
else:
122+
result = wcs.world_to_pixel_values(*world)
123+
if len(world) > 1:
124+
result = result[pixel_axis]
103125

104-
return broadcast_to(result[pixel_axis], original_shape)
126+
return np.broadcast_to(result, original_shape)
105127

106128

107129
def world_axis(wcs, data, *, pixel_axis=None, world_axis=None):

glue/core/coordinates.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,16 @@ def __init__(self):
9090
super().__init__(pixel_n_dim=10, world_n_dim=10)
9191

9292
def pixel_to_world_values(self, *pixel):
93-
return pixel
93+
if len(pixel) == 1:
94+
return pixel[0]
95+
else:
96+
return pixel
9497

9598
def world_to_pixel_values(self, *world):
96-
return world
99+
if len(world) == 1:
100+
return world[0]
101+
else:
102+
return world
97103

98104

99105
class IdentityCoordinates(Coordinates):
@@ -102,10 +108,16 @@ def __init__(self, n_dim=None):
102108
super().__init__(pixel_n_dim=n_dim, world_n_dim=n_dim)
103109

104110
def pixel_to_world_values(self, *pixel):
105-
return pixel
111+
if self.pixel_n_dim == 1:
112+
return pixel[0]
113+
else:
114+
return pixel
106115

107116
def world_to_pixel_values(self, *world):
108-
return world
117+
if self.world_n_dim == 1:
118+
return world[0]
119+
else:
120+
return world
109121

110122
@property
111123
def axis_correlation_matrix(self):

glue/core/data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from glue.core.joins import get_mask_with_key_joins
2929
from glue.config import settings, data_translator, subset_state_translator
3030
from glue.utils import (compute_statistic, unbroadcast, iterate_chunks,
31-
datetime64_to_mpl, broadcast_to, categorical_ndarray,
31+
datetime64_to_mpl, categorical_ndarray,
3232
format_choices, random_views_for_dask_array)
3333
from glue.core.coordinate_helpers import axis_label
3434

@@ -445,9 +445,9 @@ def get_data(self, cid, view=None):
445445
shape = tuple(-1 if i == cid.axis else 1 for i in range(self.ndim))
446446
pix = np.arange(self.shape[cid.axis], dtype=float).reshape(shape)
447447
if view is None:
448-
return broadcast_to(pix, self.shape)
448+
return np.broadcast_to(pix, self.shape)
449449
else:
450-
return broadcast_to(pix, self.shape)[view]
450+
return np.broadcast_to(pix, self.shape)[view]
451451
elif cid in self.world_component_ids:
452452
comp = self._world_components[cid]
453453
elif cid in self._externally_derivable_components:
@@ -1822,7 +1822,7 @@ def compute_statistic(self, statistic, cid, subset_state=None, axis=None,
18221822
if isinstance(axis, int):
18231823
axis = [axis]
18241824
final_shape = [mask.shape[i] for i in range(mask.ndim) if i not in axis]
1825-
return broadcast_to(np.nan, final_shape)
1825+
return np.broadcast_to(np.nan, final_shape)
18261826
else:
18271827
data = self.get_data(cid, view=view)
18281828
mask = None

glue/core/fixed_resolution_buffer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from glue.core.exceptions import IncompatibleAttribute, IncompatibleDataException
44
from glue.core.component import DaskComponent
55
from glue.core.coordinate_helpers import dependent_axes
6-
from glue.utils import unbroadcast, broadcast_to, broadcast_arrays_minimal
6+
from glue.utils import unbroadcast, broadcast_arrays_minimal
77

88
# TODO: cache needs to be updated when links are removed/changed
99

@@ -73,7 +73,7 @@ def translate_pixel(data, pixel_coords, target_cid):
7373
shape = values_all[0].shape
7474
values_all = broadcast_arrays_minimal(*values_all)
7575
results = link._using(*values_all)
76-
result = broadcast_to(results, shape)
76+
result = np.broadcast_to(results, shape)
7777
else:
7878
result = None
7979
return result, sorted(set(dimensions_all))
@@ -222,7 +222,7 @@ def compute_fixed_resolution_buffer(data, bounds, target_data=None, target_cid=N
222222
invalid_all |= invalid
223223

224224
# Broadcast back to the original shape and add to the list
225-
translated_coords.append(broadcast_to(translated_coord, original_shape))
225+
translated_coords.append(np.broadcast_to(translated_coord, original_shape))
226226

227227
# Also keep track of all the dimensions that contributed to this coordinate
228228
dimensions_all.extend(dimensions)

glue/core/subset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from glue.core.decorators import memoize
1515
from glue.core.visual import VisualAttributes
1616
from glue.config import settings
17-
from glue.utils import (view_shape, broadcast_to, floodfill, combine_slices,
18-
polygon_line_intersections, categorical_ndarray, iterate_chunks)
17+
from glue.utils import (categorical_ndarray, combine_slices, floodfill, iterate_chunks,
18+
polygon_line_intersections, view_shape)
1919

2020

2121
__all__ = ['Subset', 'SubsetState', 'RoiSubsetStateNd', 'RoiSubsetState', 'CategoricalROISubsetState',
@@ -458,7 +458,7 @@ def to_mask(self, data, view=None):
458458
Any object that returns a valid view for a Numpy array.
459459
"""
460460
shp = view_shape(data.shape, view)
461-
return broadcast_to(False, shp)
461+
return np.broadcast_to(False, shp)
462462

463463
@contract(returns='isinstance(SubsetState)')
464464
def copy(self):
@@ -1327,7 +1327,7 @@ def to_mask(self, data, view=None):
13271327

13281328
if order is None:
13291329
# We use broadcast_to for minimal memory usage
1330-
return broadcast_to(False, shape)
1330+
return np.broadcast_to(False, shape)
13311331
else:
13321332
# Reorder slices
13331333
slices = [self.slices[idx] for idx in order]
@@ -1350,7 +1350,7 @@ def to_mask(self, data, view=None):
13501350
elif np.isscalar(view[i]):
13511351
beg, end, stp = slices[i].indices(data.shape[i])
13521352
if view[i] < beg or view[i] >= end or (view[i] - beg) % stp != 0:
1353-
return broadcast_to(False, shape)
1353+
return np.broadcast_to(False, shape)
13541354
elif isinstance(view[i], slice):
13551355
if view[i].step is not None and view[i].step < 0:
13561356
beg, end, step = view[i].indices(data.shape[i])

glue/core/tests/test_component.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77
from unittest.mock import MagicMock
88

9+
from astropy.wcs import WCS
10+
911
from glue import core
1012
from glue.tests.helpers import requires_astropy
1113

@@ -386,3 +388,18 @@ def test_update_cid_used_in_derived():
386388
np.testing.assert_equal(data['b'], [4, 5, 2])
387389
data.update_id(data.id['a'], ComponentID('x'))
388390
np.testing.assert_equal(data['b'], [4, 5, 2])
391+
392+
393+
def test_coordinate_component_1d_coord():
394+
395+
# Regression test for a bug that caused incorrect world coordinate values
396+
# for 1D coordinates.
397+
398+
wcs = WCS(naxis=1)
399+
wcs.wcs.ctype = ['FREQ']
400+
wcs.wcs.crpix = [1]
401+
wcs.wcs.crval = [1]
402+
wcs.wcs.cdelt = [1]
403+
404+
data = Data(flux=np.random.random(5), coords=wcs, label='data')
405+
np.testing.assert_equal(data['Frequency'], [1, 2, 3, 4, 5])

glue/core/tests/test_coordinates.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,24 @@ def test_pixel2world_single_axis():
363363
assert_allclose(pixel2world_single_axis(coord, x, y, z, world_axis=2), [1.5, 1.5, 1.5])
364364

365365

366+
def test_pixel2world_single_axis_1d():
367+
368+
# Regression test for issues that occurred for 1D WCSes
369+
370+
coord = WCSCoordinates(naxis=1)
371+
coord.wcs.ctype = ['FREQ']
372+
coord.wcs.crpix = [1]
373+
coord.wcs.crval = [1]
374+
coord.wcs.cdelt = [1]
375+
376+
x = np.array([0.2, 0.4, 0.6])
377+
expected = np.array([1.2, 1.4, 1.6])
378+
379+
assert_allclose(pixel2world_single_axis(coord, x, world_axis=0), expected)
380+
assert_allclose(pixel2world_single_axis(coord, x.reshape((1, 3)), world_axis=0), expected.reshape((1, 3)))
381+
assert_allclose(pixel2world_single_axis(coord, x.reshape((3, 1)), world_axis=0), expected.reshape((3, 1)))
382+
383+
366384
def test_affine():
367385

368386
matrix = np.array([[2, 3, -1], [1, 2, 2], [0, 0, 1]])

glue/core/tests/test_data.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from astropy.utils import NumpyRNGContext
88

99
from glue import core
10-
from glue.utils import broadcast_to
1110

1211
from ..component import Component, DerivedComponent, CategoricalComponent, DateTimeComponent
1312
from ..component_id import ComponentID
@@ -819,10 +818,16 @@ def world_axis_names(self):
819818
return ['Custom {0}'.format(axis) for axis in range(3)]
820819

821820
def world_to_pixel_values(self, *world):
822-
return tuple([0.4 * w for w in world])
821+
if self.pixel_n_dim == 1:
822+
return 0.4 * world[0]
823+
else:
824+
return tuple([0.4 * w for w in world])
823825

824826
def pixel_to_world_values(self, *pixel):
825-
return tuple([2.5 * p for p in pixel])
827+
if self.world_n_dim == 1:
828+
return 2.5 * pixel[0]
829+
else:
830+
return tuple([2.5 * p for p in pixel])
826831

827832
data1.coords = CustomCoordinates()
828833

@@ -930,10 +935,10 @@ def test_compute_statistic_empty_subset():
930935
assert_equal(result, np.nan)
931936

932937
result = data.compute_statistic('maximum', data.id['x'], subset_state=subset_state, axis=1)
933-
assert_equal(result, broadcast_to(np.nan, (30, 40)))
938+
assert_equal(result, np.broadcast_to(np.nan, (30, 40)))
934939

935940
result = data.compute_statistic('median', data.id['x'], subset_state=subset_state, axis=(1, 2))
936-
assert_equal(result, broadcast_to(np.nan, (30)))
941+
assert_equal(result, np.broadcast_to(np.nan, (30)))
937942

938943
result = data.compute_statistic('sum', data.id['x'], subset_state=subset_state, axis=(0, 1, 2))
939944
assert_equal(result, np.nan)

glue/core/tests/test_links.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def test_1d_world_link():
1919
dc.add_link(LinkSame(d2.world_component_ids[0], d1.id['x']))
2020

2121
assert d2.world_component_ids[0] in d1.externally_derivable_components
22+
2223
np.testing.assert_array_equal(d1[d2.world_component_ids[0]], x)
2324
np.testing.assert_array_equal(d1[d2.pixel_component_ids[0]], x)
2425

glue/utils/array.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from numpy import nanmin, nanmax, nanmean, nanmedian, nansum # noqa
1010

1111
__all__ = ['unique', 'shape_to_string', 'view_shape', 'stack_view',
12-
'coerce_numeric', 'check_sorted', 'broadcast_to', 'unbroadcast',
12+
'coerce_numeric', 'check_sorted', 'unbroadcast',
1313
'iterate_chunks', 'combine_slices', 'format_minimal', 'compute_statistic',
1414
'categorical_ndarray', 'index_lookup', 'ensure_numerical',
1515
'broadcast_arrays_minimal', 'random_views_for_dask_array']
@@ -201,18 +201,6 @@ def pretty_number(numbers):
201201
return result
202202

203203

204-
def broadcast_to(array, shape):
205-
"""
206-
Compatibility function - can be removed once we support only Numpy 1.10
207-
and above
208-
"""
209-
try:
210-
return np.broadcast_to(array, shape)
211-
except AttributeError:
212-
array = np.asarray(array)
213-
return np.broadcast_arrays(array, np.ones(shape, array.dtype))[0]
214-
215-
216204
def find_chunk_shape(shape, n_max=None):
217205
"""
218206
Given the shape of an n-dimensional array, and the maximum number of

glue/utils/geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
from glue.utils import unbroadcast, broadcast_to
3+
from glue.utils import unbroadcast
44

55
__all__ = ['points_inside_poly', 'polygon_line_intersections', 'floodfill', 'rotation_matrix_2d']
66

@@ -82,7 +82,7 @@ def points_inside_poly(x, y, vx, vy):
8282
inside[keep][~good] = False
8383

8484
inside = inside.reshape(reduced_shape)
85-
inside = broadcast_to(inside, original_shape)
85+
inside = np.broadcast_to(inside, original_shape)
8686

8787
return inside
8888

0 commit comments

Comments
 (0)