Skip to content

Commit 32ba9cd

Browse files
Rewritten Engine's terminate and terminate_epoch logic (#2645)
* Added test_engine_run_resume * Terminate/Terminate Single Epoch work on all EPOCH/ITERATION events * - terminate() work on all events, called on catched _EngineTerminateException - terminate_epoch work on iteration-based events, called on catched _EngineTerminateSingleEpochExpection - Fixed issue when attaching handlers on Events.TERMINATE_SINGLE_EPOCH * Updated docstring * Fixed issue with max_iters handling * Fixed issue with _EngineTerminateException handled as a general exception * Updated tests and docs Co-authored-by: Sadra Barikbin <[email protected]>
1 parent 48364bd commit 32ba9cd

File tree

3 files changed

+346
-56
lines changed

3 files changed

+346
-56
lines changed

ignite/engine/engine.py

Lines changed: 101 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,12 @@ def execute_something():
298298

299299
self._assert_allowed_event(event_name)
300300

301-
event_args = (Exception(),) if event_name == Events.EXCEPTION_RAISED else ()
301+
event_args = () # type: Tuple[Any, ...]
302+
if event_name == Events.EXCEPTION_RAISED:
303+
event_args += (Exception(),)
304+
elif event_name == Events.TERMINATE_SINGLE_EPOCH:
305+
event_args += (0,)
306+
302307
try:
303308
_check_signature(handler, "handler", self, *(event_args + args), **kwargs)
304309
self._event_handlers[event_name].append((handler, (self,) + args, kwargs))
@@ -433,14 +438,28 @@ def fire_event(self, event_name: Any) -> None:
433438
return self._fire_event(event_name)
434439

435440
def terminate(self) -> None:
436-
"""Sends terminate signal to the engine, so that it terminates completely the run after
437-
the current iteration."""
441+
"""Sends terminate signal to the engine, so that it terminates completely the run. The run is
442+
terminated after the event on which ``terminate`` method was called. The following events are triggered:
443+
444+
- ...
445+
- Terminating event
446+
- :attr:`~ignite.engine.events.Events.TERMINATE`
447+
- :attr:`~ignite.engine.events.Events.COMPLETED`
448+
"""
438449
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
439450
self.should_terminate = True
440451

441452
def terminate_epoch(self) -> None:
442-
"""Sends terminate signal to the engine, so that it terminates the current epoch
443-
after the current iteration."""
453+
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
454+
continues from the next epoch. The following events are triggered:
455+
456+
- ...
457+
- Event on which ``terminate_epoch`` method is called
458+
- :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH`
459+
- :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED`
460+
- :attr:`~ignite.engine.events.Events.EPOCH_STARTED`
461+
- ...
462+
"""
444463
self.logger.info(
445464
"Terminate current epoch is signaled. "
446465
"Current epoch iteration will stop after current iteration is finished."
@@ -742,33 +761,43 @@ def _internal_run(self) -> State:
742761
self.should_terminate = self.should_terminate_single_epoch = False
743762
self._init_timers(self.state)
744763
try:
745-
start_time = time.time()
746-
self._fire_event(Events.STARTED)
747-
while not self._is_done(self.state) and not self.should_terminate:
748-
self.state.epoch += 1
749-
self._fire_event(Events.EPOCH_STARTED)
750-
751-
if self._dataloader_iter is None:
752-
self._setup_engine()
753-
754-
time_taken = self._run_once_on_dataset()
755-
# time is available for handlers but must be update after fire
756-
self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
757-
handlers_start_time = time.time()
758-
if self.should_terminate:
759-
self._fire_event(Events.TERMINATE)
760-
else:
764+
try:
765+
start_time = time.time()
766+
self._fire_event(Events.STARTED)
767+
self._maybe_terminate()
768+
769+
while not self._is_done(self.state) and not self.should_terminate:
770+
self.state.epoch += 1
771+
handlers_start_time = time.time()
772+
self._fire_event(Events.EPOCH_STARTED)
773+
epoch_time_taken = time.time() - handlers_start_time
774+
self._maybe_terminate()
775+
776+
if self._dataloader_iter is None:
777+
self._setup_engine()
778+
779+
epoch_time_taken += self._run_once_on_dataset()
780+
781+
# time is available for handlers but must be updated after fire
782+
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
783+
784+
handlers_start_time = time.time()
761785
self._fire_event(Events.EPOCH_COMPLETED)
762-
time_taken += time.time() - handlers_start_time
763-
# update time wrt handlers
764-
self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
765-
hours, mins, secs = _to_hours_mins_secs(time_taken)
766-
self.logger.info(f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
767-
if self.should_terminate:
768-
break
786+
epoch_time_taken += time.time() - handlers_start_time
787+
# update time wrt handlers
788+
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
789+
self._maybe_terminate()
790+
791+
hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
792+
self.logger.info(
793+
f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}"
794+
)
795+
796+
except _EngineTerminateException:
797+
self._fire_event(Events.TERMINATE)
769798

770799
time_taken = time.time() - start_time
771-
# time is available for handlers but must be update after fire
800+
# time is available for handlers but must be updated after fire
772801
self.state.times[Events.COMPLETED.name] = time_taken
773802
handlers_start_time = time.time()
774803
self._fire_event(Events.COMPLETED)
@@ -786,6 +815,13 @@ def _internal_run(self) -> State:
786815
self._dataloader_iter = None
787816
return self.state
788817

818+
def _maybe_terminate(self) -> None:
819+
if self.should_terminate:
820+
raise _EngineTerminateException()
821+
822+
if self.should_terminate_single_epoch:
823+
raise _EngineTerminateSingleEpochException()
824+
789825
def _run_once_on_dataset(self) -> float:
790826
start_time = time.time()
791827

@@ -805,8 +841,12 @@ def _run_once_on_dataset(self) -> float:
805841
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
806842
if self.last_event_name != Events.DATALOADER_STOP_ITERATION:
807843
self._fire_event(Events.GET_BATCH_STARTED)
844+
self._maybe_terminate()
845+
808846
self.state.batch = next(self._dataloader_iter)
809847
self._fire_event(Events.GET_BATCH_COMPLETED)
848+
self._maybe_terminate()
849+
810850
iter_counter += 1
811851
should_exit = False
812852
except StopIteration:
@@ -835,29 +875,37 @@ def _run_once_on_dataset(self) -> float:
835875
break
836876

837877
self._fire_event(Events.DATALOADER_STOP_ITERATION)
838-
self._setup_dataloader_iter()
878+
self._maybe_terminate()
839879

880+
self._setup_dataloader_iter()
840881
should_exit = True
841882

842883
continue
843884

844885
self.state.iteration += 1
845886
self._fire_event(Events.ITERATION_STARTED)
887+
self._maybe_terminate()
888+
846889
self.state.output = self._process_function(self, self.state.batch)
847890
self._fire_event(Events.ITERATION_COMPLETED)
848-
849-
if self.should_terminate or self.should_terminate_single_epoch:
850-
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
851-
self.should_terminate_single_epoch = False
852-
self._setup_dataloader_iter()
853-
break
891+
self._maybe_terminate()
854892

855893
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
856894
break
857895

858896
if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
859897
self.should_terminate = True
860-
break
898+
raise _EngineTerminateException()
899+
900+
except _EngineTerminateSingleEpochException:
901+
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
902+
self.should_terminate_single_epoch = False
903+
self._setup_dataloader_iter()
904+
905+
except _EngineTerminateException as e:
906+
# we need to reraise this exception such that it is not handled
907+
# as a general exception by the code below
908+
raise e
861909

862910
except Exception as e:
863911
self.logger.error(f"Current run is terminating due to exception: {e}")
@@ -870,3 +918,19 @@ def _get_none_data_iter(size: int) -> Iterator:
870918
# Sized iterator for data as None
871919
for _ in range(size):
872920
yield None
921+
922+
923+
class _EngineTerminateSingleEpochException(Exception):
924+
"""
925+
Exception associated with Terminate Single Epoch event
926+
"""
927+
928+
pass
929+
930+
931+
class _EngineTerminateException(Exception):
932+
"""
933+
Exception associated with Terminate event
934+
"""
935+
936+
pass

ignite/engine/events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ class CustomEvents(EventEnum):
302302
"""triggered when the run is about to end completely, after receiving terminate() call."""
303303
TERMINATE_SINGLE_EPOCH = "terminate_single_epoch"
304304
"""triggered when the run is about to end the current epoch,
305-
after receiving a terminate_epoch() or terminate() call."""
305+
after receiving a terminate_epoch() call."""
306306

307307
def __or__(self, other: Any) -> "EventsList":
308308
return EventsList() | self | other

0 commit comments

Comments
 (0)