-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Open
Labels
feature requestSuggest an idea on this projectSuggest an idea on this project
Description
We have already seen some examples which can use Taichi as a part of the PyTorch program. For example,
- https://github.com/ailzhang/blog_code/blob/master/rwkv/benchmark.py
- https://github.com/ifsheldon/stannum
However, is it possible to integrate Taichi into JAX?
Taichi is able to generate highly optimized operators, and it is very suitable to implement operators involving sparse computations. If Taichi kernels can be used in a JAX program, it will be interesting for broad programmers.
I think the key to the integration is the address of the compiled kernel in Taichi. There are examples that launch a GPU kernel compiled by Triton in JAX. Maybe it is straightforward for Taichi too.
femtomc, erizmr, BioGeek, salykova, robfiras and 2 moresharadmv, femtomc, soraros, chaoming0625 and robfiras
Metadata
Metadata
Assignees
Labels
feature requestSuggest an idea on this projectSuggest an idea on this project
Type
Projects
Status
Todo