Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 122 additions & 38 deletions python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def estimate_inmemory_data_size(self) -> Optional[int]:
total_size += file_metadata.total_byte_size
return total_size * self._encoding_ratio

def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
def get_read_tasks(self, parallelism: int, limit: Optional[int] = None) -> List[ReadTask]:
# NOTE: We override the base class FileBasedDatasource.get_read_tasks()
# method in order to leverage pyarrow's ParquetDataset abstraction,
# which simplifies partitioning logic. We still use
Expand All @@ -362,60 +362,124 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
pq_metadata,
)

# Apply limit pushdown by selecting fragments and calculating row limits.
fragments_with_limits: List[Tuple[SerializedFragment, Optional[int]]] = []
rows_left = limit

if limit is not None:
for fragment, meta in zip(pq_fragments, pq_metadata):
if rows_left <= 0:
break

fragment_rows = meta.num_rows if meta else None
if fragment_rows is None:
# Cannot apply limit accurately if row count is unknown.
# Include the fragment fully and stop limit application for safety.
# TODO: Improve handling? Maybe estimate rows?
logger.warning(
f"Row count for fragment {fragment._data} is unknown. "
"Limit pushdown may be inaccurate."
)
fragments_with_limits.append((fragment, None))
rows_left = 0 # Stop processing limit
# Need to include remaining fragments without limits
remaining_fragments = pq_fragments[len(fragments_with_limits):]
fragments_with_limits.extend([(f, None) for f in remaining_fragments])
break

if fragment_rows <= rows_left:
fragments_with_limits.append((fragment, None)) # Read full fragment
rows_left -= fragment_rows
else:
fragments_with_limits.append((fragment, rows_left)) # Read partial fragment
rows_left = 0
break # Limit reached

# Update fragments/metadata lists based on the limit
pq_fragments = [f for f, _ in fragments_with_limits]
# We don't strictly need pq_paths and pq_metadata after this, but recalculate for consistency
fragment_map = {f._data: m for f, m in zip(self._pq_fragments, self._metadata)}
path_map = {f._data: p for f, p in zip(self._pq_fragments, self._pq_paths)}
pq_paths = [path_map[f._data] for f in pq_fragments]
pq_metadata = [fragment_map[f._data] for f in pq_fragments]

if rows_left > 0 and limit is not None:
logger.warning(
f"Limit pushdown was unable to collect all {limit} rows. "
f"Collected {limit - rows_left}. This might happen if fragment metadata is inaccurate."
)

