Skip to content

Lowering stablehlo.custom_call created from Jax FFI #2820

@tudoroancea

Description

@tudoroancea

Request description

Hi all,

After following the following Jax FFI tutorial, I tried to re-export it to MLIR (using jax.export). This resulted in the following MLIR code using the stablehlo dialect:

#loc1 = loc("x")
module @jit_rms_norm attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<10x10xf32> loc("x")) -> (tensor<10x10xf32> {jax.result_info = "result"}) {
    %0 = stablehlo.custom_call @rms_norm(%arg0) {mhlo.backend_config = {eps = 9.99999974E-6 : f32}, operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<10x10xf32>) -> tensor<10x10xf32> loc(#loc7)
    return %0 : tensor<10x10xf32> loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc2 = loc("/Users/tudoroancea/dev/piqp_jax/src/piqp_jax/rms_norm.py":37:11 to :39)
#loc3 = loc("/Users/tudoroancea/dev/piqp_jax/test/test_export_rms_norm.py":8:4 to 15:18)
#loc4 = loc("rms_norm"(#loc2))
#loc5 = loc("<module>"(#loc3))
#loc6 = loc(callsite(#loc4 at #loc5))
#loc7 = loc("jit(rms_norm)/jit(main)/ffi_call"(#loc6))

I was wondering if there is a way to lower the stablehlo.custom_call that a standard MLIR equivalent (like func.call ) that would use an externally defined function only available after linking against a shared library defining the symbol. Until now, I have lowered stablehlo to standard MLIR dialects using IREE with iree-opt --iree-stablehlo-input-transformation-pipeline, but for this code it results in the following error

/Users/tudoroancea/dev/piqp_jax/src/piqp_jax/rms_norm.py:37:11: error: failed to legalize operation 'stablehlo.custom_call' that was explicitly marked illegal
    return call(x, eps=np.float32(eps))
          ^
/Users/tudoroancea/dev/piqp_jax/test/test_export_rms_norm.py:8:4: note: called from
    jexport.export(
   ^
/Users/tudoroancea/dev/piqp_jax/src/piqp_jax/rms_norm.py:37:11: note: see current operation: %0 = "stablehlo.custom_call"(%arg0) <{call_target_name = "rms_norm", operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]}> {mhlo.backend_config = {eps = 9.99999974E-6 : f32}} : (tensor<10x10xf32>) -> tensor<10x10xf32>
    return call(x, eps=np.float32(eps))
          ^

My end goal would be to create a wrapper for a C/C++ library that we can both call in Jax using FFI, and also write functions to be exported to MLIR and compiled.

Thank you in advance for your help.

Cheers,

Ted

P.S.: I am quite new to the world of XLA/MLIR/IREE, so forgive me if I didn't post this question in the write forum

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions