1
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
2
# SPDX-License-Identifier: Apache-2.0
3
3
"""
4
4
DSL nodes for the LogicalPlan of polars.
34
34
from cudf_polars .utils .versions import POLARS_VERSION_GT_112
35
35
36
36
if TYPE_CHECKING :
37
- from collections .abc import Callable , Hashable , MutableMapping , Sequence
37
+ from collections .abc import Callable , Hashable , Iterable , MutableMapping , Sequence
38
38
from typing import Literal
39
39
40
+ from polars .polars import _expr_nodes as pl_expr
41
+
40
42
from cudf_polars .typing import Schema
41
43
42
44
@@ -1019,7 +1021,27 @@ class ConditionalJoin(IR):
1019
1021
__slots__ = ("ast_predicate" , "options" , "predicate" )
1020
1022
_non_child = ("schema" , "predicate" , "options" )
1021
1023
predicate : expr .Expr
1022
- options : tuple
1024
+ """Expression predicate to join on"""
1025
+ options : tuple [
1026
+ tuple [
1027
+ str ,
1028
+ pl_expr .Operator | Iterable [pl_expr .Operator ],
1029
+ ],
1030
+ bool ,
1031
+ tuple [int , int ] | None ,
1032
+ str ,
1033
+ bool ,
1034
+ Literal ["none" , "left" , "right" , "left_right" , "right_left" ],
1035
+ ]
1036
+ """
1037
+ tuple of options:
1038
+ - predicates: tuple of ir join type (eg. ie_join) and (In)Equality conditions
1039
+ - join_nulls: do nulls compare equal?
1040
+ - slice: optional slice to perform after joining.
1041
+ - suffix: string suffix for right columns if names match
1042
+ - coalesce: should key columns be coalesced (only makes sense for outer joins)
1043
+ - maintain_order: which DataFrame row order to preserve, if any
1044
+ """
1023
1045
1024
1046
def __init__ (
1025
1047
self , schema : Schema , predicate : expr .Expr , options : tuple , left : IR , right : IR
@@ -1029,22 +1051,24 @@ def __init__(
1029
1051
self .options = options
1030
1052
self .children = (left , right )
1031
1053
self .ast_predicate = to_ast (predicate )
1032
- _ , join_nulls , zlice , suffix , coalesce = self .options
1054
+ _ , join_nulls , zlice , suffix , coalesce , maintain_order = self .options
1033
1055
# Preconditions from polars
1034
1056
assert not join_nulls
1035
1057
assert not coalesce
1058
+ assert maintain_order == "none"
1036
1059
if self .ast_predicate is None :
1037
1060
raise NotImplementedError (
1038
1061
f"Conditional join with predicate { predicate } "
1039
1062
) # pragma: no cover; polars never delivers expressions we can't handle
1040
- self ._non_child_args = (self .ast_predicate , zlice , suffix )
1063
+ self ._non_child_args = (self .ast_predicate , zlice , suffix , maintain_order )
1041
1064
1042
1065
@classmethod
1043
1066
def do_evaluate (
1044
1067
cls ,
1045
1068
predicate : plc .expressions .Expression ,
1046
1069
zlice : tuple [int , int ] | None ,
1047
1070
suffix : str ,
1071
+ maintain_order : Literal ["none" , "left" , "right" , "left_right" , "right_left" ],
1048
1072
left : DataFrame ,
1049
1073
right : DataFrame ,
1050
1074
) -> DataFrame :
@@ -1088,6 +1112,7 @@ class Join(IR):
1088
1112
tuple [int , int ] | None ,
1089
1113
str ,
1090
1114
bool ,
1115
+ Literal ["none" , "left" , "right" , "left_right" , "right_left" ],
1091
1116
]
1092
1117
"""
1093
1118
tuple of options:
@@ -1096,6 +1121,7 @@ class Join(IR):
1096
1121
- slice: optional slice to perform after joining.
1097
1122
- suffix: string suffix for right columns if names match
1098
1123
- coalesce: should key columns be coalesced (only makes sense for outer joins)
1124
+ - maintain_order: which DataFrame row order to preserve, if any
1099
1125
"""
1100
1126
1101
1127
def __init__ (
@@ -1113,6 +1139,9 @@ def __init__(
1113
1139
self .options = options
1114
1140
self .children = (left , right )
1115
1141
self ._non_child_args = (self .left_on , self .right_on , self .options )
1142
+ # TODO: Implement maintain_order
1143
+ if options [5 ] != "none" :
1144
+ raise NotImplementedError ("maintain_order not implemented yet" )
1116
1145
if any (
1117
1146
isinstance (e .value , expr .Literal )
1118
1147
for e in itertools .chain (self .left_on , self .right_on )
@@ -1222,12 +1251,13 @@ def do_evaluate(
1222
1251
tuple [int , int ] | None ,
1223
1252
str ,
1224
1253
bool ,
1254
+ Literal ["none" , "left" , "right" , "left_right" , "right_left" ],
1225
1255
],
1226
1256
left : DataFrame ,
1227
1257
right : DataFrame ,
1228
1258
) -> DataFrame :
1229
1259
"""Evaluate and return a dataframe."""
1230
- how , join_nulls , zlice , suffix , coalesce = options
1260
+ how , join_nulls , zlice , suffix , coalesce , _ = options
1231
1261
if how == "cross" :
1232
1262
# Separate implementation, since cross_join returns the
1233
1263
# result, not the gather maps
0 commit comments