# Distribute selected fragments among tasks
read_tasks = []
for fragments, paths, metadata in zip(
# Use fragments_with_limits here if we need individual limits per frag in read_fn
# For now, assume read_fragments handles limit across all its fragments
for fragments_split, paths_split, metadata_split in zip(
np.array_split(pq_fragments, parallelism),
np.array_split(pq_paths, parallelism),
np.array_split(pq_metadata, parallelism),
):
if len(fragments) <= 0:
if len(fragments_split) <= 0:
continue

# Calculate metadata for the task based on the potentially limited fragments
# This is an estimate, actual rows read might differ slightly
task_num_rows = 0
task_size_bytes = 0
known_rows = True
known_size = True
for frag_meta in metadata_split:
if frag_meta and frag_meta.num_rows is not None:
task_num_rows += frag_meta.num_rows
else:
known_rows = False
if frag_meta and frag_meta.size_bytes is not None:
task_size_bytes += frag_meta.size_bytes
else:
known_size = False

final_task_num_rows = task_num_rows if known_rows else None
final_task_size_bytes = int(task_size_bytes * self._encoding_ratio) if known_size else None

# Adjust task metadata if limit was applied across all tasks
# This is complex because limit applies globally, not per task.
# For now, let the metadata represent the full fragments in the task,
# the actual limit enforcement happens during read.
meta = self._meta_provider(
paths,
paths_split,
self._inferred_schema,
num_fragments=len(fragments),
prefetched_metadata=metadata,
num_fragments=len(fragments_split),
prefetched_metadata=metadata_split,
)
if meta.size_bytes is not None:
meta.size_bytes = int(meta.size_bytes * self._encoding_ratio)
# TODO: Adjust meta.num_rows based on the effective limit for this task?
# Requires knowing how many rows are assigned to *this* task under the limit.
# For now, meta.num_rows might overestimate if limit is applied.

# If there is a filter operation, reset the calculated row count,
# since the resulting row count is unknown.
if self._to_batches_kwargs.get("filter") is not None:
meta.num_rows = None

if meta.size_bytes is not None:
meta.size_bytes = int(meta.size_bytes * self._encoding_ratio)
# Determine the row limit for this specific task
# This requires careful calculation based on rows in preceding tasks
# For simplicity now, pass the *original* overall limit, read_fragments must handle it
task_limit = limit # This isn't quite right, but simpler for now

(
block_udf,
to_batches_kwargs,
default_read_batch_size_rows,
data_columns,
partition_columns,
read_schema,
include_paths,
partitioning,
) = (
self._block_udf,
self._to_batches_kwargs,
self._default_read_batch_size_rows,
self._data_columns,
self._partition_columns,
self._read_schema,
self._include_paths,
self._partitioning,
)
read_tasks.append(
ReadTask(
lambda f=fragments: read_fragments(
block_udf,
to_batches_kwargs,
default_read_batch_size_rows,
data_columns,
partition_columns,
read_schema,
f,
include_paths,
partitioning,
# Pass task_limit to read_fragments
lambda frags=fragments_split, t_limit=task_limit: read_fragments(
self._block_udf,
self._to_batches_kwargs,
self._default_read_batch_size_rows,
self._data_columns,
self._partition_columns,
self._read_schema,
frags, # Pass the fragments for this task
self._include_paths,
self._partitioning,
task_rows_limit=t_limit, # New argument
),
meta,
)
Expand Down Expand Up @@ -445,6 +509,7 @@ def read_fragments(
serialized_fragments: List[SerializedFragment],
include_paths: bool,
partitioning: Partitioning,
task_rows_limit: Optional[int] = None, # New argument
) -> Iterator["pyarrow.Table"]:
# This import is necessary to load the tensor extension type.
from ray.data.extensions.tensor_extension import ArrowTensorType # noqa
Expand All @@ -459,10 +524,16 @@ def read_fragments(

import pyarrow as pa

logger.debug(f"Reading {len(fragments)} parquet fragments")
logger.debug(f"Reading {len(fragments)} parquet fragments, limit={task_rows_limit}")
use_threads = to_batches_kwargs.pop("use_threads", False)
batch_size = to_batches_kwargs.pop("batch_size", default_read_batch_size_rows)

rows_yielded_total = 0

for fragment in fragments:
if task_rows_limit is not None and rows_yielded_total >= task_rows_limit:
break # Stop processing fragments if limit already reached

partitions = {}
if partitioning is not None:
parse = PathPartitionParser(partitioning)
Expand Down Expand Up @@ -491,6 +562,16 @@ def get_batch_iterable():
for batch in iterate_with_retry(
get_batch_iterable, "load batch", match=ctx.retried_io_errors
):
if task_rows_limit is not None:
rows_needed = task_rows_limit - rows_yielded_total
if rows_needed <= 0:
break # Limit reached within this fragment
if batch.num_rows > rows_needed:
batch = batch.slice(0, rows_needed)
if batch.num_rows == 0:
break # Avoid yielding empty batch

rows_yielded_total += batch.num_rows
table = pa.Table.from_batches([batch], schema=schema)
if include_paths:
table = table.append_column("path", [[fragment.path]] * len(table))
Expand All @@ -503,6 +584,9 @@ def get_batch_iterable():
yield block_udf(table)
else:
yield table

if task_rows_limit is not None and rows_yielded_total >= task_rows_limit:
break # Stop processing batches for this fragment if limit reached


def _deserialize_fragments_with_retry(fragments):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
self._mem_size = mem_size
self._concurrency = concurrency
self._detected_parallelism = None
self._limit: Optional[int] = None

def set_detected_parallelism(self, parallelism: int):
"""
Expand Down
41 changes: 37 additions & 4 deletions python/ray/data/_internal/logical/rules/limit_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,45 @@ def _apply_limit_pushdown(self, op: LogicalOperator) -> LogicalOperator:
# We should remove this case once we refactor Read op to no longer
# be an AbstractOneToOne op.
if isinstance(current_op, Limit):
limit_op_copy = copy.copy(current_op)
# Current logic moves limit *past* applicable OneToOne ops.
# We want to potentially push it *into* a Read op.

# Original limit op and its immediate input.
limit_op = current_op
upstream_op = limit_op.input_dependency

# Check if the upstream operator is a Read op.
if isinstance(upstream_op, Read):
# Push limit into the Read operator.
new_limit = limit_op._limit
if upstream_op._limit is not None:
# If Read already has a limit (e.g., from a previous fusion/pushdown),
# take the minimum.
new_limit = min(new_limit, upstream_op._limit)
upstream_op._limit = new_limit

# Remove the Limit operator from the DAG.
# Connect Read op directly to the downstream operators of Limit.
downstream_ops = limit_op.output_dependencies
upstream_op._output_dependencies = downstream_ops
for downstream_op in downstream_ops:
# Assume Limit only has one input.
input_idx = downstream_op._input_dependencies.index(limit_op)
downstream_op._input_dependencies[input_idx] = upstream_op

# Since we removed the current_op (Limit), continue to next iteration.
# We might need to adjust the traversal logic if removing nodes complicates it,
# but for now, let's assume the deque handles it.
continue # Skip the rest of the pushdown logic for this Limit op.

# --- Original Pushdown Logic (modified to use limit_op, upstream_op) ---
# If upstream is not Read, apply the existing pushdown logic.
limit_op_copy = copy.copy(limit_op)

# Traverse up the DAG until we reach the first operator that meets
# one of the conditions above, which will serve as the new input
# into the Limit operator.
new_input_into_limit = current_op.input_dependency
new_input_into_limit = upstream_op # Start traversal from upstream_op
ops_between_new_input_and_limit: List[LogicalOperator] = []
while (
isinstance(new_input_into_limit, AbstractOneToOne)
Expand Down Expand Up @@ -85,12 +118,12 @@ def _apply_limit_pushdown(self, op: LogicalOperator) -> LogicalOperator:
nodes.append(curr_op)

# Link the Limit operator to its new input operator.
for limit_output_op in current_op.output_dependencies:
for limit_output_op in limit_op.output_dependencies:
limit_output_op._input_dependencies = [
ops_between_new_input_and_limit[0]
]
last_op = ops_between_new_input_and_limit[0]
last_op._output_dependencies = current_op.output_dependencies
last_op._output_dependencies = limit_op.output_dependencies

return current_op

Expand Down
6 changes: 5 additions & 1 deletion python/ray/data/_internal/planner/plan_read_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def get_input_data(target_max_block_size) -> List[RefBundle]:
assert (
parallelism is not None
), "Read parallelism must be set by the optimizer before execution"
read_tasks = op._datasource_or_legacy_reader.get_read_tasks(parallelism)
read_tasks = op._datasource_or_legacy_reader.get_read_tasks(
parallelism, limit=op._limit
)
_warn_on_high_parallelism(parallelism, len(read_tasks))

ret = []
Expand Down Expand Up @@ -109,6 +111,8 @@ def do_read(blocks: Iterable[ReadTask], _: TaskContext) -> Iterable[Block]:
transform_fns: List[MapTransformFn] = [
# First, execute the read tasks.
BlockMapTransformFn(do_read),
# TODO(Clark): Add limit enforcement here if the datasource couldn't enforce it?
# This might require passing the limit into do_read or the ReadTask.
]
transform_fns.append(BuildOutputBlocksMapTransformFn.for_blocks())
map_transformer = MapTransformer(transform_fns)
Expand Down
10 changes: 7 additions & 3 deletions python/ray/data/datasource/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ def estimate_inmemory_data_size(self) -> Optional[int]:
"""
raise NotImplementedError

def get_read_tasks(self, parallelism: int) -> List["ReadTask"]:
def get_read_tasks(self, parallelism: int, limit: Optional[int] = None) -> List["ReadTask"]:
"""Execute the read and return read tasks.

Args:
parallelism: The requested read parallelism. The number of read
tasks should equal to this value if possible.
limit: The maximum number of rows to read, or None for no limit.
Datasources should implement this if possible for efficiency.

Returns:
A list of read tasks that can be executed to read blocks from the
Expand Down Expand Up @@ -96,13 +98,14 @@ def estimate_inmemory_data_size(self) -> Optional[int]:
"""
raise NotImplementedError

def get_read_tasks(self, parallelism: int) -> List["ReadTask"]:
def get_read_tasks(self, parallelism: int, limit: Optional[int] = None) -> List["ReadTask"]:
"""Execute the read and return read tasks.

Args:
parallelism: The requested read parallelism. The number of read
tasks should equal to this value if possible.
read_args: Additional kwargs to pass to the datasource impl.
limit: The maximum number of rows to read, or None for no limit.

Returns:
A list of read tasks that can be executed to read blocks from the
Expand All @@ -119,7 +122,8 @@ def __init__(self, datasource: Datasource, **read_args):
def estimate_inmemory_data_size(self) -> Optional[int]:
return None

def get_read_tasks(self, parallelism: int) -> List["ReadTask"]:
def get_read_tasks(self, parallelism: int, limit: Optional[int] = None) -> List["ReadTask"]:
# Legacy prepare_read doesn't support limit, so we ignore it here.
return self._datasource.prepare_read(parallelism, **self._read_args)


Expand Down