Skip to content

Documentation's example using Enzyme does not work on GPU #2621

@camilodlt

Description

@camilodlt

Hello,

I'm having trouble running the Enzyme example in the documentation. Compared to the docs, in this case I send everything to the gpu (model and inputs).

using CUDA
using Flux
using Enzyme

model = Chain(Dense(28^2 => 32, sigmoid), Dense(32 => 10), softmax);
dup_model = Duplicated(model |> gpu);

x1 = randn32(28 * 28, 1) |> gpu;
y1 = [i == 3 for i in 0:9] |> gpu;
grads_f = Flux.gradient((m, x, y) -> sum(abs2, m(x) .- y), dup_model, Const(x1), Const(y1))

The last function takes a lot of time and eventually throws and error:

ERROR: "Error cannot store inactive but differentiable variable Float32[0.42446053; -0.43670258; -0.22128746; -0.09921402; -1.5923102; -0.50225735; 0.7328375; -1.5166384; -0.22234721; 0.96836793; 1.4810076; -0.28374726; 2.0655832; 0.22402526; -2.1271694; 0.96447814; -0.8850093; -1.1225328; 
[...]
-0.84058493;;] into active tuple"
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Enzyme/ez9it/src/rules/typeunstablerules.jl:15 [inlined]
  [2] create_shadow_ret
    @ ~/.julia/packages/Enzyme/ez9it/src/rules/typeunstablerules.jl:3 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/Enzyme/ez9it/src/rules/typeunstablerules.jl:85 [inlined]
  [4] runtime_newstruct_augfwd(::Type{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Ptr{…}, ::Nothing, ::Char, ::Nothing, ::Char, ::Nothing, ::Int64, ::Nothing, ::Int64, ::Nothing, ::Int64, ::Nothing, ::CUDA.CuRefValue{…}, ::CUDA.CuRefValue{…}, ::CuArray{…}, ::CuArray{…}, ::Type{…}, ::Nothing, ::Int64, ::Nothing, ::CuArray{…}, ::Nothing, ::Type{…}, ::Nothing, ::Int64, ::Nothing, ::CUDA.CuRefValue{…}, ::CUDA.CuRefValue{…}, ::CuArray{…}, ::CuArray{…}, ::Type{…}, ::Nothing, ::Int64, ::Nothing, ::CUDA.CUBLAS.cublasComputeType_t, ::CUDA.CUBLAS.cublasComputeType_t, ::CUDA.CUBLAS.cublasGemmAlgo_t, ::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/ez9it/src/rules/typeunstablerules.jl:357
  [5] cublasGemmEx
    @ ~/.julia/packages/GPUToolbox/XaIIx/src/ccalls.jl:33 [inlined]
  [6] #gemmEx!#1222
    @ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/wrappers.jl:1251 [inlined]
  [7] augmented_julia__gemmEx__1222_386193wrap
    @ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/wrappers.jl:0
  [8] macro expansion
    @ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5713 [inlined]
  [9] enzyme_call
    @ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5247 [inlined]
 [10] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5186 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/Enzyme/ez9it/src/rules/jitrules.jl:447 [inlined]
 [12] runtime_generic_augfwd(::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::CUDA.CUBLAS.var"##gemmEx!#1222", ::Nothing, ::CUDA.CUBLAS.cublasGemmAlgo_t, ::Nothing, ::typeof(CUDA.CUBLAS.gemmEx!), ::Nothing, ::Char, ::Nothing, ::Char, ::Nothing, ::Bool, ::Nothing, ::CuArray{…}, ::CuArray{…}, ::CuArray{…}, ::Nothing, ::Bool, ::Nothing, ::CuArray{…}, ::CuArray{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/ez9it/src/rules/jitrules.jl:574
 [13] gemmEx!
    @ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/wrappers.jl:1230 [inlined]
 [14] generic_matmatmul!
    @ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/linalg.jl:251
 [15] generic_matmatmul!
    @ ~/.julia/packages/CUDA/Wfi8S/lib/cublas/linalg.jl:226 [inlined]
 [16] _mul!
    @ ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
 [17] mul!
    @ ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
 [18] mul!
    @ ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
 [19] *
    @ ~/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:124 [inlined]
 [20] Dense
    @ ~/.julia/packages/Flux/uRn8o/src/layers/basic.jl:199
 [21] macro expansion
    @ ~/.julia/packages/Flux/uRn8o/src/layers/basic.jl:68 [inlined]
 [22] _applychain
    @ ~/.julia/packages/Flux/uRn8o/src/layers/basic.jl:68 [inlined]
 [23] Chain
    @ ~/.julia/packages/Flux/uRn8o/src/layers/basic.jl:65 [inlined]
 [24] #3
    @ ./REPL[8]:1 [inlined]
 [25] diffejulia__3_32568_inner_652wrap
    @ ./REPL[8]:0
 [26] macro expansion
    @ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5713 [inlined]
 [27] enzyme_call
    @ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5247 [inlined]
 [28] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/ez9it/src/compiler.jl:5122 [inlined]
 [29] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Const{…}, ::Const{…})
    @ Enzyme ~/.julia/packages/Enzyme/ez9it/src/Enzyme.jl:517
 [30] _enzyme_gradient(::Function, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Union{Const, Duplicated}}; zero::Bool)
    @ FluxEnzymeExt ~/.julia/packages/Flux/uRn8o/ext/FluxEnzymeExt/FluxEnzymeExt.jl:50
 [31] gradient(::Function, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Union{Const, Duplicated}}; zero::Bool)
    @ Flux ~/.julia/packages/Flux/uRn8o/src/gradient.jl:122
 [32] top-level scope
    @ REPL[8]:1
Some type information was truncated. Use `show(err)` to see complete types.

Neither dup_model = Duplicated(model |> gpu) or dup_model = Duplicated(model) |> gpu work.

Julia info :

Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 16 × Intel(R) Xeon(R) W-11955M CPU @ 2.60GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, tigerlake)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  LD_GOLD = /home/camilo/miniconda3/envs/pytorch3/bin/x86_64-conda-linux-gnu-ld.gold
  JULIA_CONDAPKG_BACKEND = Current
  JULIA_CONDAPKG_OFFLINE = true

Cuda info :

CUDA toolchain:
- runtime 13.0, artifact installation
- driver 550.163.1 for 13.0
- compiler 13.0

CUDA libraries:
- CUBLAS: 13.0.2
- CURAND: 10.4.0
- CUFFT: 12.0.0
- CUSOLVER: 12.0.4
- CUSPARSE: 12.6.3
- CUPTI: 2025.3.1 (API 130001.0.0)
- NVML: 12.0.0+550.163.1

Julia packages:
- CUDA: 5.8.3
- CUDA_Driver_jll: 13.0.1+0
- CUDA_Compiler_jll: 0.2.1+0
- CUDA_Runtime_jll: 0.19.1+0

Toolchain:
- Julia: 1.11.6
- LLVM: 16.0.6

1 device:
  0: NVIDIA RTX A5000 Laptop GPU (sm_86, 12.816 GiB / 16.000 GiB available)

Flux info : Flux v0.16.5

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions