Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Dask versions >=2024.3.0 with dask expressions #288

Merged
merged 6 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
# Includes dask[array,dataframe,distributed,diagnostics].
# dask distributed eases the creation of parallel dask clients.
# dask diagnostics is required to spin up the dashboard for profiling.
"dask[complete]<=2024.2.1",
"dask[complete]>=2024.3.0", # Includes dask expressions.
"hipscat>=0.2.8",
"pyarrow",
"deprecated",
Expand Down
2 changes: 1 addition & 1 deletion src/lsdb/catalog/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Dataset:

def __init__(
self,
ddf: dd.core.DataFrame,
ddf: dd.DataFrame,
hc_structure: hc.catalog.Dataset,
):
"""Initialise a Catalog object.
Expand Down
5 changes: 4 additions & 1 deletion src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(

def __getitem__(self, item):
result = self._ddf.__getitem__(item)
if isinstance(result, dd.core.DataFrame):
if isinstance(result, dd.DataFrame):
return self.__class__(result, self._ddf_pixel_map, self.hc_structure)
return result

Expand Down Expand Up @@ -168,6 +168,9 @@ def _construct_search_ddf(
Returns:
The catalog pixel map and the respective Dask DataFrame
"""
filtered_partitions = (
filtered_partitions if len(filtered_partitions) > 0 else [delayed(self._ddf._meta)]
)
divisions = get_pixels_divisions(filtered_pixels)
search_ddf = dd.from_delayed(filtered_partitions, meta=self._ddf._meta, divisions=divisions)
search_ddf = cast(dd.DataFrame, search_ddf)
Expand Down
2 changes: 1 addition & 1 deletion src/lsdb/dask/crossmatch_catalog_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def crossmatch_catalog_data(
Type[AbstractCrossmatchAlgorithm] | BuiltInCrossmatchAlgorithm
) = BuiltInCrossmatchAlgorithm.KD_TREE,
**kwargs,
) -> Tuple[dd.core.DataFrame, DaskDFPixelMap, PixelAlignment]:
) -> Tuple[dd.DataFrame, DaskDFPixelMap, PixelAlignment]:
"""Cross-matches the data from two catalogs

Args:
Expand Down
4 changes: 2 additions & 2 deletions src/lsdb/dask/join_catalog_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def perform_join_through(

def join_catalog_data_on(
left: Catalog, right: Catalog, left_on: str, right_on: str, suffixes: Tuple[str, str]
) -> Tuple[dd.core.DataFrame, DaskDFPixelMap, PixelAlignment]:
) -> Tuple[dd.DataFrame, DaskDFPixelMap, PixelAlignment]:
"""Joins two catalogs spatially on a specified column

Args:
Expand Down Expand Up @@ -214,7 +214,7 @@ def join_catalog_data_on(

def join_catalog_data_through(
left: Catalog, right: Catalog, association: AssociationCatalog, suffixes: Tuple[str, str]
) -> Tuple[dd.core.DataFrame, DaskDFPixelMap, PixelAlignment]:
) -> Tuple[dd.DataFrame, DaskDFPixelMap, PixelAlignment]:
"""Joins two catalogs with an association table

Args:
Expand Down
6 changes: 3 additions & 3 deletions src/lsdb/dask/merge_catalog_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
from dask.delayed import Delayed
from dask.delayed import Delayed, delayed
from hipscat.catalog import PartitionInfo
from hipscat.pixel_math import HealpixPixel
from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN, healpix_to_hipscat_id
Expand Down Expand Up @@ -146,7 +146,7 @@ def filter_by_hipscat_index_to_pixel(dataframe: pd.DataFrame, order: int, pixel:

def construct_catalog_args(
partitions: List[Delayed], meta_df: pd.DataFrame, alignment: PixelAlignment
) -> Tuple[dd.core.DataFrame, DaskDFPixelMap, PixelAlignment]:
) -> Tuple[dd.DataFrame, DaskDFPixelMap, PixelAlignment]:
"""Constructs the arguments needed to create a catalog from a list of delayed partitions

Args:
Expand All @@ -160,9 +160,9 @@ def construct_catalog_args(
"""
# generate dask df partition map from alignment
partition_map = get_partition_map_from_alignment_pixels(alignment.pixel_mapping)

# create dask df from delayed partitions
divisions = get_pixels_divisions(list(partition_map.keys()))
partitions = partitions if len(partitions) > 0 else [delayed(pd.DataFrame([]))]
ddf = dd.from_delayed(partitions, meta=meta_df, divisions=divisions)
ddf = cast(dd.DataFrame, ddf)
return ddf, partition_map, alignment
Expand Down
4 changes: 3 additions & 1 deletion src/lsdb/loaders/dataframe/from_dataframe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def _generate_dask_dataframe(
Returns:
The catalog's Dask Dataframe and its total number of rows.
"""
schema = pixel_dfs[0].iloc[:0, :].copy() if len(pixels) > 0 else []
# Get one partition to find how the df schema
one_partition = pixel_dfs[0] if len(pixel_dfs) > 0 else pd.DataFrame([])
schema = one_partition.iloc[:0, :].copy()
divisions = get_pixels_divisions(pixels)
delayed_dfs = [delayed(df) for df in pixel_dfs]
ddf = dd.from_delayed(delayed_dfs, meta=schema, divisions=divisions)
Expand Down
30 changes: 16 additions & 14 deletions src/lsdb/loaders/dataframe/margin_catalog_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,37 +70,39 @@ def create_catalog(self) -> MarginCatalog | None:
Returns:
Margin catalog object, or None if the margin is empty.
"""
ddf, ddf_pixel_map, total_rows = self._generate_dask_df_and_map()
margin_pixels = list(ddf_pixel_map.keys())
if total_rows == 0:
pixels, partitions = self._get_margins()
if len(pixels) == 0:
return None
ddf, ddf_pixel_map, total_rows = self._generate_dask_df_and_map(pixels, partitions)
margin_pixels = list(ddf_pixel_map.keys())
margin_catalog_info = self._create_catalog_info(total_rows)
margin_structure = hc.catalog.MarginCatalog(margin_catalog_info, margin_pixels)
return MarginCatalog(ddf, ddf_pixel_map, margin_structure)

def _generate_dask_df_and_map(self) -> Tuple[dd.DataFrame, Dict[HealpixPixel, int], int]:
def _get_margins(self):
combined_pixels = (
self.hc_structure.get_healpix_pixels() + self.hc_structure.generate_negative_tree_pixels()
)
margin_pairs_df = self._find_margin_pixel_pairs(combined_pixels)
margins_pixel_df = self._create_margins(margin_pairs_df)
pixels, partitions = list(margins_pixel_df.keys()), list(margins_pixel_df.values())
return pixels, partitions

def _generate_dask_df_and_map(
self, pixels, partitions
) -> Tuple[dd.DataFrame, Dict[HealpixPixel, int], int]:
"""Create the Dask Dataframe containing the data points in the margins
for the catalog as well as the mapping of those HEALPix to Dataframes

Returns:
Tuple containing the Dask Dataframe, the mapping of margin HEALPix
to the respective partitions and the total number of rows.
"""
healpix_pixels = self.hc_structure.get_healpix_pixels()
negative_pixels = self.hc_structure.generate_negative_tree_pixels()
combined_pixels = healpix_pixels + negative_pixels
margin_pairs_df = self._find_margin_pixel_pairs(combined_pixels)

# Compute points for each margin pixels
margins_pixel_df = self._create_margins(margin_pairs_df)
pixels, partitions = list(margins_pixel_df.keys()), list(margins_pixel_df.values())

# Generate pixel map ordered by _hipscat_index
pixel_order = get_pixel_argsort(pixels)
ordered_pixels = np.asarray(pixels)[pixel_order]
ordered_partitions = [partitions[i] for i in pixel_order]
ddf_pixel_map = {pixel: index for index, pixel in enumerate(ordered_pixels)}

# Generate the dask dataframe with the pixels and partitions
ddf, total_rows = _generate_dask_dataframe(ordered_partitions, ordered_pixels)
return ddf, ddf_pixel_map, total_rows
Expand Down
2 changes: 1 addition & 1 deletion src/lsdb/loaders/hipscat/association_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ def load_catalog(self) -> AssociationCatalog:
def _load_empty_dask_df_and_map(self, hc_catalog):
metadata_schema = self._load_parquet_metadata_schema(hc_catalog)
dask_meta_schema = metadata_schema.empty_table().to_pandas()
ddf = dd.from_pandas(dask_meta_schema, npartitions=0)
ddf = dd.from_pandas(dask_meta_schema, npartitions=1)
return ddf, {}
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ def small_sky_order1_df(small_sky_order1_dir):

@pytest.fixture
def small_sky_source_df(test_data_dir):
return pd.read_csv(os.path.join(test_data_dir, "raw", "small_sky_source", "small_sky_source.csv"))
return pd.read_csv(
os.path.join(test_data_dir, "raw", "small_sky_source", "small_sky_source.csv"),
dtype={"band": "string[pyarrow]"},
)


@pytest.fixture
Expand Down