Skip to content

Commit 5d72ac1

Browse files
authored
Merge pull request #4 from YosysHQ/jix/ivy_wip
2 parents 794aaa0 + 9660195 commit 5d72ac1

File tree

11 files changed

+394
-42
lines changed

11 files changed

+394
-42
lines changed

docs/source/task_loop/context.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@ When reading a context variable, the value set by the first task--starting from
1212
:members:
1313

1414
.. autoclass:: InlineContextVar
15+
16+
.. autoclass:: TaskContextDict
17+
18+
.. todo:: Document `TaskContextDict` members

docs/source/task_loop/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ This page provides a general overview of the task loop and its concepts, see ind
1313
context
1414
logging
1515
process
16+
priority
1617

1718

1819
.. rubric:: Task Hierarchy and Dependencies

docs/source/task_loop/priority.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Task Priorities
2+
===============
3+
4+
.. automodule:: yosys_mau.task_loop.priority
5+
6+
.. autoclass:: JobPriorities
7+
:members:
8+
9+
.. autoclass:: PriorityScheduler
10+
:members:

docs/source/task_loop/process.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Running Subprocesses as Task
2-
============================
1+
Running Subprocesses as Tasks
2+
=============================
33

44
In the mau task loop, subprocesses run as tasks.
55
The output of a subprocess is made available using events.
@@ -16,6 +16,7 @@ Context Variables
1616
-----------------
1717

1818
.. autoclass:: ProcessContext
19+
:members: env
1920

2021
.. autoattribute:: cwd
2122
:annotation: = os.getcwd()

docs/source/task_loop/task.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ See :doc:`the task loop overview <index>` for general information about the task
1616

1717
.. automethod:: __init__
1818

19+
.. autoclass:: TaskGroup
1920

2021
Exceptions
2122
----------

src/yosys_mau/task_loop/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import context, logging, process
1+
from . import context, logging, priority, process
22
from ._task import (
33
ChildAborted,
44
ChildCancelled,
@@ -13,6 +13,7 @@
1313
TaskEvent,
1414
TaskEventStream,
1515
TaskFailed,
16+
TaskGroup,
1617
TaskLoopError,
1718
TaskLoopInterrupted,
1819
TaskStateChange,
@@ -21,6 +22,8 @@
2122
run_task_loop,
2223
)
2324

25+
priority = priority
26+
2427
task_context = context.task_context
2528

2629
Process = process.Process
@@ -38,6 +41,7 @@
3841
"run_task_loop",
3942
"current_task",
4043
"root_task",
44+
"TaskGroup",
4145
"TaskEvent",
4246
"TaskEventStream",
4347
"TaskLoopInterrupted",

src/yosys_mau/task_loop/_task.py

Lines changed: 134 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
Args = ParamSpec("Args")
2525

2626
_current_task: ContextVar[Task] = ContextVar(f"{__name__}._current_task")
27+
_in_sync_handler: ContextVar[bool] = ContextVar(f"{__name__}._in_sync_handler", default=False)
28+
_cancel_on_sync_handler_exit: ContextVar[bool] = ContextVar(
29+
f"{__name__}._in_sync_handler", default=False
30+
)
2731

2832

2933
@contextmanager
@@ -108,13 +112,17 @@ def __init__(
108112
global_task_loop = self
109113

110114
async def wrapper():
115+
from . import priority
116+
111117
if handle_sigint:
112118
asyncio.get_event_loop().add_signal_handler(signal.SIGINT, self._handle_sigint)
113-
job.global_client() # early setup of the job server client
119+
job_client = job.global_client() # early setup of the job server client
114120

115121
RootTask(on_run=on_run)
116122
self.root_task.name = "root"
117123

124+
priority.JobPriorities.scheduler = priority.PriorityScheduler(job_client)
125+
118126
try:
119127
await self.root_task.finished
120128
except BaseException as exc:
@@ -226,13 +234,23 @@ class Task:
226234
__cancelled_by: Task | None
227235
__cancellation_cause: BaseException | None
228236

237+
__restart_counter: int
238+
__in_sync_handler: bool
239+
229240
discard: bool
230-
"""If set to, the task will be discarded (automatically cancelled) when the last of the
241+
"""If set to `True`, the task will be discarded (automatically cancelled) when the last of the
231242
tasks depending on it finishes (by failure or cancellation).
232243
233244
Defaults to `True`.
234245
"""
235246

247+
restart_on_new_children: bool
248+
"""If set to `True`, new children can be added to the task even after it successfully finished.
249+
When that happens the task is restarted, i.e. its state is set to ``pending`` again.
250+
251+
Defaults to `False`.
252+
"""
253+
236254
def __getitem__(self, object: T) -> T:
237255
"""Wraps the given object in a proxy that performs all attribute accesses as if they were
238256
done with this task as current task.
@@ -372,8 +390,10 @@ def __init__(
372390
self.__cancelled_by = None
373391
self.__cancellation_cause = None
374392
self.__block_finish_counter = 0
393+
self.__restart_counter = 0
375394

376395
self.discard = True
396+
self.restart_on_new_children = False
377397

378398
if isinstance(self, RootTask):
379399
self.__parent = None
@@ -382,16 +402,15 @@ def __init__(
382402
task_loop().root_task = self
383403
else:
384404
self.__parent = current_task()
385-
386-
assert (
387-
self.__parent.state == "running"
388-
), "cannot create child tasks before the parent task is running"
389-
# TODO allow this but make children block for their parent having started
390-
405+
if self.__parent.__state == "done" and self.__parent.restart_on_new_children:
406+
self.__parent.__restart()
391407
self.__parent.__add_child(self)
392408

393409
self.name = self.__class__.__name__ if name is None else name
394410

411+
with self.as_current_task():
412+
self.configure_task()
413+
395414
self.__aio_main_task = asyncio.create_task(self.__task_main(), name=f"{self.name} main")
396415

397416
def __change_state(self, new_state: TaskState) -> None:
@@ -410,7 +429,10 @@ def depends_on(self, task: Task) -> None:
410429
), "cannot add dependencies after task has started"
411430
self.__dependencies.add(task)
412431
if task.state in ("preparing", "pending", "running"):
413-
callback: Callable[[Any], None] = lambda _: self.__dependency_finished(task)
432+
restart_counter = task.__restart_counter
433+
callback: Callable[[Any], None] = lambda _: self.__dependency_finished(
434+
task, restart_counter
435+
)
414436
task.__finished.add_done_callback(callback)
415437
self.__pending_dependencies[task] = callback
416438
task.__reverse_dependencies.add(self)
@@ -442,35 +464,72 @@ def handle_error(self, handler: Callable[[BaseException], None]) -> None:
442464
"""
443465
current_task().set_error_handler(self, handler)
444466

467+
def __restart(self) -> None:
468+
assert self.__state == "done"
469+
470+
self.__restart_counter += 1
471+
472+
self.__finished = asyncio.Future()
473+
self.__started = asyncio.Future()
474+
self.__cleaned_up = False
475+
476+
self.__change_state("preparing")
477+
478+
if self.__parent is not None:
479+
self.__parent.__add_child(self)
480+
481+
self.__reverse_dependencies = StableSet()
482+
483+
self.__aio_main_task = asyncio.create_task(self.__task_main(), name=f"{self.name} main")
484+
445485
def __add_child(self, task: Task) -> None:
446-
assert self.state == "running", "children can only be added to a running tasks"
486+
assert self.state in (
487+
"preparing",
488+
"pending",
489+
"running",
490+
"waiting",
491+
), f"cannot create child tasks in state {self.state}"
447492
self.__children.add(task)
448493
if task.state in ("preparing", "pending", "running"):
449-
callback: Callable[[Any], None] = lambda _: self.__child_finished(task)
494+
restart_counter = task.__restart_counter
495+
callback: Callable[[Any], None] = lambda _: self.__child_finished(task, restart_counter)
450496
task.__finished.add_done_callback(callback)
451497
self.__pending_children[task] = callback
452498

453-
def __dependency_finished(self, task: Task) -> None:
454-
self.__pending_dependencies.pop(task)
455-
self.__propagate_failure(task, (DependencyFailed, DependencyCancelled))
456-
self.__check_start()
499+
def __dependency_finished(self, task: Task, restart_counter: int) -> None:
500+
if task.__restart_counter == restart_counter:
501+
self.__pending_dependencies.pop(task)
502+
self.__propagate_failure(task, (DependencyFailed, DependencyCancelled))
503+
self.__check_start()
504+
elif self in task.__reverse_dependencies:
505+
# The task was restarted and the dependency was added again, so ignore that it finished
506+
# previously, we'll get notified again
507+
pass
508+
else:
509+
# The task was restarted, so it didn't fail, but the dependency wasn't re-added, so
510+
# don't propagate failure
511+
self.__pending_dependencies.pop(task)
512+
self.__check_start()
457513

458-
def __child_finished(self, task: Task) -> None:
459-
self.__pending_children.pop(task)
460-
self.__propagate_failure(task, (ChildFailed, ChildCancelled))
461-
self.__check_finish()
514+
def __child_finished(self, task: Task, restart_counter: int) -> None:
515+
if task.__restart_counter == restart_counter:
516+
self.__pending_children.pop(task)
517+
self.__propagate_failure(task, (ChildFailed, ChildCancelled))
518+
self.__check_finish()
462519

463520
def __check_start(self) -> None:
464521
if self.state != "pending":
465522
return
523+
if self.__parent is not None and self.__parent.state in ("preparing", "pending"):
524+
return
466525
if self.__pending_dependencies:
467526
self.__lease = None
468527
return
469528
if self.__use_lease:
470-
# TODO wrap the raw lease in some logic that prefers passing leases within the hierarchy
471-
# before returning them to the job server
529+
from . import priority
530+
472531
if self.__lease is None:
473-
self.__lease = job.global_client().request_lease()
532+
self.__lease = priority.JobPriorities.scheduler.request_lease()
474533
if not self.__lease.ready:
475534
self.__lease.add_ready_callback(self.__check_start)
476535
return
@@ -519,6 +578,8 @@ def __propagate_failure(
519578
if handler := self.__error_handlers.get(task):
520579
found = handler
521580

581+
ExceptionPropagation(task, exception, found is not None).emit()
582+
522583
if found is not None:
523584
try:
524585
found(exception)
@@ -534,17 +595,21 @@ def __propagate_failure(
534595
async def __task_main(self) -> None:
535596
__prev_task = _current_task.set(self)
536597
try:
537-
TaskStateChange(None, self.__state).emit()
598+
if not self.__restart_counter:
599+
TaskStateChange(None, self.__state).emit()
538600
await self.on_prepare()
539601
self.__change_state("pending")
540602
self.__check_start()
541603
await self.started
542604
self.__change_state("running")
605+
for child in self.__children:
606+
child.__check_start()
543607
await self.on_run()
544-
self.__lease = None
545-
self.__change_state("waiting")
546-
self.__check_finish()
608+
if not self.__finished.done():
609+
self.__change_state("waiting")
610+
self.__check_finish()
547611
await self.finished
612+
self.__lease = None
548613
self.__change_state("done")
549614
except Exception as exc:
550615
self.__failed(exc)
@@ -579,7 +644,10 @@ def __cleanup(self):
579644

580645
self.__aio_main_task.cancel()
581646
if asyncio.current_task() == self.__aio_main_task:
582-
raise asyncio.CancelledError()
647+
if _in_sync_handler.get():
648+
_cancel_on_sync_handler_exit.set(True)
649+
else:
650+
raise asyncio.CancelledError()
583651

584652
def __failed(self, exc: BaseException | None) -> None:
585653
if exc is None or self.is_finished:
@@ -606,6 +674,13 @@ def __failed(self, exc: BaseException | None) -> None:
606674

607675
self.__cleanup()
608676

677+
def configure_task(self):
678+
"""Invoked on construction with the task set as current task.
679+
680+
Can be used to override initialization in subclasses.
681+
"""
682+
pass
683+
609684
async def on_prepare(self) -> None:
610685
"""Actions to perform right after the task is created, before scheduling it to run.
611686
@@ -796,6 +871,7 @@ def __emit_event__(self, event: TaskEvent) -> None:
796871
while current is not None:
797872
for mro_item in type(event).mro():
798873
sync_handlers = current.__event_sync_handlers.get(mro_item, ())
874+
799875
for handler in list(sync_handlers):
800876
handler(event)
801877

@@ -843,11 +919,17 @@ def sync_handle_events(
843919
self.__event_sync_handlers[event_type] = StableSet()
844920

845921
def wrapper(event: T_TaskEvent):
922+
token = _in_sync_handler.set(True)
846923
try:
847924
with self.as_current_task():
848925
handler(event)
849926
except BaseException as exc:
850927
self.__failed(exc)
928+
finally:
929+
_in_sync_handler.reset(token)
930+
if not _in_sync_handler.get() and _cancel_on_sync_handler_exit.get():
931+
_cancel_on_sync_handler_exit.set(False)
932+
raise asyncio.CancelledError()
851933

852934
self.__event_sync_handlers[event_type].add(wrapper)
853935

@@ -876,6 +958,17 @@ def block_finishing(self) -> typing.Iterator[None]:
876958
self.__check_finish()
877959

878960

961+
class TaskGroup(Task):
962+
"""A task used to group child tasks.
963+
964+
This is normal `Task` initialized with `discard` set to `False` and `restart_on_new_children`
965+
"""
966+
967+
def configure_task(self):
968+
self.discard = False
969+
self.restart_on_new_children = True
970+
971+
879972
class RootTask(Task):
880973
pass
881974

@@ -1105,3 +1198,17 @@ class TaskStateChange(DebugEvent):
11051198

11061199
def __repr__(self) -> str:
11071200
return f"{self.source}: {self.previous_state} -> {self.state}"
1201+
1202+
1203+
@dataclass
1204+
class ExceptionPropagation(DebugEvent):
1205+
exc_source: Task
1206+
exc: BaseException
1207+
handler: bool
1208+
1209+
def __repr__(self) -> str:
1210+
handled = " handled" if self.handler else ""
1211+
return (
1212+
f"{self.source}:{handled} {self.exc.__class__.__name__} exception "
1213+
f"from {self.exc_source}: {self.exc}"
1214+
)

0 commit comments

Comments
 (0)