Motivation
Since Pytorch become the standard deep-learning framework for AI researchers and developers, we have to find a way to be compatible with Pytorch, we have already introduced mindtorch which use mindspore basic apis to ecapsulate Pytorch-like api. But this need huge works and always need upgrade when torch cames new feature.
Few weeks ago, vLLM and Jax Team proposed a new library called torchax, which use the __torch_dispatch__ mechanism to achieve Jax backend torch and also keep the Jax native features.
How it Works
Here is a schematic diagram:

The full document url is here.
MindNLP/MindSpore tasks
torch_dispatch
aten ops mapping
Tensor zero-copy convert
new device register
global switch
distributed ops
mindspore feature support