-
Notifications
You must be signed in to change notification settings - Fork 358
Using torch.compile with einops
Alex Rogozhnikov edited this page Sep 17, 2024
·
4 revisions
Pytorch 2.0 introduce torch.compile
, which 'compiles' python code into graphs.
- if you use einops layers (
Rearrange
,Reduce
,Einmix
) - no action needed, they perfectly work withtorch.compile
,torch.jit.script
, andtorch.jit.trace
- if you use einops functions (
rearrange
,reduce
,repeat
,einsum
,pack
,unpack
), you need to allow ops in graph:
from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1
allow_ops_in_compiled_graph()
torch>=2.4 would make this call for you.
If you use torch.compile
without calling allow_ops_in_compiled_graph
first, torch.compile will break graph on einops functions.
This causes significant slowdown.
In experiments with transformers (see https://github.com/arogozhnikov/einops/issues/250#issuecomment-1508138804 for details) we see that torch.compile equally well optimizes plain pytorch, einops layers and einops functions. But it has to be informed that einops functions are 'well-behaved' and can be included in the graph. That's why you need to call allow_ops_in_compiled_graph
first.