Skip to content

Commit bae62df

Browse files
authored
Use a balanced tree instead of unbalanced one (#1830)
**Use a balanced tree instead of an unbalanced one to prevent recursion error in `create_match_filter`** <!-- Closes #1776 --> ## Rationale for this change In the `create_match_filter` function, the previous implementation used `functools.reduce(operator.or_, filters)` to combine expressions. This approach constructed a right-heavy, unbalanced tree, which could lead to a `RecursionError` when dealing with a large number of expressions (e.g., over 1,000). To address this, we've introduced the `_build_balanced_tree` function. This utility constructs a balanced binary tree of expressions, reducing the maximum depth to O(log n) and thereby preventing potential recursion errors. This makes expression construction more stable and scalable, especially when working with large datasets. ```python # Past behavior Or(*[A, B, C, D]) = Or(A, Or(B, Or(C, D)) # New behavior Or(*[A, B, C, D]) = Or(Or(A, B), Or(C, D)) ``` ## Are these changes tested? Yes, existing tests cover the functionality of `Or`. Additional testing was done with large expression sets (e.g., 10,000 items) to ensure that balanced tree construction avoids recursion errors. ## Are there any user-facing changes? No, there are no user-facing changes. This is an internal implementation improvement that does not affect the public API. Closes #1759 Closes #1785 <!-- In the case of user-facing changes, please add the changelog label. -->
1 parent 4b15fb6 commit bae62df

File tree

5 files changed

+78
-31
lines changed

5 files changed

+78
-31
lines changed

pyiceberg/expressions/__init__.py

+44-3
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
from __future__ import annotations
1919

2020
from abc import ABC, abstractmethod
21-
from functools import cached_property, reduce
21+
from functools import cached_property
2222
from typing import (
2323
Any,
24+
Callable,
2425
Generic,
2526
Iterable,
27+
Sequence,
2628
Set,
2729
Tuple,
2830
Type,
@@ -79,6 +81,45 @@ def __or__(self, other: BooleanExpression) -> BooleanExpression:
7981
return Or(self, other)
8082

8183

84+
def _build_balanced_tree(
85+
operator_: Callable[[BooleanExpression, BooleanExpression], BooleanExpression], items: Sequence[BooleanExpression]
86+
) -> BooleanExpression:
87+
"""
88+
Recursively constructs a balanced binary tree of BooleanExpressions using the provided binary operator.
89+
90+
This function is a safer and more scalable alternative to:
91+
reduce(operator_, items)
92+
93+
Using `reduce` creates a deeply nested, unbalanced tree (e.g., operator_(a, operator_(b, operator_(c, ...)))),
94+
which grows linearly with the number of items. This can lead to RecursionError exceptions in Python
95+
when the number of expressions is large (e.g., >1000).
96+
97+
In contrast, this function builds a balanced binary tree with logarithmic depth (O(log n)),
98+
helping avoid recursion issues and ensuring that expression trees remain stable, predictable,
99+
and safe to traverse — especially in tools like PyIceberg that operate on large logical trees.
100+
101+
Parameters:
102+
operator_ (Callable): A binary operator function (e.g., pyiceberg.expressions.Or, And) that takes two
103+
BooleanExpressions and returns a combined BooleanExpression.
104+
items (Sequence[BooleanExpression]): A sequence of BooleanExpression objects to combine.
105+
106+
Returns:
107+
BooleanExpression: The balanced combination of all input BooleanExpressions.
108+
109+
Raises:
110+
ValueError: If the input sequence is empty.
111+
"""
112+
if not items:
113+
raise ValueError("No expressions to combine")
114+
if len(items) == 1:
115+
return items[0]
116+
mid = len(items) // 2
117+
118+
left = _build_balanced_tree(operator_, items[:mid])
119+
right = _build_balanced_tree(operator_, items[mid:])
120+
return operator_(left, right)
121+
122+
82123
class Term(Generic[L], ABC):
83124
"""A simple expression that evaluates to a value."""
84125

@@ -214,7 +255,7 @@ class And(BooleanExpression):
214255

215256
def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression) -> BooleanExpression: # type: ignore
216257
if rest:
217-
return reduce(And, (left, right, *rest))
258+
return _build_balanced_tree(And, (left, right, *rest))
218259
if left is AlwaysFalse() or right is AlwaysFalse():
219260
return AlwaysFalse()
220261
elif left is AlwaysTrue():
@@ -257,7 +298,7 @@ class Or(BooleanExpression):
257298

258299
def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression) -> BooleanExpression: # type: ignore
259300
if rest:
260-
return reduce(Or, (left, right, *rest))
301+
return _build_balanced_tree(Or, (left, right, *rest))
261302
if left is AlwaysTrue() or right is AlwaysTrue():
262303
return AlwaysTrue()
263304
elif left is AlwaysFalse():

pyiceberg/table/upsert_util.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
BooleanExpression,
2727
EqualTo,
2828
In,
29+
Or,
2930
)
3031

