Open
Description
When executing example from README, gs = gradient(x -> sum(tresnet(x)), tip);
throws ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")
minimal working example (simple Chain substituted for ResNet from README):
using Flux
using Torch
using Torch: torch
net = Chain(
Dense(10, 5, σ),
Dense(5, 2),
softmax)
tnet = net |> torch
ip = rand(Float32, 10, 1)
tip = tensor(ip, dev = 0)
gs = gradient(x -> sum(tnet(x)), tip)
Result:
ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")
Stacktrace:
[1] check_broadcast_shape(#unused#::Tuple{}, Ashp::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}})
@ Base.Broadcast ./broadcast.jl:518
[2] check_broadcast_axes
@ ./broadcast.jl:523 [inlined]
[3] check_broadcast_axes
@ ./broadcast.jl:526 [inlined]
[4] instantiate(bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{}, typeof(*), Tuple{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Tensor{Float32, 2}}})
@ Base.Broadcast ./broadcast.jl:269
[5] materialize!
@ ./broadcast.jl:894 [inlined]
[6] materialize!
@ ./broadcast.jl:891 [inlined]
[7] ∇softmax!(out::Tensor{Float32, 0}, Δ::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, x::Tensor{Float32, 2}, y::Tensor{Float32, 2}; dims::Int64)
@ NNlib ~/.julia/packages/NNlib/TOStL/src/softmax.jl:70
[8] ∇softmax(Δ::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, x::Tensor{Float32, 2}, y::Tensor{Float32, 2}; dims::Int64)
@ NNlib ~/.julia/packages/NNlib/TOStL/src/softmax.jl:62
[9] softmax_pullback
@ ~/.julia/packages/NNlib/TOStL/src/softmax.jl:81 [inlined]
[10] ZBack
@ ~/.julia/packages/Zygote/RxTZu/src/compiler/chainrules.jl:77 [inlined]
[11] Pullback
@ ~/.julia/packages/Flux/goUGu/src/layers/basic.jl:36 [inlined]
[12] (::typeof(∂(applychain)))(Δ::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/Flux/goUGu/src/layers/basic.jl:36 [inlined]
[14] (::typeof(∂(applychain)))(Δ::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[15] Pullback
@ ~/.julia/packages/Flux/goUGu/src/layers/basic.jl:36 [inlined]
[16] (::typeof(∂(applychain)))(Δ::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[17] Pullback
@ ~/.julia/packages/Flux/goUGu/src/layers/basic.jl:38 [inlined]
[18] (::typeof(∂(λ)))(Δ::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[19] Pullback
@ ./REPL[20]:1 [inlined]
[20] (::typeof(∂(#3)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[21] (::Zygote.var"#41#42"{typeof(∂(#3))})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:41
[22] gradient(f::Function, args::Tensor{Float32, 2})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:59
[23] top-level scope
Metadata
Metadata
Assignees
Labels
No labels