Skip to content

Commit 3c569a4

Browse files
committed
faster sharrow for missing categoricals
1 parent db05787 commit 3c569a4

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

sharrow/aster.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,9 @@ def _replacement(
408408
if self.get_default or (
409409
topname == pref_topname and not self.swallow_errors
410410
):
411-
raise KeyError(f"{topname}..{attr}\nexpression={self.original_expr}")
411+
raise KeyError(
412+
f"{topname}..{attr}\nexpression={self.original_expr}"
413+
)
412414
# we originally raised a KeyError here regardless, but what if
413415
# we just give back the original node, and see if other spaces,
414416
# possibly fallback spaces, might work? If nothing works then
@@ -1010,6 +1012,16 @@ def visit_Compare(self, node):
10101012
f"\ncategories: {left_dictionary}",
10111013
stacklevel=2,
10121014
)
1015+
# at this point, the right value is not in the left's categories, so
1016+
# it is guaranteed to be not equal to any of the categories.
1017+
if isinstance(node.ops[0], ast.Eq):
1018+
result = ast.Constant(False)
1019+
elif isinstance(node.ops[0], ast.NotEq):
1020+
result = ast.Constant(True)
1021+
else:
1022+
raise ValueError(
1023+
f"unexpected operator {node.ops[0]}"
1024+
) from None
10131025
if right_decoded is not None:
10141026
result = ast.Compare(
10151027
left=left.slice,
@@ -1043,6 +1055,16 @@ def visit_Compare(self, node):
10431055
f"\ncategories: {right_dictionary}",
10441056
stacklevel=2,
10451057
)
1058+
# at this point, the left value is not in the right's categories, so
1059+
# it is guaranteed to be not equal to any of the categories.
1060+
if isinstance(node.ops[0], ast.Eq):
1061+
result = ast.Constant(False)
1062+
elif isinstance(node.ops[0], ast.NotEq):
1063+
result = ast.Constant(True)
1064+
else:
1065+
raise ValueError(
1066+
f"unexpected operator {node.ops[0]}"
1067+
) from None
10461068
if left_decoded is not None:
10471069
result = ast.Compare(
10481070
left=ast_Constant(left_decoded),

sharrow/tests/test_categorical.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,48 @@ def test_missing_categorical():
177177
a = a.isel(expressions=0)
178178
assert all(a == np.asarray([1, 0, 1, 1, 1, 1]))
179179

180+
expr = "df.TourMode2 != 'BAD'"
181+
with pytest.warns(UserWarning):
182+
f8 = tree.setup_flow({expr: expr}, with_root_node_name="df")
183+
a = f8.load_dataarray(dtype=np.int8)
184+
a = a.isel(expressions=0)
185+
assert all(a == np.asarray([1, 1, 1, 1, 1, 1]))
186+
187+
expr = "'BAD' != df.TourMode2"
188+
with pytest.warns(UserWarning):
189+
f9 = tree.setup_flow({expr: expr}, with_root_node_name="df")
190+
a = f9.load_dataarray(dtype=np.int8)
191+
a = a.isel(expressions=0)
192+
assert all(a == np.asarray([1, 1, 1, 1, 1, 1]))
193+
194+
expr = "(df.TourMode2 == 'BAD') * 2"
195+
with pytest.warns(UserWarning):
196+
fA = tree.setup_flow({expr: expr}, with_root_node_name="df")
197+
a = fA.load_dataarray(dtype=np.int8)
198+
a = a.isel(expressions=0)
199+
assert all(a == np.asarray([0, 0, 0, 0, 0, 0]))
200+
201+
expr = "(df.TourMode2 == 'BAD') * 2.2"
202+
with pytest.warns(UserWarning):
203+
fB = tree.setup_flow({expr: expr}, with_root_node_name="df")
204+
a = fB.load_dataarray(dtype=np.int8)
205+
a = a.isel(expressions=0)
206+
assert all(a == np.asarray([0, 0, 0, 0, 0, 0]))
207+
208+
expr = "np.exp(df.TourMode2 == 'BAD') * 2.2"
209+
with pytest.warns(UserWarning):
210+
fC = tree.setup_flow({expr: expr}, with_root_node_name="df")
211+
a = fC.load_dataarray(dtype=np.float32)
212+
a = a.isel(expressions=0)
213+
assert all(a == np.asarray([2.2, 2.2, 2.2, 2.2, 2.2, 2.2], dtype=np.float32))
214+
215+
expr = "(df.TourMode2 != 'BAD') * 2"
216+
with pytest.warns(UserWarning):
217+
fD = tree.setup_flow({expr: expr}, with_root_node_name="df")
218+
a = fD.load_dataarray(dtype=np.int8)
219+
a = a.isel(expressions=0)
220+
assert all(a == np.asarray([2, 2, 2, 2, 2, 2]))
221+
180222

181223
def test_categorical_indexing(tours_dataset: xr.Dataset, skims_dataset: xr.Dataset):
182224
tree = sharrow.DataTree(tours=tours_dataset)

0 commit comments

Comments
 (0)