3132

@@ -39,7 +40,12 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre
3940
functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()
4041
]
4142

42-
return AlwaysFalse() if len(filters) == 0 else functools.reduce(operator.or_, filters)
43+
if len(filters) == 0:
44+
return AlwaysFalse()
45+
elif len(filters) == 1:
46+
return filters[0]
47+
else:
48+
return Or(*filters)
4349

4450

4551
def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:

tests/expressions/test_expressions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -591,11 +591,11 @@ def test_negate(lhs: BooleanExpression, rhs: BooleanExpression) -> None:
591591
[
592592
(
593593
And(ExpressionA(), ExpressionB(), ExpressionA()),
594-
And(And(ExpressionA(), ExpressionB()), ExpressionA()),
594+
And(ExpressionA(), And(ExpressionB(), ExpressionA())),
595595
),
596596
(
597597
Or(ExpressionA(), ExpressionB(), ExpressionA()),
598-
Or(Or(ExpressionA(), ExpressionB()), ExpressionA()),
598+
Or(ExpressionA(), Or(ExpressionB(), ExpressionA())),
599599
),
600600
(Not(Not(ExpressionA())), ExpressionA()),
601601
],

tests/expressions/test_visitors.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,14 @@ def test_boolean_expression_visitor() -> None:
230230
"NOT",
231231
"OR",
232232
"EQUALTO",
233-
"OR",
234233
"NOTEQUALTO",
235234
"OR",
235+
"OR",
236236
"EQUALTO",
237237
"NOT",
238-
"AND",
239238
"NOTEQUALTO",
240239
"AND",
240+
"AND",
241241
]
242242

243243

@@ -335,28 +335,28 @@ def test_always_false_or_always_true_expression_binding(table_schema_simple: Sch
335335
),
336336
),
337337
And(
338-
And(
339-
BoundIn(
340-
BoundReference(
341-
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
342-
accessor=Accessor(position=0, inner=None),
343-
),
344-
{literal("bar"), literal("baz")},
338+
BoundIn(
339+
BoundReference(
340+
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
341+
accessor=Accessor(position=0, inner=None),
345342
),
343+
{literal("bar"), literal("baz")},
344+
),
345+
And(
346346
BoundEqualTo[int](
347347
BoundReference(
348348
field=NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
349349
accessor=Accessor(position=1, inner=None),
350350
),
351351
literal(1),
352352
),
353-
),
354-
BoundEqualTo(
355-
BoundReference(
356-
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
357-
accessor=Accessor(position=0, inner=None),
353+
BoundEqualTo(
354+
BoundReference(
355+
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
356+
accessor=Accessor(position=0, inner=None),
357+
),
358+
literal("baz"),
358359
),
359-
literal("baz"),
360360
),
361361
),
362362
),
@@ -408,28 +408,28 @@ def test_and_expression_binding(
408408
),
409409
),
410410
Or(
411+
BoundIn(
412+
BoundReference(
413+
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
414+
accessor=Accessor(position=0, inner=None),
415+
),
416+
{literal("bar"), literal("baz")},
417+
),
411418
Or(
412419
BoundIn(
413420
BoundReference(
414421
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
415422
accessor=Accessor(position=0, inner=None),
416423
),
417-
{literal("bar"), literal("baz")},
424+
{literal("bar")},
418425
),
419426
BoundIn(
420427
BoundReference(
421428
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
422429
accessor=Accessor(position=0, inner=None),
423430
),
424-
{literal("bar")},
425-
),
426-
),
427-
BoundIn(
428-
BoundReference(
429-
field=NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
430-
accessor=Accessor(position=0, inner=None),
431+
{literal("baz")},
431432
),
432-
{literal("baz")},
433433
),
434434
),
435435
),

tests/io/test_pyarrow_visitor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -836,5 +836,5 @@ def test_expression_to_complementary_pyarrow(
836836
# Notice an isNan predicate on a str column is automatically converted to always false and removed from Or and thus will not appear in the pc.expr.
837837
assert (
838838
repr(result)
839-
== """<pyarrow.compute.Expression (((invert((((((string_field == "hello") and (float_field > 100)) or (is_nan(float_field) and (double_field == 0))) or (float_field > 100)) and invert(is_null(double_field, {nan_is_null=false})))) or is_null(float_field, {nan_is_null=false})) or is_null(string_field, {nan_is_null=false})) or is_nan(double_field))>"""
839+
== """<pyarrow.compute.Expression (((invert(((((string_field == "hello") and (float_field > 100)) or ((is_nan(float_field) and (double_field == 0)) or (float_field > 100))) and invert(is_null(double_field, {nan_is_null=false})))) or is_null(float_field, {nan_is_null=false})) or is_null(string_field, {nan_is_null=false})) or is_nan(double_field))>"""
840840
)

0 commit comments

Comments
 (0)