Skip to content

Commit 7cf9ccd

Browse files
committed
Move most killing logic to process
The killing process is very convoluted due to being partially performed in `tasks.py:Waiting` and `process.py:Process`. The architecture tried to split the killing process in two parts, one responsible for cancelling the job in the scheduler in (`tasks.py:Waiting`), one responsible for killing the process transitioning it to the KILLED state. Here a summary of these two steps Killing the plumpy calcjob/process:Process Event: KillMessage (through rabbitmq by through verdi) kill -> self.runner.controller.kill_process # (sending message to kill) Killing the scheduler job calcjob/tasks:Waiting (The task running the actual CalcJob) Event: CalcJobMonitorAction.KILL (through monitoring), KillInterrupt (through verdi) execute --> _kill_job -> task_kill_job -> do_kill -> execmanager.kill_calculation In this PR I am moving most of the killing logic to the process to simplify the design. This is required to fix a bug that appears when two killing commands are sent. The first killing command is sending the KillInterruption (within `process.py:Process`, part of the logic in parent class) to the `tasks.py:Waiting` that receives it and start the cancelling of the scheduler job. Since this is only triggered through a try-catch block of the `KillInterruption` it cannot be repeated when a second kill command is invoked by the user. This bug was introduced by PR TODO (the one introduced force kill), because it also started to fix the timeout issue (verdi process kill is ignoring the timeout). Moving all killing logic to the process as done in this PR solves the problem as we completely moved the cancelation of the job is reinvoked in the process class. This is the function that is invoked when a worker receives a kill message through RMQ. I put very verbose comments for the review that I will remove later. I must say the kill process seems not well tested as I had not to adapt much in the tests. The tests in `test_work_chain.py` need some adaption to also be able to kill a scheduler job in a dummy manner.
1 parent e257b3c commit 7cf9ccd

File tree

6 files changed

+284
-92
lines changed

6 files changed

+284
-92
lines changed

src/aiida/engine/processes/calcjobs/tasks.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override
543543
monitor_result = await self._monitor_job(node, transport_queue, self.monitors)
544544

545545
if monitor_result and monitor_result.action is CalcJobMonitorAction.KILL:
546-
await self._kill_job(node, transport_queue)
546+
await self.kill_job(node, transport_queue)
547547
job_done = True
548548

549549
if monitor_result and not monitor_result.retrieve:
@@ -582,7 +582,6 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override
582582
except TransportTaskException as exception:
583583
raise plumpy.process_states.PauseInterruption(f'Pausing after failed transport task: {exception}')
584584
except plumpy.process_states.KillInterruption as exception:
585-
await self._kill_job(node, transport_queue)
586585
node.set_process_status(str(exception))
587586
return self.retrieve(monitor_result=self._monitor_result)
588587
except (plumpy.futures.CancelledError, asyncio.CancelledError):
@@ -594,10 +593,13 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override
594593
else:
595594
node.set_process_status(None)
596595
return result
597-
finally:
598-
# If we were trying to kill but we didn't deal with it, make sure it's set here
599-
if self._killing and not self._killing.done():
600-
self._killing.set_result(False)
596+
# PR_COMMENT We do not use the KillInterruption anymore to kill the job here as we kill the job where the KillInterruption is sent
597+
# TODO remove
598+
# finally:
599+
# # If we were trying to kill but we didn't deal with it, make sure it's set here
600+
# #if self._killing and not self._killing.done():
601+
# # self._killing.set_result(False)
602+
# pass
601603

602604
async def _monitor_job(self, node, transport_queue, monitors) -> CalcJobMonitorResult | None:
603605
"""Process job monitors if any were specified as inputs."""
@@ -622,7 +624,7 @@ async def _monitor_job(self, node, transport_queue, monitors) -> CalcJobMonitorR
622624

623625
return monitor_result
624626

625-
async def _kill_job(self, node, transport_queue) -> None:
627+
async def kill_job(self, node, transport_queue) -> None:
626628
"""Kill the job."""
627629
await self._launch_task(task_kill_job, node, transport_queue)
628630
if self._killing is not None:

src/aiida/engine/processes/process.py

Lines changed: 149 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from aiida.common.links import LinkType
5353
from aiida.common.log import LOG_LEVEL_REPORT
5454
from aiida.orm.implementation.utils import clean_value
55+
from aiida.orm.nodes.process.calculation.calcjob import CalcJobNode
5556
from aiida.orm.utils import serialize
5657

5758
from .builder import ProcessBuilder
@@ -329,50 +330,162 @@ def load_instance_state(
329330

330331
self.node.logger.info(f'Loaded process<{self.node.pk}> from saved state')
331332

333+
async def _launch_task(self, coro, *args, **kwargs):
334+
"""Launch a coroutine as a task, making sure to make it interruptable."""
335+
import functools
336+
337+
from aiida.engine.utils import interruptable_task
338+
339+
task_fn = functools.partial(coro, *args, **kwargs)
340+
try:
341+
self._task = interruptable_task(task_fn)
342+
result = await self._task
343+
return result
344+
finally:
345+
self._task = None
346+
332347
def kill(self, msg_text: str | None = None, force_kill: bool = False) -> Union[bool, plumpy.futures.Future]:
333348
"""Kill the process and all the children calculations it called
334349
335350
:param msg: message
336351
"""
337-
self.node.logger.info(f'Request to kill Process<{self.node.pk}>')
338-
339-
had_been_terminated = self.has_terminated()
340-
341-
result = super().kill(msg_text, force_kill)
352+
# breakpoint()
353+
if self.killed():
354+
self.node.logger.info(f'Request to kill Process<{self.node.pk}> but process has already been killed.')
355+
return True
356+
elif self.has_terminated():
357+
self.node.logger.info(f'Request to kill Process<{self.node.pk}> but process has already terminated.')
358+
return False
359+
self.node.logger.info(f'Request to kill Process<{self.node.pk}>.')
360+
361+
# PR_COMMENT We need to kill the children now before because we transition to kill after the first kill
362+
# This became buggy in the last PR by allowing the user to reusing killing commands (if _killing do
363+
# nothing). Since we want to now allow the user to resend killing commands with different options we
364+
# have to kill first the children, or we still kill the children even when this process has been
365+
# killed. Otherwise you have the problematic scenario: Process is killed but did not kill the
366+
# children yet, kill timeouts, we kill again, but the parent process is already killed so it will
367+
# never enter this code
368+
#
369+
# TODO if tests just pass it could mean that this is not well tested, need to check if there is a test
370+
371+
# TODO
372+
# this blocks worker and it cannot be unblocked
373+
# need async await maybe
374+
375+
killing = []
376+
# breakpoint()
377+
for child in self.node.called:
378+
if self.runner.controller is None:
379+
self.logger.info('no controller available to kill child<%s>', child.pk)
380+
continue
381+
try:
382+
# we block for sending message
342383

343-
# Only kill children if we could be killed ourselves
344-
if result is not False and not had_been_terminated:
345-
killing = []
346-
for child in self.node.called:
347-
if self.runner.controller is None:
348-
self.logger.info('no controller available to kill child<%s>', child.pk)
349-
continue
350-
try:
351-
result = self.runner.controller.kill_process(child.pk, msg_text=f'Killed by parent<{self.node.pk}>')
352-
result = asyncio.wrap_future(result) # type: ignore[arg-type]
353-
if asyncio.isfuture(result):
354-
killing.append(result)
355-
except ConnectionClosed:
356-
self.logger.info('no connection available to kill child<%s>', child.pk)
357-
except UnroutableError:
358-
self.logger.info('kill signal was unable to reach child<%s>', child.pk)
359-
360-
if asyncio.isfuture(result):
361-
# We ourselves are waiting to be killed so add it to the list
362-
killing.append(result)
363-
364-
if killing:
384+
# result = self.loop.run_until_complete(coro)
385+
# breakpoint()
386+
result = self.runner.controller.kill_process(
387+
child.pk, msg_text=f'Killed by parent<{self.node.pk}>', force_kill=force_kill
388+
)
389+
from plumpy.futures import unwrap_kiwi_future
390+
391+
killing.append(unwrap_kiwi_future(result))
392+
breakpoint()
393+
# result = unwrapped_future.result(timeout=5)
394+
# result = asyncio.wrap_future(result) # type: ignore[arg-type]
395+
# PR_COMMENT I commented out, we wrap it before to an asyncio future why the if check?
396+
# if asyncio.isfuture(result):
397+
# killing.append(result)
398+
except ConnectionClosed:
399+
self.logger.info('no connection available to kill child<%s>', child.pk)
400+
except UnroutableError:
401+
self.logger.info('kill signal was unable to reach child<%s>', child.pk)
402+
403+
# TODO need to check this part, might be overengineered
404+
# if asyncio.isfuture(result):
405+
# # We ourselves are waiting to be killed so add it to the list
406+
# killing.append(result)
407+
408+
####### KILL TWO
409+
if not force_kill:
410+
# asyncio.send(continue_kill)
411+
# return
412+
for pending_future in killing:
413+
# breakpoint()
414+
result = pending_future.result()
365415
# We are waiting for things to be killed, so return the 'gathered' future
366-
kill_future = plumpy.futures.gather(*killing)
367-
result = self.loop.create_future()
368416

369-
def done(done_future: plumpy.futures.Future):
370-
is_all_killed = all(done_future.result())
371-
result.set_result(is_all_killed)
372-
373-
kill_future.add_done_callback(done)
374-
375-
return result
417+
# kill_future = plumpy.futures.gather(*killing)
418+
# result = self.loop.create_future()
419+
# breakpoint()
420+
421+
# def done(done_future: plumpy.futures.Future):
422+
# is_all_killed = all(done_future.result())
423+
# result.set_result(is_all_killed)
424+
425+
# kill_future.add_done_callback(done)
426+
427+
# PR_COMMENT We do not do this anymore. The original idea was to resend the killing interruption so the state
428+
# can continue freeing its resources using an EBM with new parameters as the user can change these
429+
# between kills by changing the config parameters. However this was not working properly because the
430+
# process state goes only the first time it receives a KillInterruption into the EBM. This is because
431+
# the EBM is activated within try-catch block.
432+
# try:
433+
# do_work() # <-- now we send the interrupt exception
434+
# except KillInterruption:
435+
# cancel_scheduler_job_in_ebm # <-- if we cancel it will just stop this
436+
#
437+
# Not sure why I did not detect this during my tries. We could also do a while loop of interrupts
438+
# but I think it is generally not good design that the process state cancels the scheduler job while
439+
# here we kill the process. It adds another actor responsible for killing the process correctly
440+
# making it more complex than necessary.
441+
#
442+
# Cancel any old killing command to send a new one
443+
# if self._killing:
444+
# self._killing.cancel()
445+
446+
# Send kill interruption to the tasks in the event loop so they stop
447+
# This is not blocking, so the interruption is happening concurrently
448+
if self._stepping:
449+
# Ask the step function to pause by setting this flag and giving the
450+
# caller back a future
451+
interrupt_exception = plumpy.process_states.KillInterruption(msg_text, force_kill)
452+
# PR COMMENT we do not set interrupt action because plumpy is very smart it uses the interrupt action to set
453+
# next state in the stepping, but we do not want to step to the next state through the plumpy
454+
# state machine, we want to control this here and only here
455+
# self._set_interrupt_action_from_exception(interrupt_exception)
456+
# self._killing = self._interrupt_action
457+
self._state.interrupt(interrupt_exception)
458+
# return cast(plumpy.futures.CancellableAction, self._interrupt_action)
459+
460+
# Kill jobs from scheduler associated with this process.
461+
# This is blocking so we only continue when the scheduler job has been killed.
462+
if not force_kill and isinstance(self.node, CalcJobNode):
463+
# TODO put this function into more common place
464+
from .calcjobs.tasks import task_kill_job
465+
466+
# if already killing we have triggered the Interruption
467+
coro = self._launch_task(task_kill_job, self.node, self.runner.transport)
468+
task = asyncio.create_task(coro)
469+
# task_kill_job is raising an error if not successful, e.g. EBM fails.
470+
# PR COMMENT we just return False and write why the kill fails, it does not make sense to me to put the
471+
# process to excepted. Maybe you fix your internet connection and want to try it again.
472+
# We have force-kill now if the user wants to enforce a killing
473+
try:
474+
# breakpoint()
475+
self.loop.run_until_complete(task)
476+
# breakpoint()
477+
except Exception as exc:
478+
self.node.logger.error(f'While cancelling job error was raised: {exc!s}')
479+
# breakpoint()
480+
return False
481+
482+
# Transition to killed process state
483+
# This is blocking so we only continue when we are in killed state
484+
msg = plumpy.process_comms.MessageBuilder.kill(text=msg_text, force_kill=force_kill)
485+
new_state = self._create_state_instance(plumpy.process_states.ProcessState.KILLED, msg=msg)
486+
self.transition_to(new_state)
487+
488+
return True
376489

377490
@override
378491
def out(self, output_port: str, value: Any = None) -> None:

src/aiida/workchain.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# TODO this class needs to be removed
2+
3+
from aiida.engine import ToContext, WorkChain
4+
from aiida.orm import Bool
5+
6+
7+
class MainWorkChain(WorkChain):
8+
@classmethod
9+
def define(cls, spec):
10+
super().define(spec)
11+
spec.input('kill', default=lambda: Bool(False))
12+
spec.outline(cls.submit_child, cls.check)
13+
14+
def submit_child(self):
15+
return ToContext(child=self.submit(SubWorkChain, kill=self.inputs.kill))
16+
17+
def check(self):
18+
raise RuntimeError('should have been aborted by now')
19+
20+
21+
class SubWorkChain(WorkChain):
22+
@classmethod
23+
def define(cls, spec):
24+
super().define(spec)
25+
spec.input('kill', default=lambda: Bool(False))
26+
spec.outline(cls.begin, cls.check)
27+
28+
def begin(self):
29+
"""If the Main should be killed, pause the child to give the Main a chance to call kill on its children"""
30+
if self.inputs.kill:
31+
self.pause()
32+
33+
def check(self):
34+
raise RuntimeError('should have been aborted by now')

tests/cmdline/commands/test_process.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pathlib import Path
1818

1919
import pytest
20+
from tests.conftest import await_condition
2021

2122
from aiida import get_profile
2223
from aiida.cmdline.commands import cmd_process
@@ -116,18 +117,6 @@ def fork_worker(func, func_args):
116117
client.increase_workers(nb_workers)
117118

118119

119-
def await_condition(condition: t.Callable, timeout: int = 1) -> t.Any:
120-
"""Wait for the ``condition`` to evaluate to ``True`` within the ``timeout`` or raise."""
121-
start_time = time.time()
122-
123-
while not (result := condition()):
124-
if time.time() - start_time > timeout:
125-
raise RuntimeError(f'waiting for {condition} to evaluate to `True` timed out after {timeout} seconds.')
126-
time.sleep(0.1)
127-
128-
return result
129-
130-
131120
@pytest.mark.requires_rmq
132121
@pytest.mark.usefixtures('started_daemon_client')
133122
def test_process_kill_failing_transport(
@@ -213,7 +202,7 @@ def make_a_builder(sleep_seconds=0):
213202

214203
@pytest.mark.requires_rmq
215204
@pytest.mark.usefixtures('started_daemon_client')
216-
def test_process_kill_failng_ebm(
205+
def test_process_kill_failing_ebm(
217206
fork_worker_context, submit_and_await, aiida_code_installed, run_cli_command, monkeypatch
218207
):
219208
"""9) Kill a process that is paused after EBM (5 times failed). It should be possible to kill it normally.
@@ -232,6 +221,7 @@ def make_a_builder(sleep_seconds=0):
232221

233222
kill_timeout = 10
234223

224+
# TODO instead of mocking it why didn't we just set the paramaters to 1 second?
235225
monkeypatch_args = ('aiida.engine.utils.exponential_backoff_retry', MockFunctions.mock_exponential_backoff_retry)
236226
with fork_worker_context(monkeypatch.setattr, monkeypatch_args):
237227
node = submit_and_await(make_a_builder(), ProcessState.WAITING)
@@ -242,6 +232,11 @@ def make_a_builder(sleep_seconds=0):
242232
)
243233

244234
run_cli_command(cmd_process.process_kill, [str(node.pk), '--wait'])
235+
# It should *not* be killable after the EBM expected
236+
await_condition(lambda: not node.is_killed, timeout=kill_timeout)
237+
238+
# It should be killable with the force kill option
239+
run_cli_command(cmd_process.process_kill, [str(node.pk), '-F', '--wait'])
245240
await_condition(lambda: node.is_killed, timeout=kill_timeout)
246241

247242

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,3 +953,15 @@ def cat_path() -> Path:
953953
run_process = subprocess.run(['which', 'cat'], capture_output=True, check=True)
954954
path = run_process.stdout.decode('utf-8').strip()
955955
return Path(path)
956+
957+
def await_condition(condition: t.Callable, timeout: int = 1) -> t.Any:
958+
"""Wait for the ``condition`` to evaluate to ``True`` within the ``timeout`` or raise."""
959+
import time
960+
start_time = time.time()
961+
962+
while not (result := condition()):
963+
if time.time() - start_time > timeout:
964+
raise RuntimeError(f'waiting for {condition} to evaluate to `True` timed out after {timeout} seconds.')
965+
time.sleep(0.1)
966+
967+
return result

0 commit comments

Comments
 (0)