Skip to content

Commit 5af2681

Browse files
authored
Expand column statistics to limit by pixels (#472)
* Expand column statistics to limit by pixels, and be more friendly for nested parquet. * Ooops - black formatting. * Plumb more include_pixlels. * Bad merge - missing import. * Another missing import?
1 parent 0887668 commit 5af2681

24 files changed

+218
-31
lines changed

src/hats/catalog/healpix_dataset/healpix_dataset.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from pathlib import Path
45

56
import astropy.units as u
@@ -16,6 +17,7 @@
1617
from hats.catalog.partition_info import PartitionInfo
1718
from hats.inspection import plot_pixels
1819
from hats.inspection.visualize_catalog import plot_moc
20+
from hats.io.parquet_metadata import aggregate_column_statistics
1921
from hats.pixel_math import HealpixPixel
2022
from hats.pixel_math.box_filter import generate_box_moc, wrap_ra_angles
2123
from hats.pixel_math.validators import (
@@ -256,3 +258,33 @@ def plot_moc(self, **kwargs):
256258
plot_args = {"title": default_title}
257259
plot_args.update(kwargs)
258260
return plot_moc(self.moc, **plot_args)
261+
262+
def aggregate_column_statistics(
263+
self,
264+
exclude_hats_columns: bool = True,
265+
exclude_columns: list[str] = None,
266+
include_columns: list[str] = None,
267+
include_pixels: list[HealpixPixel] = None,
268+
):
269+
"""Read footer statistics in parquet metadata, and report on global min/max values.
270+
271+
Args:
272+
exclude_hats_columns (bool): exclude HATS spatial and partitioning fields
273+
from the statistics. Defaults to True.
274+
exclude_columns (List[str]): additional columns to exclude from the statistics.
275+
include_columns (List[str]): if specified, only return statistics for the column
276+
names provided. Defaults to None, and returns all non-hats columns.
277+
"""
278+
if not self.on_disk:
279+
warnings.warn("Calling aggregate_column_statistics on an in-memory catalog. No results.")
280+
return pd.DataFrame()
281+
282+
if include_pixels is None:
283+
include_pixels = self.get_healpix_pixels()
284+
return aggregate_column_statistics(
285+
self.catalog_base_dir / "dataset" / "_metadata",
286+
exclude_hats_columns=exclude_hats_columns,
287+
exclude_columns=exclude_columns,
288+
include_columns=include_columns,
289+
include_pixels=include_pixels,
290+
)

src/hats/io/parquet_metadata.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from hats.io import file_io, paths
1313
from hats.io.file_io.file_pointer import get_upath
14+
from hats.pixel_math.healpix_pixel import HealpixPixel
1415
from hats.pixel_math.healpix_pixel_function import get_pixel_argsort
1516

1617

@@ -131,6 +132,7 @@ def aggregate_column_statistics(
131132
exclude_hats_columns: bool = True,
132133
exclude_columns: list[str] = None,
133134
include_columns: list[str] = None,
135+
include_pixels: list[HealpixPixel] = None,
134136
):
135137
"""Read footer statistics in parquet metadata, and report on global min/max values.
136138
@@ -157,6 +159,7 @@ def aggregate_column_statistics(
157159
column_names = [
158160
first_row_group.column(col).path_in_schema for col in range(0, first_row_group.num_columns)
159161
]
162+
column_names = [name.removesuffix(".list.element") for name in column_names]
160163
good_column_indexes = [
161164
index
162165
for index, name in enumerate(column_names)
@@ -166,42 +169,43 @@ def aggregate_column_statistics(
166169
if not good_column_indexes:
167170
return pd.DataFrame()
168171
column_names = [column_names[i] for i in good_column_indexes]
169-
extrema = [
170-
(
171-
(None, None, 0)
172-
if first_row_group.column(col).statistics is None
173-
else (
174-
first_row_group.column(col).statistics.min,
175-
first_row_group.column(col).statistics.max,
176-
first_row_group.column(col).statistics.null_count,
177-
)
178-
)
179-
for col in good_column_indexes
180-
]
172+
extrema = None
181173

182-
for row_group_index in range(1, num_row_groups):
174+
for row_group_index in range(0, num_row_groups):
183175
row_group = total_metadata.row_group(row_group_index)
176+
if include_pixels is not None:
177+
pixel = paths.get_healpix_from_path(row_group.column(0).file_path)
178+
if pixel not in include_pixels:
179+
continue
184180
row_stats = [
185181
(
186-
(None, None, 0)
182+
(None, None, 0, 0)
187183
if row_group.column(col).statistics is None
188184
else (
189185
row_group.column(col).statistics.min,
190186
row_group.column(col).statistics.max,
191187
row_group.column(col).statistics.null_count,
188+
row_group.column(col).num_values,
192189
)
193190
)
194191
for col in good_column_indexes
195192
]
193+
if extrema is None:
194+
extrema = row_stats
196195
## This is annoying, but avoids extra copies, or none comparison.
197-
extrema = [
198-
(
199-
(_nonemin(extrema[col][0], row_stats[col][0])),
200-
(_nonemax(extrema[col][1], row_stats[col][1])),
201-
extrema[col][2] + row_stats[col][2],
202-
)
203-
for col in range(0, len(good_column_indexes))
204-
]
196+
else:
197+
extrema = [
198+
(
199+
(_nonemin(extrema[col][0], row_stats[col][0])),
200+
(_nonemax(extrema[col][1], row_stats[col][1])),
201+
extrema[col][2] + row_stats[col][2],
202+
extrema[col][3] + row_stats[col][3],
203+
)
204+
for col in range(0, len(good_column_indexes))
205+
]
206+
207+
if extrema is None:
208+
return pd.DataFrame()
205209

206210
stats_lists = np.array(extrema).T
207211

@@ -211,6 +215,7 @@ def aggregate_column_statistics(
211215
"min_value": stats_lists[0],
212216
"max_value": stats_lists[1],
213217
"null_count": stats_lists[2],
218+
"row_count": stats_lists[3],
214219
}
215220
).set_index("column_names")
216221
return frame

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
SMALL_SKY_NPIX_AS_DIR_NAME = "small_sky_npix_as_dir"
1818
SMALL_SKY_ORDER1_DIR_NAME = "small_sky_order1"
1919
SMALL_SKY_SOURCE_OBJECT_INDEX_DIR_NAME = "small_sky_source_object_index"
20+
SMALL_SKY_NESTED_DIR_NAME = "small_sky_nested"
2021

2122
TEST_DIR = os.path.dirname(__file__)
2223

@@ -53,6 +54,11 @@ def small_sky_order1_dir(test_data_dir):
5354
return test_data_dir / SMALL_SKY_ORDER1_DIR_NAME
5455

5556

57+
@pytest.fixture
58+
def small_sky_nested_dir(test_data_dir):
59+
return test_data_dir / SMALL_SKY_NESTED_DIR_NAME
60+
61+
5662
@pytest.fixture
5763
def small_sky_source_object_index_dir(test_data_dir):
5864
return test_data_dir / SMALL_SKY_SOURCE_OBJECT_INDEX_DIR_NAME

tests/data/generate_data.ipynb

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"from hats.catalog.association_catalog.partition_join_info import PartitionJoinInfo\n",
4242
"from hats.catalog.dataset.table_properties import TableProperties\n",
4343
"from hats.io.file_io import remove_directory\n",
44+
"import lsdb\n",
4445
"from hats.pixel_math.spatial_index import healpix_to_spatial_index\n",
4546
"\n",
4647
"tmp_path = tempfile.TemporaryDirectory()\n",
@@ -446,6 +447,40 @@
446447
" runner.pipeline_with_client(args, client)"
447448
]
448449
},
450+
{
451+
"cell_type": "markdown",
452+
"metadata": {},
453+
"source": [
454+
"### Nested catalog: small_sky_nested\n",
455+
"\n",
456+
"Nests light curves from `small_sky_source` into `small_sky_order1` object catalog."
457+
]
458+
},
459+
{
460+
"cell_type": "code",
461+
"execution_count": null,
462+
"metadata": {},
463+
"outputs": [],
464+
"source": [
465+
"remove_directory(\"./small_sky_nested\")\n",
466+
"\n",
467+
"small_sky_object = lsdb.read_hats(\"small_sky_order1\")\n",
468+
"small_sky_source = lsdb.read_hats(\"small_sky_source\")\n",
469+
"small_sky_nested = small_sky_object.join_nested(\n",
470+
" small_sky_source, left_on=\"id\", right_on=\"object_id\", nested_column_name=\"lc\"\n",
471+
")\n",
472+
"small_sky_nested = small_sky_nested.map_partitions(\n",
473+
" lambda df, p: df.assign(Norder=p.order, Npix=p.pixel, Dir=p.pixel // 10000), include_pixel=True\n",
474+
")\n",
475+
"lsdb.io.to_hats(\n",
476+
" small_sky_nested,\n",
477+
" base_catalog_path=\"small_sky_nested\",\n",
478+
" catalog_name=\"small_sky_nested\",\n",
479+
" histogram_order=5,\n",
480+
" overwrite=True,\n",
481+
")"
482+
]
483+
},
449484
{
450485
"cell_type": "code",
451486
"execution_count": null,
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
32.3 KB
Binary file not shown.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
Norder,Npix
2+
2,176
3+
2,177
4+
2,178
5+
2,179
6+
2,180
7+
2,181
8+
2,182
9+
2,183
10+
2,184
11+
2,185
12+
2,186
13+
2,187
14+
1,47
104 KB
Binary file not shown.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#HATS catalog
2+
obs_collection=small_sky_nested
3+
dataproduct_type=object
4+
hats_nrows=131
5+
hats_col_ra=ra
6+
hats_col_dec=dec
7+
hats_npix_suffix=.parquet
8+
hats_max_rows=1000000
9+
hats_order=1
10+
moc_sky_fraction=0.08333
11+
hats_builder=hats-import v0.4.6.dev1+gf00cd7a
12+
hats_creation_date=2025-03-05T16\:16UTC
13+
hats_estsize=27
14+
hats_release_date=2024-09-18
15+
hats_version=v0.1

tests/hats/catalog/test_catalog.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,34 @@ def test_load_catalog_small_sky_order1(small_sky_order1_dir):
9393

9494

9595
def test_aggregate_column_statistics(small_sky_order1_dir):
96+
def assert_column_stat_as_floats(
97+
result_frame, column_name, min_value=None, max_value=None, row_count=None
98+
):
99+
assert column_name in result_frame.index
100+
data_stats = result_frame.loc[column_name]
101+
assert float(data_stats["min_value"]) >= min_value
102+
assert float(data_stats["max_value"]) <= max_value
103+
assert int(data_stats["null_count"]) == 0
104+
assert int(data_stats["row_count"]) == row_count
105+
96106
cat = read_hats(small_sky_order1_dir)
97107

98108
result_frame = cat.aggregate_column_statistics()
99109
assert len(result_frame) == 5
110+
assert_column_stat_as_floats(result_frame, "dec", min_value=-69.5, max_value=-25.5, row_count=131)
100111

101112
result_frame = cat.aggregate_column_statistics(exclude_hats_columns=False)
113+
assert_column_stat_as_floats(result_frame, "Norder", min_value=1, max_value=1, row_count=131)
102114
assert len(result_frame) == 9
103115

104116
result_frame = cat.aggregate_column_statistics(include_columns=["ra", "dec"])
105117
assert len(result_frame) == 2
106118

119+
filtered_catalog = cat.filter_by_cone(315, -66.443, 0.1)
120+
result_frame = filtered_catalog.aggregate_column_statistics()
121+
assert len(result_frame) == 5
122+
assert_column_stat_as_floats(result_frame, "dec", min_value=-69.5, max_value=-47.5, row_count=42)
123+
107124

108125
def test_aggregate_column_statistics_inmemory(catalog_info, catalog_pixels):
109126
catalog = Catalog(catalog_info, catalog_pixels)

tests/hats/io/test_parquet_metadata.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from hats.io import file_io, paths
1010
from hats.io.parquet_metadata import aggregate_column_statistics, write_parquet_metadata
11+
from hats.pixel_math.healpix_pixel import HealpixPixel
1112

1213

1314
def test_write_parquet_metadata(tmp_path, small_sky_dir, small_sky_schema, check_parquet_schema):
@@ -136,6 +137,75 @@ def test_aggregate_column_statistics(small_sky_order1_dir):
136137
assert len(result_frame) == 0
137138

138139

140+
def assert_column_stat_as_floats(
141+
result_frame, column_name, min_value=None, max_value=None, null_count=0, row_count=None
142+
):
143+
assert column_name in result_frame.index
144+
data_stats = result_frame.loc[column_name]
145+
assert float(data_stats["min_value"]) >= min_value
146+
assert float(data_stats["max_value"]) <= max_value
147+
assert int(data_stats["null_count"]) == null_count
148+
assert int(data_stats["row_count"]) == row_count
149+
150+
151+
def test_aggregate_column_statistics_with_pixel(small_sky_order1_dir):
152+
partition_info_file = paths.get_parquet_metadata_pointer(small_sky_order1_dir)
153+
154+
result_frame = aggregate_column_statistics(partition_info_file)
155+
assert len(result_frame) == 5
156+
assert_column_stat_as_floats(result_frame, "dec", min_value=-69.5, max_value=-25.5, row_count=131)
157+
158+
result_frame = aggregate_column_statistics(partition_info_file, include_pixels=[HealpixPixel(1, 45)])
159+
assert len(result_frame) == 5
160+
assert_column_stat_as_floats(result_frame, "dec", min_value=-60.5, max_value=-25.5, row_count=29)
161+
162+
result_frame = aggregate_column_statistics(partition_info_file, include_pixels=[HealpixPixel(1, 47)])
163+
assert len(result_frame) == 5
164+
assert_column_stat_as_floats(result_frame, "dec", min_value=-36.5, max_value=-25.5, row_count=18)
165+
166+
result_frame = aggregate_column_statistics(
167+
partition_info_file, include_pixels=[HealpixPixel(1, 45), HealpixPixel(1, 47)]
168+
)
169+
assert len(result_frame) == 5
170+
assert_column_stat_as_floats(result_frame, "dec", min_value=-60.5, max_value=-25.5, row_count=47)
171+
172+
result_frame = aggregate_column_statistics(partition_info_file, include_pixels=[HealpixPixel(1, 4)])
173+
assert len(result_frame) == 0
174+
175+
176+
def test_aggregate_column_statistics_with_nested(small_sky_nested_dir):
177+
partition_info_file = paths.get_parquet_metadata_pointer(small_sky_nested_dir)
178+
179+
## Will have 13 returned columns (5 object and 8 light curve)
180+
## Since object_dec is copied from object.dec, the min/max are the same,
181+
## but there are MANY more rows of light curve dec values.
182+
result_frame = aggregate_column_statistics(partition_info_file)
183+
assert len(result_frame) == 13
184+
assert_column_stat_as_floats(result_frame, "dec", min_value=-69.5, max_value=-25.5, row_count=131)
185+
assert_column_stat_as_floats(
186+
result_frame, "lc.object_dec", min_value=-69.5, max_value=-25.5, row_count=16135
187+
)
188+
189+
## Only peeking at a single pixel, we should see the same dec min/max as
190+
## we see above for the flat object table.
191+
result_frame = aggregate_column_statistics(partition_info_file, include_pixels=[HealpixPixel(1, 47)])
192+
assert len(result_frame) == 13
193+
assert_column_stat_as_floats(result_frame, "dec", min_value=-36.5, max_value=-25.5, row_count=18)
194+
assert_column_stat_as_floats(
195+
result_frame, "lc.source_id", min_value=70008, max_value=87148, row_count=2358
196+
)
197+
assert_column_stat_as_floats(result_frame, "lc.mag", min_value=15, max_value=21, row_count=2358)
198+
199+
## Test that we can request light curve columns, using the shorter name
200+
## e.g. full path in the file is "lc.source_id.list.element"
201+
result_frame = aggregate_column_statistics(
202+
partition_info_file, include_columns=["ra", "dec", "lc.source_ra", "lc.source_dec", "lc.mag"]
203+
)
204+
assert len(result_frame) == 5
205+
assert_column_stat_as_floats(result_frame, "dec", min_value=-69.5, max_value=-25.5, row_count=131)
206+
assert_column_stat_as_floats(result_frame, "lc.mag", min_value=15, max_value=21, row_count=16135)
207+
208+
139209
def test_aggregate_column_statistics_with_nulls(tmp_path):
140210
file_io.make_directory(tmp_path / "dataset")
141211

@@ -160,12 +230,5 @@ def test_aggregate_column_statistics_with_nulls(tmp_path):
160230
result_frame = aggregate_column_statistics(tmp_path / "dataset" / "_metadata", exclude_hats_columns=False)
161231
assert len(result_frame) == 2
162232

163-
data_stats = result_frame.loc["data"]
164-
assert data_stats["min_value"] == -1
165-
assert data_stats["max_value"] == 2
166-
assert data_stats["null_count"] == 4
167-
168-
data_stats = result_frame.loc["Npix"]
169-
assert data_stats["min_value"] == 1
170-
assert data_stats["max_value"] == 6
171-
assert data_stats["null_count"] == 4
233+
assert_column_stat_as_floats(result_frame, "data", min_value=-1, max_value=2, null_count=4, row_count=6)
234+
assert_column_stat_as_floats(result_frame, "Npix", min_value=1, max_value=6, null_count=4, row_count=6)

0 commit comments

Comments
 (0)