Skip to content

Add instruction for exporting inlined constant #8707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions docs/source/features/stablehlo.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ There are 2 ways to accomplish this:
from torch.export import export
import torchvision
import torch
import torch_xla2 as tx
import torch_xla2.export
import torchax as tx
import torchax.export

resnet18 = torchvision.models.resnet18()
# Sample input is a tuple
Expand Down Expand Up @@ -64,6 +64,29 @@ print(stablehlo.mlir_module())
The second to last line we used `jax.ShapedDtypeStruct` to specify the input shape.
You can also pass a numpy array here.

### Inline some weights in generated stablehlo

You can inline some or all of your model's weights into the StableHLO graph as constants by exporting a separate function that calls your model.

The convention used in `jax.jit` is all the input of the `jit`ed Python
functions are exported as parameters, everything else are inlined as constants.

So as above, the function we exported `jfunc` takes `weights` and `args` as input, so
they appear as paramters.

If you do this instead:

```
def jfunc_inlined(args):
return jfunc(weights, args)
```
and export / print out stablehlo for that:

```
print(jax.jit(jfunc_inlined).lower((jax.ShapedDtypeStruct((4, 3, 224, 224), jnp.float32.dtype, ))))
```
Then, you will see inlined constants.


## Preserving High-Level PyTorch Operations in StableHLO by generating `stablehlo.composite`

Expand Down
Loading