Skip to content

Commit

Permalink
Sharrow updates (#52)
Browse files Browse the repository at this point in the history
* add expr to error message

* extra logging

* fix deprecation

* ruffen

* dask_scheduler

* faster sharrow for missing categoricals
  • Loading branch information
jpn-- authored May 2, 2024
1 parent 085ba42 commit 29aaf6c
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 5 deletions.
24 changes: 23 additions & 1 deletion sharrow/aster.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,9 @@ def _replacement(
if self.get_default or (
topname == pref_topname and not self.swallow_errors
):
raise KeyError(f"{topname}..{attr}")
raise KeyError(
f"{topname}..{attr}\nexpression={self.original_expr}"
)
# we originally raised a KeyError here regardless, but what if
# we just give back the original node, and see if other spaces,
# possibly fallback spaces, might work? If nothing works then
Expand Down Expand Up @@ -1010,6 +1012,16 @@ def visit_Compare(self, node):
f"\ncategories: {left_dictionary}",
stacklevel=2,
)
# at this point, the right value is not in the left's categories, so
# it is guaranteed to be not equal to any of the categories.
if isinstance(node.ops[0], ast.Eq):
result = ast.Constant(False)
elif isinstance(node.ops[0], ast.NotEq):
result = ast.Constant(True)
else:
raise ValueError(
f"unexpected operator {node.ops[0]}"
) from None
if right_decoded is not None:
result = ast.Compare(
left=left.slice,
Expand Down Expand Up @@ -1043,6 +1055,16 @@ def visit_Compare(self, node):
f"\ncategories: {right_dictionary}",
stacklevel=2,
)
# at this point, the left value is not in the right's categories, so
# it is guaranteed to be not equal to any of the categories.
if isinstance(node.ops[0], ast.Eq):
result = ast.Constant(False)
elif isinstance(node.ops[0], ast.NotEq):
result = ast.Constant(True)
else:
raise ValueError(
f"unexpected operator {node.ops[0]}"
) from None
if left_decoded is not None:
result = ast.Compare(
left=ast_Constant(left_decoded),
Expand Down
30 changes: 27 additions & 3 deletions sharrow/shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import pickle
import time

import dask
import dask.array as da
Expand Down Expand Up @@ -247,7 +248,9 @@ def release_shared_memory(self):
def delete_shared_memory_files(key):
delete_shared_memory_files(key)

def to_shared_memory(self, key=None, mode="r+", _dupe=True):
def to_shared_memory(
self, key=None, mode="r+", _dupe=True, dask_scheduler="threads"
):
"""
Load this Dataset into shared memory.
Expand All @@ -262,9 +265,13 @@ def to_shared_memory(self, key=None, mode="r+", _dupe=True):
An identifying key for this shared memory. Use the same key
in `from_shared_memory` to recreate this Dataset elsewhere.
mode : {‘r+’, ‘r’, ‘w+’, ‘c’}, optional
This methid returns a copy of the Dataset in shared memory.
This method returns a copy of the Dataset in shared memory.
If memmapped, that copy can be opened in various modes.
See numpy.memmap() for details.
dask_scheduler : str, default 'threads'
The scheduler to use when loading dask arrays into shared memory.
Typically "threads" for multi-threaded reads or "synchronous"
for single-threaded reads. See dask.compute() for details.
Returns
-------
Expand All @@ -287,6 +294,7 @@ def to_shared_memory(self, key=None, mode="r+", _dupe=True):
def emit(k, a, is_coord):
nonlocal names, wrappers, sizes, position
if sparse is not None and isinstance(a.data, sparse.GCXS):
logger.info(f"preparing sparse array {a.name}")
wrappers.append(
{
"sparse": True,
Expand All @@ -308,6 +316,7 @@ def emit(k, a, is_coord):
)
a_nbytes = a.data.nbytes
else:
logger.info(f"preparing dense array {a.name}")
wrappers.append(
{
"dims": a.dims,
Expand Down Expand Up @@ -335,19 +344,23 @@ def emit(k, a, is_coord):
emit(k, a, False)

mem = create_shared_memory_array(key, size=position)

logger.info("declaring shared memory buffer")
if key.startswith("memmap:"):
buffer = memoryview(mem)
else:
buffer = mem.buf

tasks = []
task_names = []
for w in wrappers:
_is_sparse = w.get("sparse", False)
_size = w["nbytes"]
_name = w["name"]
_pos = w["position"]
a = self._obj[_name]
if _is_sparse:
logger.info(f"running load task: {_name} ({si_units(_size)})")
ad = a.data
_size_d = w["data.nbytes"]
_size_i = w["indices.nbytes"]
Expand All @@ -373,19 +386,30 @@ def emit(k, a, is_coord):
mem_arr_i[:] = ad.indices[:]
mem_arr_p[:] = ad.indptr[:]
else:
logger.info(f"preparing load task: {_name} ({si_units(_size)})")
mem_arr = np.ndarray(
shape=a.shape, dtype=a.dtype, buffer=buffer[_pos : _pos + _size]
)
if isinstance(a, xr.DataArray) and isinstance(a.data, da.Array):
tasks.append(da.store(a.data, mem_arr, lock=False, compute=False))
task_names.append(_name)
else:
mem_arr[:] = a[:]
if tasks:
dask.compute(tasks, scheduler="threads")
t = time.time()
logger.info(f"running {len(tasks)} dask data load tasks")
if dask_scheduler == "synchronous":
for task, task_name in zip(tasks, task_names):
logger.info(f"running load task: {task_name}")
dask.compute(task, scheduler=dask_scheduler)
else:
dask.compute(tasks, scheduler=dask_scheduler)
logger.info(f"completed dask data load in {time.time()-t:.3f} seconds")

if key.startswith("memmap:"):
mem.flush()

logger.info("storing metadata in shared memory")
create_shared_list(
[pickle.dumps(self._obj.attrs)] + [pickle.dumps(i) for i in wrappers], key
)
Expand Down
2 changes: 1 addition & 1 deletion sharrow/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def apply_mapper(x):
raise ImportError("sparse is not installed")

sparse_data = sparse.GCXS(
sparse.COO((i_, j_), data, shape=shape), compressed_axes=(0,)
sparse.COO(np.stack((i_, j_)), data, shape=shape), compressed_axes=(0,)
)
self._obj[f"_s_{name}"] = xr.DataArray(
sparse_data,
Expand Down
42 changes: 42 additions & 0 deletions sharrow/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,48 @@ def test_missing_categorical():
a = a.isel(expressions=0)
assert all(a == np.asarray([1, 0, 1, 1, 1, 1]))

expr = "df.TourMode2 != 'BAD'"
with pytest.warns(UserWarning):
f8 = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = f8.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([1, 1, 1, 1, 1, 1]))

expr = "'BAD' != df.TourMode2"
with pytest.warns(UserWarning):
f9 = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = f9.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([1, 1, 1, 1, 1, 1]))

expr = "(df.TourMode2 == 'BAD') * 2"
with pytest.warns(UserWarning):
fA = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = fA.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([0, 0, 0, 0, 0, 0]))

expr = "(df.TourMode2 == 'BAD') * 2.2"
with pytest.warns(UserWarning):
fB = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = fB.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([0, 0, 0, 0, 0, 0]))

expr = "np.exp(df.TourMode2 == 'BAD') * 2.2"
with pytest.warns(UserWarning):
fC = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = fC.load_dataarray(dtype=np.float32)
a = a.isel(expressions=0)
assert all(a == np.asarray([2.2, 2.2, 2.2, 2.2, 2.2, 2.2], dtype=np.float32))

expr = "(df.TourMode2 != 'BAD') * 2"
with pytest.warns(UserWarning):
fD = tree.setup_flow({expr: expr}, with_root_node_name="df")
a = fD.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([2, 2, 2, 2, 2, 2]))


def test_categorical_indexing(tours_dataset: xr.Dataset, skims_dataset: xr.Dataset):
tree = sharrow.DataTree(tours=tours_dataset)
Expand Down

0 comments on commit 29aaf6c

Please sign in to comment.