Skip to content

Commit

Permalink
Merge pull request #4 from YosysHQ/jix/ivy_wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jix authored Sep 11, 2023
2 parents 794aaa0 + 9660195 commit 5d72ac1
Show file tree
Hide file tree
Showing 11 changed files with 394 additions and 42 deletions.
4 changes: 4 additions & 0 deletions docs/source/task_loop/context.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ When reading a context variable, the value set by the first task--starting from
:members:

.. autoclass:: InlineContextVar

.. autoclass:: TaskContextDict

.. todo:: Document `TaskContextDict` members
1 change: 1 addition & 0 deletions docs/source/task_loop/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ This page provides a general overview of the task loop and its concepts, see ind
context
logging
process
priority


.. rubric:: Task Hierarchy and Dependencies
Expand Down
10 changes: 10 additions & 0 deletions docs/source/task_loop/priority.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Task Priorities
===============

.. automodule:: yosys_mau.task_loop.priority

.. autoclass:: JobPriorities
:members:

.. autoclass:: PriorityScheduler
:members:
5 changes: 3 additions & 2 deletions docs/source/task_loop/process.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Running Subprocesses as Task
============================
Running Subprocesses as Tasks
=============================

In the mau task loop, subprocesses run as tasks.
The output of a subprocess is made available using events.
Expand All @@ -16,6 +16,7 @@ Context Variables
-----------------

.. autoclass:: ProcessContext
:members: env

.. autoattribute:: cwd
:annotation: = os.getcwd()
Expand Down
1 change: 1 addition & 0 deletions docs/source/task_loop/task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ See :doc:`the task loop overview <index>` for general information about the task

.. automethod:: __init__

.. autoclass:: TaskGroup

Exceptions
----------
Expand Down
6 changes: 5 additions & 1 deletion src/yosys_mau/task_loop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import context, logging, process
from . import context, logging, priority, process
from ._task import (
ChildAborted,
ChildCancelled,
Expand All @@ -13,6 +13,7 @@
TaskEvent,
TaskEventStream,
TaskFailed,
TaskGroup,
TaskLoopError,
TaskLoopInterrupted,
TaskStateChange,
Expand All @@ -21,6 +22,8 @@
run_task_loop,
)

priority = priority

task_context = context.task_context

Process = process.Process
Expand All @@ -38,6 +41,7 @@
"run_task_loop",
"current_task",
"root_task",
"TaskGroup",
"TaskEvent",
"TaskEventStream",
"TaskLoopInterrupted",
Expand Down
161 changes: 134 additions & 27 deletions src/yosys_mau/task_loop/_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
Args = ParamSpec("Args")

_current_task: ContextVar[Task] = ContextVar(f"{__name__}._current_task")
_in_sync_handler: ContextVar[bool] = ContextVar(f"{__name__}._in_sync_handler", default=False)
_cancel_on_sync_handler_exit: ContextVar[bool] = ContextVar(
f"{__name__}._in_sync_handler", default=False
)


