Skip to content

Commit

Permalink
faster sharrow for missing categoricals
Browse files Browse the repository at this point in the history
  • Loading branch information
jpn-- committed May 2, 2024
1 parent db05787 commit 3c569a4
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
24 changes: 23 additions & 1 deletion sharrow/aster.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,9 @@ def _replacement(
if self.get_default or (
topname == pref_topname and not self.swallow_errors
):
raise KeyError(f"{topname}..{attr}\nexpression={self.original_expr}")
raise KeyError(
f"{topname}..{attr}\nexpression={self.original_expr}"
)
# we originally raised a KeyError here regardless, but what if
# we just give back the original node, and see if other spaces,
# possibly fallback spaces, might work? If nothing works then
Expand Down Expand Up @@ -1010,6 +1012,16 @@ def visit_Compare(self, node):
f"\ncategories: {left_dictionary}",
stacklevel=2,
)
# at this point, the right value is not in the left's categories, so
# it is guaranteed to be not equal to any of the categories.
if isinstance(node.ops[0], ast.Eq):
result = ast.Constant(False)
elif isinstance(node.ops[0], ast.NotEq):
result = ast.Constant(True)
else:
raise ValueError(
f"unexpected operator {node.ops[0]}"
) from None
if right_decoded is not None:
result = ast.Compare(
left=left.slice,
Expand Down Expand Up @@ -1043,6 +1055,16 @@ def visit_Compare(self, node):
f"\ncategories: {right_dictionary}",
stacklevel=2,
)
# at this point, the left value is not in the right's categories, so
# it is guaranteed to be not equal to any of the categories.
if isinstance(node.ops[0], ast.Eq):
result = ast.Constant(False)
elif isinstance(node.ops[0], ast.NotEq):
result = ast.Constant(True)
else:
raise ValueError(
f"unexpected operator {node.ops[0]}"
) from None
if left_decoded is not None:
result = ast.Compare(
left=ast_Constant(left_decoded),
Expand Down
42 changes: 42 additions & 0 deletions sharrow/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,48 @@ def test_missing_categorical():
a = a.isel(expressions=0)
assert all(a == np.asarray([1, 0, 1, 1, 1, 1]))

expr = "df.TourMode2 != 'BAD'"
with pytest.warns(UserWarning):
f8 = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = f8.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([1, 1, 1, 1, 1, 1]))

expr = "'BAD' != df.TourMode2"
with pytest.warns(UserWarning):
f9 = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = f9.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([1, 1, 1, 1, 1, 1]))

expr = "(df.TourMode2 == 'BAD') * 2"
with pytest.warns(UserWarning):
fA = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = fA.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([0, 0, 0, 0, 0, 0]))

expr = "(df.TourMode2 == 'BAD') * 2.2"
with pytest.warns(UserWarning):
fB = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = fB.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([0, 0, 0, 0, 0, 0]))

expr = "np.exp(df.TourMode2 == 'BAD') * 2.2"
with pytest.warns(UserWarning):
fC = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = fC.load_dataarray(dtype=np.float32)
a = a.isel(expressions=0)
assert all(a == np.asarray([2.2, 2.2, 2.2, 2.2, 2.2, 2.2], dtype=np.float32))

expr = "(df.TourMode2 != 'BAD') * 2"
with pytest.warns(UserWarning):
fD = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = fD.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([2, 2, 2, 2, 2, 2]))


def test_categorical_indexing(tours_dataset: xr.Dataset, skims_dataset: xr.Dataset):
tree = sharrow.DataTree(tours=tours_dataset)
Expand Down

0 comments on commit 3c569a4

Please sign in to comment.