Skip to content

Commit

Permalink
shared memory pre-init
Browse files Browse the repository at this point in the history
  • Loading branch information
jpn-- committed May 9, 2024
1 parent b6affe4 commit 05e9090
Showing 1 changed file with 44 additions and 12 deletions.
56 changes: 44 additions & 12 deletions sharrow/shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,13 @@ def delete_shared_memory_files(key):
delete_shared_memory_files(key)

def to_shared_memory(
self, key=None, mode="r+", _dupe=True, dask_scheduler="threads"
self,
key=None,
mode="r+",
_dupe=True,
dask_scheduler="threads",
pre_init=False,
load=True,
):
"""
Load this Dataset into shared memory.
Expand Down Expand Up @@ -351,6 +357,10 @@ def emit(k, a, is_coord):
else:
buffer = mem.buf

if pre_init:
# gross init with all zeros
buffer[:] = b"\0" * len(buffer)

tasks = []
task_names = []
for w in wrappers:
Expand Down Expand Up @@ -395,16 +405,10 @@ def emit(k, a, is_coord):
task_names.append(_name)
else:
mem_arr[:] = a[:]
if tasks:
t = time.time()
logger.info(f"running {len(tasks)} dask data load tasks")
if dask_scheduler == "synchronous":
for task, task_name in zip(tasks, task_names):
logger.info(f"running load task: {task_name}")
dask.compute(task, scheduler=dask_scheduler)
else:
dask.compute(tasks, scheduler=dask_scheduler)
logger.info(f"completed dask data load in {time.time()-t:.3f} seconds")
if tasks and load:
self.tasks = tasks
self.task_names = task_names
self.run_tasks(dask_scheduler=dask_scheduler)

if key.startswith("memmap:"):
mem.flush()
Expand All @@ -413,7 +417,35 @@ def emit(k, a, is_coord):
create_shared_list(
[pickle.dumps(self._obj.attrs)] + [pickle.dumps(i) for i in wrappers], key
)
return type(self).from_shared_memory(key, own_data=mem, mode=mode)
result = type(self).from_shared_memory(key, own_data=mem, mode=mode)
if tasks and not load:
# attach incompleted tasks to the result
result.shm.tasks = tasks
result.shm.task_names = task_names
result.shm._buffer = buffer
result.shm._position = position
return result

def run_tasks(self, dask_scheduler="threads"):
"""Run any deferred dask tasks."""
if not hasattr(self, "tasks"):
return
else:
tasks = self.tasks
if not hasattr(self, "task_names"):
task_names = ["untitled" for _ in tasks]
else:
task_names = self.task_names
t = time.time()
logger.info(f"running {len(tasks)} dask data load tasks")
if dask_scheduler == "synchronous":
for task, task_name in zip(tasks, task_names):
logger.info(f"running load task: {task_name}")
dask.compute(task, scheduler=dask_scheduler)
else:
dask.compute(tasks, scheduler=dask_scheduler)
logger.info(f"completed tasks in {time.time() - t:.3f} seconds")
del self.tasks

@property
def shared_memory_key(self):
Expand Down

0 comments on commit 05e9090

Please sign in to comment.