-
-
Notifications
You must be signed in to change notification settings - Fork 217
Closed
Labels
CUDAAll things GPUAll things GPUsecond orderzygote over zygote, or otherwisezygote over zygote, or otherwise
Description
Hi,
I'm trying to implement a gradient penalty with Lux. It is fine on CPU but raise a "try/catch" error on GPU (CUDA). It is seems to be linked to the try catch here but I'm not able to figure out what could be the problem.
using Zygote, Lux,LuxCUDA,Random,CUDA, cuDNN
function mwe(dev)
rng = Random.default_rng()
D = Dense(5,1, relu)
ps, st = Lux.setup(rng, D) |> dev
x = rand(5, 2) |> dev
g(ps) = sum(abs2,only(gradient(x -> sum(first(D(x,ps,st))),x)))
gradient(x->g(x),ps)
end
(@v1.9) pkg> status Lux, Zygote, LuxCUDA
Status `~/.julia/environments/v1.9/Project.toml`
[b2108857] Lux v0.5.3
[d0bbae9a] LuxCUDA v0.3.0
[e88e6eb3] Zygote v0.6.63
julia> mwe(cpu_device())
((layer_1 = (weight = Float32[0.7818906 3.0653906 … 1.399884 -0.26843658; 0.0059568123 0.023353593 … 0.010664978 -0.0020450768; 0.47049373 1.8445637 … 0.8423642 -0.16152865], bias = Float32[0.0; 0.0; 0.0;;]), layer_2 = (weight = Float32[5.04391 -0.035778362 2.7771542], bias = nothing)),)
julia> mwe(gpu_device())
ERROR: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] instrument(ir::IRTools.Inner.IR)
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/reverse.jl:128
[3] #Primal#31
@ ~/.julia/packages/Zygote/4rucm/src/compiler/reverse.jl:227 [inlined]
[4] Primal
@ ~/.julia/packages/Zygote/4rucm/src/compiler/reverse.jl:226 [inlined]
[5] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/reverse.jl:352
[6] _generate_pullback_via_decomposition(T::Type, world::Nothing)
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/emit.jl:101
[7] _generate_pullback(::Type, ::Nothing, ::Type, ::Type, ::Vararg{Type})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:27
[8] #s86#1607
@ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:102 [inlined]
[9] var"#s86#1607"(::Any, ctx::Any, f::Any, args::Any)
@ Zygote ./none:0
[10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:602
[11] _pullback
@ ~/.julia/packages/CUDA/35NC6/src/compiler/execution.jl:310 [inlined]
[12] _pullback(::Zygote.Context{false}, ::typeof(cufunction), ::GPUArrays.var"#broadcast_kernel#26", ::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, Int64}})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
[13] macro expansion
@ ~/.julia/packages/CUDA/35NC6/src/compiler/execution.jl:104 [inlined]
[14] _pullback
@ ~/.julia/packages/CUDA/35NC6/src/gpuarrays.jl:17 [inlined]
[15] _pullback(::Zygote.Context{false}, ::CUDA.var"##launch_heuristic#1080", ::Int64, ::Int64, ::typeof(GPUArrays.launch_heuristic), ::CUDA.CuArrayBackend, ::GPUArrays.var"#broadcast_kernel#26", ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, ::Int64)
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
[16] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[17] adjoint
@ ~/.julia/packages/Zygote/4rucm/src/lib/lib.jl:203 [inlined]
[18] adjoint(::Zygote.Context{false}, ::typeof(Core._apply_iterate), ::typeof(iterate), ::Function, ::Tuple{Int64, Int64, typeof(GPUArrays.launch_heuristic), CUDA.CuArrayBackend, GPUArrays.var"#broadcast_kernel#26"}, ::Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, Int64})
@ Zygote ./none:0
[19] _pullback
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
[20] _pullback
@ ~/.julia/packages/CUDA/35NC6/src/gpuarrays.jl:15 [inlined]
[21] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:elements, :elements_per_thread), Tuple{Int64, Int64}}, ::typeof(GPUArrays.launch_heuristic), ::CUDA.CuArrayBackend, ::GPUArrays.var"#broadcast_kernel#26", ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}}, ::Int64)
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
[22] _pullback
@ ~/.julia/packages/GPUArrays/5XhED/src/host/broadcast.jl:65 [inlined]
[23] _pullback(::Zygote.Context{false}, ::typeof(GPUArrays._copyto!), ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
[24] _pullback
@ ~/.julia/packages/GPUArrays/5XhED/src/host/broadcast.jl:41 [inlined]
[25] _pullback
@ ./broadcast.jl:881 [inlined]
[26] _pullback(::Zygote.Context{false}, ::typeof(Base.Broadcast.materialize!), ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(identity), Tuple{Float32}})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
[27] _pullback
@ ./broadcast.jl:877 [inlined]
[28] _pullback
@ ~/.julia/packages/Zygote/4rucm/src/lib/broadcast.jl:369 [inlined]
[29] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#1453#1456"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, args::Float32)
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
[30] _pullback
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
[31] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#4229#back#1457"{Zygote.var"#1453#1456"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, args::Float32)
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
[32] _pullback
@ ./REPL[2]:7 [inlined]
...
Metadata
Metadata
Assignees
Labels
CUDAAll things GPUAll things GPUsecond orderzygote over zygote, or otherwisezygote over zygote, or otherwise