-
Notifications
You must be signed in to change notification settings - Fork 696
Open
Labels
GPUXLA on GPUXLA on GPUbugSomething isn't workingSomething isn't workingerr: RuntimeRuntime ErrorRuntime Errorstat:awaiting openxla-engAwaiting response from openxla-engAwaiting response from openxla-eng
Description
Description
When running AOT-compiled GPU executables that contain bounded dynamic shapes, the execution crashes with:
Check failed: extent <= base.size() (108 vs. 100)
slice extent 108 must be smaller than buffer #1 size 100
The issue reproduces both in AOT and JIT modes.
Reproduction Steps
-
Modify and rebuild the following test file:
xla/service/gpu/gpu_aot_compilation_test.cc -
Add an HLO module that includes dynamic dimensions (bounded shapes), e.g.:
const absl::string_view hlo_string = R"(
HloModule Test
ENTRY main {
a = f32[<=5, 5]{1,0} parameter(0)
ROOT b = f32[<=5, 5]{0,1} copy(a)
}
)";- Compile Ahead-Of-Time:
AotCompilationOptions aot_options(compiler->PlatformId());
aot_options.set_executor(stream_exec);
TF_ASSERT_OK_AND_ASSIGN(
std::vector<std::unique_ptr<AotCompilationResult>> aot_results,
compiler->CompileAheadOfTime(std::move(module), aot_options));- Load and execute the AOT executable:
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Executable> executable,
std::move(*aot_result).LoadExecutable(&compiler, stream_exec));
TF_ASSERT_OK_AND_ASSIGN(xla::Literal result,
test_runner_as_hlo_runner().ExecuteWithExecutable(
wrapped_executable.get(), {&argument}));- Execution crashes with:
Check failed: extent <= base.size() (108 vs. 100)
slice extent 108 must be smaller than buffer #1 size 100
Expected behavior
- The program should execute successfully on GPU when using bounded dynamic shapes,
or at least report a clear “unsupported” error instead of a segmentation fault.
Actual behavior
- AOT compilation succeeds, but runtime execution crashes.
- JIT execution with similar HLO also crashes with the same buffer extent mismatch.
UT Code
TEST_F(GpuAotCompilationTest, ExportAndLoadExecutable) {
const absl::string_view hlo_string = R"(
HloModule Test
ENTRY main {
a = f32[<=5, 5]{1,0} parameter(0)
ROOT b = f32[<=5, 5]{0,1} copy(a)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
auto compiler = backend().compiler();
auto name =
absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value());
TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
se::PlatformManager::PlatformWithName(name));
TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec,
platform->ExecutorForDevice(0));
// Compile AOT.
AotCompilationOptions aot_options(compiler->PlatformId());
aot_options.set_executor(stream_exec);
TF_ASSERT_OK_AND_ASSIGN(
std::vector<std::unique_ptr<AotCompilationResult>> aot_results,
compiler->CompileAheadOfTime(std::move(module), aot_options));
// Serialize-deserialize AOT compilation result.
TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result,
aot_results[0]->SerializeAsString());
// aot_results[0] converts to GpuCompiler
LOG(INFO) << "Debug string: " << aot_results[0]->DebugString();
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<AotCompilationResult> aot_result,
compiler->LoadAotCompilationResult(serialized_aot_result));
// Load Executable from AOT compilation result.
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Executable> executable,
std::move(*aot_result).LoadExecutable(compiler, stream_exec));
std::unique_ptr<OpaqueExecutable> wrapped_executable =
test_runner_as_hlo_runner().WrapExecutable(std::move(executable));
const xla::Literal argument = xla::LiteralUtil::CreateR2<float>(
{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1},
{1, 1, 1, 1, 1}});
TF_ASSERT_OK_AND_ASSIGN(
xla::Literal result,
test_runner_as_hlo_runner().ExecuteWithExecutable(
wrapped_executable.get(), {&argument}));
EXPECT_TRUE(xla::LiteralTestUtil::Equal(
result, xla::LiteralUtil::CreateR2<float>(
{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1},
{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}})));
}UT test.log
exec ${PAGER:-/usr/bin/less} "$0" || exit 1
Executing tests from //xla/service/gpu:gpu_aot_compilation_test
-----------------------------------------------------------------------------
Note: Randomizing tests' orders with a seed of 41023 .
[==========] Running 3 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 3 tests from GpuAotCompilationTest
[ RUN ] GpuAotCompilationTest.ExportAndLoadExecutable
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1761646825.427641 457517 cuda_dnn.cc:463] Loaded cuDNN version 90300
I0000 00:00:1761646825.485168 457517 gpu_aot_compilation_test.cc:79] Debug string: goo.gle/debugonly
hlo_module_with_config {
hlo_module {
name: "Test"
entry_computation_name: "main"
computations {
name: "main"
instructions {
name: "a"
opcode: "parameter"
shape {
element_type: F32
dimensions: 5
dimensions: 5
layout {
minor_to_major: 1
minor_to_major: 0
tail_padding_alignment_in_elements: 1
}
is_dynamic_dimension: true
is_dynamic_dimension: false
}
metadata {
}
frontend_attributes {
}
statistics_viz {
}
}
instructions {
name: "a.padded"
opcode: "custom-call"
shape {
element_type: TUPLE
tuple_shapes {
element_type: F32
dimensions: 5
dimensions: 5
layout {
minor_to_major: 1
minor_to_major: 0
tail_padding_alignment_in_elements: 1
}
is_dynamic_dimension: false
is_dynamic_dimension: false
}
tuple_shapes {
element_type: S32
layout {
tail_padding_alignment_in_elements: 1
}
}
tuple_shapes {
element_type: S32
layout {
tail_padding_alignment_in_elements: 1
}
}
}
metadata {
}
custom_call_target: "PadToStatic"
id: 1
operand_ids: 0
feature_group_count: 1
precision_config {
}
batch_group_count: 1
frontend_attributes {
}
custom_call_api_version: API_VERSION_ORIGINAL
statistics_viz {
}
}
instructions {
name: "a.data.1"
opcode: "get-tuple-element"
shape {
element_type: F32
dimensions: 5
dimensions: 5
layout {
minor_to_major: 1
minor_to_major: 0
tail_padding_alignment_in_elements: 1
}
is_dynamic_dimension: false
is_dynamic_dimension: false
}
metadata {
}
id: 2
operand_ids: 1
frontend_attributes {
}
statistics_viz {
}
}
instructions {
name: "a.size.1"
opcode: "get-tuple-element"
shape {
element_type: S32
layout {
tail_padding_alignment_in_elements: 1
}
}
metadata {
}
tuple_index: 1
id: 3
operand_ids: 1
frontend_attributes {
}
statistics_viz {
}
}
instructions {
name: "constant_1"
opcode: "constant"
shape {
element_type: S32
layout {
tail_padding_alignment_in_elements: 1
}
}
metadata {
}
literal {
shape {
element_type: S32
layout {
tail_padding_alignment_in_elements: 1
}
}
s32s: 5
}
id: 4
frontend_attributes {
}
statistics_viz {
}
}
instructions {
name: "custom-call"
opcode: "custom-call"
shape {
element_type: F32
dimensions: 5
dimensions: 5
layout {
minor_to_major: 0
minor_to_major: 1
tail_padding_alignment_in_elements: 1
}
is_dynamic_dimension: true
is_dynamic_dimension: false
}
metadata {
}
custom_call_target: "SliceToDynamic"
id: 5
operand_ids: 2
operand_ids: 3
operand_ids: 4
feature_group_count: 1
precision_config {
}
batch_group_count: 1
frontend_attributes {
}
custom_call_api_version: API_VERSION_ORIGINAL
statistics_viz {
}
}
program_shape {
parameters {
element_type: F32
dimensions: 5
dimensions: 5
layout {
minor_to_major: 1
minor_to_major: 0
tail_padding_alignment_in_elements: 1
}
is_dynamic_dimension: true
is_dynamic_dimension: false
}
result {
element_type: F32
dimensions: 5
dimensions: 5
layout {
minor_to_major: 0
minor_to_major: 1
tail_padding_alignment_in_elements: 1
}
is_dynamic_dimension: true
is_dynamic_dimension: false
}
parameter_names: "a"
}
root_id: 5
}
host_program_shape {
parameters {
element_type: F32
dimensions: 5
dimensions: 5
layout {
minor_to_major: 1
minor_to_major: 0
tail_padding_alignment_in_elements: 1
}
is_dynamic_dimension: true
is_dynamic_dimension: false
}
result {
element_type: F32
dimensions: 5
dimensions: 5
layout {
minor_to_major: 0
minor_to_major: 1
tail_padding_alignment_in_elements: 1
}
is_dynamic_dimension: true
is_dynamic_dimension: false
}
parameter_names: "p0"
}
schedule {
sequences {
key: 0
value {
instruction_ids: 4
instruction_ids: 0
instruction_ids: 1
instruction_ids: 3
instruction_ids: 2
instruction_ids: 5
}
}
}
input_output_alias {
}
is_dynamic: true
device_assignment {
replica_count: 1
computation_count: 1
computation_devices {
replica_device_ids: 0
}
}
buffer_donor {
}
frontend_attributes {
map {
key: "fingerprint_before_lhs"
value: "a905fb396bb09b6ad42563c2d4e4675d"
}
map {
key: "suggested_combiner_threshold"
value: "22603077737"
}
}
}
config {
entry_computation_layout {
parameters {
element_type: F32
dimensions: 5
dimensions: 5
layout {
minor_to_major: 1
minor_to_major: 0
tail_padding_alignment_in_elements: 1
}
is_dynamic_dimension: true
is_dynamic_dimension: false
}
result {
element_type: F32
dimensions: 5
dimensions: 5
layout {
minor_to_major: 0
minor_to_major: 1
tail_padding_alignment_in_elements: 1
}
is_dynamic_dimension: true
is_dynamic_dimension: false
}
parameter_names: "p0"
}
replica_count: 1
num_partitions: 1
intra_op_parallelism_threads: -1
debug_options {
xla_disable_hlo_passes: "constant_folding"
xla_backend_optimization_level: 3
xla_eliminate_hlo_implicit_broadcast: true
xla_cpu_multi_thread_eigen: true
xla_gpu_cuda_data_dir: "./cuda_sdk_lib"
xla_llvm_enable_alias_scope_metadata: true
xla_llvm_enable_noalias_metadata: true
xla_llvm_enable_invariant_load_metadata: true
xla_llvm_disable_expensive_passes: false
xla_cpu_use_onednn: false
xla_cpu_enable_fast_math: false
xla_gpu_enable_fast_min_max: false
xla_force_host_platform_device_count: 1
xla_hlo_evaluator_use_fast_path: true
xla_dump_hlo_as_html: false
xla_cpu_fast_math_honor_nans: true
xla_cpu_fast_math_honor_infs: true
xla_allow_excess_precision: true
xla_gpu_autotune_level: 4
xla_cpu_fast_math_honor_division: true
xla_cpu_fast_math_honor_functions: true
xla_dump_include_timestamp: false
xla_dump_max_hlo_modules: -1
xla_cpu_enable_xprof_traceme: false
xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found: false
xla_cpu_enable_fast_min_max: true
xla_multiheap_size_constraint_per_heap: -1
xla_dump_module_metadata: false
xla_dump_fusion_visualization: false
xla_gpu_strict_conv_algorithm_picker: true
xla_gpu_all_reduce_combine_threshold_bytes: 31457287
xla_gpu_nccl_termination_timeout_seconds: -1
xla_dump_hlo_as_long_text: true
xla_gpu_enable_shared_constants: true
xla_gpu_enable_cublaslt: false
xla_gpu_shape_checks: RUNTIME
xla_gpu_use_runtime_fusion: false
xla_dump_latency_hiding_schedule: false
xla_dump_enable_mlir_pretty_form: true
xla_gpu_enable_latency_hiding_scheduler: false
xla_partitioning_algorithm: PARTITIONING_ALGORITHM_NOOP
xla_gpu_enable_triton_gemm: true
xla_gpu_enable_cudnn_int8x32_convolution_reordering: true
xla_gpu_triton_gemm_any: true
xla_gpu_enable_while_loop_reduce_scatter_code_motion: false
xla_gpu_collective_inflation_factor: 1
xla_gpu_graph_min_graph_size: 5
xla_gpu_enable_reassociation_for_converted_ar: true
xla_gpu_pgle_profile_file_or_directory_path: ""
xla_gpu_all_gather_combine_threshold_bytes: 31457287
xla_gpu_reduce_scatter_combine_threshold_bytes: 31457287
xla_gpu_enable_highest_priority_async_stream: true
xla_gpu_enable_pipelined_all_reduce: false
xla_gpu_exhaustive_tiling_search: false
xla_gpu_auto_spmd_partitioning_memory_budget_gb: 0
xla_gpu_auto_spmd_partitioning_memory_budget_ratio: 1.1
xla_gpu_enable_pipelined_all_gather: false
xla_gpu_redzone_padding_bytes: 8388608
xla_gpu_enable_pipelined_reduce_scatter: true
xla_gpu_fused_attention_use_cudnn_rng: false
xla_gpu_copy_insertion_use_region_analysis: false
xla_gpu_collective_permute_decomposer_threshold: 9223372036854775807
xla_gpu_collect_cost_model_stats: false
xla_gpu_enable_split_k_autotuning: true
xla_gpu_enable_reduction_epilogue_fusion: true
xla_gpu_enable_pipelined_p2p: false
xla_gpu_cublas_fallback: true
xla_gpu_enable_while_loop_double_buffering: false
xla_gpu_filter_kernels_spilling_registers_on_autotuning: true
xla_debug_buffer_assignment_show_max: 15
xla_detailed_logging: true
xla_enable_dumping: true
xla_gpu_enable_all_gather_combine_by_dim: false
xla_gpu_enable_analytical_latency_estimator: false
xla_gpu_llvm_verification_level: 0
xla_gpu_enable_reduce_scatter_combine_by_dim: false
xla_gpu_enable_command_buffer: FUSION
xla_gpu_enable_command_buffer: CUBLAS
xla_gpu_enable_command_buffer: CUBLASLT
xla_gpu_enable_command_buffer: CUSTOM_CALL
xla_gpu_enable_command_buffer: CUDNN
xla_gpu_enable_cub_radix_sort: true
xla_gpu_memory_limit_slop_factor: 95
xla_gpu_target_config_filename: ""
xla_gpu_enable_cudnn_layer_norm: false
xla_gpu_threshold_for_windowed_einsum_mib: 100000
xla_gpu_enable_nccl_user_buffers: false
xla_gpu_enable_llvm_module_compilation_parallelism: false
xla_gpu_enable_libnvptxcompiler: false
xla_gpu_enable_nccl_comm_splitting: true
xla_gpu_nccl_collective_max_nchannels: 0
xla_gpu_nccl_p2p_max_nchannels: 0
xla_gpu_nccl_init_max_rank_per_root_ratio: 0
xla_gpu_multi_streamed_windowed_einsum: true
xla_gpu_gemm_rewrite_size_threshold: 100
xla_gpu_require_complete_aot_autotune_results: false
xla_gpu_cudnn_gemm_fusion_level: 0
xla_gpu_use_memcpy_local_p2p: false
xla_gpu_autotune_max_solutions: 0
xla_dump_large_constants: false
xla_gpu_verify_triton_fusion_numerics: false
xla_reduce_window_rewrite_base_length: 16
xla_gpu_enable_while_loop_unrolling: WHILE_LOOP_UNROLLING_AUTO_UNROLL
xla_gpu_enable_host_memory_offloading: false
xla_llvm_force_inline_before_split: true
xla_gpu_nccl_terminate_on_error: false
xla_gpu_shard_autotuning: true
xla_gpu_enable_approx_costly_collectives: false
xla_cpu_enable_concurrency_optimized_scheduler: true
xla_cpu_prefer_vector_width: 256
xla_gpu_per_fusion_autotune_cache_dir: ""
xla_cmd_buffer_trace_cache_size: 16
xla_gpu_temp_buffer_use_separate_color: false
xla_syntax_sugar_async_ops: false
xla_gpu_autotune_gemm_rtol: 0.1
xla_enable_command_buffers_during_profiling: false
xla_gpu_cudnn_gemm_max_plans: 5
xla_cpu_parallel_codegen_split_count: 32
xla_gpu_experimental_autotune_cache_mode: AUTOTUNE_CACHE_MODE_UPDATE
xla_gpu_executable_warn_stuck_timeout_seconds: 10
xla_gpu_executable_terminate_timeout_seconds: 30
xla_gpu_experimental_disable_binary_libraries: false
xla_ignore_channel_id: false
xla_gpu_dot_merger_threshold_mb: 32
xla_cpu_max_isa: ""
xla_gpu_experimental_enable_fusion_block_level_rewriter: false
xla_enable_fast_math: false
xla_gpu_experimental_parallel_collective_overlap_limit: 1
xla_cpu_copy_insertion_use_region_analysis: false
xla_gpu_operand_bytes_threshold_for_windowed_einsum: -1
xla_gpu_experimental_enable_triton_heroless_priority_fusion: false
xla_gpu_pgle_accuracy_checker: PGLE_STRICTNESS_LEVEL_WARN
xla_gpu_experimental_stream_annotation: true
xla_gpu_libnvjitlink_mode: LIB_NV_JIT_LINK_MODE_AUTO
xla_pjrt_allow_auto_layout_in_hlo: false
xla_gpu_enable_scatter_determinism_expander: false
xla_gpu_require_exclusive_lock: false
xla_gpu_generate_debug_info: false
xla_gpu_generate_line_info: false
xla_gpu_unsupported_enable_ragged_all_to_all_decomposer: false
xla_gpu_experimental_pipeline_parallelism_opt_level: PIPELINE_PARALLELISM_OPT_LEVEL_DISABLE
xla_gpu_fail_ptx_compilation_on_register_spilling: false
xla_gpu_collectives_use_persistent_cliques: false
xla_gpu_experimental_enable_triton_tma: false
xla_gpu_enable_analytical_sol_latency_estimator: true
xla_gpu_analytical_latency_estimator_options {
key: "chunk_prep_us"
value: "-1"
}
xla_gpu_analytical_latency_estimator_options {
key: "chunk_size_bytes"
value: "-1"
}
xla_gpu_analytical_latency_estimator_options {
key: "gpus_per_node"
value: "-1"
}
xla_gpu_analytical_latency_estimator_options {
key: "nccl_op_launch_us"
value: "-1"
}
xla_gpu_analytical_latency_estimator_options {
key: "nic_speed_gbps"
value: "-1"
}
xla_gpu_analytical_latency_estimator_options {
key: "rtt_us"
value: "-1"
}
xla_gpu_unsupported_annotate_with_emitter_loc: false
xla_cpu_use_xnnpack: true
xla_gpu_experimental_pack_dot_operands_along_k_dimension: true
xla_unsupported_crash_on_hlo_pass_fix_max_iterations: false
xla_cpu_experimental_xnn_grap
F0000 00:00:1761646825.524539 457517 buffer_allocations.cc:79] Check failed: extent <= base.size() (108 vs. 100) slice extent 108 must be smaller than buffer #1 size 100
*** Check failure stack trace: ***
@ 0x55873e10dff4 absl::lts_20250814::log_internal::LogMessage::SendToLog()
@ 0x55873e10df76 absl::lts_20250814::log_internal::LogMessage::Flush()
@ 0x55873c339fc4 xla::gpu::BufferAllocations::GetDeviceAddress()
@ 0x5587352f0f16 xla::gpu::KernelThunk::ExecuteOnStream()
@ 0x55873531587e xla::gpu::SequentialThunk::ExecuteOnStream()
@ 0x558735272a8c xla::gpu::(anonymous namespace)::ExecuteThunksImpl()
@ 0x558735270196 xla::gpu::GpuExecutable::ExecuteThunks()
@ 0x55873526e9ba xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl()
@ 0x55873526e47d xla::gpu::GpuExecutable::ExecuteAsyncOnStream()
@ 0x55873cfe84e8 xla::Executable::ExecuteAsyncOnStreamWrapper()
@ 0x55873cfe8216 xla::Executable::ExecuteOnStreamWrapper()
@ 0x55873ccec144 xla::HloRunner::ExecuteWithExecutionInputs()
@ 0x55873ccebc9e xla::HloRunner::ExecuteWithDeviceBuffers()
@ 0x55873cceb064 xla::HloRunner::ExecuteWithExecutable()
@ 0x55873cd04c00 xla::HloRunnerInterface::ExecuteWithExecutable()
@ 0x5587344e2cf3 xla::gpu::GpuAotCompilationTest_ExportAndLoadExecutable_Test::TestBody()
@ 0x55873db8f52e testing::internal::HandleExceptionsInMethodIfSupported<>()
@ 0x55873db8f420 testing::Test::Run()
@ 0x55873db903a2 testing::TestInfo::Run()
@ 0x55873db914d5 testing::TestSuite::Run()
@ 0x55873dba24a1 testing::internal::UnitTestImpl::RunAllTests()
@ 0x55873dba1abe testing::internal::HandleExceptionsInMethodIfSupported<>()
@ 0x55873dba1939 testing::UnitTest::Run()
@ 0x55873db0fb35 main
@ 0x7f76e764b9d0 (unknown)
Metadata
Metadata
Assignees
Labels
GPUXLA on GPUXLA on GPUbugSomething isn't workingSomething isn't workingerr: RuntimeRuntime ErrorRuntime Errorstat:awaiting openxla-engAwaiting response from openxla-engAwaiting response from openxla-eng