Skip to content

Crash in MLIR emitters for 4D convolution #32635

@hawkinsp

Description

@hawkinsp

Originally reported as jax-ml/jax#32485
The following HLO module

HloModule jit_conv_general_dilated, entry_computation_layout={(f64[1,1,16,16,2,2]{5,4,3,2,1,0}, f64[1,1,3,3,1,1]{5,4,3,2,1,0})->f64[1,1,14,14,2,2]{5,4,3,2,1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}

ENTRY %main.1 (args_0_.1: f64[1,1,16,16,2,2], args_1_.1: f64[1,1,3,3,1,1]) -> f64[1,1,14,14,2,2] {
  %args_0_.1 = f64[1,1,16,16,2,2]{5,4,3,2,1,0} parameter(0), metadata={op_name="args[0]"}
  %args_1_.1 = f64[1,1,3,3,1,1]{5,4,3,2,1,0} parameter(1), metadata={op_name="args[1]"}
  ROOT %conv_general_dilated.1 = f64[1,1,14,14,2,2]{5,4,3,2,1,0} convolution(%args_0_.1, %args_1_.1), window={size=3x3x1x1}, dim_labels=bf0123_oi0123->bf0123, metadata={op_name="jit(conv_general_dilated)/conv_general_dilated" source_file="third_party/py/jax/tests/lax_test.py" source_line=1074 source_end_line=1079 source_column=10 source_end_column=21}
}

causes the following crash:

*** Check failure stack trace: ***
    @     0x557c199f30a9  absl::log_internal::LogMessage::SendToLog()
    @     0x557c199f302e  absl::log_internal::LogMessage::Flush()
    @     0x557c199b6fc4  __assert_fail
    @     0x557c0ec2f8e3  xla::ApplyIndexingOp::fold()
    @     0x557c0ec1fd92  mlir::Op<>::foldHook<>()
    @     0x557c0ec1f438  mlir::RegisteredOperationName::Model<>::foldHook()
    @     0x557c15e75a89  mlir::Operation::fold()
    @     0x557c15e75e28  mlir::Operation::fold()
    @     0x557c15dece70  mlir::OpBuilder::tryFold()
    @     0x557c0ec0db49  mlir::OpBuilder::createOrFold<>()
    @     0x557c0ebde7eb  xla::emitters::ApplyIndexing()
    @     0x557c0ebf4693  llvm::function_ref<>::callback_fn<>()
    @     0x557c0ec0e4e0  llvm::function_ref<>::callback_fn<>()
    @     0x557c0ec0d5a9  llvm::function_ref<>::callback_fn<>()
    @     0x557c157b5123  mlir::scf::IfOp::build()
    @     0x557c157c34a3  mlir::scf::IfOp::create()
    @     0x557c0ec0d222  llvm::function_ref<>::callback_fn<>()
    @     0x557c157ae69a  mlir::scf::buildLoopNest()
    @     0x557c0ebe22ee  xla::emitters::(anonymous namespace)::EmitLoopNestImpl()
    @     0x557c0ebe1dd4  xla::emitters::EmitLoopNest()
    @     0x557c0ebf409f  xla::emitters::(anonymous namespace)::EmitDotLoop()
    @     0x557c0ebe5bc5  xla::emitters::(anonymous namespace)::HloToMlir()
    @     0x557c0ebe400e  xla::emitters::(anonymous namespace)::SubgraphConverter::EmitInstruction()
    @     0x557c0ebe10b0  xla::emitters::SubgraphToMlirFunction()
    @     0x557c0ebdcc7d  xla::emitters::EmitPartitionedComputations()
    @     0x557c0d899b6f  xla::emitters::LoopFusionKernelEmitter::EmitKernelDefinition()
    @     0x557bf622be12  xla::gpu::LoopFusion::CreateMLIRModule()
    @     0x557bf625ca0d  xla::gpu::EmitterBase::CreateLLVMModule()
    @     0x557bf625f6bd  std::__u::__function::__policy_func<>::__call_func<>()
    @     0x557bf7011d20  xla::gpu::KernelReuseCache::GetWithStatus()
    @     0x557bf7011a79  xla::gpu::KernelReuseCache::GetWithStatus()
    @     0x557bf625beb6  xla::gpu::EmitterBase::Emit()
    @     0x557bf5448cd7  xla::gpu::IrEmitterUnnested::EmitFusion()
    @     0x557bf544ed6a  xla::gpu::IrEmitterUnnested::EmitHloInstruction()
    @     0x557bf5433be3  xla::gpu::IrEmitterUnnested::EmitHloComputation()
    @     0x557bf5421bc1  xla::gpu::CompileModuleToLlvmIr()
    @     0x557bf53f64dd  xla::gpu::GpuCompiler::CompileToBackendResult()
    @     0x557bf53f9f80  xla::gpu::GpuCompiler::RunBackend()

Metadata

Metadata

Assignees

No one assigned

    Labels

    GPUXLA on GPU

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions