diff --git a/sharrow/aster.py b/sharrow/aster.py index fe970ee..8762b27 100755 --- a/sharrow/aster.py +++ b/sharrow/aster.py @@ -408,7 +408,9 @@ def _replacement( if self.get_default or ( topname == pref_topname and not self.swallow_errors ): - raise KeyError(f"{topname}..{attr}\nexpression={self.original_expr}") + raise KeyError( + f"{topname}..{attr}\nexpression={self.original_expr}" + ) # we originally raised a KeyError here regardless, but what if # we just give back the original node, and see if other spaces, # possibly fallback spaces, might work? If nothing works then @@ -1010,6 +1012,16 @@ def visit_Compare(self, node): f"\ncategories: {left_dictionary}", stacklevel=2, ) + # at this point, the right value is not in the left's categories, so + # it is guaranteed to be not equal to any of the categories. + if isinstance(node.ops[0], ast.Eq): + result = ast.Constant(False) + elif isinstance(node.ops[0], ast.NotEq): + result = ast.Constant(True) + else: + raise ValueError( + f"unexpected operator {node.ops[0]}" + ) from None if right_decoded is not None: result = ast.Compare( left=left.slice, @@ -1043,6 +1055,16 @@ def visit_Compare(self, node): f"\ncategories: {right_dictionary}", stacklevel=2, ) + # at this point, the left value is not in the right's categories, so + # it is guaranteed to be not equal to any of the categories. + if isinstance(node.ops[0], ast.Eq): + result = ast.Constant(False) + elif isinstance(node.ops[0], ast.NotEq): + result = ast.Constant(True) + else: + raise ValueError( + f"unexpected operator {node.ops[0]}" + ) from None if left_decoded is not None: result = ast.Compare( left=ast_Constant(left_decoded), diff --git a/sharrow/tests/test_categorical.py b/sharrow/tests/test_categorical.py index c5ab0b1..cd2a4e0 100644 --- a/sharrow/tests/test_categorical.py +++ b/sharrow/tests/test_categorical.py @@ -177,6 +177,48 @@ def test_missing_categorical(): a = a.isel(expressions=0) assert all(a == np.asarray([1, 0, 1, 1, 1, 1])) + expr = "df.TourMode2 != 'BAD'" + with pytest.warns(UserWarning): + f8 = tree.setup_flow({expr: expr}, with_root_node_name="df") + a = f8.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([1, 1, 1, 1, 1, 1])) + + expr = "'BAD' != df.TourMode2" + with pytest.warns(UserWarning): + f9 = tree.setup_flow({expr: expr}, with_root_node_name="df") + a = f9.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([1, 1, 1, 1, 1, 1])) + + expr = "(df.TourMode2 == 'BAD') * 2" + with pytest.warns(UserWarning): + fA = tree.setup_flow({expr: expr}, with_root_node_name="df") + a = fA.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([0, 0, 0, 0, 0, 0])) + + expr = "(df.TourMode2 == 'BAD') * 2.2" + with pytest.warns(UserWarning): + fB = tree.setup_flow({expr: expr}, with_root_node_name="df") + a = fB.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([0, 0, 0, 0, 0, 0])) + + expr = "np.exp(df.TourMode2 == 'BAD') * 2.2" + with pytest.warns(UserWarning): + fC = tree.setup_flow({expr: expr}, with_root_node_name="df") + a = fC.load_dataarray(dtype=np.float32) + a = a.isel(expressions=0) + assert all(a == np.asarray([2.2, 2.2, 2.2, 2.2, 2.2, 2.2], dtype=np.float32)) + + expr = "(df.TourMode2 != 'BAD') * 2" + with pytest.warns(UserWarning): + fD = tree.setup_flow({expr: expr}, with_root_node_name="df") + a = fD.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([2, 2, 2, 2, 2, 2])) + def test_categorical_indexing(tours_dataset: xr.Dataset, skims_dataset: xr.Dataset): tree = sharrow.DataTree(tours=tours_dataset)