@@ -298,7 +298,12 @@ def execute_something():
298
298
299
299
self ._assert_allowed_event (event_name )
300
300
301
- event_args = (Exception (),) if event_name == Events .EXCEPTION_RAISED else ()
301
+ event_args = () # type: Tuple[Any, ...]
302
+ if event_name == Events .EXCEPTION_RAISED :
303
+ event_args += (Exception (),)
304
+ elif event_name == Events .TERMINATE_SINGLE_EPOCH :
305
+ event_args += (0 ,)
306
+
302
307
try :
303
308
_check_signature (handler , "handler" , self , * (event_args + args ), ** kwargs )
304
309
self ._event_handlers [event_name ].append ((handler , (self ,) + args , kwargs ))
@@ -433,14 +438,28 @@ def fire_event(self, event_name: Any) -> None:
433
438
return self ._fire_event (event_name )
434
439
435
440
def terminate (self ) -> None :
436
- """Sends terminate signal to the engine, so that it terminates completely the run after
437
- the current iteration."""
441
+ """Sends terminate signal to the engine, so that it terminates completely the run. The run is
442
+ terminated after the event on which ``terminate`` method was called. The following events are triggered:
443
+
444
+ - ...
445
+ - Terminating event
446
+ - :attr:`~ignite.engine.events.Events.TERMINATE`
447
+ - :attr:`~ignite.engine.events.Events.COMPLETED`
448
+ """
438
449
self .logger .info ("Terminate signaled. Engine will stop after current iteration is finished." )
439
450
self .should_terminate = True
440
451
441
452
def terminate_epoch (self ) -> None :
442
- """Sends terminate signal to the engine, so that it terminates the current epoch
443
- after the current iteration."""
453
+ """Sends terminate signal to the engine, so that it terminates the current epoch. The run
454
+ continues from the next epoch. The following events are triggered:
455
+
456
+ - ...
457
+ - Event on which ``terminate_epoch`` method is called
458
+ - :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH`
459
+ - :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED`
460
+ - :attr:`~ignite.engine.events.Events.EPOCH_STARTED`
461
+ - ...
462
+ """
444
463
self .logger .info (
445
464
"Terminate current epoch is signaled. "
446
465
"Current epoch iteration will stop after current iteration is finished."
@@ -742,33 +761,43 @@ def _internal_run(self) -> State:
742
761
self .should_terminate = self .should_terminate_single_epoch = False
743
762
self ._init_timers (self .state )
744
763
try :
745
- start_time = time .time ()
746
- self ._fire_event (Events .STARTED )
747
- while not self ._is_done (self .state ) and not self .should_terminate :
748
- self .state .epoch += 1
749
- self ._fire_event (Events .EPOCH_STARTED )
750
-
751
- if self ._dataloader_iter is None :
752
- self ._setup_engine ()
753
-
754
- time_taken = self ._run_once_on_dataset ()
755
- # time is available for handlers but must be update after fire
756
- self .state .times [Events .EPOCH_COMPLETED .name ] = time_taken
757
- handlers_start_time = time .time ()
758
- if self .should_terminate :
759
- self ._fire_event (Events .TERMINATE )
760
- else :
764
+ try :
765
+ start_time = time .time ()
766
+ self ._fire_event (Events .STARTED )
767
+ self ._maybe_terminate ()
768
+
769
+ while not self ._is_done (self .state ) and not self .should_terminate :
770
+ self .state .epoch += 1
771
+ handlers_start_time = time .time ()
772
+ self ._fire_event (Events .EPOCH_STARTED )
773
+ epoch_time_taken = time .time () - handlers_start_time
774
+ self ._maybe_terminate ()
775
+
776
+ if self ._dataloader_iter is None :
777
+ self ._setup_engine ()
778
+
779
+ epoch_time_taken += self ._run_once_on_dataset ()
780
+
781
+ # time is available for handlers but must be updated after fire
782
+ self .state .times [Events .EPOCH_COMPLETED .name ] = epoch_time_taken
783
+
784
+ handlers_start_time = time .time ()
761
785
self ._fire_event (Events .EPOCH_COMPLETED )
762
- time_taken += time .time () - handlers_start_time
763
- # update time wrt handlers
764
- self .state .times [Events .EPOCH_COMPLETED .name ] = time_taken
765
- hours , mins , secs = _to_hours_mins_secs (time_taken )
766
- self .logger .info (f"Epoch[{ self .state .epoch } ] Complete. Time taken: { hours :02d} :{ mins :02d} :{ secs :06.3f} " )
767
- if self .should_terminate :
768
- break
786
+ epoch_time_taken += time .time () - handlers_start_time
787
+ # update time wrt handlers
788
+ self .state .times [Events .EPOCH_COMPLETED .name ] = epoch_time_taken
789
+ self ._maybe_terminate ()
790
+
791
+ hours , mins , secs = _to_hours_mins_secs (epoch_time_taken )
792
+ self .logger .info (
793
+ f"Epoch[{ self .state .epoch } ] Complete. Time taken: { hours :02d} :{ mins :02d} :{ secs :06.3f} "
794
+ )
795
+
796
+ except _EngineTerminateException :
797
+ self ._fire_event (Events .TERMINATE )
769
798
770
799
time_taken = time .time () - start_time
771
- # time is available for handlers but must be update after fire
800
+ # time is available for handlers but must be updated after fire
772
801
self .state .times [Events .COMPLETED .name ] = time_taken
773
802
handlers_start_time = time .time ()
774
803
self ._fire_event (Events .COMPLETED )
@@ -786,6 +815,13 @@ def _internal_run(self) -> State:
786
815
self ._dataloader_iter = None
787
816
return self .state
788
817
818
+ def _maybe_terminate (self ) -> None :
819
+ if self .should_terminate :
820
+ raise _EngineTerminateException ()
821
+
822
+ if self .should_terminate_single_epoch :
823
+ raise _EngineTerminateSingleEpochException ()
824
+
789
825
def _run_once_on_dataset (self ) -> float :
790
826
start_time = time .time ()
791
827
@@ -805,8 +841,12 @@ def _run_once_on_dataset(self) -> float:
805
841
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
806
842
if self .last_event_name != Events .DATALOADER_STOP_ITERATION :
807
843
self ._fire_event (Events .GET_BATCH_STARTED )
844
+ self ._maybe_terminate ()
845
+
808
846
self .state .batch = next (self ._dataloader_iter )
809
847
self ._fire_event (Events .GET_BATCH_COMPLETED )
848
+ self ._maybe_terminate ()
849
+
810
850
iter_counter += 1
811
851
should_exit = False
812
852
except StopIteration :
@@ -835,29 +875,37 @@ def _run_once_on_dataset(self) -> float:
835
875
break
836
876
837
877
self ._fire_event (Events .DATALOADER_STOP_ITERATION )
838
- self ._setup_dataloader_iter ()
878
+ self ._maybe_terminate ()
839
879
880
+ self ._setup_dataloader_iter ()
840
881
should_exit = True
841
882
842
883
continue
843
884
844
885
self .state .iteration += 1
845
886
self ._fire_event (Events .ITERATION_STARTED )
887
+ self ._maybe_terminate ()
888
+
846
889
self .state .output = self ._process_function (self , self .state .batch )
847
890
self ._fire_event (Events .ITERATION_COMPLETED )
848
-
849
- if self .should_terminate or self .should_terminate_single_epoch :
850
- self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
851
- self .should_terminate_single_epoch = False
852
- self ._setup_dataloader_iter ()
853
- break
891
+ self ._maybe_terminate ()
854
892
855
893
if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
856
894
break
857
895
858
896
if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
859
897
self .should_terminate = True
860
- break
898
+ raise _EngineTerminateException ()
899
+
900
+ except _EngineTerminateSingleEpochException :
901
+ self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
902
+ self .should_terminate_single_epoch = False
903
+ self ._setup_dataloader_iter ()
904
+
905
+ except _EngineTerminateException as e :
906
+ # we need to reraise this exception such that it is not handled
907
+ # as a general exception by the code below
908
+ raise e
861
909
862
910
except Exception as e :
863
911
self .logger .error (f"Current run is terminating due to exception: { e } " )
@@ -870,3 +918,19 @@ def _get_none_data_iter(size: int) -> Iterator:
870
918
# Sized iterator for data as None
871
919
for _ in range (size ):
872
920
yield None
921
+
922
+
923
+ class _EngineTerminateSingleEpochException (Exception ):
924
+ """
925
+ Exception associated with Terminate Single Epoch event
926
+ """
927
+
928
+ pass
929
+
930
+
931
+ class _EngineTerminateException (Exception ):
932
+ """
933
+ Exception associated with Terminate event
934
+ """
935
+
936
+ pass
0 commit comments