@contextmanager
Expand Down Expand Up @@ -108,13 +112,17 @@ def __init__(
global_task_loop = self

async def wrapper():
from . import priority

if handle_sigint:
asyncio.get_event_loop().add_signal_handler(signal.SIGINT, self._handle_sigint)
job.global_client() # early setup of the job server client
job_client = job.global_client() # early setup of the job server client

RootTask(on_run=on_run)
self.root_task.name = "root"

priority.JobPriorities.scheduler = priority.PriorityScheduler(job_client)

try:
await self.root_task.finished
except BaseException as exc:
Expand Down Expand Up @@ -226,13 +234,23 @@ class Task:
__cancelled_by: Task | None
__cancellation_cause: BaseException | None

__restart_counter: int
__in_sync_handler: bool

discard: bool
"""If set to, the task will be discarded (automatically cancelled) when the last of the
"""If set to `True`, the task will be discarded (automatically cancelled) when the last of the
tasks depending on it finishes (by failure or cancellation).
Defaults to `True`.
"""

restart_on_new_children: bool
"""If set to `True`, new children can be added to the task even after it successfully finished.
When that happens the task is restarted, i.e. its state is set to ``pending`` again.
Defaults to `False`.
"""

def __getitem__(self, object: T) -> T:
"""Wraps the given object in a proxy that performs all attribute accesses as if they were
done with this task as current task.
Expand Down Expand Up @@ -372,8 +390,10 @@ def __init__(
self.__cancelled_by = None
self.__cancellation_cause = None
self.__block_finish_counter = 0
self.__restart_counter = 0

self.discard = True
self.restart_on_new_children = False

if isinstance(self, RootTask):
self.__parent = None
Expand All @@ -382,16 +402,15 @@ def __init__(
task_loop().root_task = self
else:
self.__parent = current_task()

assert (
self.__parent.state == "running"
), "cannot create child tasks before the parent task is running"
# TODO allow this but make children block for their parent having started

if self.__parent.__state == "done" and self.__parent.restart_on_new_children:
self.__parent.__restart()
self.__parent.__add_child(self)

self.name = self.__class__.__name__ if name is None else name

with self.as_current_task():
self.configure_task()

self.__aio_main_task = asyncio.create_task(self.__task_main(), name=f"{self.name} main")

def __change_state(self, new_state: TaskState) -> None:
Expand All @@ -410,7 +429,10 @@ def depends_on(self, task: Task) -> None:
), "cannot add dependencies after task has started"
self.__dependencies.add(task)
if task.state in ("preparing", "pending", "running"):
callback: Callable[[Any], None] = lambda _: self.__dependency_finished(task)
restart_counter = task.__restart_counter
callback: Callable[[Any], None] = lambda _: self.__dependency_finished(
task, restart_counter
)
task.__finished.add_done_callback(callback)
self.__pending_dependencies[task] = callback
task.__reverse_dependencies.add(self)
Expand Down Expand Up @@ -442,35 +464,72 @@ def handle_error(self, handler: Callable[[BaseException], None]) -> None:
"""
current_task().set_error_handler(self, handler)

def __restart(self) -> None:
assert self.__state == "done"

self.__restart_counter += 1

self.__finished = asyncio.Future()
self.__started = asyncio.Future()
self.__cleaned_up = False

self.__change_state("preparing")

if self.__parent is not None:
self.__parent.__add_child(self)

self.__reverse_dependencies = StableSet()

self.__aio_main_task = asyncio.create_task(self.__task_main(), name=f"{self.name} main")

def __add_child(self, task: Task) -> None:
assert self.state == "running", "children can only be added to a running tasks"
assert self.state in (
"preparing",
"pending",
"running",
"waiting",
), f"cannot create child tasks in state {self.state}"
self.__children.add(task)
if task.state in ("preparing", "pending", "running"):
callback: Callable[[Any], None] = lambda _: self.__child_finished(task)
restart_counter = task.__restart_counter
callback: Callable[[Any], None] = lambda _: self.__child_finished(task, restart_counter)
task.__finished.add_done_callback(callback)
self.__pending_children[task] = callback

def __dependency_finished(self, task: Task) -> None:
self.__pending_dependencies.pop(task)
self.__propagate_failure(task, (DependencyFailed, DependencyCancelled))
self.__check_start()
def __dependency_finished(self, task: Task, restart_counter: int) -> None:
if task.__restart_counter == restart_counter:
self.__pending_dependencies.pop(task)
self.__propagate_failure(task, (DependencyFailed, DependencyCancelled))
self.__check_start()
elif self in task.__reverse_dependencies:
# The task was restarted and the dependency was added again, so ignore that it finished
# previously, we'll get notified again
pass
else:
# The task was restarted, so it didn't fail, but the dependency wasn't re-added, so
# don't propagate failure
self.__pending_dependencies.pop(task)
self.__check_start()

def __child_finished(self, task: Task) -> None:
self.__pending_children.pop(task)
self.__propagate_failure(task, (ChildFailed, ChildCancelled))
self.__check_finish()
def __child_finished(self, task: Task, restart_counter: int) -> None:
if task.__restart_counter == restart_counter:
self.__pending_children.pop(task)
self.__propagate_failure(task, (ChildFailed, ChildCancelled))
self.__check_finish()

def __check_start(self) -> None:
if self.state != "pending":
return
if self.__parent is not None and self.__parent.state in ("preparing", "pending"):
return
if self.__pending_dependencies:
self.__lease = None
return
if self.__use_lease:
# TODO wrap the raw lease in some logic that prefers passing leases within the hierarchy
# before returning them to the job server
from . import priority

if self.__lease is None:
self.__lease = job.global_client().request_lease()
self.__lease = priority.JobPriorities.scheduler.request_lease()
if not self.__lease.ready:
self.__lease.add_ready_callback(self.__check_start)
return
Expand Down Expand Up @@ -519,6 +578,8 @@ def __propagate_failure(
if handler := self.__error_handlers.get(task):
found = handler

ExceptionPropagation(task, exception, found is not None).emit()

if found is not None:
try:
found(exception)
Expand All @@ -534,17 +595,21 @@ def __propagate_failure(
async def __task_main(self) -> None:
__prev_task = _current_task.set(self)
try:
TaskStateChange(None, self.__state).emit()
if not self.__restart_counter:
TaskStateChange(None, self.__state).emit()
await self.on_prepare()
self.__change_state("pending")
self.__check_start()
await self.started
self.__change_state("running")
for child in self.__children:
child.__check_start()
await self.on_run()
self.__lease = None
self.__change_state("waiting")
self.__check_finish()
if not self.__finished.done():
self.__change_state("waiting")
self.__check_finish()
await self.finished
self.__lease = None
self.__change_state("done")
except Exception as exc:
self.__failed(exc)
Expand Down Expand Up @@ -579,7 +644,10 @@ def __cleanup(self):

self.__aio_main_task.cancel()
if asyncio.current_task() == self.__aio_main_task:
raise asyncio.CancelledError()
if _in_sync_handler.get():
_cancel_on_sync_handler_exit.set(True)
else:
raise asyncio.CancelledError()

def __failed(self, exc: BaseException | None) -> None:
if exc is None or self.is_finished:
Expand All @@ -606,6 +674,13 @@ def __failed(self, exc: BaseException | None) -> None:

self.__cleanup()

def configure_task(self):
"""Invoked on construction with the task set as current task.
Can be used to override initialization in subclasses.
"""
pass

async def on_prepare(self) -> None:
"""Actions to perform right after the task is created, before scheduling it to run.
Expand Down Expand Up @@ -796,6 +871,7 @@ def __emit_event__(self, event: TaskEvent) -> None:
while current is not None:
for mro_item in type(event).mro():
sync_handlers = current.__event_sync_handlers.get(mro_item, ())

for handler in list(sync_handlers):
handler(event)

Expand Down Expand Up @@ -843,11 +919,17 @@ def sync_handle_events(
self.__event_sync_handlers[event_type] = StableSet()

def wrapper(event: T_TaskEvent):
token = _in_sync_handler.set(True)
try:
with self.as_current_task():
handler(event)
except BaseException as exc:
self.__failed(exc)
finally:
_in_sync_handler.reset(token)
if not _in_sync_handler.get() and _cancel_on_sync_handler_exit.get():
_cancel_on_sync_handler_exit.set(False)
raise asyncio.CancelledError()

self.__event_sync_handlers[event_type].add(wrapper)

Expand Down Expand Up @@ -876,6 +958,17 @@ def block_finishing(self) -> typing.Iterator[None]:
self.__check_finish()


class TaskGroup(Task):
"""A task used to group child tasks.
This is normal `Task` initialized with `discard` set to `False` and `restart_on_new_children`
"""

def configure_task(self):
self.discard = False
self.restart_on_new_children = True


class RootTask(Task):
pass

Expand Down Expand Up @@ -1105,3 +1198,17 @@ class TaskStateChange(DebugEvent):

def __repr__(self) -> str:
return f"{self.source}: {self.previous_state} -> {self.state}"


@dataclass
class ExceptionPropagation(DebugEvent):
exc_source: Task
exc: BaseException
handler: bool

def __repr__(self) -> str:
handled = " handled" if self.handler else ""
return (
f"{self.source}:{handled} {self.exc.__class__.__name__} exception "
f"from {self.exc_source}: {self.exc}"
)
Loading

0 comments on commit 5d72ac1

Please sign in to comment.