TL;DR: Here's an implementation of setjmp() and longjmp() for Python 3.6+. The state of the program can be serialized to disk or over the network. There are a few other twists and Python flavors.
If you write data processing pipelines that take hours to run on your desktop, you live in fear of crashes. Maybe you've invested in some hand-woven checkpointing code.
This package helps you checkpoint your pipeline's work automatically. When your pipeline crashes, this package helps you restart it from the last checkpoint. As a bonus, if you modify your code to fix the crash, this package resumes it from a known good checkpoint that does not get affected by your code change.
Your code ends up looking like this:
import function_checkpointing as ckpt
def sub_stage1():
for step in range(5):
... do some work...
# Snapshot the state of this function to disk
ckpt.save_checkpoint(f"snapshot sub_stage1 {step}")
def processing():
sub_stage1() # Do some heavy lifting
... do some more work...
# Snapshot the state of this function to disk
ckpt.save_checkpoint("snapshot 2")
... do some more work...
You can drop a checkpoint anywhere in your code. You can then call
resume_from_checkpoint
to resume your code with the state it had at the
corresponding save_checkpoint
call:
ckpt.resume_from_checkpoint("snapshot 2")
The function save_checkpoint
snapshots to disk the state of your program as
of when it's called to disk. When resume_from_checkpoint
restores your
program from this checkpoint, save_checkpoint
appears to just be returning with
an empty return value.
A disclaimer before we get too far: It's nearly impossible to automatically fully checkpoint a program. That would require checkpointing the state of the resources your program depends on too, like the disk files and database servers it accesses, and the state of the GPU. That becomes very hard unless all these external resources provide a standardized snapshotting mechanism. You may need to store some of the state manually. This particular package also has some intentional design limitations.
-
Deep Learning: You probably don't need this package if you just want to checkpoint an off-the-shelf deep learning training loop after every training epoch. The state of these loops is already well encapsulated in the state of the optimizer, the model, and the epoch counter. These can be painlessly serialized and deserialized. But if you have a complicated training loop with several different expensive phases, this package might be worth your while.
-
Implement "Undo" in your application / let the checkpoint become the document: Supporting "undo" in complicated applications means writing your program so that the document captures your program's entire state. But you could implement undo as just a jump to an old checkpoint. Once you save checkpoints, you don't need to also create or save a document: The checkpoint becomes the document and users can just move around checkpoints instead of document files.
-
Implement a reversible debugger: These are all the rage now. Each breakpoint can become a checkpoint, and you can jump back and forth between them.
-
You love call/cc: Maybe you miss call-with-current-continuation and you want it in Python. You can implement that with
save_jump
andjump
described below. -
Checkpointing as a game primitive. This probably won't make sense immediately, so maybe revisit this paragraph after you understand the structure of the stack during a jump in the next section. Imagine a video game where a returnable longjump as described in the next section is part of the gameplay. The old state of the world is allowed to nest in the new state of the world.
The main entry points of this package are save_checkpoint
and
resume_from_checkpoint
. Under the hood, these functions write and read from
pickle files, and call functions named save_jump
and jump
respectively. These
two functions are useful in their own right. If you grok them, you grok this package.
save_jump
takes no argument. It bundles up and returns the chain
of frames that led up to this call. This bundle contains the local variables,
the Python bytecode interpreter stack, and the bytecode instruction pointer at
the location where save_jump
was called. It also includes this information
for the parent frame that called save_jump
, and its caller's frame, all the
way up to the topmost frame. It then returns this stashed bundle of stack
frame as a Python object (which you can inspect and modify if you wish).
But there's a twist: save_jump
can return twice. The above is what happens
when you're taking a snapshot. But save_jump
can also return when you're
restoring a snapshot. In that case, the return value of save_jump
is an empty
list. You can use this to check if your program is being restored or if it's
checkpointing:
c = ckpt.save_jump()
if not c:
# We're resuming from a checkpoint. Reopen connection to the database, or whatever.
db = connect_to_database()
The semantics of save_jump
are similar to those of the POSIX setjmp() function.
jump
takes one argument: a bundle returned by save_jump
. It restores the
state of the frame stack to the state that save_jump
captured.
But instead of rooting the stack frames at the topmost frame, it roots it
under its own stack frame. That means when the topmost stack frame that
originated the sequence of calls that culminated in save_jump
finishes,
control returns to the stack frame that called jump
.
This is an important difference with the POSIX longjmp()
function, which
never returns.
Practically, this means the function foo below returns twice, an unusual concept in most programming languages:
import function_checkpointing.save_restore as save_restore
def foo():
c = save_restore.save_jump()
if c:
print('saved checkpoint')
save_restore.jump(c)
else:
print('restored from checkpoint')
print('foo returns')
This function prints:
saved checkpoint
restored from checkpoint
foo returns
foo returns
To appreciate this, let's call this function from another function:
def caller():
print('caller->foo')
foo()
print('foo->caller')
This function returns twice as well. Calling it produces
caller->foo
saved checkpoint
restored from checkpoint
foo returns
foo->caller
foo returns
foo->caller
Here's an illustration of what's happening. Calling caller results in a call to
foo, which results in a call to save_jump
. This diagram captures the snapshot
captured in the bundle c
:
As execution proceed down foo, we reach a call to jump
. Jump reroots the
call chain captured in c
under itself and fast forwards the Python interpreter
to the leaf frame:
There are now two instances of foo in the call stack. The rightmost frame,
foo, continues executing until it returns. Then caller
finishes executing,
then the top level module executes, and finaly, the jump
function returns. We
have reached the end of the leftmost foo, so foo returns. Notice that foo
returned twice.
If you're not enjoying this mindbending fact, you can just terminate your program
after jump returns by calling sys.exit
immediately after jump
.
There is a superficial resemblance between save_jump
and the Python yield
statement: both stash the current stack frame. But then yield
forces the
interpreter to exit the stack frame and returns control to the caller's stack
frame, essentially acting as return
. On the other hand, save_jump
saves
the entire chain of stack frames, and allows execution to remain in the current
frame. save_jump
appears as a normal function call and not at all like
return
.
Similarly, there is a superficial resemblance between jump
and calling a
generator's send
method. Calling send
resumes the generator frame's
execution at its last yield
. But jump
does more than that: it resumes
execution of the entire call chain leading up to save_jump
, not just the
single stack frame where save_jump
was called.
My earlier attempt at an automatic
checkpointing library used
generators. The practical limitation with generators is that you could only
snapshot one function (the generator), not the call chain that resulted in the
generator being called. In my previous attempt, if you wanted to snapshot
multiple layers of the call chain, you'd have to use yield from
at each
layer, which imposed an inconvenient structure on your code.
Normally, you'd want to resume your program from the latest checkpoint it created before it crashed. But after a crash you might modify your code to fix the problem that caused the crash. In such cases, instead of resuming from the latest checkpoint, this package helps you identify the latest checkpoint that incurred no code change, and to resume from that checkpoint.
To detect where your code has changed, the package provides a run-time profiler that logs the function calls in your code. The call log is stored to disk along with each checkpoint. When it's time to restore the state of the program, we identify the latest checkpoint whose call log involves no calls to any modified function.
A few functions give you control over this automatic restart mechanism.
start_call_tracing
turns on the call tracer.save_checkpoint_and_call_log
is a variant ofsave_checkpoint
that stores the call log along with the checkpoint.resume_from_last_unchanged_checkpoint
is a variant ofresume_from_checkpoint
that identifies and loads the relevant checkpoint.
See examples/detect_code_change.py for an example.
Saving the Python interpreter state is tricky. The CPython interpreter has effectively two stacks: the interpreter stack, which the bytecode instructions operate on, and the C stack, which the bytecode interpreter uses to keep track of its own state. This package manipulates the state of both stacks.
In my previous attempt, I documented why this problem is hard. The chief concern is that obtaining the top of the CPython stack seems impossible (corroborated by this excellent answer on Stackoverflow). This package does this by relying on some known facts about the depth of the python stack: When a function is called, its arguments are on the stack. We can inspect the number of arguments of the function from the particular flavor of CALL bytecode being used, and increment our estimate of the stack depth accordingly. Yet more object can be on the stack: if an exception is being handled or the caller is in the middle of a loop, the stack contains respectively the exception triplet and the loop iterator. We count the level of nesting of exceptions and loops to count these objects too.
To manipulate the C stack, this package guides the CPython bytcode
interpreter, _PyEval_EvalFrameDefault
, to call itself recursively. These
nested calls cause the C stack to attain the state it had when save_jump
was
called.
This is a hobby project, not a software engineering artifact. Its limitations are commensurate:
-
Requires Python 3.6+: It relies on PEP523, which is only available in 3.6 and later.
-
Only tested with Python 3.6 and 3.7. There's no fundamental limitation here as far as I know. These just happen to be the Pythons I have.
-
Does not snapshot global variables: Again, no fundamental limitation here as far as I know. Saving globals is simultaneously relatively straightforward and not particular urgent for me, so it's not yet implemented (it might make sense to implement it as an auxiliary package to keep this one simple).
-
Might crash. Your code could segfault. I'll fix the package if you send me a bug report. Instead of developing the perfect checkpointer with exhaustive test cases, I built the simplest checkpointer for myself, and I refine it as my usage outgrows it.
I learned how to use Cython to do this manipulation from this very helpful repo. I also copied code from it.
This Stackoverflow answer does a good job of explaining why this problem was hard to solve.