Skip to content

Commit

Permalink
Merge pull request #1694 from larrybradley/check-labels
Browse files Browse the repository at this point in the history
Add a check for valid labels in get_labels
  • Loading branch information
larrybradley authored Jan 18, 2024
2 parents e52697b + 6adac3c commit a59a7ab
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 1 deletion.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ API Changes
- The ``GridddedPSFModel`` string representations now include the
model ``flux``, ``x_0``, and ``y_0`` parameters. [#1680]

- ``photutils.segmentation``

- The ``SourceCatalog`` ``get_label`` and ``get_labels`` methods now
raise a ``ValueError`` if any of the input labels are invalid. [#1694]


1.10.0 (2023-11-21)
-------------------
Expand Down
1 change: 1 addition & 0 deletions photutils/segmentation/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,7 @@ def get_labels(self, labels):
A new `SourceCatalog` object containing only the sources with
the input ``labels``.
"""
self._segment_img.check_labels(labels)
sorter = np.argsort(self.labels)
indices = sorter[np.searchsorted(self.labels, labels, sorter=sorter)]
return self[indices]
Expand Down
2 changes: 1 addition & 1 deletion photutils/segmentation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def check_labels(self, labels):
# check if label is in the segmentation array
bad_labels.update(np.setdiff1d(labels, self.labels))

if bad_labels:
if bad_labels: # bad_labels is a set
if len(bad_labels) == 1:
raise ValueError(f'label {bad_labels} is invalid')
raise ValueError(f'labels {bad_labels} are invalid')
Expand Down
7 changes: 7 additions & 0 deletions photutils/segmentation/tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ def test_slicing(self):
obj1 = self.cat[0]
obj2 = obj1[0]

match = 'is invalid'
with pytest.raises(ValueError, match=match):
self.cat.get_label(1000)

with pytest.raises(ValueError, match=match):
self.cat.get_labels([1, 2, 1000])

def test_iter(self):
labels = []
for obj in self.cat:
Expand Down

0 comments on commit a59a7ab

Please sign in to comment.