This repository has been archived by the owner on May 11, 2023. It is now read-only.
feat: Allowing the use of Hyper-Efficient Kernel Operator Library. #56
Labels
enhancement
New feature or request
I wondered if you aware / had considered making use of a third-party hyper-efficient kernel library such KeOps or Triton.
KeOps currently has bindings for PyTorch and you can use it with GPyTorch for defining kernels. Although there is no current JAX support, I believe it is probably possible via the custom_call functionality now available within JAX to allow interaction with C++ libraries. It is both extremely fast and capable of processing massive matrix operations that wouldn't normally fit into memory.
Alternatively, there is already a wrapper for JAX around Triton (triton-lang.org), which is language created by OpenAI for high-performance kernel computations. I believe JAX is already leveraging Triton for some internal operations. Effectively it allows you to define a custom kernel in the way you can do with CUDA, but with added advantages (one being you can trivially include it with JAX).
The text was updated successfully, but these errors were encountered: