-
Notifications
You must be signed in to change notification settings - Fork 696
Open
Labels
CPURelated to XLA on CPURelated to XLA on CPUerr:BuildBuild or compilation failedBuild or compilation failed
Description
I am trying to run JAX exported stablehlo modules in C++, using the C API CPU plugin. I can run modules with static shapes, but for dynamic shapes (bound or unbound) I get errors. Here is what I did with a minimal reproducible sample.
- compiled the CPU C API plugins to get
pjrt_c_api_cpu_plugin.so
bazelisk build --enable_bzlmod --incompatible_disallow_empty_glob=false --strip=never //xla/pjrt/c:pjrt_c_api_cpu
bazelisk build --enable_bzlmod --incompatible_disallow_empty_glob=false --strip=never //xla/pjrt/c:pjrt_c_api_cpu_plugin.so
- compiled the program below as
g++ -std=c++17 -I./include run_hlo.cc -ldl -o run_hlo; and run it as./run_hlo ./path/to/lib/pjrt_c_api_cpu_plugin.so
This is the error I got:
# dynamic error
[vector] Compile failed: during context [hlo verifier]: Unbounded dynamism is disabled for instruction: %x.1 = f64[?]{0} parameter(1), metadata={op_name="x"}
Failed after pipeline-start
# dynamic bound error
[vector] Compile failed: Custom call target PadToStatic is not implemented.
How can I run dynamic/ dynamic bound batch size like highlighted in the stablehlo documentation. [1][2]
[1] https://openxla.org/stablehlo/tutorials/jax-export
[2] https://openxla.org/stablehlo/dynamism
Example program:
#include <dlfcn.h>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <string_view>
#include <vector>
#include "include/pjrt_c_api.h"
#include "include/pjrt_c_api_cpu.h"
constexpr std::string_view kAddScalarStableHlo = R"(
module {
func.func @main(%x: tensor<f64>, %y: tensor<f64>) -> tensor<f64> {
%sum = stablehlo.add %x, %y : tensor<f64>
return %sum : tensor<f64>
}
}
)";
constexpr std::string_view kAddVectorStableHlo = R"(
#loc = loc(unknown)
#loc2 = loc("x")
#loc3 = loc("y")
module @jit_add_vec attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<?xf64> {mhlo.layout_mode = "default"} loc(unknown), %arg1: tensor<?xf64> {mhlo.layout_mode = "default"} loc(unknown)) -> (tensor<?xf64> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf64>) -> tensor<i32> loc(#loc7)
%1 = stablehlo.convert %0 : (tensor<i32>) -> tensor<i64> loc(#loc7)
%2 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor<?xf64>) -> tensor<i32> loc(#loc7)
%3 = stablehlo.convert %2 : (tensor<i32>) -> tensor<i64> loc(#loc7)
%c = stablehlo.constant dense<1> : tensor<i64> loc(#loc)
%4 = stablehlo.compare GE, %1, %c, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1> loc(#loc8)
stablehlo.custom_call @shape_assertion(%4, %1) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'n'. Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {0} from specification 'n' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i64>) -> () loc(#loc9)
%5 = stablehlo.compare EQ, %3, %1, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1> loc(#loc10)
stablehlo.custom_call @shape_assertion(%5, %3, %1) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Found inconsistency between dimension size args[1].shape[0] (= {0}) and the specification 'n' (= {1}). Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {1} from specification 'n' for dimension args[0].shape[0] (= {1}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i64>, tensor<i64>) -> () loc(#loc11)
%6 = call @_wrapped_jax_export_main(%1, %arg0, %arg1) : (tensor<i64>, tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64> loc(#loc)
return %6 : tensor<?xf64> loc(#loc)
} loc(#loc)
func.func private @_wrapped_jax_export_main(%arg0: tensor<i64> {jax.global_constant = "n", mhlo.layout_mode = "default"} loc(unknown), %arg1: tensor<?xf64> {mhlo.layout_mode = "default"} loc("x"), %arg2: tensor<?xf64> {mhlo.layout_mode = "default"} loc("y")) -> (tensor<?xf64> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.add %arg1, %arg2 : tensor<?xf64> loc(#loc13)
return %0 : tensor<?xf64> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc1 = loc("/home//XLA/my_pjrt/minimal_add.py":20:0)
#loc4 = loc("/home//XLA/my_pjrt/minimal_add.py":10:0)
#loc5 = loc("<module>"(#loc1))
#loc6 = loc("add_vec"(#loc4))
#loc7 = loc("/dimension_size[dimension=0]"(#loc5))
#loc8 = loc("/ge"(#loc5))
#loc9 = loc("/shape_assertion[error_message=Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'n'. Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {0} from specification 'n' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.]"(#loc5))
#loc10 = loc("/eq"(#loc5))
#loc11 = loc("/shape_assertion[error_message=Input shapes do not match the polymorphic shapes specification. Found inconsistency between dimension size args[1].shape[0] (= {0}) and the specification 'n' (= {1}). Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {1} from specification 'n' for dimension args[0].shape[0] (= {1}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.]"(#loc5))
#loc12 = loc(callsite(#loc6 at #loc5))
#loc13 = loc("jit(add_vec)/jit(main)/add"(#loc12))
)";
constexpr std::string_view kAddVectorBoundedStableHlo = R"(
#loc = loc(unknown)
#loc2 = loc("x")
#loc3 = loc("y")
module @jit_add_vec attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<?xf64, #stablehlo.bounds<4>> {mhlo.layout_mode = "default"} loc(unknown), %arg1: tensor<?xf64, #stablehlo.bounds<4>> {mhlo.layout_mode = "default"} loc(unknown)) -> (tensor<?xf64, #stablehlo.bounds<4>> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf64, #stablehlo.bounds<4>>) -> tensor<i32> loc(#loc7)
%1 = stablehlo.convert %0 : (tensor<i32>) -> tensor<i64> loc(#loc7)
%2 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor<?xf64, #stablehlo.bounds<4>>) -> tensor<i32> loc(#loc7)
%3 = stablehlo.convert %2 : (tensor<i32>) -> tensor<i64> loc(#loc7)
%c = stablehlo.constant dense<1> : tensor<i64> loc(#loc)
%4 = stablehlo.compare GE, %1, %c, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1> loc(#loc8)
stablehlo.custom_call @shape_assertion(%4, %1) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'n'. Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {0} from specification 'n' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i64>) -> () loc(#loc9)
%5 = stablehlo.compare EQ, %3, %1, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1> loc(#loc10)
stablehlo.custom_call @shape_assertion(%5, %3, %1) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Found inconsistency between dimension size args[1].shape[0] (= {0}) and the specification 'n' (= {1}). Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {1} from specification 'n' for dimension args[0].shape[0] (= {1}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i64>, tensor<i64>) -> () loc(#loc11)
%6 = call @_wrapped_jax_export_main(%1, %arg0, %arg1) : (tensor<i64>, tensor<?xf64, #stablehlo.bounds<4>>, tensor<?xf64, #stablehlo.bounds<4>>) -> tensor<?xf64, #stablehlo.bounds<4>> loc(#loc)
return %6 : tensor<?xf64, #stablehlo.bounds<4>> loc(#loc)
} loc(#loc)
func.func private @_wrapped_jax_export_main(%arg0: tensor<i64> {jax.global_constant = "n", mhlo.layout_mode = "default"} loc(unknown), %arg1: tensor<?xf64, #stablehlo.bounds<4>> {mhlo.layout_mode = "default"} loc("x"), %arg2: tensor<?xf64, #stablehlo.bounds<4>> {mhlo.layout_mode = "default"} loc("y")) -> (tensor<?xf64, #stablehlo.bounds<4>> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.add %arg1, %arg2 : tensor<?xf64, #stablehlo.bounds<4>> loc(#loc13)
return %0 : tensor<?xf64, #stablehlo.bounds<4>> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc1 = loc("/home//XLA/my_pjrt/minimal_add.py":20:0)
#loc4 = loc("/home//XLA/my_pjrt/minimal_add.py":10:0)
#loc5 = loc("<module>"(#loc1))
#loc6 = loc("add_vec"(#loc4))
#loc7 = loc("/dimension_size[dimension=0]"(#loc5))
#loc8 = loc("/ge"(#loc5))
#loc9 = loc("/shape_assertion[error_message=Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'n'. Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {0} from specification 'n' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.]"(#loc5))
#loc10 = loc("/eq"(#loc5))
#loc11 = loc("/shape_assertion[error_message=Input shapes do not match the polymorphic shapes specification. Found inconsistency between dimension size args[1].shape[0] (= {0}) and the specification 'n' (= {1}). Using the following polymorphic shapes specifications: args[0].shape = (n,),args[1].shape = (n,). Obtained dimension variables: 'n' = {1} from specification 'n' for dimension args[0].shape[0] (= {1}), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.]"(#loc5))
#loc12 = loc(callsite(#loc6 at #loc5))
#loc13 = loc("jit(add_vec)/jit(main)/add"(#loc12))
)";
// from chat gpt suggestions
constexpr unsigned char kCompileOptionsProto[] = {
0x1A, 0x0E, 0x20, 0x01, 0x28, 0x01, 0x4A, 0x08,
0x08, 0x01, 0x10, 0x01, 0x1A, 0x02, 0x08, 0x00,
};
int main(int argc, char** argv) {
if (argc != 2) {
std::cerr << "usage: " << argv[0]
<< " <pjrt_plugin.so>\n"
<< "Example: " << argv[0] << " ./lib/pjrt_c_api_cpu_plugin.so\n";
return 1;
}
std::cout << "=== PJRT add kernel demo ===\n";
// Step 1. Load PJRT plugin and fetch the API table.
void* handle = dlopen(argv[1], RTLD_NOW | RTLD_LOCAL);
if (!handle) std::cerr << "dlopen failed LINE 76\n";
auto get_api =
reinterpret_cast<const PJRT_Api* (*)()>(dlsym(handle, "GetPjrtApi"));
if (!get_api) std::cerr << "dlsym(GetPjrtApi) failed\n";
const PJRT_Api* api = get_api();
std::cout << "[1] PJRT API struct size: " << api->struct_size << "\n";
// Step 2. Create a PJRT client (CPU backend in this example).
PJRT_Client_Create_Args cargs{};
cargs.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE;
auto err = api->PJRT_Client_Create(&cargs);
PJRT_Client* client = cargs.client;
std::cout << "[2] Client created\n";
// Step 3. Grab the first addressable device.
PJRT_Client_AddressableDevices_Args dargs{};
dargs.struct_size = PJRT_Client_AddressableDevices_Args_STRUCT_SIZE;
dargs.client = client;
err = api->PJRT_Client_AddressableDevices(&dargs);
if (dargs.num_addressable_devices == 0) std::cerr << "No devices found\n";
PJRT_Device* device = dargs.addressable_devices[0];
std::cout << "[3] Using device[0]\n";
// Step 4. Compile the embedded StableHLO.
PJRT_Program program{};
program.struct_size = PJRT_Program_STRUCT_SIZE;
program.code = const_cast<char*>(kAddScalarStableHlo.data());
program.code_size = kAddScalarStableHlo.size();
program.format = "mlir";
program.format_size = 4;
PJRT_Client_Compile_Args compile_args{};
compile_args.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE;
compile_args.client = client;
compile_args.program = &program;
compile_args.compile_options =
reinterpret_cast<const char*>(kCompileOptionsProto);
compile_args.compile_options_size = sizeof(kCompileOptionsProto);
err = api->PJRT_Client_Compile(&compile_args);
PJRT_LoadedExecutable* exec = compile_args.executable;
std::cout << "[4] Compilation OK\n";
// Step 5. Parse inputs and upload to device.
auto lhs = std::vector<double>{2};
auto rhs = std::vector<double>{3};
std::cout << "[5] lhs = " << lhs[0] << ", rhs = " << rhs[0] << "\n";
auto upload_scalar = [&](double value) -> PJRT_Buffer* {
PJRT_Client_BufferFromHostBuffer_Args bargs{};
bargs.struct_size = PJRT_Client_BufferFromHostBuffer_Args_STRUCT_SIZE;
bargs.client = client;
bargs.data = &value;
bargs.type = PJRT_Buffer_Type_F64;
bargs.dims = nullptr; // scalar
bargs.num_dims = 0;
bargs.device = device;
err = api->PJRT_Client_BufferFromHostBuffer(&bargs);
if (bargs.done_with_host_buffer) {
PJRT_Event_Await_Args await_args{};
await_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE;
await_args.event = bargs.done_with_host_buffer;
err = api->PJRT_Event_Await(&await_args);
PJRT_Event_Destroy_Args destroy_args{};
destroy_args.struct_size = PJRT_Event_Destroy_Args_STRUCT_SIZE;
destroy_args.event = bargs.done_with_host_buffer;
api->PJRT_Event_Destroy(&destroy_args);
}
return bargs.buffer;
};
PJRT_Buffer* in0 = upload_scalar(lhs[0]);
PJRT_Buffer* in1 = upload_scalar(rhs[0]);
// Step 6. Execute the compiled program.
PJRT_Buffer* inputs[] = {in0, in1};
PJRT_Buffer** argument_lists[] = {inputs};
PJRT_Buffer* outputs[1] = {nullptr};
PJRT_Buffer** output_lists[] = {outputs};
PJRT_ExecuteOptions exec_opts{};
exec_opts.struct_size = PJRT_ExecuteOptions_STRUCT_SIZE;
PJRT_LoadedExecutable_Execute_Args exec_args{};
exec_args.struct_size = PJRT_LoadedExecutable_Execute_Args_STRUCT_SIZE;
exec_args.executable = exec;
exec_args.options = &exec_opts;
exec_args.argument_lists = argument_lists;
exec_args.num_devices = 1;
exec_args.num_args = 2;
exec_args.output_lists = output_lists;
exec_args.execute_device = device;
err = api->PJRT_LoadedExecutable_Execute(&exec_args);
std::cout << "[6] Execution complete\n";
// Step 7. Copy result back to host.
double result = 0.0;
PJRT_Buffer_ToHostBuffer_Args to_host{};
to_host.struct_size = PJRT_Buffer_ToHostBuffer_Args_STRUCT_SIZE;
to_host.src = outputs[0];
to_host.dst = &result;
to_host.dst_size = sizeof(result);
err = api->PJRT_Buffer_ToHostBuffer(&to_host);
if (to_host.event) {
PJRT_Event_Await_Args await_args{};
await_args.struct_size = PJRT_Event_Await_Args_STRUCT_SIZE;
await_args.event = to_host.event;
err = api->PJRT_Event_Await(&await_args);
PJRT_Event_Destroy_Args destroy_args{};
destroy_args.struct_size = PJRT_Event_Destroy_Args_STRUCT_SIZE;
destroy_args.event = to_host.event;
api->PJRT_Event_Destroy(&destroy_args);
}
std::cout << "[7] Result = " << result << "\n";
std::cout << "=== done ===\n";
// failure when compiling a dynamic vector kernel
std::cout << "\n=== Attempting dynamic vector kernel ===\n";
PJRT_Program vec_program{};
vec_program.struct_size = PJRT_Program_STRUCT_SIZE;
vec_program.code = const_cast<char*>(kAddVectorStableHlo.data());
vec_program.code_size = kAddVectorStableHlo.size();
vec_program.format = "mlir";
vec_program.format_size = 4;
PJRT_Client_Compile_Args vec_compile{};
vec_compile.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE;
vec_compile.client = client;
vec_compile.program = &vec_program;
vec_compile.compile_options =
reinterpret_cast<const char*>(kCompileOptionsProto);
vec_compile.compile_options_size = sizeof(kCompileOptionsProto);
if (PJRT_Error* err = api->PJRT_Client_Compile(&vec_compile)) {
PJRT_Error_Message_Args msg{};
msg.struct_size = PJRT_Error_Message_Args_STRUCT_SIZE;
msg.error = err;
api->PJRT_Error_Message(&msg);
std::string message(msg.message, msg.message_size);
std::cout << "[vector] Compile failed: " << message << "\n";
PJRT_Error_Destroy_Args destroy{};
destroy.struct_size = PJRT_Error_Destroy_Args_STRUCT_SIZE;
destroy.error = err;
api->PJRT_Error_Destroy(&destroy);
} else {
std::cout << "[vector] Unexpectedly compiled successfully.\n";
}
std::cout << "\n=== Attempting bounded dynamic vector kernel ===\n";
vec_program.struct_size = PJRT_Program_STRUCT_SIZE;
vec_program.code = const_cast<char*>(kAddVectorBoundedStableHlo.data());
vec_program.code_size = kAddVectorBoundedStableHlo.size();
vec_program.format = "mlir";
vec_program.format_size = 4;
vec_compile.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE;
vec_compile.client = client;
vec_compile.program = &vec_program;
vec_compile.compile_options =
reinterpret_cast<const char*>(kCompileOptionsProto);
vec_compile.compile_options_size = sizeof(kCompileOptionsProto);
if (PJRT_Error* err = api->PJRT_Client_Compile(&vec_compile)) {
PJRT_Error_Message_Args msg{};
msg.struct_size = PJRT_Error_Message_Args_STRUCT_SIZE;
msg.error = err;
api->PJRT_Error_Message(&msg);
std::string message(msg.message, msg.message_size);
std::cout << "[vector] Compile failed: " << message << "\n";
PJRT_Error_Destroy_Args destroy{};
destroy.struct_size = PJRT_Error_Destroy_Args_STRUCT_SIZE;
destroy.error = err;
api->PJRT_Error_Destroy(&destroy);
} else {
std::cout << "[vector] Unexpectedly compiled successfully.\n";
}
return 0;
}
Metadata
Metadata
Assignees
Labels
CPURelated to XLA on CPURelated to XLA on CPUerr:BuildBuild or compilation failedBuild or compilation failed