Skip to content

Complex numbers in Qwen Image pipeline incompatible with torch.compile #12668

@dxqb

Description

@dxqb

Describe the bug

Note: This might be something for the MVP program #12635 if there's anyone who already has a deep understanding of rotary embeddings and complex numbers. I don't.

The Qwen image pipeline calls

with use_real==False.

The function therefore operates on complex numbers.
If compiled, torch.compile warns about this: venv/lib/python3.12/site-packages/torch/_inductor/lowering.py:1890: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.

Performance being worse than eager isn't a big deal. This is not a performance critical part of the model.
However, due to a subtle torch.compile bug it leads to random compile failures:
pytorch/pytorch#163876

Can the code path with real numbers be used instead?

Reproduction

I cannot provide reproduction code, because it's random and shows up mostly when a kernel is recompiled, but also not consistently.
Multiple users are affected though. It can be worked around by putting a compile.disable decorator around the function, but I don't like this solution because then you cannot compile with fullgraph=True anymore.

Logs

packed_predicted_flow = model.transformer(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/nn/
modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/linux/KI/OneTrainer/src/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 629, in forward
    encoder_hidden_states, hidden_states = block(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1771, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1871, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1846, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/__init__.py", line 2380, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2418, in compile_fx
    return aot_autograd(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 109, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1199, in aot_module_simplified
    compiled_fn = AOTAutogradCache.load(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 1140, in load
    compiled_fn = dispatch_and_compile()
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1184, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 576, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 836, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 1604, in aot_dispatch_autograd
    compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 483, in __call__
    return self.compiler_fn(gm, example_inputs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2250, in fw_compiler_base
    return inner_compile(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 745, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 896, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1578, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1456, in codegen_and_compile
    compiled_module = graph.compile_to_module()
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2293, in compile_to_module
    return self._compile_to_module()
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2299, in _compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2238, in codegen
    self.scheduler.codegen()
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 4598, in codegen
    else self._codegen(self.nodes)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 4750, in _codegen
    self.get_backend(device).codegen_node(node)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py", line 107, in codegen_node
    return self._triton_scheduling.codegen_node(node)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/simd.py", line 1363, in codegen_node
    coalesce_analysis = analyze_memory_coalescing(node)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/tiling_utils.py", line 650, in analyze_memory_coalescing
    norm_read_writes = extract_normalized_read_writes(fused_node)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/tiling_utils.py", line 482, in extract_normalized_read_writes
    if any(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/tiling_utils.py", line 483, in <genexpr>
    (isinstance(var, sympy.Expr) and not var.is_constant())
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/expr.py", line 724, in is_constant
    b = expr._random(None, -1, 0, 1, 0)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/expr.py", line 562, in _random
    nmag = abs(self.evalf(2, subs=reps))
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1654, in evalf
    result = evalf(self, prec + 4, options)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1489, in evalf
    r = rf(x, prec, options)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 602, in evalf_add
    terms = [evalf(arg, prec + 10, options) for arg in v.args]
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 602, in <listcomp>
    terms = [evalf(arg, prec + 10, options) for arg in v.args]
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1489, in evalf
    r = rf(x, prec, options)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 650, in evalf_mul
    result = evalf(arg, prec, options)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1493, in evalf
    x = x.subs(evalf_subs(prec, options['subs']))
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/basic.py", line 1171, in subs
    rv = rv._subs(old, new, **kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/cache.py", line 72, in wrapper
    retval = cfunc(*args, **kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/basic.py", line 1285, in _subs
    rv = fallback(self, old, new)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/basic.py", line 1262, in fallback
    rv = self.func(*args)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/cache.py", line 72, in wrapper
    retval = cfunc(*args, **kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/function.py", line 450, in __new__
    return cls._new_(*args, **options)  # type: ignore
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/function.py", line 472, in _new_
    result = super().__new__(cls, *args, **options)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/cache.py", line 72, in wrapper
    retval = cfunc(*args, **kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/function.py", line 309, in __new__
    evaluated = cls.eval(*args)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/utils/_sympy/functions.py", line 488, in eval
    assert p >= 0, p
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: -1470286036225387/1000000000000000

System Info

various with torch 2.8

Who can help?

@DN6 @yiyixuxu @sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions