Skip to content

AssertionError in back! #118

@cscherrer

Description

@cscherrer

I'm trying to get this code working:

using Tilde, Umlaut, Yota, ChainRulesCore, ChainRules

# From https://github.com/dfdx/Yota.jl/issues/113
# Guess this should really go into ChainRules
function ChainRulesCore.rrule(::typeof(tuple), args...)
    y = tuple(args...)
    return y, dy -> (NoTangent(), collect(dy)...)
end

m = @model begin
    p ~ Uniform()
    x ~ Bernoulli(p)
end

post = m() | (x = true,)
(p) = logdensityof(post, (p=p,))

# Works fine
(0.2)

# Throws an error
grad(ℓ, 0.2)

The error I get is

ERROR: AssertionError: zval isa Number || zval isa AbstractArray
Stacktrace:
  [1] back!(tape::Tape{Yota.GradCtx}; seed::Symbol)
    @ Yota ~/git/Yota.jl/src/grad.jl:200
  [2] gradtape!(tape::Tape{Yota.GradCtx}; seed::Symbol)
    @ Yota ~/git/Yota.jl/src/grad.jl:223
  [3] #gradtape#90
    @ ~/git/Yota.jl/src/grad.jl:245 [inlined]
  [4] make_rrule(f::Function, args::Accessors.PropertyLens{:ℓ})
    @ Yota ~/git/Yota.jl/src/cr_api.jl:128
  [5] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Accessors.PropertyLens{:ℓ})
    @ Yota ~/git/Yota.jl/src/cr_api.jl:170
  [6] rrule(#unused#::Yota.YotaRuleConfig, #unused#::typeof(Core._apply_iterate), #unused#::typeof(iterate), f::typeof(∘), args::Tuple{Accessors.PropertyLens{:ℓ}})
    @ Yota ~/git/Yota.jl/src/rulesets.jl:29
  [7] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/git/Umlaut.jl/src/tape.jl:194
  [8] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/tape.jl:179
  [9] record_primitive!(::Tape{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Yota ~/git/Yota.jl/src/grad.jl:49
 [10] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:194
 [11] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Variable, ::Vararg{Variable})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
 [12] trace(f::Function, args::Accessors.PropertyLens{:ℓ}; ctx::Yota.GradCtx, fargtypes::Nothing, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:347
 [13] #gradtape#90
    @ ~/git/Yota.jl/src/grad.jl:244 [inlined]
 [14] make_rrule(f::Function, args::Accessors.PropertyLens{:ℓ})
    @ Yota ~/git/Yota.jl/src/cr_api.jl:128
 [15] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Accessors.PropertyLens{:ℓ})
    @ Yota ~/git/Yota.jl/src/cr_api.jl:170
 [16] rrule(#unused#::Yota.YotaRuleConfig, #unused#::typeof(Core._apply_iterate), #unused#::typeof(iterate), f::typeof(CompositionsBase.opcompose), args::Tuple{Accessors.PropertyLens{:ℓ}})
    @ Yota ~/git/Yota.jl/src/rulesets.jl:29
 [17] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/git/Umlaut.jl/src/tape.jl:194
 [18] mkcall(::Function, ::Yota.YotaRuleConfig, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/tape.jl:179
 [19] record_primitive!(::Tape{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Yota ~/git/Yota.jl/src/grad.jl:49
 [20] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:194
 [21] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
 [22] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:203
 [23] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
 [24] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:203
 [25] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
 [26] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:203
 [27] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
 [28] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:203
 [29] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
 [30] record_or_recurse!(::Umlaut.Tracer{Yota.GradCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:203
 [31] trace!(::Umlaut.Tracer{Yota.GradCtx}, ::Core.CodeInfo, ::Variable, ::Vararg{Variable})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:221
 [32] trace(f::Function, args::Float64; ctx::Yota.GradCtx, fargtypes::Nothing, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:347
 [33] #gradtape#90
    @ ~/git/Yota.jl/src/grad.jl:244 [inlined]
 [34] grad(f::typeof(ℓ), args::Float64; seed::Int64)
    @ Yota ~/git/Yota.jl/src/grad.jl:315
 [35] grad(f::typeof(ℓ), args::Float64)
    @ Yota ~/git/Yota.jl/src/grad.jl:307
 [36] top-level scope
    @ REPL[17]:1

By adding a @show in back!, I see the problem is that in this case, zval is Accessors.@optic _.ℓ, which has type Accessors.PropertyLens{:ℓ}.

Is this a bug? If not, could you give me some advice, for example how to know if this means I need another rrule? Teach a man to fish, etc :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions