-
Notifications
You must be signed in to change notification settings - Fork 12
Description
Here is a list of some operations that did not work for me. I wonder about the errors that involve ChainRules in their message? For instance, in the sum example, I guess we are tracing too deep into the sum implementation. E.g. there exists a more high level sum rule:
https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl#L9
@dfdx Maybe the I realize that trace
used for gradtape
should have an is_primitive that checks if the signature is covered by an rrule?Yota.is_primitive
!== Ghost.is_primitive
and there is already such a rule. I think the issue is
what should happen when one starts tracing with a call that is already primitive. Not obvious whats the best design. Currently, such a call is entered anyway, this is why e.g. sum([1.0])
fails.
################################################################################
Yota.gradtape(sum, [1.0])
fails
No deriative rule found for op %42 = mapreduce(identity, add_sum, %2)::Float64, try defining it us
ing ChainRules.rrule(::typeof(mapreduce), ::typeof(identity), ::typeof(Base.add_sum), ::Vector{Flo
at64}) = ...
################################################################################
Yota.gradtape(sum, abs2, [1.0])
fails
No deriative rule found for op %30 = mapreduce(%2, add_sum, %3)::Float64, try defining it using Ch
ainRules.rrule(::typeof(mapreduce), ::typeof(abs2), ::typeof(Base.add_sum), ::Vector{Float64}) = .
..
################################################################################
Yota.gradtape(identity, 1.0)
fails
MethodError: no method matching call_signature(::Tape{Yota.GradCtx}, ::Ghost.Input)
Closest candidates are:
call_signature(::Tape, ::Ghost.Call) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:516
################################################################################
Yota.gradtape(sin, 1.0)
fails
MethodError: Cannot `convert` an object of type Float64 to an object of type Ghost.Variable
Closest candidates are:
convert(::Type{T}, ::T) where T at essentials.jl:205
Ghost.Variable(::Any, ::Any) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:22
################################################################################
Yota.gradtape(*, 1.0)
fails
MethodError: no method matching call_signature(::Tape{Yota.GradCtx}, ::Ghost.Input)
Closest candidates are:
call_signature(::Tape, ::Ghost.Call) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:516
################################################################################
Yota.gradtape(*, 1.0, 2.0)
fails
No deriative rule found for op %4 = mul_float(%2, %3)::Float64, try defining it using ChainRules.r
rule(::Core.IntrinsicFunction, ::Float64, ::Float64) = ...