Skip to content

Enabled Unbounded dynamism in XLA CPU C API #33092

@ipcamit

Description

@ipcamit

I am trying to run JAX exported stablehlo modules in C++, using the C API CPU plugin. I can run modules with static shapes, but for dynamic shapes (bound or unbound) I get errors. Here is what I did with a minimal reproducible sample.

  1. compiled the CPU C API plugins to get pjrt_c_api_cpu_plugin.so
bazelisk build  --enable_bzlmod  --incompatible_disallow_empty_glob=false  --strip=never  //xla/pjrt/c:pjrt_c_api_cpu
bazelisk build  --enable_bzlmod  --incompatible_disallow_empty_glob=false  --strip=never  //xla/pjrt/c:pjrt_c_api_cpu_plugin.so
  1. compiled the program below as g++ -std=c++17 -I./include run_hlo.cc -ldl -o run_hlo; and run it as ./run_hlo ./path/to/lib/pjrt_c_api_cpu_plugin.so

This is the error I got:

# dynamic error
[vector] Compile failed: during context [hlo verifier]: Unbounded dynamism is disabled for instruction: %x.1 = f64[?]{0} parameter(1), metadata={op_name="x"}

Failed after pipeline-start

# dynamic bound error
[vector] Compile failed: Custom call target PadToStatic is not implemented.

How can I run dynamic/ dynamic bound batch size like highlighted in the stablehlo documentation. [1][2]

[1] https://openxla.org/stablehlo/tutorials/jax-export
[2] https://openxla.org/stablehlo/dynamism

Example program:

#include <dlfcn.h>

#include <cstdlib>
#include <cstring>
#include <iostream>
#include <string_view>
#include <vector>

#include "include/pjrt_c_api.h"
#include "include/pjrt_c_api_cpu.h"

constexpr std::string_view kAddScalarStableHlo = R"(
module {
  func.func @main(%x: tensor<f64>, %y: tensor<f64>) -> tensor<f64> {
    %sum = stablehlo.add %x, %y : tensor<f64>
    return %sum : tensor<f64>
  }
}
)";

constexpr std::string_view kAddVectorStableHlo = R"(
#loc = loc(unknown)
#loc2 = loc("x")
#loc3 = loc("y")
module @jit_add_vec attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<?xf64> {mhlo.layout_mode = "default"} loc(unknown), %arg1: tensor<?xf64> {mhlo.layout_mode = "default"} loc(unknown)) -> (tensor<?xf64> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf64>) -> tensor<i32> loc(#loc7)
    %1 = stablehlo.convert %0 : (tensor<i32>) -> tensor<i64> loc(#loc7)
    %2 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor<?xf64>) -> tensor<i32> loc(#loc7)
    %3 = stablehlo.convert %2 : (tensor<i32>) -> tensor<i64> loc(#loc7)
    %c = stablehlo.constant dense<1> : tensor<i64> loc(#loc)
    %4 = stablehlo.compare  GE, %1, %c,  SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1> loc(#loc8)
    stablehlo.custom_call @shape_assertion(%4, %1) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'n'. Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {0} from specification 'n' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i64>) -> () loc(#loc9)
    %5 = stablehlo.compare  EQ, %3, %1,  SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1> loc(#loc10)
    stablehlo.custom_call @shape_assertion(%5, %3, %1) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Found inconsistency between dimension size args[1].shape[0] (= {0}) and the specification 'n' (= {1}). Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {1} from specification 'n' for dimension args[0].shape[0] (= {1}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i64>, tensor<i64>) -> () loc(#loc11)
    %6 = call @_wrapped_jax_export_main(%1, %arg0, %arg1) : (tensor<i64>, tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64> loc(#loc)
    return %6 : tensor<?xf64> loc(#loc)
  } loc(#loc)
  func.func private @_wrapped_jax_export_main(%arg0: tensor<i64> {jax.global_constant = "n", mhlo.layout_mode = "default"} loc(unknown), %arg1: tensor<?xf64> {mhlo.layout_mode = "default"} loc("x"), %arg2: tensor<?xf64> {mhlo.layout_mode = "default"} loc("y")) -> (tensor<?xf64> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.add %arg1, %arg2 : tensor<?xf64> loc(#loc13)
    return %0 : tensor<?xf64> loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc1 = loc("/home//XLA/my_pjrt/minimal_add.py":20:0)
#loc4 = loc("/home//XLA/my_pjrt/minimal_add.py":10:0)
#loc5 = loc("<module>"(#loc1))
#loc6 = loc("add_vec"(#loc4))
#loc7 = loc("/dimension_size[dimension=0]"(#loc5))
#loc8 = loc("/ge"(#loc5))
#loc9 = loc("/shape_assertion[error_message=Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'n'. Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {0} from specification 'n' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.]"(#loc5))
#loc10 = loc("/eq"(#loc5))
#loc11 = loc("/shape_assertion[error_message=Input shapes do not match the polymorphic shapes specification. Found inconsistency between dimension size args[1].shape[0] (= {0}) and the specification 'n' (= {1}). Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {1} from specification 'n' for dimension args[0].shape[0] (= {1}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.]"(#loc5))
#loc12 = loc(callsite(#loc6 at #loc5))
#loc13 = loc("jit(add_vec)/jit(main)/add"(#loc12))
)";


constexpr std::string_view kAddVectorBoundedStableHlo = R"(
#loc = loc(unknown)
#loc2 = loc("x")
#loc3 = loc("y")
module @jit_add_vec attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<?xf64, #stablehlo.bounds<4>> {mhlo.layout_mode = "default"} loc(unknown), %arg1: tensor<?xf64, #stablehlo.bounds<4>> {mhlo.layout_mode = "default"} loc(unknown)) -> (tensor<?xf64, #stablehlo.bounds<4>> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf64, #stablehlo.bounds<4>>) -> tensor<i32> loc(#loc7)
    %1 = stablehlo.convert %0 : (tensor<i32>) -> tensor<i64> loc(#loc7)
    %2 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor<?xf64, #stablehlo.bounds<4>>) -> tensor<i32> loc(#loc7)
    %3 = stablehlo.convert %2 : (tensor<i32>) -> tensor<i64> loc(#loc7)
    %c = stablehlo.constant dense<1> : tensor<i64> loc(#loc)
    %4 = stablehlo.compare  GE, %1, %c,  SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1> loc(#loc8)
    stablehlo.custom_call @shape_assertion(%4, %1) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'n'. Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {0} from specification 'n' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i64>) -> () loc(#loc9)
    %5 = stablehlo.compare  EQ, %3, %1,  SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1> loc(#loc10)
    stablehlo.custom_call @shape_assertion(%5, %3, %1) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Found inconsistency between dimension size args[1].shape[0] (= {0}) and the specification 'n' (= {1}). Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {1} from specification 'n' for dimension args[0].shape[0] (= {1}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i64>, tensor<i64>) -> () loc(#loc11)
    %6 = call @_wrapped_jax_export_main(%1, %arg0, %arg1) : (tensor<i64>, tensor<?xf64, #stablehlo.bounds<4>>, tensor<?xf64, #stablehlo.bounds<4>>) -> tensor<?xf64, #stablehlo.bounds<4>> loc(#loc)
    return %6 : tensor<?xf64, #stablehlo.bounds<4>> loc(#loc)
  } loc(#loc)
  func.func private @_wrapped_jax_export_main(%arg0: tensor<i64> {jax.global_constant = "n", mhlo.layout_mode = "default"} loc(unknown), %arg1: tensor<?xf64, #stablehlo.bounds<4>> {mhlo.layout_mode = "default"} loc("x"), %arg2: tensor<?xf64, #stablehlo.bounds<4>> {mhlo.layout_mode = "default"} loc("y")) -> (tensor<?xf64, #stablehlo.bounds<4>> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.add %arg1, %arg2 : tensor<?xf64, #stablehlo.bounds<4>> loc(#loc13)
    return %0 : tensor<?xf64, #stablehlo.bounds<4>> loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc1 = loc("/home//XLA/my_pjrt/minimal_add.py":20:0)
#loc4 = loc("/home//XLA/my_pjrt/minimal_add.py":10:0)
#loc5 = loc("<module>"(#loc1))
#loc6 = loc("add_vec"(#loc4))
#loc7 = loc("/dimension_size[dimension=0]"(#loc5))
#loc8 = loc("/ge"(#loc5))
#loc9 = loc("/shape_assertion[error_message=Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'n'. Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {0} from specification 'n' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.]"(#loc5))
#loc10 = loc("/eq"(#loc5))
#loc11 = loc("/shape_assertion[error_message=Input shapes do not match the polymorphic shapes specification. Found inconsistency between dimension size args[1].shape[0] (= {0}) and the specification 'n' (= {1}). Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {1} from specification 'n' for dimension args[0].shape[0] (= {1}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.]"(#loc5))
#loc12 = loc(callsite(#loc6 at #loc5))
#loc13 = loc("jit(add_vec)/jit(main)/add"(#loc12))
)";

// from chat gpt suggestions
constexpr unsigned char kCompileOptionsProto[] = {
    0x1A, 0x0E, 0x20, 0x01, 0x28, 0x01, 0x4A, 0x08,
    0x08, 0x01, 0x10, 0x01, 0x1A, 0x02, 0x08, 0x00,
};


int main(int argc, char** argv) {
  if (argc != 2) {
    std::cerr << "usage: " << argv[0]
              << " <pjrt_plugin.so>\n"
              << "Example: " << argv[0] << " ./lib/pjrt_c_api_cpu_plugin.so\n";
    return 1;
  }

  std::cout << "=== PJRT add kernel demo ===\n";

  // Step 1. Load PJRT plugin and fetch the API table.
  void* handle = dlopen(argv[1], RTLD_NOW | RTLD_LOCAL);
  if (!handle) std::cerr << "dlopen failed LINE 76\n";
  auto get_api =
      reinterpret_cast<const PJRT_Api* (*)()>(dlsym(handle, "GetPjrtApi"));
  if (!get_api) std::cerr << "dlsym(GetPjrtApi) failed\n";
  const PJRT_Api* api = get_api();
  std::cout << "[1] PJRT API struct size: " << api->struct_size << "\n";

  // Step 2. Create a PJRT client (CPU backend in this example).
  PJRT_Client_Create_Args cargs{};
  cargs.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE;
  auto err = api->PJRT_Client_Create(&cargs);
  PJRT_Client* client = cargs.client;
  std::cout << "[2] Client created\n";

  // Step 3. Grab the first addressable device.
  PJRT_Client_AddressableDevices_Args dargs{};
  dargs.struct_size = PJRT_Client_AddressableDevices_Args_STRUCT_SIZE;
  dargs.client = client;
  err = api->PJRT_Client_AddressableDevices(&dargs);
  if (dargs.num_addressable_devices == 0) std::cerr << "No devices found\n";
  PJRT_Device* device = dargs.addressable_devices[0];
  std::cout << "[3] Using device[0]\n";

  // Step 4. Compile the embedded StableHLO.
  PJRT_Program program{};
  program.struct_size = PJRT_Program_STRUCT_SIZE;
  program.code = const_cast<char*>(kAddScalarStableHlo.data());
  program.code_size = kAddScalarStableHlo.size();
  program.format = "mlir";
  program.format_size = 4;

  PJRT_Client_Compile_Args compile_args{};
  compile_args.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE;
  compile_args.client = client;
  compile_args.program = &program;
  compile_args.compile_options =
      reinterpret_cast<const char*>(kCompileOptionsProto);
  compile_args.compile_options_size = sizeof(kCompileOptionsProto);
  err = api->PJRT_Client_Compile(&compile_args);
  PJRT_LoadedExecutable* exec = compile_args.executable;
  std::cout << "[4] Compilation OK\n";

  // Step 5. Parse inputs and upload to device.
  auto lhs = std::vector<double>{2};
  auto rhs = std::vector<double>{3};
  std::cout << "[5] lhs = " << lhs[0] << ", rhs = " << rhs[0] << "\n";

  auto upload_scalar = [&](double value) -> PJRT_Buffer* {
    PJRT_Client_BufferFromHostBuffer_Args bargs{};
    bargs.struct_size = PJRT_Client_BufferFromHostBuffer_Args_STRUCT_SIZE;
    bargs.client = client;
    bargs.data = &value;
    bargs.type = PJRT_Buffer_Type_F64;
    bargs.dims = nullptr;  // scalar
    bargs.num_dims = 0;
    bargs.device = device;
    err = api->PJRT_Client_BufferFromHostBuffer(&bargs);
    if (bargs.done_with_host_buffer) {
      PJRT_Event_Await_Args await_args{};
      await_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE;
      await_args.event = bargs.done_with_host_buffer;
      err = api->PJRT_Event_Await(&await_args);
      PJRT_Event_Destroy_Args destroy_args{};
      destroy_args.struct_size = PJRT_Event_Destroy_Args_STRUCT_SIZE;
      destroy_args.event = bargs.done_with_host_buffer;
      api->PJRT_Event_Destroy(&destroy_args);
    }
    return bargs.buffer;
  };

  PJRT_Buffer* in0 = upload_scalar(lhs[0]);
  PJRT_Buffer* in1 = upload_scalar(rhs[0]);

  // Step 6. Execute the compiled program.
  PJRT_Buffer* inputs[] = {in0, in1};
  PJRT_Buffer** argument_lists[] = {inputs};
  PJRT_Buffer* outputs[1] = {nullptr};
  PJRT_Buffer** output_lists[] = {outputs};

  PJRT_ExecuteOptions exec_opts{};
  exec_opts.struct_size = PJRT_ExecuteOptions_STRUCT_SIZE;

  PJRT_LoadedExecutable_Execute_Args exec_args{};
  exec_args.struct_size = PJRT_LoadedExecutable_Execute_Args_STRUCT_SIZE;
  exec_args.executable = exec;
  exec_args.options = &exec_opts;
  exec_args.argument_lists = argument_lists;
  exec_args.num_devices = 1;
  exec_args.num_args = 2;
  exec_args.output_lists = output_lists;
  exec_args.execute_device = device;
  err = api->PJRT_LoadedExecutable_Execute(&exec_args);
  std::cout << "[6] Execution complete\n";

  // Step 7. Copy result back to host.
  double result = 0.0;
  PJRT_Buffer_ToHostBuffer_Args to_host{};
  to_host.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE;
  to_host.src = outputs[0];
  to_host.dst = &result;
  to_host.dst_size = sizeof(result);
  err = api->PJRT_Buffer_ToHostBuffer(&to_host);
  if (to_host.event) {
    PJRT_Event_Await_Args await_args{};
    await_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE;
    await_args.event = to_host.event;
    err = api->PJRT_Event_Await(&await_args);
    PJRT_Event_Destroy_Args destroy_args{};
    destroy_args.struct_size = PJRT_Event_Destroy_Args_STRUCT_SIZE;
    destroy_args.event = to_host.event;
    api->PJRT_Event_Destroy(&destroy_args);
  }

  std::cout << "[7] Result = " << result << "\n";
  std::cout << "=== done ===\n";

  // failure when compiling a dynamic vector kernel
  std::cout << "\n=== Attempting dynamic vector kernel ===\n";
  PJRT_Program vec_program{};
  vec_program.struct_size = PJRT_Program_STRUCT_SIZE;
  vec_program.code = const_cast<char*>(kAddVectorStableHlo.data());
  vec_program.code_size = kAddVectorStableHlo.size();


  vec_program.format = "mlir";
  vec_program.format_size = 4;

  PJRT_Client_Compile_Args vec_compile{};
  vec_compile.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE;
  vec_compile.client = client;
  vec_compile.program = &vec_program;
  vec_compile.compile_options =
      reinterpret_cast<const char*>(kCompileOptionsProto);
  vec_compile.compile_options_size = sizeof(kCompileOptionsProto);

  if (PJRT_Error* err = api->PJRT_Client_Compile(&vec_compile)) {
    PJRT_Error_Message_Args msg{};
    msg.struct_size = PJRT_Error_Message_Args_STRUCT_SIZE;
    msg.error = err;
    api->PJRT_Error_Message(&msg);
    std::string message(msg.message, msg.message_size);
    std::cout << "[vector] Compile failed: " << message << "\n";

    PJRT_Error_Destroy_Args destroy{};
    destroy.struct_size = PJRT_Error_Destroy_Args_STRUCT_SIZE;
    destroy.error = err;
    api->PJRT_Error_Destroy(&destroy);
  } else {
    std::cout << "[vector] Unexpectedly compiled successfully.\n";
  }


  std::cout << "\n=== Attempting  bounded dynamic vector kernel ===\n";
  vec_program.struct_size = PJRT_Program_STRUCT_SIZE;
  vec_program.code = const_cast<char*>(kAddVectorBoundedStableHlo.data());
  vec_program.code_size = kAddVectorBoundedStableHlo.size();


  vec_program.format = "mlir";
  vec_program.format_size = 4;

  vec_compile.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE;
  vec_compile.client = client;
  vec_compile.program = &vec_program;
  vec_compile.compile_options =
      reinterpret_cast<const char*>(kCompileOptionsProto);
  vec_compile.compile_options_size = sizeof(kCompileOptionsProto);

  if (PJRT_Error* err = api->PJRT_Client_Compile(&vec_compile)) {
    PJRT_Error_Message_Args msg{};
    msg.struct_size = PJRT_Error_Message_Args_STRUCT_SIZE;
    msg.error = err;
    api->PJRT_Error_Message(&msg);
    std::string message(msg.message, msg.message_size);
    std::cout << "[vector] Compile failed: " << message << "\n";

    PJRT_Error_Destroy_Args destroy{};
    destroy.struct_size = PJRT_Error_Destroy_Args_STRUCT_SIZE;
    destroy.error = err;
    api->PJRT_Error_Destroy(&destroy);
  } else {
    std::cout << "[vector] Unexpectedly compiled successfully.\n";
  }


  return 0;
}

Metadata

Metadata

Labels

CPURelated to XLA on CPUerr:BuildBuild or compilation failed

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions