Skip to content
Discussion options

You must be logged in to vote

Unlike for example PyTorch we haven't put much effort in reducing overhead. This has a few reasons:

  1. Python is slow, so fast code eventually tends to move to c++ making the implementation much more complex
  2. Jax itself has large dispatch overheads
  3. Most importantly, the dispatch overheads disappear once you compile you code. So using jax.jit on any of the examples you gave should lead to the same compiled program being generated with the exact same performance.

The reason why in this case the overhead is so high is because similar to other frameworks the compute is done async from the python thread. Normally the tensors are big enough to mask any overhead from the python interpreter. Still j…

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@GJBoth
Comment options

@jheek
Comment options

jheek Mar 4, 2021
Maintainer

@8bitmp3
Comment options

@GJBoth
Comment options

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants
Converted from issue

This discussion was converted from issue #1080 on March 04, 2021 13:26.