-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Description
I'm trying to get this code working:
using Tilde, Umlaut, Yota, ChainRulesCore, ChainRules
# From https://github.com/dfdx/Yota.jl/issues/113
# Guess this should really go into ChainRules
function ChainRulesCore.rrule(::typeof(tuple), args...)
y = tuple(args...)
return y, dy -> (NoTangent(), collect(dy)...)
end
m = @model begin
p ~ Uniform()
x ~ Bernoulli(p)
end
post = m() | (x = true,)
ℓ(p) = logdensityof(post, (p=p,))
# Works fine
ℓ(0.2)
# Throws an error
grad(ℓ, 0.2)
The error I get is
ERROR: AssertionError: zval isa Number || zval isa AbstractArray
Stacktrace:
[1] back!(tape::Tape{Yota.GradCtx}; seed::Symbol)
@ Yota ~/git/Yota.jl/src/grad.jl:200
[2] gradtape!(tape::Tape{Yota.GradCtx}; seed::Symbol)
@ Yota ~/git/Yota.jl/src/grad.jl:223
[3] #gradtape#90
@ ~/git/Yota.jl/src/grad.jl:245 [inlined]
[4] make_rrule(f::Function, args::Accessors.PropertyLens{:ℓ})
@ Yota ~/git/Yota.jl/src/cr_api.jl:128
[5] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Accessors.PropertyLens{:ℓ})
@ Yota ~/git/Yota.jl/src/cr_api.jl:170
[6] rrule(#unused#::Yota.YotaRuleConfig, #unused#::typeof(Core._apply_iterate), #unused#::typeof(iterate), f::typeof(∘), args::Tuple{Accessors.PropertyLens{:ℓ}})
@ Yota ~/git/Yota.jl/src/rulesets.jl:29
[7] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/git/Umlaut.jl/src/tape.jl:194
[8] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/tape.jl:179
[9] record_primitive!(::Tape{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Yota ~/git/Yota.jl/src/grad.jl:49
[10] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:194
[11] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Variable, ::Vararg{Variable})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
[12] trace(f::Function, args::Accessors.PropertyLens{:ℓ}; ctx::Yota.GradCtx, fargtypes::Nothing, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:347
[13] #gradtape#90
@ ~/git/Yota.jl/src/grad.jl:244 [inlined]
[14] make_rrule(f::Function, args::Accessors.PropertyLens{:ℓ})
@ Yota ~/git/Yota.jl/src/cr_api.jl:128
[15] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Accessors.PropertyLens{:ℓ})
@ Yota ~/git/Yota.jl/src/cr_api.jl:170
[16] rrule(#unused#::Yota.YotaRuleConfig, #unused#::typeof(Core._apply_iterate), #unused#::typeof(iterate), f::typeof(CompositionsBase.opcompose), args::Tuple{Accessors.PropertyLens{:ℓ}})
@ Yota ~/git/Yota.jl/src/rulesets.jl:29
[17] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/git/Umlaut.jl/src/tape.jl:194
[18] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/tape.jl:179
[19] record_primitive!(::Tape{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Yota ~/git/Yota.jl/src/grad.jl:49
[20] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:194
[21] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
[22] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:203
[23] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
[24] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:203
[25] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
[26] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:203
[27] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
[28] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:203
[29] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
[30] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:203
[31] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Variable, ::Vararg{Variable})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
[32] trace(f::Function, args::Float64; ctx::Yota.GradCtx, fargtypes::Nothing, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Umlaut ~/git/Umlaut.jl/src/trace.jl:347
[33] #gradtape#90
@ ~/git/Yota.jl/src/grad.jl:244 [inlined]
[34] grad(f::typeof(ℓ), args::Float64; seed::Int64)
@ Yota ~/git/Yota.jl/src/grad.jl:315
[35] grad(f::typeof(ℓ), args::Float64)
@ Yota ~/git/Yota.jl/src/grad.jl:307
[36] top-level scope
@ REPL[17]:1
By adding a @show
in back!
, I see the problem is that in this case, zval
is Accessors.@optic _.ℓ
, which has type Accessors.PropertyLens{:ℓ}
.
Is this a bug? If not, could you give me some advice, for example how to know if this means I need another rrule
? Teach a man to fish, etc :)
Metadata
Metadata
Assignees
Labels
No labels