Skip to content

Synchronous Python callbacks from non-calling thread #32426

@kcdodd

Description

@kcdodd

This issue is to report a possible regression from #21980, where callbacks were changed to execute synchronously to avoid depleting the thread pool and causing a deadlock. For a while after that change it appeared that all callbacks would run within the same thread that launched the computation, but recently I've noticed callbacks sometimes run in a different thread.

It seems that the logic in PjRtLoadedExecutable still sets and respects kSynchronous, and even though ThunkExecutor chooses "sequential" execution, at any point a dependency that is not immediately ready will stage out the remaining work (the callback) in an async event. The main thread runs ahead back to the executable where it blocks until the thunk completes. I'm pretty ignorant of how threads are managed in this case, but it seems the main thread basically is suspended until the computation finishes (waiting at the block until ready). Each time a callback is staged out that thread becomes unavailable for work and multiple callbacks can still deadlock. There are other reasons that it would be preferable to have callbacks always happen in the calling thread, but that has more to do with interacting with non-jax/xla code from the callback itself.

Metadata

Metadata

Labels

bugSomething isn't workingerr: RuntimeRuntime Error

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions