-
Notifications
You must be signed in to change notification settings - Fork 165
Description
Request description
Hi all,
After following the following Jax FFI tutorial, I tried to re-export it to MLIR (using jax.export). This resulted in the following MLIR code using the stablehlo dialect:
#loc1 = loc("x")
module @jit_rms_norm attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<10x10xf32> loc("x")) -> (tensor<10x10xf32> {jax.result_info = "result"}) {
%0 = stablehlo.custom_call @rms_norm(%arg0) {mhlo.backend_config = {eps = 9.99999974E-6 : f32}, operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<10x10xf32>) -> tensor<10x10xf32> loc(#loc7)
return %0 : tensor<10x10xf32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("/Users/tudoroancea/dev/piqp_jax/src/piqp_jax/rms_norm.py":37:11 to :39)
#loc3 = loc("/Users/tudoroancea/dev/piqp_jax/test/test_export_rms_norm.py":8:4 to 15:18)
#loc4 = loc("rms_norm"(#loc2))
#loc5 = loc("<module>"(#loc3))
#loc6 = loc(callsite(#loc4 at #loc5))
#loc7 = loc("jit(rms_norm)/jit(main)/ffi_call"(#loc6))I was wondering if there is a way to lower the stablehlo.custom_call that a standard MLIR equivalent (like func.call ) that would use an externally defined function only available after linking against a shared library defining the symbol. Until now, I have lowered stablehlo to standard MLIR dialects using IREE with iree-opt --iree-stablehlo-input-transformation-pipeline, but for this code it results in the following error
/Users/tudoroancea/dev/piqp_jax/src/piqp_jax/rms_norm.py:37:11: error: failed to legalize operation 'stablehlo.custom_call' that was explicitly marked illegal
return call(x, eps=np.float32(eps))
^
/Users/tudoroancea/dev/piqp_jax/test/test_export_rms_norm.py:8:4: note: called from
jexport.export(
^
/Users/tudoroancea/dev/piqp_jax/src/piqp_jax/rms_norm.py:37:11: note: see current operation: %0 = "stablehlo.custom_call"(%arg0) <{call_target_name = "rms_norm", operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]}> {mhlo.backend_config = {eps = 9.99999974E-6 : f32}} : (tensor<10x10xf32>) -> tensor<10x10xf32>
return call(x, eps=np.float32(eps))
^
My end goal would be to create a wrapper for a C/C++ library that we can both call in Jax using FFI, and also write functions to be exported to MLIR and compiled.
Thank you in advance for your help.
Cheers,
Ted
P.S.: I am quite new to the world of XLA/MLIR/IREE, so forgive me if I didn't post this question in the write forum