Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

feat: Allowing the use of Hyper-Efficient Kernel Operator Library. #56

Open
adam-hartshorne opened this issue Mar 25, 2023 · 0 comments
Labels
enhancement New feature or request

Comments

@adam-hartshorne
Copy link

I wondered if you aware / had considered making use of a third-party hyper-efficient kernel library such KeOps or Triton.

KeOps currently has bindings for PyTorch and you can use it with GPyTorch for defining kernels. Although there is no current JAX support, I believe it is probably possible via the custom_call functionality now available within JAX to allow interaction with C++ libraries. It is both extremely fast and capable of processing massive matrix operations that wouldn't normally fit into memory.

Alternatively, there is already a wrapper for JAX around Triton (triton-lang.org), which is language created by OpenAI for high-performance kernel computations. I believe JAX is already leveraging Triton for some internal operations. Effectively it allows you to define a custom kernel in the way you can do with CUDA, but with added advantages (one being you can trivially include it with JAX).

@adam-hartshorne adam-hartshorne added the enhancement New feature or request label Mar 25, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant