Skip to content

nested AD on CUDA array #1450

@FerreolS

Description

@FerreolS

Hi,
I'm trying to implement a gradient penalty with Lux. It is fine on CPU but raise a "try/catch" error on GPU (CUDA). It is seems to be linked to the try catch here but I'm not able to figure out what could be the problem.

using Zygote, Lux,LuxCUDA,Random,CUDA, cuDNN
function mwe(dev)
	rng = Random.default_rng()
	D = Dense(5,1, relu)
	ps, st = Lux.setup(rng, D) |> dev
	x = rand(5, 2) |> dev
	g(ps) = sum(abs2,only(gradient(x -> sum(first(D(x,ps,st))),x)))
	gradient(x->g(x),ps)
end
(@v1.9) pkg> status Lux, Zygote, LuxCUDA
Status `~/.julia/environments/v1.9/Project.toml`
  [b2108857] Lux v0.5.3
  [d0bbae9a] LuxCUDA v0.3.0
  [e88e6eb3] Zygote v0.6.63

julia> mwe(cpu_device())

((layer_1 = (weight = Float32[0.7818906 3.0653906  1.399884 -0.26843658; 0.0059568123 0.023353593  0.010664978 -0.0020450768; 0.47049373 1.8445637  0.8423642 -0.16152865], bias = Float32[0.0; 0.0; 0.0;;]), layer_2 = (weight = Float32[5.04391 -0.035778362 2.7771542], bias = nothing)),)

julia> mwe(gpu_device())
ERROR: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/reverse.jl:128
  [3] #Primal#31
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/reverse.jl:227 [inlined]
  [4] Primal
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/reverse.jl:226 [inlined]
  [5] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/reverse.jl:352
  [6] _generate_pullback_via_decomposition(T::Type, world::Nothing)
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/emit.jl:101
  [7] _generate_pullback(::Type, ::Nothing, ::Type, ::Type, ::Vararg{Type})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:27
  [8] #s86#1607
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:102 [inlined]
  [9] var"#s86#1607"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [11] _pullback
    @ ~/.julia/packages/CUDA/35NC6/src/compiler/execution.jl:310 [inlined]
 [12] _pullback(::Zygote.Context{false}, ::typeof(cufunction), ::GPUArrays.var"#broadcast_kernel#26", ::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, Int64}})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
 [13] macro expansion
    @ ~/.julia/packages/CUDA/35NC6/src/compiler/execution.jl:104 [inlined]
 [14] _pullback
    @ ~/.julia/packages/CUDA/35NC6/src/gpuarrays.jl:17 [inlined]
 [15] _pullback(::Zygote.Context{false}, ::CUDA.var"##launch_heuristic#1080", ::Int64, ::Int64, ::typeof(GPUArrays.launch_heuristic), ::CUDA.CuArrayBackend, ::GPUArrays.var"#broadcast_kernel#26", ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
 [16] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [17] adjoint
    @ ~/.julia/packages/Zygote/4rucm/src/lib/lib.jl:203 [inlined]
 [18] adjoint(::Zygote.Context{false}, ::typeof(Core._apply_iterate), ::typeof(iterate), ::Function, ::Tuple{Int64, Int64, typeof(GPUArrays.launch_heuristic), CUDA.CuArrayBackend, GPUArrays.var"#broadcast_kernel#26"}, ::Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, Int64})
    @ Zygote ./none:0
 [19] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [20] _pullback
    @ ~/.julia/packages/CUDA/35NC6/src/gpuarrays.jl:15 [inlined]
 [21] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:elements, :elements_per_thread), Tuple{Int64, Int64}}, ::typeof(GPUArrays.launch_heuristic), ::CUDA.CuArrayBackend, ::GPUArrays.var"#broadcast_kernel#26", ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
 [22] _pullback
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/broadcast.jl:65 [inlined]
 [23] _pullback(::Zygote.Context{false}, ::typeof(GPUArrays._copyto!), ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
 [24] _pullback
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/broadcast.jl:41 [inlined]
 [25] _pullback
    @ ./broadcast.jl:881 [inlined]
 [26] _pullback(::Zygote.Context{false}, ::typeof(Base.Broadcast.materialize!), ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
 [27] _pullback
    @ ./broadcast.jl:877 [inlined]
 [28] _pullback
    @ ~/.julia/packages/Zygote/4rucm/src/lib/broadcast.jl:369 [inlined]
 [29] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#1453#1456"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
 [30] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
 [31] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#4229#back#1457"{Zygote.var"#1453#1456"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
 [32] _pullback
    @ ./REPL[2]:7 [inlined]
...

Metadata

Metadata

Assignees

No one assigned

    Labels

    CUDAAll things GPUsecond orderzygote over zygote, or otherwise

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions