You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to fix inference issues within a Zygote gradient for DynamicExpressions.jl – with the goal of using fast AD in SymbolicRegression.jl. Right now Zygote.gradient is inferring Any as a return value and I can't figure out why. The weird thing is that I can infer fine on internal functions (which have a custom chain rule). It's only the outermost wrapper function that fails to infer.
Context – for reference you can get the same version of the package I'm debugging with the following command. I have tried to produce a smaller MWE but failed to split it further than what I show below.
Basically I have this recursive binary tree structure Node{T} (docs). I don't want Zygote to try to walk through the whole tree, and turn it into a tuple — which would be hugely inefficient — so instead, I have this custom NodeTangent type (in src/ChainRules.jl)
I no longer get this successful inference. Here is the wrapper method of the evaluation:
function eval_tree_array(
ex::AbstractExpression,
cX::AbstractMatrix,
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
kws...,
)
return eval_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
end
So it basically just unpacks ex -> ex.tree and ex -> ex.metadata.operators.
Now, say that I try to take the gradient of this instead. Unlike the internal eval_tree_array call, this one I do not define a custom chain rule for (since the wrapper call is simple).
julia> ex = Expression(tree; operators, variable_names=["x1", "x2"])
x1 * cos(x2 - 3.2)
julia> Zygote.gradient(ex -> eval_tree_array(ex, ones(2, 1))[1][1], ex)
((tree = NodeTangent{Float64, Node{Float64}, Vector{Float64}}(x1 * cos(x2 - 3.2), [-0.8084964038195901]), metadata = nothing),)
julia> Test.@inferred Zygote.gradient(ex -> eval_tree_array(ex, ones(2, 1))[1][1], ex)
ERROR: return type Tuple{@NamedTuple{tree::NodeTangent{Float64, Node{Float64}, Vector{Float64}}, metadata::Nothing}} does not match inferred return type Tuple{Any}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] top-level scope
@ REPL[25]:1
Even though all eval_tree_array(::AbstractExpression is doing is some getproperty calls before passing to a call – which I know works – inference on this wrapper fails.
Questions:
Any guesses as to what the issue is from?
Do I need to define a custom tangent type for Expression, and what's the actual interface for AbstractTangent?
For the record, I did try creating a zero_tangent(::Expression), but this didn't seem to fix the issue. Maybe there's some other function I need to define?
Perhaps I need to declare NodeTangent for some other function symbols so that Zygote doesn't try descending at the outermost call? And if so, what methods need to be implemented?
(General) How does one go about debugging type inference issues in Zygote, when type inference on the primal is fine? I can't seem to use Cthulhu.jl effectively though perhaps I am descending the wrong tree.
I'm trying to fix inference issues within a Zygote gradient for DynamicExpressions.jl – with the goal of using fast AD in SymbolicRegression.jl. Right now
Zygote.gradient
is inferringAny
as a return value and I can't figure out why. The weird thing is that I can infer fine on internal functions (which have a custom chain rule). It's only the outermost wrapper function that fails to infer.Context – for reference you can get the same version of the package I'm debugging with the following command. I have tried to produce a smaller MWE but failed to split it further than what I show below.
Basically I have this recursive binary tree structure
Node{T}
(docs). I don't want Zygote to try to walk through the whole tree, and turn it into a tuple — which would be hugely inefficient — so instead, I have this customNodeTangent
type (insrc/ChainRules.jl
)I then have a chain rule for evaluation which returns this
NodeTangent
, defined as follows:This actually works fine. I can get derivatives that are correct and inference seems good:
and it returns a
NodeTangent
which prevents Zygote from walking the tree.However, when I then try to use my new
Expression
type, which is nothing but aNode{T}
plus a named tuple of operators and variable names:I no longer get this successful inference. Here is the wrapper method of the evaluation:
So it basically just unpacks
ex -> ex.tree
andex -> ex.metadata.operators
.Now, say that I try to take the gradient of this instead. Unlike the internal
eval_tree_array
call, this one I do not define a custom chain rule for (since the wrapper call is simple).Even though all
eval_tree_array(::AbstractExpression
is doing is somegetproperty
calls before passing to a call – which I know works – inference on this wrapper fails.Questions:
Expression
, and what's the actual interface forAbstractTangent
?zero_tangent(::Expression)
, but this didn't seem to fix the issue. Maybe there's some other function I need to define?NodeTangent
for some other function symbols so that Zygote doesn't try descending at the outermost call? And if so, what methods need to be implemented?X-post from https://discourse.julialang.org/t/problems-with-ad-inference-on-wrapper-function/116454?u=milescranmer in the hope I can find the right person to help debug this. This is very important for my work so I can collect as much info as you need, please ask away.
The text was updated successfully, but these errors were encountered: