24
24
Args = ParamSpec ("Args" )
25
25
26
26
_current_task : ContextVar [Task ] = ContextVar (f"{ __name__ } ._current_task" )
27
+ _in_sync_handler : ContextVar [bool ] = ContextVar (f"{ __name__ } ._in_sync_handler" , default = False )
28
+ _cancel_on_sync_handler_exit : ContextVar [bool ] = ContextVar (
29
+ f"{ __name__ } ._in_sync_handler" , default = False
30
+ )
27
31
28
32
29
33
@contextmanager
@@ -108,13 +112,17 @@ def __init__(
108
112
global_task_loop = self
109
113
110
114
async def wrapper ():
115
+ from . import priority
116
+
111
117
if handle_sigint :
112
118
asyncio .get_event_loop ().add_signal_handler (signal .SIGINT , self ._handle_sigint )
113
- job .global_client () # early setup of the job server client
119
+ job_client = job .global_client () # early setup of the job server client
114
120
115
121
RootTask (on_run = on_run )
116
122
self .root_task .name = "root"
117
123
124
+ priority .JobPriorities .scheduler = priority .PriorityScheduler (job_client )
125
+
118
126
try :
119
127
await self .root_task .finished
120
128
except BaseException as exc :
@@ -226,13 +234,23 @@ class Task:
226
234
__cancelled_by : Task | None
227
235
__cancellation_cause : BaseException | None
228
236
237
+ __restart_counter : int
238
+ __in_sync_handler : bool
239
+
229
240
discard : bool
230
- """If set to, the task will be discarded (automatically cancelled) when the last of the
241
+ """If set to `True` , the task will be discarded (automatically cancelled) when the last of the
231
242
tasks depending on it finishes (by failure or cancellation).
232
243
233
244
Defaults to `True`.
234
245
"""
235
246
247
+ restart_on_new_children : bool
248
+ """If set to `True`, new children can be added to the task even after it successfully finished.
249
+ When that happens the task is restarted, i.e. its state is set to ``pending`` again.
250
+
251
+ Defaults to `False`.
252
+ """
253
+
236
254
def __getitem__ (self , object : T ) -> T :
237
255
"""Wraps the given object in a proxy that performs all attribute accesses as if they were
238
256
done with this task as current task.
@@ -372,8 +390,10 @@ def __init__(
372
390
self .__cancelled_by = None
373
391
self .__cancellation_cause = None
374
392
self .__block_finish_counter = 0
393
+ self .__restart_counter = 0
375
394
376
395
self .discard = True
396
+ self .restart_on_new_children = False
377
397
378
398
if isinstance (self , RootTask ):
379
399
self .__parent = None
@@ -382,16 +402,15 @@ def __init__(
382
402
task_loop ().root_task = self
383
403
else :
384
404
self .__parent = current_task ()
385
-
386
- assert (
387
- self .__parent .state == "running"
388
- ), "cannot create child tasks before the parent task is running"
389
- # TODO allow this but make children block for their parent having started
390
-
405
+ if self .__parent .__state == "done" and self .__parent .restart_on_new_children :
406
+ self .__parent .__restart ()
391
407
self .__parent .__add_child (self )
392
408
393
409
self .name = self .__class__ .__name__ if name is None else name
394
410
411
+ with self .as_current_task ():
412
+ self .configure_task ()
413
+
395
414
self .__aio_main_task = asyncio .create_task (self .__task_main (), name = f"{ self .name } main" )
396
415
397
416
def __change_state (self , new_state : TaskState ) -> None :
@@ -410,7 +429,10 @@ def depends_on(self, task: Task) -> None:
410
429
), "cannot add dependencies after task has started"
411
430
self .__dependencies .add (task )
412
431
if task .state in ("preparing" , "pending" , "running" ):
413
- callback : Callable [[Any ], None ] = lambda _ : self .__dependency_finished (task )
432
+ restart_counter = task .__restart_counter
433
+ callback : Callable [[Any ], None ] = lambda _ : self .__dependency_finished (
434
+ task , restart_counter
435
+ )
414
436
task .__finished .add_done_callback (callback )
415
437
self .__pending_dependencies [task ] = callback
416
438
task .__reverse_dependencies .add (self )
@@ -442,35 +464,72 @@ def handle_error(self, handler: Callable[[BaseException], None]) -> None:
442
464
"""
443
465
current_task ().set_error_handler (self , handler )
444
466
467
+ def __restart (self ) -> None :
468
+ assert self .__state == "done"
469
+
470
+ self .__restart_counter += 1
471
+
472
+ self .__finished = asyncio .Future ()
473
+ self .__started = asyncio .Future ()
474
+ self .__cleaned_up = False
475
+
476
+ self .__change_state ("preparing" )
477
+
478
+ if self .__parent is not None :
479
+ self .__parent .__add_child (self )
480
+
481
+ self .__reverse_dependencies = StableSet ()
482
+
483
+ self .__aio_main_task = asyncio .create_task (self .__task_main (), name = f"{ self .name } main" )
484
+
445
485
def __add_child (self , task : Task ) -> None :
446
- assert self .state == "running" , "children can only be added to a running tasks"
486
+ assert self .state in (
487
+ "preparing" ,
488
+ "pending" ,
489
+ "running" ,
490
+ "waiting" ,
491
+ ), f"cannot create child tasks in state { self .state } "
447
492
self .__children .add (task )
448
493
if task .state in ("preparing" , "pending" , "running" ):
449
- callback : Callable [[Any ], None ] = lambda _ : self .__child_finished (task )
494
+ restart_counter = task .__restart_counter
495
+ callback : Callable [[Any ], None ] = lambda _ : self .__child_finished (task , restart_counter )
450
496
task .__finished .add_done_callback (callback )
451
497
self .__pending_children [task ] = callback
452
498
453
- def __dependency_finished (self , task : Task ) -> None :
454
- self .__pending_dependencies .pop (task )
455
- self .__propagate_failure (task , (DependencyFailed , DependencyCancelled ))
456
- self .__check_start ()
499
+ def __dependency_finished (self , task : Task , restart_counter : int ) -> None :
500
+ if task .__restart_counter == restart_counter :
501
+ self .__pending_dependencies .pop (task )
502
+ self .__propagate_failure (task , (DependencyFailed , DependencyCancelled ))
503
+ self .__check_start ()
504
+ elif self in task .__reverse_dependencies :
505
+ # The task was restarted and the dependency was added again, so ignore that it finished
506
+ # previously, we'll get notified again
507
+ pass
508
+ else :
509
+ # The task was restarted, so it didn't fail, but the dependency wasn't re-added, so
510
+ # don't propagate failure
511
+ self .__pending_dependencies .pop (task )
512
+ self .__check_start ()
457
513
458
- def __child_finished (self , task : Task ) -> None :
459
- self .__pending_children .pop (task )
460
- self .__propagate_failure (task , (ChildFailed , ChildCancelled ))
461
- self .__check_finish ()
514
+ def __child_finished (self , task : Task , restart_counter : int ) -> None :
515
+ if task .__restart_counter == restart_counter :
516
+ self .__pending_children .pop (task )
517
+ self .__propagate_failure (task , (ChildFailed , ChildCancelled ))
518
+ self .__check_finish ()
462
519
463
520
def __check_start (self ) -> None :
464
521
if self .state != "pending" :
465
522
return
523
+ if self .__parent is not None and self .__parent .state in ("preparing" , "pending" ):
524
+ return
466
525
if self .__pending_dependencies :
467
526
self .__lease = None
468
527
return
469
528
if self .__use_lease :
470
- # TODO wrap the raw lease in some logic that prefers passing leases within the hierarchy
471
- # before returning them to the job server
529
+ from . import priority
530
+
472
531
if self .__lease is None :
473
- self .__lease = job . global_client () .request_lease ()
532
+ self .__lease = priority . JobPriorities . scheduler .request_lease ()
474
533
if not self .__lease .ready :
475
534
self .__lease .add_ready_callback (self .__check_start )
476
535
return
@@ -519,6 +578,8 @@ def __propagate_failure(
519
578
if handler := self .__error_handlers .get (task ):
520
579
found = handler
521
580
581
+ ExceptionPropagation (task , exception , found is not None ).emit ()
582
+
522
583
if found is not None :
523
584
try :
524
585
found (exception )
@@ -534,17 +595,21 @@ def __propagate_failure(
534
595
async def __task_main (self ) -> None :
535
596
__prev_task = _current_task .set (self )
536
597
try :
537
- TaskStateChange (None , self .__state ).emit ()
598
+ if not self .__restart_counter :
599
+ TaskStateChange (None , self .__state ).emit ()
538
600
await self .on_prepare ()
539
601
self .__change_state ("pending" )
540
602
self .__check_start ()
541
603
await self .started
542
604
self .__change_state ("running" )
605
+ for child in self .__children :
606
+ child .__check_start ()
543
607
await self .on_run ()
544
- self . __lease = None
545
- self .__change_state ("waiting" )
546
- self .__check_finish ()
608
+ if not self . __finished . done ():
609
+ self .__change_state ("waiting" )
610
+ self .__check_finish ()
547
611
await self .finished
612
+ self .__lease = None
548
613
self .__change_state ("done" )
549
614
except Exception as exc :
550
615
self .__failed (exc )
@@ -579,7 +644,10 @@ def __cleanup(self):
579
644
580
645
self .__aio_main_task .cancel ()
581
646
if asyncio .current_task () == self .__aio_main_task :
582
- raise asyncio .CancelledError ()
647
+ if _in_sync_handler .get ():
648
+ _cancel_on_sync_handler_exit .set (True )
649
+ else :
650
+ raise asyncio .CancelledError ()
583
651
584
652
def __failed (self , exc : BaseException | None ) -> None :
585
653
if exc is None or self .is_finished :
@@ -606,6 +674,13 @@ def __failed(self, exc: BaseException | None) -> None:
606
674
607
675
self .__cleanup ()
608
676
677
+ def configure_task (self ):
678
+ """Invoked on construction with the task set as current task.
679
+
680
+ Can be used to override initialization in subclasses.
681
+ """
682
+ pass
683
+
609
684
async def on_prepare (self ) -> None :
610
685
"""Actions to perform right after the task is created, before scheduling it to run.
611
686
@@ -796,6 +871,7 @@ def __emit_event__(self, event: TaskEvent) -> None:
796
871
while current is not None :
797
872
for mro_item in type (event ).mro ():
798
873
sync_handlers = current .__event_sync_handlers .get (mro_item , ())
874
+
799
875
for handler in list (sync_handlers ):
800
876
handler (event )
801
877
@@ -843,11 +919,17 @@ def sync_handle_events(
843
919
self .__event_sync_handlers [event_type ] = StableSet ()
844
920
845
921
def wrapper (event : T_TaskEvent ):
922
+ token = _in_sync_handler .set (True )
846
923
try :
847
924
with self .as_current_task ():
848
925
handler (event )
849
926
except BaseException as exc :
850
927
self .__failed (exc )
928
+ finally :
929
+ _in_sync_handler .reset (token )
930
+ if not _in_sync_handler .get () and _cancel_on_sync_handler_exit .get ():
931
+ _cancel_on_sync_handler_exit .set (False )
932
+ raise asyncio .CancelledError ()
851
933
852
934
self .__event_sync_handlers [event_type ].add (wrapper )
853
935
@@ -876,6 +958,17 @@ def block_finishing(self) -> typing.Iterator[None]:
876
958
self .__check_finish ()
877
959
878
960
961
+ class TaskGroup (Task ):
962
+ """A task used to group child tasks.
963
+
964
+ This is normal `Task` initialized with `discard` set to `False` and `restart_on_new_children`
965
+ """
966
+
967
+ def configure_task (self ):
968
+ self .discard = False
969
+ self .restart_on_new_children = True
970
+
971
+
879
972
class RootTask (Task ):
880
973
pass
881
974
@@ -1105,3 +1198,17 @@ class TaskStateChange(DebugEvent):
1105
1198
1106
1199
def __repr__ (self ) -> str :
1107
1200
return f"{ self .source } : { self .previous_state } -> { self .state } "
1201
+
1202
+
1203
+ @dataclass
1204
+ class ExceptionPropagation (DebugEvent ):
1205
+ exc_source : Task
1206
+ exc : BaseException
1207
+ handler : bool
1208
+
1209
+ def __repr__ (self ) -> str :
1210
+ handled = " handled" if self .handler else ""
1211
+ return (
1212
+ f"{ self .source } :{ handled } { self .exc .__class__ .__name__ } exception "
1213
+ f"from { self .exc_source } : { self .exc } "
1214
+ )
0 commit comments