-
|
I'm comparing the speed of performing a simple dot product with a fixed matrix 'naked' or in a flax layer. I noticed that by using flax, I'm seeing a massively decreased performance: from approximately 350 microsecond to 5-10ms. Here's a colab where I reproduce the issue. I compare a naked approach, flax (both implicit and explicit) and a python dataclass. Only with the Flax classes do I see the performance hit. Am I timing wrong (I used .block_until_ready() everywhere), or am I making a wrong comparison? Or is it something with Flax? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
|
Unlike for example PyTorch we haven't put much effort in reducing overhead. This has a few reasons:
The reason why in this case the overhead is so high is because similar to other frameworks the compute is done async from the python thread. Normally the tensors are big enough to mask any overhead from the python interpreter. Still |
Beta Was this translation helpful? Give feedback.
Unlike for example PyTorch we haven't put much effort in reducing overhead. This has a few reasons:
jax.jiton any of the examples you gave should lead to the same compiled program being generated with the exact same performance.The reason why in this case the overhead is so high is because similar to other frameworks the compute is done async from the python thread. Normally the tensors are big enough to mask any overhead from the python interpreter. Still
j…