From 1b542bca424b34908fbcf8c634d8abccef3b7146 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 27 Jun 2024 19:11:16 +0100 Subject: [PATCH 01/74] testing out n-arity implementation with SizedVector --- Project.toml | 1 + src/DynamicExpressions.jl | 2 +- src/Node.jl | 162 ++++++++++++++------------------ src/NodeUtils.jl | 20 ++-- src/OperatorEnumConstruction.jl | 8 +- src/ParametricExpression.jl | 10 +- 6 files changed, 90 insertions(+), 113 deletions(-) diff --git a/Project.toml b/Project.toml index 20faf9f0..bb7c8e90 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 0eebeadc..be8c76f6 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -107,5 +107,5 @@ end @ignore include("../test/runtests.jl") include("precompile.jl") -do_precompilation(; mode=:precompile) +# do_precompilation(; mode=:precompile) end diff --git a/src/Node.jl b/src/Node.jl index 303c8a93..eeb5e246 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -1,6 +1,7 @@ module NodeModule using DispatchDoctor: @unstable +using StaticArrays: SizedVector import ..OperatorEnumModule: AbstractOperatorEnum import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined @@ -8,25 +9,30 @@ import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined const DEFAULT_NODE_TYPE = Float32 """ - AbstractNode + AbstractNode{D,shared} -Abstract type for binary trees. Must have the following fields: +Abstract type for D-arity trees. If `shared`, the node type +permits graph-like structures. Must have the following fields: - `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1, then `l` needs to be defined as the left child. If 2, then `r` also needs to be defined as the right child. -- `l::AbstractNode`: Left child of the current node. Should only be +- `children`: A collection of D children nodes. + +# Deprecated fields + +- `l::AbstractNode{D}`: Left child of the current node. Should only be defined if `degree >= 1`; otherwise, leave it undefined (see the the constructors of [`Node{T}`](@ref) for an example). Don't use `nothing` to represent an undefined value as it will incur a large performance penalty. -- `r::AbstractNode`: Right child of the current node. Should only +- `r::AbstractNode{D}`: Right child of the current node. Should only be defined if `degree == 2`. """ -abstract type AbstractNode end +abstract type AbstractNode{D,shared} end """ - AbstractExpressionNode{T} <: AbstractNode + AbstractExpressionNode{T,D} <: AbstractNode{D} Abstract type for nodes that represent an expression. Along with the fields required for `AbstractNode`, @@ -67,11 +73,25 @@ You likely do not need to, but you could choose to override the following: - `with_type_parameters` """ -abstract type AbstractExpressionNode{T} <: AbstractNode end +abstract type AbstractExpressionNode{T,D,shared} <: AbstractNode{D,shared} end + +mutable struct GeneralNode{T,D,shared} <: AbstractExpressionNode{T,D,shared} + degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. + constant::Bool # false if variable + val::T # If is a constant, this stores the actual value + feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. + op::UInt8 # If operator, this is the index of the operator in the degree-specific operator enum + children::SizedVector{D,GeneralNode{T,D,shared}} # Children nodes + + ################# + ## Constructors: + ################# + GeneralNode{_T,_D,_shared}() where {_T,_D,_shared} = new{_T,_D,_shared}() +end #! format: off """ - Node{T} <: AbstractExpressionNode{T} + Node{T} <: AbstractExpressionNode{T,2} Node defines a symbolic expression stored in a binary tree. A single `Node` instance is one "node" of this tree, and @@ -81,58 +101,37 @@ nodes, you can evaluate or print a given expression. # Fields - `degree::UInt8`: Degree of the node. 0 for constants, 1 for - unary operators, 2 for binary operators. + unary operators, 2 for binary operators, etc. Maximum of `D`. - `constant::Bool`: Whether the node is a constant. - `val::T`: Value of the node. If `degree==0`, and `constant==true`, this is the value of the constant. It has a type specified by the overall type of the `Node` (e.g., `Float64`). - `feature::UInt16`: Index of the feature to use in the - case of a feature node. Only used if `degree==0` and `constant==false`. - Only defined if `degree == 0 && constant == false`. + case of a feature node. Only defined if `degree == 0 && constant == false`. - `op::UInt8`: If `degree==1`, this is the index of the operator in `operators.unaops`. If `degree==2`, this is the index of the operator in `operators.binops`. In other words, this is an enum of the operators, and is dependent on the specific `OperatorEnum` object. Only defined if `degree >= 1` -- `l::Node{T}`: Left child of the node. Only defined if `degree >= 1`. - Same type as the parent node. -- `r::Node{T}`: Right child of the node. Only defined if `degree == 2`. - Same type as the parent node. This is to be passed as the right - argument to the binary operator. +- `children::SizedArray{D,Node{T,D}}`: Children of the node. Only defined up to `degree` # Constructors - Node([T]; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator) - Node{T}(; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator=default_allocator) + Node([T]; val=nothing, feature=nothing, op=nothing, children=nothing, allocator=default_allocator) + Node{T}(; val=nothing, feature=nothing, op=nothing, children=nothing, allocator=default_allocator) Create a new node in an expression tree. If `T` is not specified in either the type or the -first argument, it will be inferred from the value of `val` passed or `l` and/or `r`. -If it cannot be inferred from these, it will default to `Float32`. - -The `children` keyword can be used instead of `l` and `r` and should be a tuple of children. This -is to permit the use of splatting in constructors. +first argument, it will be inferred from the value of `val` passed or the children. +The `children` keyword is used to pass in a collection of children nodes. You may also construct nodes via the convenience operators generated by creating an `OperatorEnum`. You may also choose to specify a default memory allocator for the node other than simply `Node{T}()` in the `allocator` keyword argument. """ -mutable struct Node{T} <: AbstractExpressionNode{T} - degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. - constant::Bool # false if variable - val::T # If is a constant, this stores the actual value - # ------------------- (possibly undefined below) - feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. - op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops - l::Node{T} # Left child node. Only defined for degree=1 or degree=2. - r::Node{T} # Right child node. Only defined for degree=2. +const Node{T} = GeneralNode{T,2,false} - ################# - ## Constructors: - ################# - Node{_T}() where {_T} = new{_T}() -end """ GraphNode{T} <: AbstractExpressionNode{T} @@ -146,7 +145,7 @@ be performed with this assumption, to preserve structure of the graph. ```julia julia> operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos, sin] - ); + ); julia> x = GraphNode(feature=1) x1 @@ -165,18 +164,7 @@ This has the same constructors as [`Node{T}`](@ref). Shared nodes are created simply by using the same node in multiple places when constructing or setting properties. """ -mutable struct GraphNode{T} <: AbstractExpressionNode{T} - degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. - constant::Bool # false if variable - val::T # If is a constant, this stores the actual value - # ------------------- (possibly undefined below) - feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. - op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops - l::GraphNode{T} # Left child node. Only defined for degree=1 or degree=2. - r::GraphNode{T} # Right child node. Only defined for degree=2. - - GraphNode{_T}() where {_T} = new{_T}() -end +const GraphNode{T} = GeneralNode{T,2,true} ################################################################################ #! format: on @@ -184,49 +172,41 @@ end Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T Base.eltype(::AbstractExpressionNode{T}) where {T} = T -@unstable constructorof(::Type{N}) where {N<:AbstractNode} = Base.typename(N).wrapper +function max_degree(::Type{N}) where {N<:AbstractExpressionNode} + return (N isa UnionAll ? N.body : N).parameters[2] +end + @unstable constructorof(::Type{<:Node}) = Node @unstable constructorof(::Type{<:GraphNode}) = GraphNode -function with_type_parameters(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T} - return constructorof(N){T} -end with_type_parameters(::Type{<:Node}, ::Type{T}) where {T} = Node{T} with_type_parameters(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T} -function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T} - return with_type_parameters(N, T)() -end default_allocator(::Type{<:Node}, ::Type{T}) where {T} = Node{T}() default_allocator(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}() """Trait declaring whether nodes share children or not.""" preserve_sharing(::Union{Type{<:AbstractNode},AbstractNode}) = false -preserve_sharing(::Union{Type{<:Node},Node}) = false -preserve_sharing(::Union{Type{<:GraphNode},GraphNode}) = true +function preserve_sharing( + ::Union{Type{<:G},G} +) where {shared,G<:GeneralNode{T,D,shared} where {T,D}} + return shared +end include("base.jl") #! format: off @inline function (::Type{N})( - ::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator, + ::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, children=nothing, allocator::F=default_allocator, ) where {T1,N<:AbstractExpressionNode,F} - validate_not_all_defaults(N, val, feature, op, l, r, children) - if children !== nothing - @assert l === nothing && r === nothing - if length(children) == 1 - return node_factory(N, T1, val, feature, op, only(children), nothing, allocator) - else - return node_factory(N, T1, val, feature, op, children..., allocator) - end - end - return node_factory(N, T1, val, feature, op, l, r, allocator) + validate_not_all_defaults(N, val, feature, op, children) + return node_factory(N, T1, val, feature, op, children, allocator) end -function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {N<:AbstractExpressionNode} +function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {N<:AbstractExpressionNode} return nothing end -function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) where {T,N<:AbstractExpressionNode{T}} - if val === nothing && feature === nothing && op === nothing && l === nothing && r === nothing && children === nothing +function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {T,N<:AbstractExpressionNode{T}} + if val === nothing && feature === nothing && op === nothing && children === nothing error( "Encountered the call for $N() inside the generic constructor. " * "Did you forget to define `$(Base.typename(N).wrapper){T}() where {T} = new{T}()`?" @@ -236,7 +216,7 @@ function validate_not_all_defaults(::Type{N}, val, feature, op, l, r, children) end """Create a constant leaf.""" @inline function node_factory( - ::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, ::Nothing, allocator::F, + ::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,T2,F} T = node_factory_type(N, T1, T2) n = allocator(N, T) @@ -247,7 +227,7 @@ end end """Create a variable leaf, to store data.""" @inline function node_factory( - ::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, ::Nothing, allocator::F, + ::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,F} T = node_factory_type(N, T1, DEFAULT_NODE_TYPE) n = allocator(N, T) @@ -256,28 +236,22 @@ end n.feature = feature return n end -"""Create a unary operator node.""" +"""Create an operator node.""" @inline function node_factory( - ::Type{N}, ::Type{T1}, ::Nothing, ::Nothing, op::Integer, l::AbstractExpressionNode{T2}, ::Nothing, allocator::F, -) where {N,T1,T2,F} - @assert l isa N - T = T2 # Always prefer existing nodes, so we don't mess up references from conversion + ::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::NTuple{D2}, allocator::F, +) where {N,F,D2} + D = max_degree(N) + @assert D2 <= D + T = promote_type(map(eltype, children)...) # Always prefer existing nodes, so we don't mess up references from conversion + NT = with_type_parameters(N, T) n = allocator(N, T) - n.degree = 1 + n.degree = D2 n.op = op - n.l = l - return n -end -"""Create a binary operator node.""" -@inline function node_factory( - ::Type{N}, ::Type{T1}, ::Nothing, ::Nothing, op::Integer, l::AbstractExpressionNode{T2}, r::AbstractExpressionNode{T3}, allocator::F, -) where {N,T1,T2,T3,F} - T = promote_type(T2, T3) - n = allocator(N, T) - n.degree = 2 - n.op = op - n.l = T2 === T ? l : convert(with_type_parameters(N, T), l) - n.r = T3 === T ? r : convert(with_type_parameters(N, T), r) + ar = SizedVector{D,NT}(undef) + for i in 1:D2 + ar[i] = children[i] + end + n.children = ar return n end diff --git a/src/NodeUtils.jl b/src/NodeUtils.jl index ae331f26..17b4e898 100644 --- a/src/NodeUtils.jl +++ b/src/NodeUtils.jl @@ -1,9 +1,11 @@ module NodeUtilsModule +using StaticArrays: MVector import Compat: Returns import ..NodeModule: AbstractNode, AbstractExpressionNode, + GeneralNode, Node, preserve_sharing, constructorof, @@ -98,18 +100,18 @@ end ## Assign index to nodes of a tree # This will mirror a Node struct, rather # than adding a new attribute to Node. -struct NodeIndex{T} <: AbstractNode +struct NodeIndex{T,D} <: AbstractNode{D,false} degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. val::T # If is a constant, this stores the actual value # ------------------- (possibly undefined below) - l::NodeIndex{T} # Left child node. Only defined for degree=1 or degree=2. - r::NodeIndex{T} # Right child node. Only defined for degree=2. - - NodeIndex(::Type{_T}) where {_T} = new{_T}(0, zero(_T)) - NodeIndex(::Type{_T}, val) where {_T} = new{_T}(0, convert(_T, val)) - NodeIndex(::Type{_T}, l::NodeIndex) where {_T} = new{_T}(1, zero(_T), l) - function NodeIndex(::Type{_T}, l::NodeIndex, r::NodeIndex) where {_T} - return new{_T}(2, zero(_T), l, r) + children::MVector{D,NodeIndex{T,D}} + + NodeIndex(::Type{_T}, ::Type{_D}) where {_T,_D} = new{_T,_D}(0, zero(_T)) + NodeIndex(::Type{_T}, ::Type{_D}, val) where {_T,_D} = new{_T,_D}(0, convert(_T, val)) + function NodeIndex(::Type{_T}, ::Type{_D}, children::Vararg{Any,_D2}) where {_T,_D,_D2} + _children = MVector{_D,NodeIndex{_T,_D}}(undef) + _children[begin:_D2] = children + return new{_T,_D}(1, zero(_T), _children) end end # Sharing is never needed for NodeIndex, diff --git a/src/OperatorEnumConstruction.jl b/src/OperatorEnumConstruction.jl index 919dea10..3c17ecb7 100644 --- a/src/OperatorEnumConstruction.jl +++ b/src/OperatorEnumConstruction.jl @@ -141,7 +141,7 @@ function _extend_unary_operator(f::Symbol, type_requirements, internal) $_constructorof(N)(T; val=$($f)(l.val)) else latest_op_idx = $($lookup_op)($($f), Val(1)) - $_constructorof(N)(; op=latest_op_idx, l) + $_constructorof(N)(; op=latest_op_idx, children=(l,)) end end end @@ -168,7 +168,7 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters, $_constructorof(N)(T; val=$($f)(l.val, r.val)) else latest_op_idx = $($lookup_op)($($f), Val(2)) - $_constructorof(N)(; op=latest_op_idx, l, r) + $_constructorof(N)(; op=latest_op_idx, children=(l, r)) end end function $($f)( @@ -179,7 +179,7 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters, else latest_op_idx = $($lookup_op)($($f), Val(2)) $_constructorof(N)(; - op=latest_op_idx, l, r=$_constructorof(N)(T; val=r) + op=latest_op_idx, children=(l, $_constructorof(N)(T; val=r)) ) end end @@ -191,7 +191,7 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters, else latest_op_idx = $($lookup_op)($($f), Val(2)) $_constructorof(N)(; - op=latest_op_idx, l=$_constructorof(N)(T; val=l), r + op=latest_op_idx, children=($_constructorof(N)(T; val=l), r) ) end end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 383af13d..d7960943 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -1,6 +1,7 @@ module ParametricExpressionModule using DispatchDoctor: @stable, @unstable +using StaticArrays: MVector using ..OperatorEnumModule: AbstractOperatorEnum using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce @@ -29,7 +30,7 @@ import ..ExpressionModule: import ..ParseModule: parse_leaf """A type of expression node that also stores a parameter index""" -mutable struct ParametricNode{T} <: AbstractExpressionNode{T} +mutable struct ParametricNode{T,D,shared} <: AbstractExpressionNode{T,D,shared} degree::UInt8 constant::Bool # if true => constant; if false, then check `is_parameter` val::T @@ -39,11 +40,10 @@ mutable struct ParametricNode{T} <: AbstractExpressionNode{T} parameter::UInt16 # Stores index of per-class parameter op::UInt8 - l::ParametricNode{T} - r::ParametricNode{T} + children::MVector{D,ParametricNode{T,D}} # Children nodes - function ParametricNode{_T}() where {_T} - n = new{_T}() + function ParametricNode{_T,_D,_shared}() where {_T,_D,_shared} + n = new{_T,_D,_shared}() n.is_parameter = false n.parameter = UInt16(0) return n From ab7c65ca810672fcacc65ba826b899e3bab21ad1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jul 2024 11:52:01 +0100 Subject: [PATCH 02/74] wip --- Project.toml | 1 - src/Node.jl | 108 +++++++++++++++++++++++++++------------------------ 2 files changed, 57 insertions(+), 52 deletions(-) diff --git a/Project.toml b/Project.toml index b2216e58..3f5df1a5 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,6 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [weakdeps] diff --git a/src/Node.jl b/src/Node.jl index eeb5e246..d1102168 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -1,7 +1,6 @@ module NodeModule using DispatchDoctor: @unstable -using StaticArrays: SizedVector import ..OperatorEnumModule: AbstractOperatorEnum import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined @@ -9,15 +8,14 @@ import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined const DEFAULT_NODE_TYPE = Float32 """ - AbstractNode{D,shared} + AbstractNode{D} -Abstract type for D-arity trees. If `shared`, the node type -permits graph-like structures. Must have the following fields: +Abstract type for D-arity trees. Must have the following fields: - `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1, then `l` needs to be defined as the left child. If 2, then `r` also needs to be defined as the right child. -- `children`: A collection of D children nodes. +- `children`: A collection of D references to children nodes. # Deprecated fields @@ -29,7 +27,7 @@ permits graph-like structures. Must have the following fields: - `r::AbstractNode{D}`: Right child of the current node. Should only be defined if `degree == 2`. """ -abstract type AbstractNode{D,shared} end +abstract type AbstractNode{D} end """ AbstractExpressionNode{T,D} <: AbstractNode{D} @@ -73,25 +71,27 @@ You likely do not need to, but you could choose to override the following: - `with_type_parameters` """ -abstract type AbstractExpressionNode{T,D,shared} <: AbstractNode{D,shared} end - -mutable struct GeneralNode{T,D,shared} <: AbstractExpressionNode{T,D,shared} - degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. - constant::Bool # false if variable - val::T # If is a constant, this stores the actual value - feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index. - op::UInt8 # If operator, this is the index of the operator in the degree-specific operator enum - children::SizedVector{D,GeneralNode{T,D,shared}} # Children nodes - - ################# - ## Constructors: - ################# - GeneralNode{_T,_D,_shared}() where {_T,_D,_shared} = new{_T,_D,_shared}() +abstract type AbstractExpressionNode{T,D} <: AbstractNode{D} end + +for N in (:Node, :GraphNode) + @eval mutable struct $N{T,D} <: AbstractExpressionNode{T,D} + degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. + constant::Bool # false if variable + val::T # If is a constant, this stores the actual value + feature::UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index. + op::UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum + children::NTuple{D,Base.RefValue{$N{T,D}}} # Children nodes + + ################# + ## Constructors: + ################# + $N{_T,_D}() where {_T,_D} = new{_T,_D}() + end end #! format: off """ - Node{T} <: AbstractExpressionNode{T,2} + Node{T,D} <: AbstractExpressionNode{T,D} Node defines a symbolic expression stored in a binary tree. A single `Node` instance is one "node" of this tree, and @@ -113,7 +113,7 @@ nodes, you can evaluate or print a given expression. operator in `operators.binops`. In other words, this is an enum of the operators, and is dependent on the specific `OperatorEnum` object. Only defined if `degree >= 1` -- `children::SizedArray{D,Node{T,D}}`: Children of the node. Only defined up to `degree` +- `children::NTuple{D,Base.RefValue{Node{T,D}}}`: Children of the node. Only defined up to `degree` # Constructors @@ -130,13 +130,13 @@ You may also construct nodes via the convenience operators generated by creating You may also choose to specify a default memory allocator for the node other than simply `Node{T}()` in the `allocator` keyword argument. """ -const Node{T} = GeneralNode{T,2,false} +Node """ - GraphNode{T} <: AbstractExpressionNode{T} + GraphNode{T,D} <: AbstractExpressionNode{T,D} -Exactly the same as [`Node{T}`](@ref), but with the assumption that some +Exactly the same as [`Node{T,D}`](@ref), but with the assumption that some nodes will be shared. All copies of this graph-like structure will be performed with this assumption, to preserve structure of the graph. @@ -164,7 +164,7 @@ This has the same constructors as [`Node{T}`](@ref). Shared nodes are created simply by using the same node in multiple places when constructing or setting properties. """ -const GraphNode{T} = GeneralNode{T,2,true} +GraphNode ################################################################################ #! format: on @@ -177,30 +177,41 @@ function max_degree(::Type{N}) where {N<:AbstractExpressionNode} end @unstable constructorof(::Type{<:Node}) = Node +@unstable constructorof(::Type{<:Node{T,D} where T}) where {D} = Node{T,D} where T @unstable constructorof(::Type{<:GraphNode}) = GraphNode +@unstable constructorof(::Type{<:GraphNode{T,D} where T}) where {D} = GraphNode{T,D} where T -with_type_parameters(::Type{<:Node}, ::Type{T}) where {T} = Node{T} -with_type_parameters(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T} +with_type_parameters(::Type{<:Node}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = Node{T,D} +with_type_parameters(::Type{<:GraphNode}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = GraphNode{T,D} -default_allocator(::Type{<:Node}, ::Type{T}) where {T} = Node{T}() -default_allocator(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}() +default_allocator(::Type{<:Node}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = Node{T,D}() +default_allocator(::Type{<:GraphNode}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = GraphNode{T,D}() """Trait declaring whether nodes share children or not.""" preserve_sharing(::Union{Type{<:AbstractNode},AbstractNode}) = false -function preserve_sharing( - ::Union{Type{<:G},G} -) where {shared,G<:GeneralNode{T,D,shared} where {T,D}} - return shared -end +preserve_sharing(::Union{Type{<:GraphNode},GraphNode}) = true include("base.jl") #! format: off @inline function (::Type{N})( - ::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, children=nothing, allocator::F=default_allocator, + ::Type{T1}=Undefined; kws... ) where {T1,N<:AbstractExpressionNode,F} - validate_not_all_defaults(N, val, feature, op, children) - return node_factory(N, T1, val, feature, op, children, allocator) +end +@inline function (::Type{N})( + ::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator, +) where {T1,D,N<:AbstractExpressionNode{T,D} where T,F} + _children = if l !== nothing && r === nothing + @assert children === nothing + (l,) + elseif l !== nothing && r !== nothing + @assert children === nothing + (l, r) + else + children + end + validate_not_all_defaults(N, val, feature, op, _children) + return node_factory(N, T1, val, feature, op, _children, allocator) end function validate_not_all_defaults(::Type{N}, val, feature, op, children) where {N<:AbstractExpressionNode} return nothing @@ -209,7 +220,7 @@ function validate_not_all_defaults(::Type{N}, val, feature, op, children) where if val === nothing && feature === nothing && op === nothing && children === nothing error( "Encountered the call for $N() inside the generic constructor. " - * "Did you forget to define `$(Base.typename(N).wrapper){T}() where {T} = new{T}()`?" + * "Did you forget to define `$(Base.typename(N).wrapper){T,D}() where {T,D} = new{T,D}()`?" ) end return nothing @@ -219,7 +230,7 @@ end ::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,T2,F} T = node_factory_type(N, T1, T2) - n = allocator(N, T) + n = allocator(N, T, D) n.degree = 0 n.constant = true n.val = convert(T, val) @@ -230,7 +241,7 @@ end ::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,F} T = node_factory_type(N, T1, DEFAULT_NODE_TYPE) - n = allocator(N, T) + n = allocator(N, T, D) n.degree = 0 n.constant = false n.feature = feature @@ -239,19 +250,14 @@ end """Create an operator node.""" @inline function node_factory( ::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::NTuple{D2}, allocator::F, -) where {N,F,D2} - D = max_degree(N) - @assert D2 <= D +) where {D,N<:AbstractExpressionNode{T where T,D},F,D2} T = promote_type(map(eltype, children)...) # Always prefer existing nodes, so we don't mess up references from conversion - NT = with_type_parameters(N, T) - n = allocator(N, T) + NT = with_type_parameters(N, T, D) + n = allocator(N, T, D) n.degree = D2 n.op = op - ar = SizedVector{D,NT}(undef) - for i in 1:D2 - ar[i] = children[i] - end - n.children = ar + n.children + # map(Ref, children) return n end From 3ed6b41067c48f8e374f0c198ce3d6203fab8185 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jul 2024 12:38:39 +0100 Subject: [PATCH 03/74] fix: various aspects of degree interface --- src/Node.jl | 95 +++++++++++++++++++++++++------------ src/NodeUtils.jl | 33 +++++++------ src/ParametricExpression.jl | 9 ++-- 3 files changed, 86 insertions(+), 51 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index d1102168..b3731f85 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -85,7 +85,7 @@ for N in (:Node, :GraphNode) ################# ## Constructors: ################# - $N{_T,_D}() where {_T,_D} = new{_T,_D}() + $N{_T,_D}() where {_T,_D} = new{_T,_D::Int}() end end @@ -166,26 +166,62 @@ when constructing or setting properties. """ GraphNode +@inline function Base.getproperty(n::Union{Node,GraphNode}, k::Symbol) + if k == :l + # TODO: Should a depwarn be raised here? Or too slow? + return getfield(n, :children)[1][] + elseif k == :r + return getfield(n, :children)[2][] + else + return getfield(n, k) + end +end +@inline function Base.setproperty!(n::Union{Node,GraphNode}, k::Symbol, v) + if k == :l + getfield(n, :children)[1][] = v + elseif k == :r + getfield(n, :children)[2][] = v + elseif k == :degree + setfield!(n, :degree, convert(UInt8, v)) + elseif k == :constant + setfield!(n, :constant, convert(Bool, v)) + elseif k == :feature + setfield!(n, :feature, convert(UInt16, v)) + elseif k == :op + setfield!(n, :op, convert(UInt8, v)) + elseif k == :val + setfield!(n, :val, convert(eltype(n), v)) + elseif k == :children + setfield!(n, :children, v) + else + error("Invalid property: $k") + end +end + ################################################################################ #! format: on Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T Base.eltype(::AbstractExpressionNode{T}) where {T} = T -function max_degree(::Type{N}) where {N<:AbstractExpressionNode} - return (N isa UnionAll ? N.body : N).parameters[2] -end +max_degree(::Type{<:AbstractNode}) = 2 # Default +max_degree(::Type{<:AbstractNode{D}}) where {D} = D + +@unstable constructorof(::Type{N}) where {N<:Node} = Node{T,max_degree(N)} where {T} +@unstable constructorof(::Type{N}) where {N<:GraphNode} = + GraphNode{T,max_degree(N)} where {T} -@unstable constructorof(::Type{<:Node}) = Node -@unstable constructorof(::Type{<:Node{T,D} where T}) where {D} = Node{T,D} where T -@unstable constructorof(::Type{<:GraphNode}) = GraphNode -@unstable constructorof(::Type{<:GraphNode{T,D} where T}) where {D} = GraphNode{T,D} where T +with_type_parameters(::Type{N}, ::Type{T}) where {N<:Node,T} = Node{T,max_degree(N)} +function with_type_parameters(::Type{N}, ::Type{T}) where {N<:GraphNode,T} + return GraphNode{T,max_degree(N)} +end -with_type_parameters(::Type{<:Node}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = Node{T,D} -with_type_parameters(::Type{<:GraphNode}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = GraphNode{T,D} +# with_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D} +# with_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D} -default_allocator(::Type{<:Node}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = Node{T,D}() -default_allocator(::Type{<:GraphNode}, ::Type{T}, ::Val{D}=Val(2)) where {T,D} = GraphNode{T,D}() +function default_allocator(::Type{N}, ::Type{T}) where {N<:Union{Node,GraphNode},T} + return with_type_parameters(N, T)() +end """Trait declaring whether nodes share children or not.""" preserve_sharing(::Union{Type{<:AbstractNode},AbstractNode}) = false @@ -194,13 +230,9 @@ preserve_sharing(::Union{Type{<:GraphNode},GraphNode}) = true include("base.jl") #! format: off -@inline function (::Type{N})( - ::Type{T1}=Undefined; kws... -) where {T1,N<:AbstractExpressionNode,F} -end @inline function (::Type{N})( ::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator, -) where {T1,D,N<:AbstractExpressionNode{T,D} where T,F} +) where {T1,N<:AbstractExpressionNode{T} where T,F} _children = if l !== nothing && r === nothing @assert children === nothing (l,) @@ -230,7 +262,7 @@ end ::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,T2,F} T = node_factory_type(N, T1, T2) - n = allocator(N, T, D) + n = allocator(N, T) n.degree = 0 n.constant = true n.val = convert(T, val) @@ -241,7 +273,7 @@ end ::Type{N}, ::Type{T1}, ::Nothing, feature::Integer, ::Nothing, ::Nothing, allocator::F, ) where {N,T1,F} T = node_factory_type(N, T1, DEFAULT_NODE_TYPE) - n = allocator(N, T, D) + n = allocator(N, T) n.degree = 0 n.constant = false n.feature = feature @@ -249,15 +281,16 @@ end end """Create an operator node.""" @inline function node_factory( - ::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::NTuple{D2}, allocator::F, -) where {D,N<:AbstractExpressionNode{T where T,D},F,D2} + ::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::Tuple, allocator::F, +) where {N<:AbstractExpressionNode,F} T = promote_type(map(eltype, children)...) # Always prefer existing nodes, so we don't mess up references from conversion - NT = with_type_parameters(N, T, D) - n = allocator(N, T, D) + D2 = length(children) + @assert D2 <= max_degree(N) + NT = with_type_parameters(N, T) + n = allocator(N, T) n.degree = D2 n.op = op - n.children - # map(Ref, children) + n.children = ntuple(i -> i <= D2 ? Ref(convert(NT, children[i])) : Ref{NT}(), Val(max_degree(N))) return n end @@ -298,14 +331,14 @@ function (::Type{N})( return N(; feature=i) end -function Base.promote_rule(::Type{Node{T1}}, ::Type{Node{T2}}) where {T1,T2} - return Node{promote_type(T1, T2)} +function Base.promote_rule(::Type{Node{T1,D}}, ::Type{Node{T2,D}}) where {T1,T2,D} + return Node{promote_type(T1, T2),D} end -function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{Node{T2}}) where {T1,T2} - return GraphNode{promote_type(T1, T2)} +function Base.promote_rule(::Type{GraphNode{T1,D}}, ::Type{Node{T2,D}}) where {T1,T2,D} + return GraphNode{promote_type(T1, T2),D} end -function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where {T1,T2} - return GraphNode{promote_type(T1, T2)} +function Base.promote_rule(::Type{GraphNode{T1,D}}, ::Type{GraphNode{T2,D}}) where {T1,T2,D} + return GraphNode{promote_type(T1, T2),D} end # TODO: Verify using this helps with garbage collection diff --git a/src/NodeUtils.jl b/src/NodeUtils.jl index 37e60596..2b49441c 100644 --- a/src/NodeUtils.jl +++ b/src/NodeUtils.jl @@ -1,11 +1,9 @@ module NodeUtilsModule -using StaticArrays: MVector import Compat: Returns import ..NodeModule: AbstractNode, AbstractExpressionNode, - GeneralNode, Node, preserve_sharing, constructorof, @@ -145,17 +143,20 @@ end ## Assign index to nodes of a tree # This will mirror a Node struct, rather # than adding a new attribute to Node. -struct NodeIndex{T,D} <: AbstractNode{D,false} +struct NodeIndex{T,D} <: AbstractNode{D} degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. val::T # If is a constant, this stores the actual value # ------------------- (possibly undefined below) - children::MVector{D,NodeIndex{T,D}} - - NodeIndex(::Type{_T}, ::Type{_D}) where {_T,_D} = new{_T,_D}(0, zero(_T)) - NodeIndex(::Type{_T}, ::Type{_D}, val) where {_T,_D} = new{_T,_D}(0, convert(_T, val)) - function NodeIndex(::Type{_T}, ::Type{_D}, children::Vararg{Any,_D2}) where {_T,_D,_D2} - _children = MVector{_D,NodeIndex{_T,_D}}(undef) - _children[begin:_D2] = children + children::NTuple{D,Base.RefValue{NodeIndex{T,D}}} + + NodeIndex(::Type{_T}, ::Val{_D}) where {_T,_D} = new{_T,_D}(0, zero(_T)) + NodeIndex(::Type{_T}, ::Val{_D}, val) where {_T,_D} = new{_T,_D}(0, convert(_T, val)) + function NodeIndex( + ::Type{_T}, ::Val{_D}, children::Vararg{NodeIndex{_T,_D},_D2} + ) where {_T,_D,_D2} + _children = ntuple( + i -> i <= _D2 ? Ref(children[i]) : Ref{NodeIndex{_T,_D}}(), Val(_D) + ) return new{_T,_D}(1, zero(_T), _children) end end @@ -163,20 +164,22 @@ end # as we trace over the node we are indexing on. preserve_sharing(::Union{Type{<:NodeIndex},NodeIndex}) = false -function index_constant_nodes(tree::AbstractExpressionNode, ::Type{T}=UInt16) where {T} +function index_constant_nodes( + tree::AbstractExpressionNode{Ti,D} where {Ti}, ::Type{T}=UInt16 +) where {D,T} # Essentially we copy the tree, replacing the values # with indices constant_index = Ref(T(0)) return tree_mapreduce( t -> if t.constant - NodeIndex(T, (constant_index[] += T(1))) + NodeIndex(T, Val(D), (constant_index[] += T(1))) else - NodeIndex(T) + NodeIndex(T, Val(D)) end, t -> nothing, - (_, c...) -> NodeIndex(T, c...), + (_, c...) -> NodeIndex(T, Val(D), c...), tree, - NodeIndex{T}; + NodeIndex{T,D}; ) end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index b67f0490..4b6d4b50 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -1,7 +1,6 @@ module ParametricExpressionModule using DispatchDoctor: @stable, @unstable -using StaticArrays: MVector using ChainRulesCore: ChainRulesCore as CRC, NoTangent, @thunk using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum @@ -35,7 +34,7 @@ import ..ValueInterfaceModule: count_scalar_constants, pack_scalar_constants!, unpack_scalar_constants """A type of expression node that also stores a parameter index""" -mutable struct ParametricNode{T,D,shared} <: AbstractExpressionNode{T,D,shared} +mutable struct ParametricNode{T,D} <: AbstractExpressionNode{T,D} degree::UInt8 constant::Bool # if true => constant; if false, then check `is_parameter` val::T @@ -45,10 +44,10 @@ mutable struct ParametricNode{T,D,shared} <: AbstractExpressionNode{T,D,shared} parameter::UInt16 # Stores index of per-class parameter op::UInt8 - children::MVector{D,ParametricNode{T,D}} # Children nodes + children::NTuple{D,Base.RefValue{ParametricNode{T,D}}} # Children nodes - function ParametricNode{_T,_D,_shared}() where {_T,_D,_shared} - n = new{_T,_D,_shared}() + function ParametricNode{_T,_D}() where {_T,_D} + n = new{_T,_D}() n.is_parameter = false n.parameter = UInt16(0) return n From b5285f7c355129deed855d12632749d32f00915f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jul 2024 21:50:12 +0100 Subject: [PATCH 04/74] fix: segfault in NodeIndex See https://github.com/JuliaLang/julia/issues/55076 for details --- src/DynamicExpressions.jl | 2 +- src/Node.jl | 2 +- src/NodeUtils.jl | 13 +++++++++---- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 927adb78..c2d71489 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -117,5 +117,5 @@ end @ignore include("../test/runtests.jl") include("precompile.jl") -# do_precompilation(; mode=:precompile) +do_precompilation(; mode=:precompile) end diff --git a/src/Node.jl b/src/Node.jl index b3731f85..0d1ab048 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -3,7 +3,7 @@ module NodeModule using DispatchDoctor: @unstable import ..OperatorEnumModule: AbstractOperatorEnum -import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined +import ..UtilsModule: deprecate_varmap, Undefined const DEFAULT_NODE_TYPE = Float32 diff --git a/src/NodeUtils.jl b/src/NodeUtils.jl index 2b49441c..87c4ba60 100644 --- a/src/NodeUtils.jl +++ b/src/NodeUtils.jl @@ -143,23 +143,28 @@ end ## Assign index to nodes of a tree # This will mirror a Node struct, rather # than adding a new attribute to Node. -struct NodeIndex{T,D} <: AbstractNode{D} +mutable struct NodeIndex{T,D} <: AbstractNode{D} degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. val::T # If is a constant, this stores the actual value # ------------------- (possibly undefined below) children::NTuple{D,Base.RefValue{NodeIndex{T,D}}} - NodeIndex(::Type{_T}, ::Val{_D}) where {_T,_D} = new{_T,_D}(0, zero(_T)) - NodeIndex(::Type{_T}, ::Val{_D}, val) where {_T,_D} = new{_T,_D}(0, convert(_T, val)) + function NodeIndex(::Type{_T}, ::Val{_D}, val) where {_T,_D} + return new{_T,_D}( + 0, convert(_T, val), ntuple(_ -> Ref{NodeIndex{_T,_D}}(), Val(_D)) + ) + end function NodeIndex( ::Type{_T}, ::Val{_D}, children::Vararg{NodeIndex{_T,_D},_D2} ) where {_T,_D,_D2} _children = ntuple( i -> i <= _D2 ? Ref(children[i]) : Ref{NodeIndex{_T,_D}}(), Val(_D) ) - return new{_T,_D}(1, zero(_T), _children) + return new{_T,_D}(convert(UInt8, _D2), zero(_T), _children) end end +NodeIndex(::Type{T}, ::Val{D}) where {T,D} = NodeIndex(T, Val(D), zero(T)) + # Sharing is never needed for NodeIndex, # as we trace over the node we are indexing on. preserve_sharing(::Union{Type{<:NodeIndex},NodeIndex}) = false From 8707d24726e42d1f4d0d36339062d95070b4bb9d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jul 2024 21:50:45 +0100 Subject: [PATCH 05/74] refactor: no more need for `memoize_on` --- src/Utils.jl | 97 --------------------------------------------- src/base.jl | 83 +++++++++++++++++++++++++++----------- test/test_graphs.jl | 70 -------------------------------- 3 files changed, 60 insertions(+), 190 deletions(-) diff --git a/src/Utils.jl b/src/Utils.jl index bd3326e2..691de70a 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -12,103 +12,6 @@ macro return_on_false2(flag, retval, retval2) ) end -""" - @memoize_on tree [postprocess] function my_function_on_tree(tree::AbstractExpressionNode) - ... - end - -This macro takes a function definition and creates a second version of the -function with an additional `id_map` argument. When passed this argument (an -IdDict()), it will use use the `id_map` to avoid recomputing the same value -for the same node in a tree. Use this to automatically create functions that -work with trees that have shared child nodes. - -Can optionally take a `postprocess` function, which will be applied to the -result of the function before returning it, taking the result as the -first argument and a boolean for whether the result was memoized as the -second argument. This is useful for functions that need to count the number -of unique nodes in a tree, for example. -""" -macro memoize_on(tree, args...) - if length(args) ∉ (1, 2) - error("Expected 2 or 3 arguments to @memoize_on") - end - postprocess = length(args) == 1 ? :((r, _) -> r) : args[1] - def = length(args) == 1 ? args[1] : args[2] - idmap_def = _memoize_on(tree, postprocess, def) - - return quote - $(esc(def)) # The normal function - $(esc(idmap_def)) # The function with an id_map argument - end -end -function _memoize_on(tree::Symbol, postprocess, def) - sdef = splitdef(def) - - # Add an id_map argument - push!(sdef[:args], :(id_map::AbstractDict)) - - f_name = sdef[:name] - - # Forward id_map argument to all calls of the same function - # within the function body: - sdef[:body] = postwalk(sdef[:body]) do ex - if @capture(ex, f_(args__)) - if f == f_name - return Expr(:call, f, args..., :id_map) - end - end - return ex - end - - # Wrap the function body in a get!(id_map, tree) do ... end block: - @gensym key is_memoized result body - sdef[:body] = quote - $key = objectid($tree) - $is_memoized = haskey(id_map, $key) - function $body() - return $(sdef[:body]) - end - $result = if $is_memoized - @inbounds(id_map[$key]) - else - id_map[$key] = $body() - end - return $postprocess($result, $is_memoized) - end - - return combinedef(sdef) -end - -""" - @with_memoize(call, id_map) - -This simple macro simply puts the `id_map` -into the call, to be consistent with the `@memoize_on` macro. - -``` -@with_memoize(_copy_node(tree), IdDict{Any,Any}()) -```` - -is converted to - -``` -_copy_node(tree, IdDict{Any,Any}()) -``` - -""" -macro with_memoize(def, id_map) - idmap_def = _add_idmap_to_call(def, id_map) - return quote - $(esc(idmap_def)) - end -end - -function _add_idmap_to_call(def::Expr, id_map::Union{Symbol,Expr}) - @assert def.head == :call - return Expr(:call, def.args[1], def.args[2:end]..., id_map) -end - @inline function fill_similar(value::T, array, args...) where {T} out_array = similar(array, args...) fill!(out_array, value) diff --git a/src/base.jl b/src/base.jl index 7d0c0041..32d67fd6 100644 --- a/src/base.jl +++ b/src/base.jl @@ -25,7 +25,7 @@ import Base: using DispatchDoctor: @unstable using Compat: @inline, Returns -using ..UtilsModule: @memoize_on, @with_memoize, Undefined +using ..UtilsModule: Undefined """ tree_mapreduce( @@ -89,38 +89,76 @@ function tree_mapreduce( f_leaf::F1, f_branch::F2, op::G, - tree::AbstractNode, + tree::AbstractNode{D}, result_type::Type{RT}=Undefined; f_on_shared::H=(result, is_shared) -> result, - break_sharing::Val=Val(false), -) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT} - - # Trick taken from here: - # https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5 - # to speed up recursive closure - @memoize_on t f_on_shared function inner(inner, t) - if t.degree == 0 - return @inline(f_leaf(t)) - elseif t.degree == 1 - return @inline(op(@inline(f_branch(t)), inner(inner, t.l))) - else - return @inline(op(@inline(f_branch(t)), inner(inner, t.l), inner(inner, t.r))) - end - end - - sharing = preserve_sharing(typeof(tree)) && break_sharing === Val(false) + break_sharing::Val{BS}=Val(false), +) where {F1<:Function,F2<:Function,G<:Function,D,H<:Function,RT,BS} + sharing = preserve_sharing(typeof(tree)) && !break_sharing RT == Undefined && sharing && throw(ArgumentError("Need to specify `result_type` if nodes are shared..")) if sharing && RT != Undefined - d = allocate_id_map(tree, RT) - return @with_memoize inner(inner, tree) d + id_map = allocate_id_map(tree, RT) + reducer = TreeMapreducer(Val(D), id_map, f_leaf, f_branch, op, f_on_shared) + return reducer(tree) + else + reducer = TreeMapreducer(Val(D), nothing, f_leaf, f_branch, op, f_on_shared) + return reducer(tree) + end +end + +struct TreeMapreducer{D,ID,F1<:Function,F2<:Function,G<:Function,H<:Function} + max_degree::Val{D} + id_map::ID + f_leaf::F1 + f_branch::F2 + op::G + f_on_shared::H +end + +@generated function (mapreducer::TreeMapreducer{MAX_DEGREE,ID})( + tree::AbstractNode +) where {MAX_DEGREE,ID} + base_expr = quote + d = tree.degree + Base.Cartesian.@nif( + $(MAX_DEGREE + 1), + d_p_one -> (d_p_one - 1) == d, + d_p_one -> if d_p_one == 1 + mapreducer.f_leaf(tree) + else + mapreducer.op( + mapreducer.f_branch(tree), + Base.Cartesian.@ntuple( + d_p_one - 1, i -> mapreducer(tree.children[i][]) + )..., + ) + end + ) + end + if ID <: Nothing + # No sharing of nodes (is a tree, not a graph) + return base_expr else - return inner(inner, tree) + # Otherwise, we need to cache results in `id_map` + # according to `objectid` of the node + return quote + key = objectid(tree) + is_cached = haskey(mapreducer.id_map, key) + if is_cached + return mapreducer.f_on_shared(@inbounds(mapreducer.id_map[key]), true) + else + res = $base_expr + mapreducer.id_map[key] = res + return mapreducer.f_on_shared(res, false) + end + end end end + function allocate_id_map(tree::AbstractNode, ::Type{RT}) where {RT} d = Dict{UInt,RT}() # Preallocate maximum storage (counting with duplicates is fast) @@ -128,7 +166,6 @@ function allocate_id_map(tree::AbstractNode, ::Type{RT}) where {RT} sizehint!(d, N) return d end - # TODO: Raise Julia issue for this. # Surprisingly Dict{UInt,RT} is faster than IdDict{Node{T},RT} here! # I think it's because `setindex!` is declared with `@nospecialize` in IdDict. diff --git a/test/test_graphs.jl b/test/test_graphs.jl index 2f31c4ed..c25a3ab6 100644 --- a/test/test_graphs.jl +++ b/test/test_graphs.jl @@ -120,76 +120,6 @@ end @test expr_eql(ex, true_ex) end - - @testset "@memoize_on" begin - ex = @macroexpand DynamicExpressions.UtilsModule.@memoize_on tree ((x, _) -> x) function _copy_node( - tree::Node{T} - )::Node{T} where {T} - if tree.degree == 0 - if tree.constant - Node(; val=copy(tree.val)) - else - Node(T; feature=copy(tree.feature)) - end - elseif tree.degree == 1 - Node(copy(tree.op), _copy_node(tree.l)) - else - Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r)) - end - end - true_ex = quote - function _copy_node(tree::Node{T})::Node{T} where {T} - if tree.degree == 0 - if tree.constant - Node(; val=copy(tree.val)) - else - Node(T; feature=copy(tree.feature)) - end - elseif tree.degree == 1 - Node(copy(tree.op), _copy_node(tree.l)) - else - Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r)) - end - end - function _copy_node(tree::Node{T}, id_map::AbstractDict;)::Node{T} where {T} - key = objectid(tree) - is_memoized = haskey(id_map, key) - function body() - return begin - if tree.degree == 0 - if tree.constant - Node(; val=copy(tree.val)) - else - Node(T; feature=copy(tree.feature)) - end - elseif tree.degree == 1 - Node(copy(tree.op), _copy_node(tree.l, id_map)) - else - Node( - copy(tree.op), - _copy_node(tree.l, id_map), - _copy_node(tree.r, id_map), - ) - end - end - end - result = if is_memoized - begin - $(Expr(:inbounds, true)) - local val = id_map[key] - $(Expr(:inbounds, :pop)) - val - end - else - id_map[key] = body() - end - return (((x, _) -> begin - x - end)(result, is_memoized)) - end - end - @test expr_eql(ex, true_ex) - end end @testset "Operations on graphs" begin From 2a0bd054578c88f47b70b61ad0141e40c8e6ce47 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 8 Jul 2024 22:11:29 +0100 Subject: [PATCH 06/74] fix: various aspects of degree interface --- src/DynamicExpressions.jl | 1 + src/Node.jl | 2 +- src/NodeUtils.jl | 11 +++++++++++ src/base.jl | 2 +- test/test_base.jl | 4 ++-- test/test_custom_node_type.jl | 22 +++++++++++++--------- test/test_equality.jl | 4 ++-- test/test_extra_node_fields.jl | 25 ++++++++++++++++--------- test/test_graphs.jl | 13 +------------ test/test_parse.jl | 4 ++-- 10 files changed, 50 insertions(+), 38 deletions(-) diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index c2d71489..b9aa4b4f 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -47,6 +47,7 @@ import .NodeModule: constructorof, with_type_parameters, preserve_sharing, + max_degree, leaf_copy, branch_copy, leaf_hash, diff --git a/src/Node.jl b/src/Node.jl index 0d1ab048..275fb496 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -219,7 +219,7 @@ end # with_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D} # with_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D} -function default_allocator(::Type{N}, ::Type{T}) where {N<:Union{Node,GraphNode},T} +function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T} return with_type_parameters(N, T)() end diff --git a/src/NodeUtils.jl b/src/NodeUtils.jl index 87c4ba60..392df3b7 100644 --- a/src/NodeUtils.jl +++ b/src/NodeUtils.jl @@ -165,6 +165,17 @@ mutable struct NodeIndex{T,D} <: AbstractNode{D} end NodeIndex(::Type{T}, ::Val{D}) where {T,D} = NodeIndex(T, Val(D), zero(T)) +@inline function Base.getproperty(n::NodeIndex, k::Symbol) + if k == :l + # TODO: Should a depwarn be raised here? Or too slow? + return getfield(n, :children)[1][] + elseif k == :r + return getfield(n, :children)[2][] + else + return getfield(n, k) + end +end + # Sharing is never needed for NodeIndex, # as we trace over the node we are indexing on. preserve_sharing(::Union{Type{<:NodeIndex},NodeIndex}) = false diff --git a/src/base.jl b/src/base.jl index 32d67fd6..6f2a8fbd 100644 --- a/src/base.jl +++ b/src/base.jl @@ -94,7 +94,7 @@ function tree_mapreduce( f_on_shared::H=(result, is_shared) -> result, break_sharing::Val{BS}=Val(false), ) where {F1<:Function,F2<:Function,G<:Function,D,H<:Function,RT,BS} - sharing = preserve_sharing(typeof(tree)) && !break_sharing + sharing = preserve_sharing(typeof(tree)) && !BS RT == Undefined && sharing && diff --git a/test/test_base.jl b/test/test_base.jl index b14894b1..f7e7a483 100644 --- a/test/test_base.jl +++ b/test/test_base.jl @@ -32,11 +32,11 @@ end @testset "collect" begin ctree = copy(tree) - @test typeof(first(collect(ctree))) == Node{Float64} + @test typeof(first(collect(ctree))) <: Node{Float64} @test objectid(first(collect(ctree))) == objectid(ctree) @test objectid(first(collect(ctree))) == objectid(ctree) @test objectid(first(collect(ctree))) == objectid(ctree) - @test typeof(collect(ctree)) == Vector{Node{Float64}} + @test typeof(collect(ctree)) <: Vector{<:Node{Float64}} @test length(collect(ctree)) == 24 @test sum((t -> (t.degree == 0 && t.constant) ? t.val : 0.0).(collect(ctree))) ≈ 11.6 end diff --git a/test/test_custom_node_type.jl b/test/test_custom_node_type.jl index 3fc333bc..57a3706c 100644 --- a/test/test_custom_node_type.jl +++ b/test/test_custom_node_type.jl @@ -1,16 +1,21 @@ using DynamicExpressions using Test -mutable struct MyCustomNode{A,B} <: AbstractNode +mutable struct MyCustomNode{A,B} <: AbstractNode{2} degree::Int val1::A val2::B - l::MyCustomNode{A,B} - r::MyCustomNode{A,B} + children::NTuple{2,Base.RefValue{MyCustomNode{A,B}}} MyCustomNode(val1, val2) = new{typeof(val1),typeof(val2)}(0, val1, val2) - MyCustomNode(val1, val2, l) = new{typeof(val1),typeof(val2)}(1, val1, val2, l) - MyCustomNode(val1, val2, l, r) = new{typeof(val1),typeof(val2)}(2, val1, val2, l, r) + function MyCustomNode(val1, val2, l) + return new{typeof(val1),typeof(val2)}( + 1, val1, val2, (Ref(l), Ref{MyCustomNode{typeof(val1),typeof(val2)}}()) + ) + end + function MyCustomNode(val1, val2, l, r) + return new{typeof(val1),typeof(val2)}(2, val1, val2, (Ref(l), Ref(r))) + end end node1 = MyCustomNode(1.0, 2) @@ -24,7 +29,7 @@ node2 = MyCustomNode(1.5, 3, node1) @test typeof(node2) == MyCustomNode{Float64,Int} @test node2.degree == 1 -@test node2.l.degree == 0 +@test node2.children[1][].degree == 0 @test count_depth(node2) == 2 @test count_nodes(node2) == 2 @@ -37,14 +42,13 @@ node2 = MyCustomNode(1.5, 3, node1, node1) @test count(t -> t.degree == 0, node2) == 2 # If we have a bad definition, it should get caught with a helpful message -mutable struct MyCustomNode2{T} <: AbstractExpressionNode{T} +mutable struct MyCustomNode2{T} <: AbstractExpressionNode{T,2} degree::UInt8 constant::Bool val::T feature::UInt16 op::UInt8 - l::MyCustomNode2{T} - r::MyCustomNode2{T} + children::NTuple{2,Base.RefValue{MyCustomNode2{T}}} end @test_throws ErrorException MyCustomNode2() diff --git a/test/test_equality.jl b/test/test_equality.jl index 220e63c3..7e9b845b 100644 --- a/test/test_equality.jl +++ b/test/test_equality.jl @@ -45,8 +45,8 @@ modified_tree5 = 1.5 * cos(x2 * x1) + x1 + x2 * x3 - log(x2 * 3.2) f64_tree = GraphNode{Float64}(x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1)) f32_tree = GraphNode{Float32}(x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1)) -@test typeof(f64_tree) == GraphNode{Float64} -@test typeof(f32_tree) == GraphNode{Float32} +@test typeof(f64_tree) <: GraphNode{Float64} +@test typeof(f32_tree) <: GraphNode{Float32} @test convert(GraphNode{Float64}, f32_tree) == f64_tree diff --git a/test/test_extra_node_fields.jl b/test/test_extra_node_fields.jl index 467c6226..60b35595 100644 --- a/test/test_extra_node_fields.jl +++ b/test/test_extra_node_fields.jl @@ -2,24 +2,31 @@ using Test using DynamicExpressions -using DynamicExpressions: constructorof +using DynamicExpressions: constructorof, max_degree -mutable struct FrozenNode{T} <: AbstractExpressionNode{T} +mutable struct FrozenNode{T,D} <: AbstractExpressionNode{T,D} degree::UInt8 constant::Bool val::T frozen::Bool # Extra field! feature::UInt16 op::UInt8 - l::FrozenNode{T} - r::FrozenNode{T} + children::NTuple{D,Base.RefValue{FrozenNode{T,D}}} - function FrozenNode{_T}() where {_T} - n = new{_T}() + function FrozenNode{_T,_D}() where {_T,_D} + n = new{_T,_D}() n.frozen = false return n end end +function DynamicExpressions.constructorof(::Type{N}) where {N<:FrozenNode} + return FrozenNode{T,max_degree(N)} where {T} +end +function DynamicExpressions.with_type_parameters( + ::Type{N}, ::Type{T} +) where {T,N<:FrozenNode} + return FrozenNode{T,max_degree(N)} +end function DynamicExpressions.leaf_copy(t::FrozenNode{T}) where {T} out = if t.constant constructorof(typeof(t))(; val=t.val) @@ -56,7 +63,7 @@ function DynamicExpressions.leaf_equal(a::FrozenNode, b::FrozenNode) end end -n = let n = FrozenNode{Float64}() +n = let n = FrozenNode{Float64,2}() n.degree = 0 n.constant = true n.val = 0.0 @@ -92,5 +99,5 @@ ex = parse_expression( @test string_tree(ex) == "x + sin(y + 2.1)" @test ex.tree.frozen == false -@test ex.tree.r.frozen == true -@test ex.tree.r.l.frozen == false +@test ex.tree.children[2][].frozen == true +@test ex.tree.children[2][].children[1][].frozen == false diff --git a/test/test_graphs.jl b/test/test_graphs.jl index c25a3ab6..55ab4d79 100644 --- a/test/test_graphs.jl +++ b/test/test_graphs.jl @@ -109,17 +109,6 @@ end :(_convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}())), ) end - - @testset "@with_memoize" begin - ex = @macroexpand DynamicExpressions.UtilsModule.@with_memoize( - _convert(Node{T1}, tree), IdDict{Node{T2},Node{T1}}() - ) - true_ex = quote - _convert(Node{T1}, tree, IdDict{Node{T2},Node{T1}}()) - end - - @test expr_eql(ex, true_ex) - end end @testset "Operations on graphs" begin @@ -283,7 +272,7 @@ end x = GraphNode(Float32; feature=1) tree = x + 1.0 @test tree.l === x - @test typeof(tree) === GraphNode{Float32} + @test typeof(tree) <: GraphNode{Float32} # Detect error from Float32(1im) @test_throws InexactError x + 1im diff --git a/test/test_parse.jl b/test/test_parse.jl index c9b40d0c..8d9c351d 100644 --- a/test/test_parse.jl +++ b/test/test_parse.jl @@ -108,7 +108,7 @@ end variable_names = ["x"], ) - @test typeof(ex.tree) === Node{Any} + @test typeof(ex.tree) <: Node{Any} @test typeof(ex.metadata.operators) <: GenericOperatorEnum s = sprint((io, e) -> show(io, MIME("text/plain"), e), ex) @test s == "[1, 2, 3] * tan(cos(5.0 + x))" @@ -184,7 +184,7 @@ end s = sprint((io, e) -> show(io, MIME("text/plain"), e), ex) @test s == "(x * 2.5) - cos(y)" end - @test contains(logged_out, "Node{Float32}") + @test contains(logged_out, "Node{Float32") end @testitem "Helpful errors for missing operator" begin From 1e672bc3a930c6c79726b9f0852e0044934e9f95 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 9 Jul 2024 10:09:32 +0100 Subject: [PATCH 07/74] fix: some issues with D-degree ParametricNode --- src/ParametricExpression.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 4b6d4b50..47e9bbeb 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -8,7 +8,7 @@ using ..NodeModule: AbstractExpressionNode, Node, tree_mapreduce using ..ExpressionModule: AbstractExpression, Metadata using ..ChainRulesModule: NodeTangent -import ..NodeModule: constructorof, preserve_sharing, leaf_copy, leaf_hash, leaf_equal +import ..NodeModule: constructorof, max_degree, preserve_sharing, leaf_copy, leaf_hash, leaf_equal import ..NodeUtilsModule: count_constant_nodes, index_constant_nodes, @@ -96,10 +96,10 @@ end ############################################################################### # Abstract expression node interface ########################################## ############################################################################### -@unstable constructorof(::Type{<:ParametricNode}) = ParametricNode +@unstable constructorof(::Type{N}) where {N<:ParametricNode} = ParametricNode{T,max_degree(N)} where {T} @unstable constructorof(::Type{<:ParametricExpression}) = ParametricExpression -@unstable default_node_type(::Type{<:ParametricExpression}) = ParametricNode -default_node_type(::Type{<:ParametricExpression{T}}) where {T} = ParametricNode{T} +@unstable default_node_type(::Type{<:ParametricExpression}) = ParametricNode{T,2} where {T} +default_node_type(::Type{<:ParametricExpression{T}}) where {T} = ParametricNode{T,2} preserve_sharing(::Union{Type{<:ParametricNode},ParametricNode}) = false # TODO: Change this? function leaf_copy(t::ParametricNode{T}) where {T} out = if t.constant From 5505506215896dc0479c73003f9cffcdf744a377 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 3 May 2025 16:55:17 -0400 Subject: [PATCH 08/74] fix: a few merge issues --- src/Node.jl | 9 +++++---- src/ParametricExpression.jl | 7 +++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index 4fc40a8d..1d5a441f 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -82,6 +82,8 @@ for N in (:Node, :GraphNode) ## Constructors: ################# $N{_T,_D}() where {_T,_D} = new{_T,_D::Int}() + $N{_T}() where {_T} = $N{_T,2}() + # TODO: Test with this disabled to spot any unintended uses end end @@ -207,14 +209,13 @@ max_degree(::Type{<:AbstractNode{D}}) where {D} = D @unstable constructorof(::Type{N}) where {N<:GraphNode} = GraphNode{T,max_degree(N)} where {T} -with_type_parameters(::Type{N}, ::Type{T}) where {N<:Node,T} = Node{T,max_degree(N)} +function with_type_parameters(::Type{N}, ::Type{T}) where {N<:Node,T} + return Node{T,max_degree(N)} +end function with_type_parameters(::Type{N}, ::Type{T}) where {N<:GraphNode,T} return GraphNode{T,max_degree(N)} end -# with_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D} -# with_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D} - function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T} return with_type_parameters(N, T)() end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 7f99d8e4..50929634 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -63,6 +63,10 @@ mutable struct ParametricNode{T,D} <: AbstractExpressionNode{T,D} n.parameter = UInt16(0) return n end + # TODO: Test with this disabled to spot any unintended uses + function ParametricNode{_T}() where {_T} + return ParametricNode{_T,2}() + end end """ @@ -111,6 +115,9 @@ end @unstable constructorof(::Type{N}) where {N<:ParametricNode} = ParametricNode{T,max_degree(N)} where {T} @unstable constructorof(::Type{<:ParametricExpression}) = ParametricExpression +function with_type_parameters(::Type{N}, ::Type{T}) where {N<:ParametricNode,T} + return ParametricNode{T,max_degree(N)} +end @unstable default_node_type(::Type{<:ParametricExpression}) = ParametricNode{T,2} where {T} default_node_type(::Type{<:ParametricExpression{T}}) where {T} = ParametricNode{T,2} preserve_sharing(::Union{Type{<:ParametricNode},ParametricNode}) = false # TODO: Change this? From 051459859a84ca9706aaafb7df3b14487bc2b1d3 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 3 May 2025 17:32:24 -0400 Subject: [PATCH 09/74] fix: `setproperty!` for tuple children --- src/Node.jl | 70 +++++++++++++++++++++---------------- src/ParametricExpression.jl | 5 ++- 2 files changed, 44 insertions(+), 31 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index 1d5a441f..93a29ce3 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -86,6 +86,8 @@ for N in (:Node, :GraphNode) # TODO: Test with this disabled to spot any unintended uses end end +# TODO: If we can't reach the same speed, we should make a Node2 type +# that is specialized for 2-arity nodes. #! format: off """ @@ -164,38 +166,46 @@ when constructing or setting properties. """ GraphNode -@inline function Base.getproperty(n::Union{Node,GraphNode}, k::Symbol) - if k == :l - # TODO: Should a depwarn be raised here? Or too slow? - return getfield(n, :children)[1][] - elseif k == :r - return getfield(n, :children)[2][] - else - return getfield(n, k) - end -end -@inline function Base.setproperty!(n::Union{Node,GraphNode}, k::Symbol, v) - if k == :l - getfield(n, :children)[1][] = v - elseif k == :r - getfield(n, :children)[2][] = v - elseif k == :degree - setfield!(n, :degree, convert(UInt8, v)) - elseif k == :constant - setfield!(n, :constant, convert(Bool, v)) - elseif k == :feature - setfield!(n, :feature, convert(UInt16, v)) - elseif k == :op - setfield!(n, :op, convert(UInt8, v)) - elseif k == :val - setfield!(n, :val, convert(eltype(n), v)) - elseif k == :children - setfield!(n, :children, v) - else - error("Invalid property: $k") - end +macro make_accessors(node_type) + esc(quote + @inline function Base.getproperty(n::$node_type, k::Symbol) + if k == :l + # TODO: Should a depwarn be raised here? Or too slow? + return getfield(n, :children)[1][] + elseif k == :r + return getfield(n, :children)[2][] + else + return getfield(n, k) + end + end + @inline function Base.setproperty!(n::$node_type, k::Symbol, v) + if k == :l + if isdefined(n, :children) + getfield(n, :children)[1][] = v + else + r = Ref(v) + setfield!(n, :children, (r, Ref{typeof(n)}())) + r + end + elseif k == :r + # TODO: Remove this assert once we know that this is safe + @assert isdefined(n, :children) + getfield(n, :children)[2][] = v + else + T = fieldtype(typeof(n), k) + if v isa T + setfield!(n, k, v) + else + setfield!(n, k, convert(T, v)) + end + end + end + end) end +@make_accessors Node +@make_accessors GraphNode + ################################################################################ #! format: on diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 50929634..3db7666c 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -18,7 +18,8 @@ import ..NodeModule: leaf_convert, leaf_hash, leaf_equal, - set_node! + set_node!, + @make_accessors import ..NodePreallocationModule: copy_into!, allocate_container import ..NodeUtilsModule: count_constant_nodes, @@ -69,6 +70,8 @@ mutable struct ParametricNode{T,D} <: AbstractExpressionNode{T,D} end end +@make_accessors ParametricNode + """ ParametricExpression{T,N<:ParametricNode{T},D<:NamedTuple} <: AbstractExpression{T,N} From 0079acfcbc36cdee5eecc1d02313e9d0185b3272 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 3 May 2025 18:00:42 -0400 Subject: [PATCH 10/74] fix: constructor from explicit eltype --- src/Node.jl | 6 ++++-- test/test_base_2.jl | 1 + test/test_parametric_expression.jl | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index 93a29ce3..2783c42d 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -299,16 +299,18 @@ end end @inline function node_factory_type(::Type{N}, ::Type{T1}, ::Type{T2}) where {N,T1,T2} - if T1 === Undefined && N isa UnionAll + if T1 === Undefined && !defines_eltype(N) T2 elseif T1 === Undefined eltype(N) - elseif N isa UnionAll + elseif !defines_eltype(N) T1 else eltype(N) end end +defines_eltype(::Type{<:AbstractExpressionNode}) = false +defines_eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = true #! format: on function (::Type{N})( diff --git a/test/test_base_2.jl b/test/test_base_2.jl index 2c8dc6d2..e0ab645b 100644 --- a/test/test_base_2.jl +++ b/test/test_base_2.jl @@ -12,6 +12,7 @@ using DynamicExpressions, Random x = Node{Float64}(; feature=1) + @test x isa Node{Float64} # We can also create values, using `val`: const_1 = Node{Float64}(; val=1.0) diff --git a/test/test_parametric_expression.jl b/test/test_parametric_expression.jl index e222765f..b8acf2c0 100644 --- a/test/test_parametric_expression.jl +++ b/test/test_parametric_expression.jl @@ -349,7 +349,7 @@ end @test val isa Float64 @test grad isa NamedTuple @test grad.tree isa DynamicExpressions.ChainRulesModule.NodeTangent{ - Float64,ParametricNode{Float64},Vector{Float64} + Float64,<:ParametricNode{Float64},Vector{Float64} } @test grad.metadata._data.parameters isa Matrix{Float64} From 97c88687205c3abf02ede86cd2c614856d0e5596 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 3 May 2025 18:52:22 -0400 Subject: [PATCH 11/74] fix: permit vector children to constructor --- src/Node.jl | 2 +- test/test_parse.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index 2783c42d..1ab61528 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -285,7 +285,7 @@ end end """Create an operator node.""" @inline function node_factory( - ::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::Tuple, allocator::F, + ::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::Union{Tuple,AbstractVector}, allocator::F, ) where {N<:AbstractExpressionNode,F} T = promote_type(map(eltype, children)...) # Always prefer existing nodes, so we don't mess up references from conversion D2 = length(children) diff --git a/test/test_parse.jl b/test/test_parse.jl index fb30469f..2fce5289 100644 --- a/test/test_parse.jl +++ b/test/test_parse.jl @@ -127,7 +127,7 @@ end variable_names = ["x"], node_type = Node{Union{Int,Vector{Int}}} ) - @test typeof(ex.tree) === Node{Union{Int,Vector{Int}}} + @test typeof(ex.tree) <: Node{Union{Int,Vector{Int}}} @test typeof(ex.metadata.operators) <: GenericOperatorEnum s = sprint((io, e) -> show(io, MIME("text/plain"), e), ex) @test s == "[1, 2, 3] * tan(cos(5 + x))" From 33e59660fa6f33b534f35541d583afa428870c58 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 4 May 2025 15:32:31 -0400 Subject: [PATCH 12/74] fix: other fixes for d-degree nodes --- src/Node.jl | 5 +++ src/base.jl | 52 +++++++++++++++---------- test/test_non_number_eval_tree_array.jl | 10 ++--- test/test_tree_construction.jl | 4 +- 4 files changed, 43 insertions(+), 28 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index 1ab61528..512c8a5f 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -206,6 +206,11 @@ end @make_accessors Node @make_accessors GraphNode +@inline function children(node::AbstractNode, ::Val{n}) where {n} + cs = node.children + return ntuple(i -> cs[i][], Val(n)) +end + ################################################################################ #! format: on diff --git a/src/base.jl b/src/base.jl index 416ce93b..ec209214 100644 --- a/src/base.jl +++ b/src/base.jl @@ -120,28 +120,38 @@ struct TreeMapreducer{ f_on_shared::H end -function call_mapreducer(mapreducer::TreeMapreducer{2,ID}, tree::AbstractNode) where {ID} - key = ID <: Dict ? objectid(tree) : nothing - if ID <: Dict && haskey(mapreducer.id_map, key) - result = @inbounds(mapreducer.id_map[key]) - return mapreducer.f_on_shared(result, true) - else - result = if tree.degree == 0 - mapreducer.f_leaf(tree) - elseif tree.degree == 1 - mapreducer.op(mapreducer.f_branch(tree), call_mapreducer(mapreducer, tree.l)) - else - mapreducer.op( - mapreducer.f_branch(tree), - call_mapreducer(mapreducer, tree.l), - call_mapreducer(mapreducer, tree.r), - ) - end - if ID <: Dict - mapreducer.id_map[key] = result - return mapreducer.f_on_shared(result, false) +@generated function call_mapreducer( + mapreducer::TreeMapreducer{D,ID}, tree::AbstractNode +) where {D,ID} + quote + key = ID <: Dict ? objectid(tree) : nothing + if ID <: Dict && haskey(mapreducer.id_map, key) + result = @inbounds(mapreducer.id_map[key]) + return mapreducer.f_on_shared(result, true) else - return result + d = tree.degree + result = if d == 0 + mapreducer.f_leaf(tree) + else + Base.Cartesian.@nif( + $D, + i -> i == d, + i -> let cs = children(tree, Val(i)) + mapreducer.op( + mapreducer.f_branch(tree), + Base.Cartesian.@ntuple( + i, j -> call_mapreducer(mapreducer, cs[j]) + )..., + ) + end + ) + end + if ID <: Dict + mapreducer.id_map[key] = result + return mapreducer.f_on_shared(result, false) + else + return result + end end end end diff --git a/test/test_non_number_eval_tree_array.jl b/test/test_non_number_eval_tree_array.jl index a77ed55f..ca5db078 100644 --- a/test/test_non_number_eval_tree_array.jl +++ b/test/test_non_number_eval_tree_array.jl @@ -165,11 +165,11 @@ Base.invokelatest( () -> begin # test operator extended operators - @test hasmethod(q, Tuple{Node{Max2Tensor{Float64}}}) - @test hasmethod(a, Tuple{Max2Tensor{Float64},Node{Max2Tensor{Float64}}}) - @test hasmethod(a, Tuple{Node{Max2Tensor{Float64}},Node{Max2Tensor{Float64}}}) - @test !hasmethod(a, Tuple{Float64,Node{Float64}}) - @test !hasmethod(a, Tuple{Node{Max2Tensor{Float32}},Node{Max2Tensor{Float32}}}) + @test hasmethod(q, Tuple{Node{Max2Tensor{Float64},2}}) + @test hasmethod(a, Tuple{Max2Tensor{Float64},Node{Max2Tensor{Float64},2}}) + @test hasmethod(a, Tuple{Node{Max2Tensor{Float64},2},Node{Max2Tensor{Float64},2}}) + @test !hasmethod(a, Tuple{Float64,Node{Float64,2}}) + @test !hasmethod(a, Tuple{Node{Max2Tensor{Float32},2},Node{Max2Tensor{Float32},2}}) tree = a(Node{Max2Tensor{Float64}}(; feature=1), Max2Tensor{Float64}(3.0)) results = tree( diff --git a/test/test_tree_construction.jl b/test/test_tree_construction.jl index 60899008..40dd67c0 100644 --- a/test/test_tree_construction.jl +++ b/test/test_tree_construction.jl @@ -126,7 +126,7 @@ end x = N{BigFloat}(; feature=1) @test_throws AssertionError N{Float32}(1, x) @test N{BigFloat}(1, x) == N(1, x) - @test typeof(N(1, x, N{Float32}(; val=1))) === N{BigFloat} - @test typeof(N(1, N{Float32}(; val=1), x)) === N{BigFloat} + @test typeof(N(1, x, N{Float32}(; val=1))) <: N{BigFloat} + @test typeof(N(1, N{Float32}(; val=1), x)) <: N{BigFloat} end end From 24aec43d67dd094876551cbfdde5c4b11c87cad9 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 4 May 2025 16:15:17 -0400 Subject: [PATCH 13/74] fix: type assertion issue in constructor --- src/Node.jl | 3 ++- test/test_tree_construction.jl | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Node.jl b/src/Node.jl index 512c8a5f..0d954365 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -292,7 +292,8 @@ end @inline function node_factory( ::Type{N}, ::Type, ::Nothing, ::Nothing, op::Integer, children::Union{Tuple,AbstractVector}, allocator::F, ) where {N<:AbstractExpressionNode,F} - T = promote_type(map(eltype, children)...) # Always prefer existing nodes, so we don't mess up references from conversion + T = defines_eltype(N) ? eltype(N) : promote_type(map(eltype, children)...) + defines_eltype(N) && @assert T === promote_type(T, map(eltype, children)...) D2 = length(children) @assert D2 <= max_degree(N) NT = with_type_parameters(N, T) diff --git a/test/test_tree_construction.jl b/test/test_tree_construction.jl index 40dd67c0..fc5e0cac 100644 --- a/test/test_tree_construction.jl +++ b/test/test_tree_construction.jl @@ -126,6 +126,7 @@ end x = N{BigFloat}(; feature=1) @test_throws AssertionError N{Float32}(1, x) @test N{BigFloat}(1, x) == N(1, x) + @test N{BigFloat}(1, x) isa N{BigFloat} @test typeof(N(1, x, N{Float32}(; val=1))) <: N{BigFloat} @test typeof(N(1, N{Float32}(; val=1), x)) <: N{BigFloat} end From 5f647ff42c9693ab9363f0d79dccf68158924f06 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 6 May 2025 20:36:54 +0100 Subject: [PATCH 14/74] test: fix interfaces --- src/Interfaces.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Interfaces.jl b/src/Interfaces.jl index c2d44b59..65acf0d5 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -70,8 +70,8 @@ function _check_get_metadata(ex::AbstractExpression) new_ex = with_metadata(ex, get_metadata(ex)) return new_ex == ex && new_ex isa typeof(ex) end -function _check_get_tree(ex::AbstractExpression{T,N}) where {T,N} - return get_tree(ex) isa N || get_tree(ex) isa AbstractReadOnlyNode{T,N} +function _check_get_tree(ex::AbstractExpression{T,N}) where {T,D,N<:AbstractExpressionNode{T,D}} + return get_tree(ex) isa N || get_tree(ex) isa AbstractReadOnlyNode{T,D,N} end function _check_get_operators(ex::AbstractExpression) return get_operators(ex) isa AbstractOperatorEnum From 867dab6d8d4ebce048242ca6e9494200c174370c Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 6 May 2025 20:45:03 +0100 Subject: [PATCH 15/74] test: fix `children` call in ReadOnlyNode --- src/Interfaces.jl | 4 +++- src/Node.jl | 3 ++- src/ReadOnlyNode.jl | 7 +++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/Interfaces.jl b/src/Interfaces.jl index 65acf0d5..df2784d5 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -70,7 +70,9 @@ function _check_get_metadata(ex::AbstractExpression) new_ex = with_metadata(ex, get_metadata(ex)) return new_ex == ex && new_ex isa typeof(ex) end -function _check_get_tree(ex::AbstractExpression{T,N}) where {T,D,N<:AbstractExpressionNode{T,D}} +function _check_get_tree( + ex::AbstractExpression{T,N} +) where {T,D,N<:AbstractExpressionNode{T,D}} return get_tree(ex) isa N || get_tree(ex) isa AbstractReadOnlyNode{T,D,N} end function _check_get_operators(ex::AbstractExpression) diff --git a/src/Node.jl b/src/Node.jl index 0d954365..099d2206 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -206,8 +206,9 @@ end @make_accessors Node @make_accessors GraphNode +@inline children(node::AbstractNode) = node.children @inline function children(node::AbstractNode, ::Val{n}) where {n} - cs = node.children + cs = children(node) return ntuple(i -> cs[i][], Val(n)) end diff --git a/src/ReadOnlyNode.jl b/src/ReadOnlyNode.jl index c96f2fd4..c10bfdac 100644 --- a/src/ReadOnlyNode.jl +++ b/src/ReadOnlyNode.jl @@ -1,7 +1,7 @@ module ReadOnlyNodeModule using ..NodeModule: AbstractExpressionNode, Node -import ..NodeModule: default_allocator, with_type_parameters, constructorof +import ..NodeModule: default_allocator, with_type_parameters, constructorof, children abstract type AbstractReadOnlyNode{T,D,N<:AbstractExpressionNode{T,D}} <: AbstractExpressionNode{T,D} end @@ -16,11 +16,14 @@ constructorof(::Type{<:ReadOnlyNode}) = ReadOnlyNode @inline function Base.getproperty(n::AbstractReadOnlyNode, s::Symbol) out = getproperty(getfield(n, :_inner), s) if out isa AbstractExpressionNode - return constructorof(typeof(n))(out) + return ReadOnlyNode(out) else return out end end +@inline function children(node::AbstractReadOnlyNode) + return map(ReadOnlyNode, children(node)) +end function Base.setproperty!(::AbstractReadOnlyNode, ::Symbol, v) return error("Cannot set properties on a ReadOnlyNode") end From b6d187b3bd7a838d064d2e5e5ebc0b61412c2557 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 6 May 2025 21:11:40 +0100 Subject: [PATCH 16/74] fix: read only nodes when given ref --- src/Interfaces.jl | 6 ++++-- src/ReadOnlyNode.jl | 33 +++++++++++++++++++++++++-------- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/Interfaces.jl b/src/Interfaces.jl index df2784d5..267496c4 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -145,9 +145,11 @@ end function _check_constructorof(ex::AbstractExpression) return constructorof(typeof(ex)) isa Base.Callable end -function _check_tree_mapreduce(ex::AbstractExpression{T,N}) where {T,N} +function _check_tree_mapreduce( + ex::AbstractExpression{T,N} +) where {T,D,N<:AbstractExpressionNode{T,D}} return tree_mapreduce(node -> [node], vcat, ex) isa - (Vector{N2} where {N2<:Union{N,AbstractReadOnlyNode{T,N}}}) + (Vector{N2} where {N2<:Union{N,AbstractReadOnlyNode{T,D,N}}}) end #! format: off diff --git a/src/ReadOnlyNode.jl b/src/ReadOnlyNode.jl index c10bfdac..98edf13c 100644 --- a/src/ReadOnlyNode.jl +++ b/src/ReadOnlyNode.jl @@ -1,28 +1,45 @@ module ReadOnlyNodeModule +using DispatchDoctor: @unstable + using ..NodeModule: AbstractExpressionNode, Node import ..NodeModule: default_allocator, with_type_parameters, constructorof, children -abstract type AbstractReadOnlyNode{T,D,N<:AbstractExpressionNode{T,D}} <: +abstract type AbstractReadOnlyNode{T,D,N<:AbstractExpressionNode{T,D},IS_REF} <: AbstractExpressionNode{T,D} end """A type of expression node that prevents writing to the inner node""" -struct ReadOnlyNode{T,D,N} <: AbstractReadOnlyNode{T,D,N} +struct ReadOnlyNode{T,D,N,IS_REF} <: AbstractReadOnlyNode{T,D,N,IS_REF} _inner::N - ReadOnlyNode(n::N) where {T,D,N<:AbstractExpressionNode{T,D}} = new{T,D,N}(n) + function ReadOnlyNode( + n::N, ::Val{IS_REF} + ) where {T,D,N<:AbstractExpressionNode{T,D},IS_REF} + return new{T,D,N,IS_REF}(n) + end + function ReadOnlyNode(n::N) where {T,D,N<:AbstractExpressionNode{T,D}} + return ReadOnlyNode(n, Val(false)) + end + function ReadOnlyNode(n::AbstractReadOnlyNode) + return n + end + function ReadOnlyNode(n::Ref{<:AbstractExpressionNode}) + return ReadOnlyNode(n[], Val(true)) + end end -constructorof(::Type{<:ReadOnlyNode}) = ReadOnlyNode +@inline inner(n::AbstractReadOnlyNode) = getfield(n, :_inner) +@unstable constructorof(::Type{<:ReadOnlyNode}) = ReadOnlyNode +Base.getindex(n::AbstractReadOnlyNode{T,D,N,true} where {T,D,N}) = n @inline function Base.getproperty(n::AbstractReadOnlyNode, s::Symbol) - out = getproperty(getfield(n, :_inner), s) - if out isa AbstractExpressionNode + out = getproperty(inner(n), s) + if out isa Union{AbstractExpressionNode,Ref{<:AbstractExpressionNode}} return ReadOnlyNode(out) else return out end end -@inline function children(node::AbstractReadOnlyNode) - return map(ReadOnlyNode, children(node)) +@inline function children(node::AbstractReadOnlyNode, ::Val{n}) where {n} + return map(ReadOnlyNode, children(inner(node), Val(n))) end function Base.setproperty!(::AbstractReadOnlyNode, ::Symbol, v) return error("Cannot set properties on a ReadOnlyNode") From 87102d8201d8c7bd94fef0c80f8f8c668c6ca2b6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 6 May 2025 22:13:00 +0100 Subject: [PATCH 17/74] feat: create n-arity operator enum --- src/Evaluate.jl | 10 ++++++++-- src/OperatorEnum.jl | 40 ++++++++++++++++++++++++++++++---------- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index c8947929..7f4e3c14 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -251,8 +251,14 @@ end # These are marked unstable due to issues discussed on # https://github.com/JuliaLang/julia/issues/55147 -@unstable get_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U) -@unstable get_nbin(::Type{<:OperatorEnum{B}}) where {B} = counttuple(B) +@unstable function get_nuna(::Type{<:OperatorEnum{OPS}}) where {OPS} + ts = OPS.types + return isempty(ts) ? 0 : counttuple(ts[1]) +end +@unstable function get_nbin(::Type{<:OperatorEnum{OPS}}) where {OPS} + ts = OPS.types + return length(ts) == 1 ? 0 : counttuple(ts[2]) +end function _eval_tree_array( tree::AbstractExpressionNode{T}, diff --git a/src/OperatorEnum.jl b/src/OperatorEnum.jl index 6da6779c..16620acf 100644 --- a/src/OperatorEnum.jl +++ b/src/OperatorEnum.jl @@ -6,29 +6,49 @@ abstract type AbstractOperatorEnum end OperatorEnum Defines an enum over operators, along with their derivatives. + # Fields -- `binops`: A tuple of binary operators. Scalar input type. -- `unaops`: A tuple of unary operators. Scalar input type. +- `ops`: A tuple of operators, with index `i` corresponding to the operator tuple for a node of degree `i`. """ -struct OperatorEnum{B,U} <: AbstractOperatorEnum - binops::B - unaops::U +struct OperatorEnum{OPS<:Tuple{Vararg{Tuple}}} <: AbstractOperatorEnum + ops::OPS +end + +function OperatorEnum(binary_operators::Tuple, unary_operators::Tuple) + return OperatorEnum((unary_operators, binary_operators)) end """ GenericOperatorEnum Defines an enum over operators, along with their derivatives. + # Fields -- `binops`: A tuple of binary operators. -- `unaops`: A tuple of unary operators. +- `ops`: A tuple of operators, with index `i` corresponding to the operator tuple for a node of degree `i`. """ -struct GenericOperatorEnum{B,U} <: AbstractOperatorEnum - binops::B - unaops::U +struct GenericOperatorEnum{OPS<:Tuple{Vararg{Tuple}}} <: AbstractOperatorEnum + ops::OPS +end + +function GenericOperatorEnum(binops::Tuple, unaops::Tuple) + return GenericOperatorEnum((unaops, binops)) end Base.copy(op::AbstractOperatorEnum) = op # TODO: Is this safe? What if a vector is passed here? +@inline function Base.getindex(op::AbstractOperatorEnum, i::Int) + return getfield(op, :ops)[i] +end +@inline function Base.getproperty(op::AbstractOperatorEnum, k::Symbol) + if k == :unaops + return getfield(op, :ops)[1] + elseif k == :binops + ops = getfield(op, :ops) + return length(ops) > 1 ? ops[2] : () + else + return getfield(op, k) + end +end + end From cda089653297352554ff463b1a911438274b157d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 7 May 2025 23:30:13 +0100 Subject: [PATCH 18/74] feat: add generic degree pathway in evaluation --- src/Evaluate.jl | 82 ++++++++++++++++++++++++++++++++++++++++--------- src/Node.jl | 1 + 2 files changed, 69 insertions(+), 14 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 7f4e3c14..e33149bf 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -2,7 +2,7 @@ module EvaluateModule using DispatchDoctor: @stable, @unstable -import ..NodeModule: AbstractExpressionNode, constructorof +import ..NodeModule: AbstractExpressionNode, constructorof, max_degree, children import ..StringsModule: string_tree import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum import ..UtilsModule: fill_similar, counttuple, ResultOk @@ -281,11 +281,27 @@ function _eval_tree_array( elseif tree.degree == 1 op_idx = tree.op return dispatch_deg1_eval(tree, cX, op_idx, operators, eval_options) - else + elseif max_degree(tree) == 2 || tree.degree == 2 # TODO - add op(op2(x, y), z) and op(x, op2(y, z)) # op(x, y), where x, y are constants or variables. op_idx = tree.op return dispatch_deg2_eval(tree, cX, op_idx, operators, eval_options) + else + op_idx = tree.op + return dispatch_degn_eval(tree, cX, op_idx, operators, eval_options) + end +end + +@generated function degn_eval( + cumulators::NTuple{N,<:AbstractVector{T}}, op::F, ::EvalOptions{false} +)::ResultOk where {N,T,F} + # Fast general implementation of `cumulators[1] .= op.(cumulators[1], cumulators[2], ...)` + quote + Base.Cartesian.@nexprs($N, i -> cumulator_i = cumulators[i]) + @inbounds @simd for j in eachindex(cumulator_1) + cumulator_1[j] = Base.Cartesian.@ncall($N, op, i -> cumulator_i[j])::T + end + return ResultOk(cumulator_1, true) end end @@ -293,23 +309,15 @@ function deg2_eval( cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, - ::EvalOptions{false}, + eval_options::EvalOptions{false}, )::ResultOk where {T,F} - @inbounds @simd for j in eachindex(cumulator_l) - x = op(cumulator_l[j], cumulator_r[j])::T - cumulator_l[j] = x - end - return ResultOk(cumulator_l, true) + return degn_eval((cumulator_l, cumulator_r), op, eval_options) end function deg1_eval( - cumulator::AbstractVector{T}, op::F, ::EvalOptions{false} + cumulator::AbstractVector{T}, op::F, eval_options::EvalOptions{false} )::ResultOk where {T,F} - @inbounds @simd for j in eachindex(cumulator) - x = op(cumulator[j])::T - cumulator[j] = x - end - return ResultOk(cumulator, true) + return degn_eval((cumulator,), op, eval_options) end function deg0_eval( @@ -324,6 +332,52 @@ function deg0_eval( end end +@generated function inner_dispatch_degn_eval( + tree::AbstractExpressionNode{T}, + cX::AbstractMatrix{T}, + op_idx::Integer, + ::Val{degree}, + operators::OperatorEnum{OPS}, + eval_options::EvalOptions, +) where {T,degree,OPS} + nops = length(OPS.types[degree].types) + return quote + cs = children(tree, Val($degree)) + Base.Cartesian.@nexprs( + $degree, + i -> begin + result_i = _eval_tree_array(cs[i], cX, operators, eval_options) + !result_i.ok && return result_i + @return_on_nonfinite_array(eval_options, result_i.x) + end + ) + cumulators = Base.Cartesian.@ntuple($degree, i -> result_i.x) + Base.Cartesian.@nif( + $nops, + i -> i == op_idx, + i -> degn_eval(cumulators, operators[$degree][i], eval_options), + ) + end +end +@generated function dispatch_degn_eval( + tree::AbstractExpressionNode{T}, + cX::AbstractMatrix{T}, + op_idx::Integer, + operators::OperatorEnum, + eval_options::EvalOptions, +) where {T} + D = max_degree(tree) + return quote + # If statement over degrees + degree = tree.degree + return Base.Cartesian.@nif( + $D, + d -> d == degree, + d -> + inner_dispatch_degn_eval(tree, cX, op_idx, Val(d), operators, eval_options) + ) + end +end @generated function dispatch_deg2_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, diff --git a/src/Node.jl b/src/Node.jl index 099d2206..e430c313 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -220,6 +220,7 @@ Base.eltype(::AbstractExpressionNode{T}) where {T} = T max_degree(::Type{<:AbstractNode}) = 2 # Default max_degree(::Type{<:AbstractNode{D}}) where {D} = D +max_degree(node::AbstractNode) = max_degree(typeof(node)) @unstable constructorof(::Type{N}) where {N<:Node} = Node{T,max_degree(N)} where {T} @unstable constructorof(::Type{N}) where {N<:GraphNode} = From 673146c9944ac9b2073e85d0825c31bba2911f3d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 7 May 2025 23:34:56 +0100 Subject: [PATCH 19/74] feat: n-arity strings --- src/Strings.jl | 71 ++++++++++++++++++++++++++------------------------ 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/src/Strings.jl b/src/Strings.jl index 66b08d1f..3730cb97 100644 --- a/src/Strings.jl +++ b/src/Strings.jl @@ -2,28 +2,43 @@ module StringsModule using ..UtilsModule: deprecate_varmap using ..OperatorEnumModule: AbstractOperatorEnum -using ..NodeModule: AbstractExpressionNode, tree_mapreduce +using ..NodeModule: AbstractExpressionNode, tree_mapreduce, max_degree function dispatch_op_name( ::Val{deg}, ::Nothing, idx, pretty::Bool )::Vector{Char} where {deg} - return vcat( - collect(deg == 1 ? "unary_operator[" : "binary_operator["), - collect(string(idx)), - [']'], - ) + return vcat(collect( + if deg == 1 + "unary_operator[" + elseif deg == 2 + "binary_operator[" + else + "operator_deg$deg[" + end, + ), collect(string(idx)), [']']) end function dispatch_op_name( ::Val{deg}, operators::AbstractOperatorEnum, idx, pretty::Bool ) where {deg} - op = if deg == 1 - operators.unaops[idx] - else - operators.binops[idx] - end + op = operators[deg][idx] return collect((pretty ? get_pretty_op_name(op) : get_op_name(op))::String) end +struct OpNameDispatcher{D,O<:AbstractOperatorEnum} <: Function + operators::O + pretty::Bool +end +@generated function (f::OpNameDispatcher{D,O})(branch) where {D,O} + return quote + degree = branch.degree + Base.Cartesian.@nif( + $D, + d -> d == degree, + d -> dispatch_op_name(Val(d), f.operators, branch.op, f.pretty), + )::Vector{Char} + end +end + const OP_NAME_CACHE = (; x=Dict{UInt64,String}(), lock=Threads.SpinLock()) function get_op_name(op::F) where {F} @@ -89,36 +104,30 @@ function string_variable(feature, variable_names) end # Vector of chars is faster than strings, so we use that. -function combine_op_with_inputs(op, l, r)::Vector{Char} - if first(op) in ('+', '-', '*', '/', '^', '.', '>', '<', '=') || op == "!=" +function combine_op_with_inputs(op, args::Vararg{Any,D})::Vector{Char} where {D} + if D == 2 && (first(op) in ('+', '-', '*', '/', '^', '.', '>', '<', '=') || op == "!=") # "(l op r)" out = ['('] - append!(out, l) + append!(out, args[1]) push!(out, ' ') append!(out, op) push!(out, ' ') - append!(out, r) + append!(out, args[2]) push!(out, ')') else # "op(l, r)" out = copy(op) push!(out, '(') - append!(out, strip_brackets(l)) - push!(out, ',') - push!(out, ' ') - append!(out, strip_brackets(r)) + for i in 1:(D - 1) + append!(out, strip_brackets(args[i])) + push!(out, ',') + push!(out, ' ') + end + append!(out, strip_brackets(args[D])) push!(out, ')') return out end end -function combine_op_with_inputs(op, l) - # "op(l)" - out = copy(op) - push!(out, '(') - append!(out, strip_brackets(l)) - push!(out, ')') - return out -end """ string_tree( @@ -169,13 +178,7 @@ function string_tree( collect(f_variable(leaf.feature, variable_names))::Vector{Char} end end, - let operators = operators - (branch,) -> if branch.degree == 1 - dispatch_op_name(Val(1), operators, branch.op, pretty)::Vector{Char} - else - dispatch_op_name(Val(2), operators, branch.op, pretty)::Vector{Char} - end - end, + OpNameDispatcher{max_degree(tree),typeof(operators)}(operators, pretty), combine_op_with_inputs, tree, Vector{Char}; From 9ccb46a00bc1b441929e5c29a886e264a5468e86 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 9 May 2025 12:25:44 +0100 Subject: [PATCH 20/74] feat: get expression operators working with 3-arg input --- src/ExpressionAlgebra.jl | 120 ++++++++++++++++++++------------------- src/OperatorEnum.jl | 4 +- src/Strings.jl | 2 +- 3 files changed, 65 insertions(+), 61 deletions(-) diff --git a/src/ExpressionAlgebra.jl b/src/ExpressionAlgebra.jl index f5dcd7b4..dd9159ec 100644 --- a/src/ExpressionAlgebra.jl +++ b/src/ExpressionAlgebra.jl @@ -54,41 +54,23 @@ of the expression. """ declare_operator_alias(op::F, _) where {F<:Function} = op -function apply_operator(op::F, l::AbstractExpression) where {F<:Function} - operators = get_operators(l, nothing) - op_idx = findfirst( - ==(op), map(Base.Fix2(declare_operator_alias, Val(1)), operators.unaops) - ) - if op_idx === nothing +function apply_operator(op::F, args::Vararg{Any,D}) where {F<:Function,D} + idx = findfirst(e -> e isa AbstractExpression, args)::Int + example_expr = args[idx] + E = typeof(example_expr) + @assert all(e -> !(e isa AbstractExpression) || typeof(e) === E, args) + operators = get_operators(example_expr, nothing) + + op_idx = findfirst(==(op), map(Base.Fix2(declare_operator_alias, Val(D)), operators[D])) + if isnothing(op_idx) throw( MissingOperatorError( - "Operator $op not found in operators for expression type $(typeof(l)) with unary operators $(operators.unaops)", + "Operator $op not found in operators for expression type " * + "$(typeof(l)) with $(D)-degree operators $(operators[D])", ), ) end - return insert_operator_index(op_idx, (l,), l) -end -function apply_operator(op::F, l, r) where {F<:Function} - (operators, example_expr) = if l isa AbstractExpression && r isa AbstractExpression - @assert typeof(r) === typeof(l) - (get_operators(l, nothing), l) - elseif l isa AbstractExpression - (get_operators(l, nothing), l) - else - r::AbstractExpression - (get_operators(r, nothing), r) - end - op_idx = findfirst( - ==(op), map(Base.Fix2(declare_operator_alias, Val(2)), operators.binops) - ) - if op_idx === nothing - throw( - MissingOperatorError( - "Operator $op not found in operators for expression type $(typeof(l)) with binary operators $(operators.binops)", - ), - ) - end - return insert_operator_index(op_idx, (l, r), example_expr) + return insert_operator_index(op_idx, args, example_expr) end """ @@ -96,44 +78,59 @@ end Declare an operator function for `AbstractExpression` types. -This macro generates a method for the given operator `op` that works with -`AbstractExpression` arguments. The `arity` parameter specifies whether -the operator is unary (1) or binary (2). - -# Arguments -- `op`: The operator to be declared (e.g., `Base.sin`, `Base.:+`). -- `arity`: The number of arguments the operator takes (1 for unary, 2 for binary). +This macro generates methods for the given operator `op` that work with +`AbstractExpression` arguments. The `arity` parameter specifies the number +of arguments the operator takes. """ macro declare_expression_operator(op, arity) - @assert arity ∈ (1, 2) + syms = [Symbol('x', i) for i in 1:arity] + AE = :($(AbstractExpression)) if arity == 1 return esc( quote - $op(l::AbstractExpression) = $(apply_operator)($op, l) + $op($(only(syms))::$(AE)) = $(apply_operator)($op, $(only(syms))) end, ) - elseif arity == 2 - return esc( - quote - function $op(l::AbstractExpression, r::AbstractExpression) - return $(apply_operator)($op, l, r) - end - function $op(l::T, r::AbstractExpression{T}) where {T} - return $(apply_operator)($op, l, r) - end - function $op(l::AbstractExpression{T}, r::T) where {T} - return $(apply_operator)($op, l, r) - end - # Convenience methods for Number types - function $op(l::Number, r::AbstractExpression{T}) where {T} - return $(apply_operator)($op, l, r) - end - function $op(l::AbstractExpression{T}, r::Number) where {T} - return $(apply_operator)($op, l, r) - end - end, + end + + wrappers = (AE, :($(AE){T}), :T, :Number) + methods = Expr(:block) + + for types in Iterators.product(ntuple(_ -> wrappers, arity)...) + has_expr = any( + t -> t == AE || (t isa Expr && t.head == :curly && t.args[1] == AE), types + ) + has_plain_T = any(==(:T), types) + has_abstract_expr_T = any( + t -> t isa Expr && t.head == :curly && t.args[1] == AE && :T in t.args, types ) + has_abstract_expr_plain = any(==(AE), types) + if any(( + !has_expr, + # ^At least one arg must be an AbstractExpression (avoid type‑piracy) + has_abstract_expr_plain && has_abstract_expr_T, + # ^If a plain `T` appears, ensure an `AbstractExpression{T}` is also present + has_plain_T ⊻ has_abstract_expr_T, + # ^Do not mix bare `AbstractExpression` with `AbstractExpression{T}` + )) + continue + end + + + arglist = [Expr(:(::), syms[i], types[i]) for i in 1:arity] + signature = Expr(:call, op, arglist...) + if any(t -> t == :T || (t isa Expr && t.head == :curly && :T in t.args), types) + signature = Expr(:where, signature, :(T)) + end + + body = Expr(:block, :(return $(apply_operator)($op, $(syms...)))) + + fn = Expr(:function, signature, body) + + push!(methods.args, fn) end + + return esc(methods) end #! format: off @@ -159,6 +156,11 @@ for op in ( ) @eval @declare_expression_operator Base.$(op) 2 end +for op in ( + :*, :+, :clamp, :max, :min, :fma, :muladd, +) + @eval @declare_expression_operator Base.$(op) 3 +end #! format: on end diff --git a/src/OperatorEnum.jl b/src/OperatorEnum.jl index 16620acf..3f1d131d 100644 --- a/src/OperatorEnum.jl +++ b/src/OperatorEnum.jl @@ -1,5 +1,7 @@ module OperatorEnumModule +using DispatchDoctor: @unstable + abstract type AbstractOperatorEnum end """ @@ -37,7 +39,7 @@ end Base.copy(op::AbstractOperatorEnum) = op # TODO: Is this safe? What if a vector is passed here? -@inline function Base.getindex(op::AbstractOperatorEnum, i::Int) +@unstable @inline function Base.getindex(op::AbstractOperatorEnum, i::Int) return getfield(op, :ops)[i] end @inline function Base.getproperty(op::AbstractOperatorEnum, k::Symbol) diff --git a/src/Strings.jl b/src/Strings.jl index 3730cb97..ba5e5f7e 100644 --- a/src/Strings.jl +++ b/src/Strings.jl @@ -24,7 +24,7 @@ function dispatch_op_name( return collect((pretty ? get_pretty_op_name(op) : get_op_name(op))::String) end -struct OpNameDispatcher{D,O<:AbstractOperatorEnum} <: Function +struct OpNameDispatcher{D,O<:Union{AbstractOperatorEnum,Nothing}} <: Function operators::O pretty::Bool end From 633cc943d562c895a803bed60a96f68b1fd83446 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 10 May 2025 13:44:25 +0100 Subject: [PATCH 21/74] feat: get expression algebra working --- src/Expression.jl | 10 ++++++++-- src/ExpressionAlgebra.jl | 19 ++++++++++++++++--- src/OperatorEnum.jl | 3 +++ src/ParametricExpression.jl | 4 +++- test/test_buffered_evaluation.jl | 2 +- 5 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/Expression.jl b/src/Expression.jl index 9e7325a6..9e8ac626 100644 --- a/src/Expression.jl +++ b/src/Expression.jl @@ -8,7 +8,8 @@ using ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum using ..UtilsModule: Undefined using ..ChainRulesModule: NodeTangent -import ..NodeModule: copy_node, set_node!, count_nodes, tree_mapreduce, constructorof +import ..NodeModule: + copy_node, set_node!, count_nodes, tree_mapreduce, constructorof, max_degree import ..NodeUtilsModule: preserve_sharing, count_constant_nodes, @@ -99,9 +100,14 @@ end return Expression(tree, Metadata(d)) end +has_node_type(::Union{E,Type{E}}) where {N,E<:AbstractExpression{<:Any,N}} = true +has_node_type(::Union{E,Type{E}}) where {E<:AbstractExpression} = false node_type(::Union{E,Type{E}}) where {N,E<:AbstractExpression{<:Any,N}} = N +function max_degree(::Union{E,Type{E}}) where {E<:AbstractExpression} + return has_node_type(E) ? max_degree(node_type(E)) : max_degree(Node) +end @unstable default_node_type(_) = Node -default_node_type(::Type{<:AbstractExpression{T}}) where {T} = Node{T} +default_node_type(::Type{N}) where {T,N<:AbstractExpression{T}} = Node{T,max_degree(N)} ######################################################## # Abstract interface ################################### diff --git a/src/ExpressionAlgebra.jl b/src/ExpressionAlgebra.jl index dd9159ec..b1e0bd08 100644 --- a/src/ExpressionAlgebra.jl +++ b/src/ExpressionAlgebra.jl @@ -54,6 +54,10 @@ of the expression. """ declare_operator_alias(op::F, _) where {F<:Function} = op +allow_chaining(@nospecialize(op)) = false +allow_chaining(::typeof(+)) = true +allow_chaining(::typeof(*)) = true + function apply_operator(op::F, args::Vararg{Any,D}) where {F<:Function,D} idx = findfirst(e -> e isa AbstractExpression, args)::Int example_expr = args[idx] @@ -61,12 +65,22 @@ function apply_operator(op::F, args::Vararg{Any,D}) where {F<:Function,D} @assert all(e -> !(e isa AbstractExpression) || typeof(e) === E, args) operators = get_operators(example_expr, nothing) - op_idx = findfirst(==(op), map(Base.Fix2(declare_operator_alias, Val(D)), operators[D])) + op_idx = if length(operators) >= D + findfirst(==(op), map(Base.Fix2(declare_operator_alias, Val(D)), operators[D])) + else + nothing + end if isnothing(op_idx) + if allow_chaining(op) && D > 2 + # These operators might get chained by Julia, so we check + # downward for any matching arity. + inner = apply_operator(op, args[1:(end - 1)]...) + return apply_operator(op, inner, args[end]) + end throw( MissingOperatorError( "Operator $op not found in operators for expression type " * - "$(typeof(l)) with $(D)-degree operators $(operators[D])", + "$(E) with $(D)-degree operators $(operators[D])", ), ) end @@ -116,7 +130,6 @@ macro declare_expression_operator(op, arity) continue end - arglist = [Expr(:(::), syms[i], types[i]) for i in 1:arity] signature = Expr(:call, op, arglist...) if any(t -> t == :T || (t isa Expr && t.head == :curly && :T in t.args), types) diff --git a/src/OperatorEnum.jl b/src/OperatorEnum.jl index 3f1d131d..77049b13 100644 --- a/src/OperatorEnum.jl +++ b/src/OperatorEnum.jl @@ -42,6 +42,9 @@ Base.copy(op::AbstractOperatorEnum) = op @unstable @inline function Base.getindex(op::AbstractOperatorEnum, i::Int) return getfield(op, :ops)[i] end +@inline function Base.length(op::AbstractOperatorEnum) + return length(getfield(op, :ops)) +end @inline function Base.getproperty(op::AbstractOperatorEnum, k::Symbol) if k == :unaops return getfield(op, :ops)[1] diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 3db7666c..3414001b 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -122,7 +122,9 @@ function with_type_parameters(::Type{N}, ::Type{T}) where {N<:ParametricNode,T} return ParametricNode{T,max_degree(N)} end @unstable default_node_type(::Type{<:ParametricExpression}) = ParametricNode{T,2} where {T} -default_node_type(::Type{<:ParametricExpression{T}}) where {T} = ParametricNode{T,2} +function default_node_type(::Type{N}) where {T,N<:ParametricExpression{T}} + return ParametricNode{T,max_degree(N)} +end preserve_sharing(::Union{Type{<:ParametricNode},ParametricNode}) = false # TODO: Change this? function leaf_copy(t::ParametricNode{T}) where {T} if t.constant diff --git a/test/test_buffered_evaluation.jl b/test/test_buffered_evaluation.jl index a8d052b9..a1b41b9a 100644 --- a/test/test_buffered_evaluation.jl +++ b/test/test_buffered_evaluation.jl @@ -147,7 +147,7 @@ end result2, ok2 = eval_tree_array(tree, X, operators; eval_options) # Results should be identical - @test result1 ≈ result2 + @test isapprox(result1, result2; atol=1e-10) @test ok1 == ok2 end end From 38101d1dee4d8d47ac3806f2e4a8dae73156da50 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 10 May 2025 14:05:24 +0100 Subject: [PATCH 22/74] fix: node conversion changing degree --- src/Node.jl | 18 ++++++++++++++---- src/base.jl | 8 +++++++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index e430c313..80e2ddb2 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -218,13 +218,20 @@ end Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T Base.eltype(::AbstractExpressionNode{T}) where {T} = T -max_degree(::Type{<:AbstractNode}) = 2 # Default +const DEFAULT_MAX_DEGREE = 2 +max_degree(::Type{<:AbstractNode}) = DEFAULT_MAX_DEGREE max_degree(::Type{<:AbstractNode{D}}) where {D} = D max_degree(node::AbstractNode) = max_degree(typeof(node)) -@unstable constructorof(::Type{N}) where {N<:Node} = Node{T,max_degree(N)} where {T} -@unstable constructorof(::Type{N}) where {N<:GraphNode} = - GraphNode{T,max_degree(N)} where {T} +has_max_degree(::Type{<:AbstractNode}) = false +has_max_degree(::Type{<:AbstractNode{D}}) where {D} = true + +@unstable function constructorof(::Type{N}) where {N<:Node} + return Node{T,max_degree(N)} where {T} +end +@unstable function constructorof(::Type{N}) where {N<:GraphNode} + return GraphNode{T,max_degree(N)} where {T} +end function with_type_parameters(::Type{N}, ::Type{T}) where {N<:Node,T} return Node{T,max_degree(N)} @@ -233,6 +240,9 @@ function with_type_parameters(::Type{N}, ::Type{T}) where {N<:GraphNode,T} return GraphNode{T,max_degree(N)} end +with_max_degree(::Type{N}, ::Val{D}) where {T,N<:Node{T},D} = Node{T,D} +with_max_degree(::Type{N}, ::Val{D}) where {T,N<:GraphNode{T},D} = GraphNode{T,D} + function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T} return with_type_parameters(N, T)() end diff --git a/src/base.jl b/src/base.jl index ec209214..81cec93a 100644 --- a/src/base.jl +++ b/src/base.jl @@ -520,10 +520,11 @@ using `convert(T1, tree.val)` at constant nodes. """ function convert( ::Type{N1}, tree::N2 -) where {T1,T2,N1<:AbstractExpressionNode{T1},N2<:AbstractExpressionNode{T2}} +) where {T1,T2,D1,D2,N1<:AbstractExpressionNode{T1,D1},N2<:AbstractExpressionNode{T2,D2}} if N1 === N2 return tree end + @assert max_degree(N1) == max_degree(N2) return tree_mapreduce( Base.Fix1(leaf_convert, N1), identity, @@ -533,6 +534,11 @@ function convert( ) # TODO: Need to allow user to overload this! end +function convert( + ::Type{N1}, tree::N2 +) where {T1,T2,D2,N1<:AbstractExpressionNode{T1},N2<:AbstractExpressionNode{T2,D2}} + return convert(with_max_degree(N1, Val(D2)), tree) +end function convert( ::Type{N1}, tree::N2 ) where {T2,N1<:AbstractExpressionNode,N2<:AbstractExpressionNode{T2}} From 22154d0575ffed901e9b4211b67a97cbc99a7bf8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 10 May 2025 15:10:26 +0100 Subject: [PATCH 23/74] fix: type instabilities --- ext/DynamicExpressionsSymbolicUtilsExt.jl | 10 +++++++--- src/Evaluate.jl | 5 +++-- src/ParametricExpression.jl | 4 ++++ 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ext/DynamicExpressionsSymbolicUtilsExt.jl b/ext/DynamicExpressionsSymbolicUtilsExt.jl index 30a987d3..e455eda7 100644 --- a/ext/DynamicExpressionsSymbolicUtilsExt.jl +++ b/ext/DynamicExpressionsSymbolicUtilsExt.jl @@ -175,9 +175,13 @@ function Base.convert( findoperation(op, operators.unaops) end - return constructorof(N)(; - op=ind, children=map(x -> convert(N, x, operators; variable_names), args) - ) + if length(args) == 2 + children = map(x -> convert(N, x, operators; variable_names), (args[1], args[2])) + return constructorof(N)(; op=ind, children) + else + children = map(x -> convert(N, x, operators; variable_names), (only(args),)) + return constructorof(N)(; op=ind, children) + end end _node_type(::Type{<:AbstractExpression{T,N}}) where {T,N<:AbstractExpressionNode} = N diff --git a/src/Evaluate.jl b/src/Evaluate.jl index e33149bf..ee27251d 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -2,7 +2,8 @@ module EvaluateModule using DispatchDoctor: @stable, @unstable -import ..NodeModule: AbstractExpressionNode, constructorof, max_degree, children +import ..NodeModule: + AbstractExpressionNode, constructorof, max_degree, children, with_type_parameters import ..StringsModule: string_tree import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum import ..UtilsModule: fill_similar, counttuple, ResultOk @@ -244,7 +245,7 @@ function eval_tree_array( kws..., ) where {T1,T2} T = promote_type(T1, T2) - tree = convert(constructorof(typeof(tree)){T}, tree) + tree = convert(with_type_parameters(typeof(tree), T), tree) cX = Base.Fix1(convert, T).(cX) return eval_tree_array(tree, cX, operators; kws...) end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 3414001b..8f8048cc 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -12,6 +12,7 @@ using ..ChainRulesModule: NodeTangent import ..NodeModule: constructorof, with_type_parameters, + with_max_degree, max_degree, preserve_sharing, leaf_copy, @@ -121,6 +122,9 @@ end function with_type_parameters(::Type{N}, ::Type{T}) where {N<:ParametricNode,T} return ParametricNode{T,max_degree(N)} end +function with_max_degree(::Type{N}, ::Val{D}) where {T,N<:ParametricNode{T},D} + return ParametricNode{T,D} +end @unstable default_node_type(::Type{<:ParametricExpression}) = ParametricNode{T,2} where {T} function default_node_type(::Type{N}) where {T,N<:ParametricExpression{T}} return ParametricNode{T,max_degree(N)} From 092b9454bf5d5c4d25365e265482fe2927a0e35c Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 10 May 2025 15:42:16 +0100 Subject: [PATCH 24/74] feat: n-arity LoopVectorization compat --- ext/DynamicExpressionsLoopVectorizationExt.jl | 31 +++++++------------ src/Evaluate.jl | 4 +-- 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index bae2b835..f65fcf4b 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -8,6 +8,7 @@ using DynamicExpressions.EvaluateModule: import DynamicExpressions.EvaluateModule: deg1_eval, deg2_eval, + degn_eval, deg1_l2_ll0_lr0_eval, deg1_l1_ll0_eval, deg2_l0_r0_eval, @@ -18,27 +19,17 @@ import DynamicExpressions.ExtensionInterfaceModule: _is_loopvectorization_loaded(::Int) = true -function deg2_eval( - cumulator_l::AbstractVector{T}, - cumulator_r::AbstractVector{T}, - op::F, - ::EvalOptions{true}, -)::ResultOk where {T<:Number,F} - @turbo for j in eachindex(cumulator_l) - x = op(cumulator_l[j], cumulator_r[j]) - cumulator_l[j] = x - end - return ResultOk(cumulator_l, true) -end - -function deg1_eval( - cumulator::AbstractVector{T}, op::F, ::EvalOptions{true} -)::ResultOk where {T<:Number,F} - @turbo for j in eachindex(cumulator) - x = op(cumulator[j]) - cumulator[j] = x +@generated function degn_eval( + cumulators::NTuple{N,<:AbstractVector{T}}, op::F, ::EvalOptions{true} +)::ResultOk where {N,T,F} + # Fast general implementation of `cumulators[1] .= op.(cumulators[1], cumulators[2], ...)` + quote + Base.Cartesian.@nexprs($N, i -> cumulator_i = cumulators[i]) + @turbo for j in eachindex(cumulator_1) + cumulator_1[j] = Base.Cartesian.@ncall($N, op, i -> cumulator_i[j]) + end + return ResultOk(cumulator_1, true) end - return ResultOk(cumulator, true) end function deg1_l2_ll0_lr0_eval( diff --git a/src/Evaluate.jl b/src/Evaluate.jl index ee27251d..c8291907 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -310,13 +310,13 @@ function deg2_eval( cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, - eval_options::EvalOptions{false}, + eval_options::EvalOptions, )::ResultOk where {T,F} return degn_eval((cumulator_l, cumulator_r), op, eval_options) end function deg1_eval( - cumulator::AbstractVector{T}, op::F, eval_options::EvalOptions{false} + cumulator::AbstractVector{T}, op::F, eval_options::EvalOptions )::ResultOk where {T,F} return degn_eval((cumulator,), op, eval_options) end From f7b49a5fed6ac25a32c1ff01b46a4091dcf1ea6a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 10 May 2025 20:52:53 +0100 Subject: [PATCH 25/74] refactor: cleaner `call_mapreducer` for n-arity --- src/base.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/base.jl b/src/base.jl index 81cec93a..a7d8f4c4 100644 --- a/src/base.jl +++ b/src/base.jl @@ -133,15 +133,16 @@ end result = if d == 0 mapreducer.f_leaf(tree) else + branch = mapreducer.f_branch(tree) Base.Cartesian.@nif( $D, i -> i == d, i -> let cs = children(tree, Val(i)) - mapreducer.op( - mapreducer.f_branch(tree), - Base.Cartesian.@ntuple( - i, j -> call_mapreducer(mapreducer, cs[j]) - )..., + Base.Cartesian.@ncall( + i, + mapreducer.op, + branch, + j -> call_mapreducer(mapreducer, cs[j]) ) end ) From 560af0750b25f315e5eae9e0491f5cac45237a0f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 10 May 2025 21:42:56 +0100 Subject: [PATCH 26/74] refactor: simplify constant eval code for n-arity --- src/Evaluate.jl | 92 ++++++++++++++++++++++--------------------------- 1 file changed, 42 insertions(+), 50 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index c8291907..b96184f1 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -708,47 +708,54 @@ gives better performance, as we do not need to perform computation over an entire array when the values are all the same. """ @generated function dispatch_constant_tree( - tree::AbstractExpressionNode{T}, operators::OperatorEnum -) where {T} - nuna = get_nuna(operators) - nbin = get_nbin(operators) - deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN - quote - deg1_eval_constant(tree, operators.unaops[op_idx], operators)::ResultOk{T} - end - else - quote - Base.Cartesian.@nif( - $nuna, - i -> i == op_idx, - i -> deg1_eval_constant(tree, operators.unaops[i], operators)::ResultOk{T} - ) - end + tree::AbstractExpressionNode{T,D}, operators::OperatorEnum +) where {T,D} + quote + deg = tree.degree + deg == 0 && return deg0_eval_constant(tree) + Base.Cartesian.@nif( + $D, + i -> i == deg, + i -> inner_dispatch_degn_eval_constant(tree, Val(i), operators) + ) end - deg2_branch = if nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN - quote - deg2_eval_constant(tree, operators.binops[op_idx], operators)::ResultOk{T} +end + +# Now that we have the degree, we can get the operator +@generated function inner_dispatch_degn_eval_constant( + tree::AbstractExpressionNode{T}, ::Val{degree}, operators::OperatorEnum{OPS} +) where {T,degree,OPS} + nops = length(OPS.types[degree].types) + get_inputs = quote + cs = children(tree, Val($degree)) + Base.Cartesian.@nexprs( + $degree, + i -> begin + input_i = let result = dispatch_constant_tree(cs[i], operators) + !result.ok && return result + result.x + end + end + ) + inputs = Base.Cartesian.@ntuple($degree, i -> input_i) + end + if nops > OPERATOR_LIMIT_BEFORE_SLOWDOWN + return quote + $get_inputs + op_idx = tree.op + degn_eval_constant(inputs, operators[$degree][op_idx])::ResultOk{T} end else - quote + return quote + $get_inputs + op_idx = tree.op Base.Cartesian.@nif( - $nbin, + $nops, i -> i == op_idx, - i -> deg2_eval_constant(tree, operators.binops[i], operators)::ResultOk{T} + i -> degn_eval_constant(inputs, operators[$degree][i])::ResultOk{T} ) end end - return quote - if tree.degree == 0 - return deg0_eval_constant(tree)::ResultOk{T} - elseif tree.degree == 1 - op_idx = tree.op - return $deg1_branch - else - op_idx = tree.op - return $deg2_branch - end - end end @inline function deg0_eval_constant(tree::AbstractExpressionNode{T}) where {T} @@ -756,23 +763,8 @@ end return ResultOk(output, is_valid(output))::ResultOk{T} end -function deg1_eval_constant( - tree::AbstractExpressionNode{T}, op::F, operators::OperatorEnum -) where {T,F} - result = dispatch_constant_tree(tree.l, operators) - !result.ok && return result - output = op(result.x)::T - return ResultOk(output, is_valid(output))::ResultOk{T} -end - -function deg2_eval_constant( - tree::AbstractExpressionNode{T}, op::F, operators::OperatorEnum -) where {T,F} - cumulator = dispatch_constant_tree(tree.l, operators) - !cumulator.ok && return cumulator - result_r = dispatch_constant_tree(tree.r, operators) - !result_r.ok && return result_r - output = op(cumulator.x, result_r.x)::T +function degn_eval_constant(inputs::Tuple{T,Vararg{T}}, op::F) where {T,F} + output = op(inputs...)::T return ResultOk(output, is_valid(output))::ResultOk{T} end From a51f1fb0fd90c6f2453ab3eabaa3af30a421ba57 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 10 May 2025 21:43:28 +0100 Subject: [PATCH 27/74] refactor: avoid need for refs by setting tuple to self --- src/Node.jl | 42 ++++++++++++++++++++++------------ src/NodeUtils.jl | 22 +++++++++--------- src/ParametricExpression.jl | 2 +- test/test_custom_node_type.jl | 14 +++++++----- test/test_extra_node_fields.jl | 4 ++-- 5 files changed, 50 insertions(+), 34 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index 80e2ddb2..0d55b7b4 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -76,7 +76,7 @@ for N in (:Node, :GraphNode) val::T # If is a constant, this stores the actual value feature::UInt16 # (Possibly undefined) If is a variable (e.g., x in cos(x)), this stores the feature index. op::UInt8 # (Possibly undefined) If operator, this is the index of the operator in the degree-specific operator enum - children::NTuple{D,Base.RefValue{$N{T,D}}} # Children nodes + children::NTuple{D,$N{T,D}} ################# ## Constructors: @@ -113,7 +113,7 @@ nodes, you can evaluate or print a given expression. operator in `operators.binops`. In other words, this is an enum of the operators, and is dependent on the specific `OperatorEnum` object. Only defined if `degree >= 1` -- `children::NTuple{D,Base.RefValue{Node{T,D}}}`: Children of the node. Only defined up to `degree` +- `children::NTuple{D,Node{T,D}}`: Children of the node. Only defined up to `degree` # Constructors @@ -166,14 +166,24 @@ when constructing or setting properties. """ GraphNode +function get_poison(n::AbstractNode) + # We don't want to use `nothing` because the type instability + # hits memory hard. + # Setting itself as the right child is the best thing, + # because it (1) doesn't allocate, and (2) will trigger + # infinite recursion errors if someone is mistakenly trying + # to access the right child when `.degree == 1`. + return n +end + macro make_accessors(node_type) esc(quote @inline function Base.getproperty(n::$node_type, k::Symbol) if k == :l # TODO: Should a depwarn be raised here? Or too slow? - return getfield(n, :children)[1][] + return getfield(n, :children)[1] elseif k == :r - return getfield(n, :children)[2][] + return getfield(n, :children)[2] else return getfield(n, k) end @@ -181,16 +191,19 @@ macro make_accessors(node_type) @inline function Base.setproperty!(n::$node_type, k::Symbol, v) if k == :l if isdefined(n, :children) - getfield(n, :children)[1][] = v + old = getfield(n, :children) + setfield!(n, :children, (v, old[2])) + v else - r = Ref(v) - setfield!(n, :children, (r, Ref{typeof(n)}())) - r + poison = $(get_poison)(n) + setfield!(n, :children, (v, poison)) + v end elseif k == :r # TODO: Remove this assert once we know that this is safe - @assert isdefined(n, :children) - getfield(n, :children)[2][] = v + old = getfield(n, :children) + setfield!(n, :children, (old[1], v)) + v else T = fieldtype(typeof(n), k) if v isa T @@ -203,13 +216,13 @@ macro make_accessors(node_type) end) end -@make_accessors Node -@make_accessors GraphNode +@make_accessors Node{T,2} where {T} +@make_accessors GraphNode{T,2} where {T} @inline children(node::AbstractNode) = node.children @inline function children(node::AbstractNode, ::Val{n}) where {n} cs = children(node) - return ntuple(i -> cs[i][], Val(n)) + return ntuple(i -> cs[i], Val(n)) end ################################################################################ @@ -312,7 +325,8 @@ end n = allocator(N, T) n.degree = D2 n.op = op - n.children = ntuple(i -> i <= D2 ? Ref(convert(NT, children[i])) : Ref{NT}(), Val(max_degree(N))) + poison = get_poison(n) + n.children = ntuple(i -> i <= D2 ? convert(NT, children[i]) : poison, Val(max_degree(N))) return n end diff --git a/src/NodeUtils.jl b/src/NodeUtils.jl index 293f4e97..35bfda13 100644 --- a/src/NodeUtils.jl +++ b/src/NodeUtils.jl @@ -6,6 +6,7 @@ import ..NodeModule: Node, preserve_sharing, constructorof, + get_poison, copy_node, count_nodes, tree_mapreduce, @@ -146,20 +147,19 @@ mutable struct NodeIndex{T,D} <: AbstractNode{D} degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc. val::T # If is a constant, this stores the actual value # ------------------- (possibly undefined below) - children::NTuple{D,Base.RefValue{NodeIndex{T,D}}} + children::NTuple{D,NodeIndex{T,D}} function NodeIndex(::Type{_T}, ::Val{_D}, val) where {_T,_D} - return new{_T,_D}( - 0, convert(_T, val), ntuple(_ -> Ref{NodeIndex{_T,_D}}(), Val(_D)) - ) + return new{_T,_D}(0, convert(_T, val)) end function NodeIndex( - ::Type{_T}, ::Val{_D}, children::Vararg{NodeIndex{_T,_D},_D2} + ::Type{_T}, ::Val{_D}, child::NodeIndex{_T,_D}, childs::Vararg{NodeIndex{_T,_D},_D2} ) where {_T,_D,_D2} - _children = ntuple( - i -> i <= _D2 ? Ref(children[i]) : Ref{NodeIndex{_T,_D}}(), Val(_D) - ) - return new{_T,_D}(convert(UInt8, _D2), zero(_T), _children) + node = NodeIndex(_T, Val(_D)) + poison = get_poison(node) + children = (child, childs...) + node.children = ntuple(i -> i <= _D2 + 1 ? children[i] : poison, Val(_D)) + return node end end NodeIndex(::Type{T}, ::Val{D}) where {T,D} = NodeIndex(T, Val(D), zero(T)) @@ -167,9 +167,9 @@ NodeIndex(::Type{T}, ::Val{D}) where {T,D} = NodeIndex(T, Val(D), zero(T)) @inline function Base.getproperty(n::NodeIndex, k::Symbol) if k == :l # TODO: Should a depwarn be raised here? Or too slow? - return getfield(n, :children)[1][] + return getfield(n, :children)[1] elseif k == :r - return getfield(n, :children)[2][] + return getfield(n, :children)[2] else return getfield(n, k) end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 8f8048cc..305c2810 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -57,7 +57,7 @@ mutable struct ParametricNode{T,D} <: AbstractExpressionNode{T,D} parameter::UInt16 # Stores index of per-class parameter op::UInt8 - children::NTuple{D,Base.RefValue{ParametricNode{T,D}}} # Children nodes + children::NTuple{D,ParametricNode{T,D}} # Children nodes function ParametricNode{_T,_D}() where {_T,_D} n = new{_T,_D}() diff --git a/test/test_custom_node_type.jl b/test/test_custom_node_type.jl index 57a3706c..a2e65717 100644 --- a/test/test_custom_node_type.jl +++ b/test/test_custom_node_type.jl @@ -5,16 +5,18 @@ mutable struct MyCustomNode{A,B} <: AbstractNode{2} degree::Int val1::A val2::B - children::NTuple{2,Base.RefValue{MyCustomNode{A,B}}} + children::NTuple{2,MyCustomNode{A,B}} MyCustomNode(val1, val2) = new{typeof(val1),typeof(val2)}(0, val1, val2) function MyCustomNode(val1, val2, l) - return new{typeof(val1),typeof(val2)}( - 1, val1, val2, (Ref(l), Ref{MyCustomNode{typeof(val1),typeof(val2)}}()) - ) + n = MyCustomNode(val1, val2) + poison = n + n.degree = 1 + n.children = (l, poison) + return n end function MyCustomNode(val1, val2, l, r) - return new{typeof(val1),typeof(val2)}(2, val1, val2, (Ref(l), Ref(r))) + return new{typeof(val1),typeof(val2)}(2, val1, val2, (l, r)) end end @@ -29,7 +31,7 @@ node2 = MyCustomNode(1.5, 3, node1) @test typeof(node2) == MyCustomNode{Float64,Int} @test node2.degree == 1 -@test node2.children[1][].degree == 0 +@test node2.children[1].degree == 0 @test count_depth(node2) == 2 @test count_nodes(node2) == 2 diff --git a/test/test_extra_node_fields.jl b/test/test_extra_node_fields.jl index 60b35595..0c7fef36 100644 --- a/test/test_extra_node_fields.jl +++ b/test/test_extra_node_fields.jl @@ -99,5 +99,5 @@ ex = parse_expression( @test string_tree(ex) == "x + sin(y + 2.1)" @test ex.tree.frozen == false -@test ex.tree.children[2][].frozen == true -@test ex.tree.children[2][].children[1][].frozen == false +@test ex.tree.children[2].frozen == true +@test ex.tree.children[2].children[1].frozen == false From 1d3d834a441ba9feec94c4f362a131eb177f455e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 11 May 2025 03:31:32 +0100 Subject: [PATCH 28/74] feat: `any` and `==` working with n-arity nodes --- src/Node.jl | 7 +- src/base.jl | 107 +++--- test/runtests.jl | 6 +- test/test_extra_node_fields.jl | 2 +- test/test_n_arity_nodes.jl | 632 +++++++++++++++++++++++++++++++++ 5 files changed, 697 insertions(+), 57 deletions(-) create mode 100644 test/test_n_arity_nodes.jl diff --git a/src/Node.jl b/src/Node.jl index 0d55b7b4..ba78fbd7 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -216,8 +216,11 @@ macro make_accessors(node_type) end) end -@make_accessors Node{T,2} where {T} -@make_accessors GraphNode{T,2} where {T} +# @make_accessors Node{T,2} where {T} +# @make_accessors GraphNode{T,2} where {T} +@make_accessors Node +@make_accessors GraphNode +# TODO: Disable the `.l` accessors eventually, once the codebase is fully generic @inline children(node::AbstractNode) = node.children @inline function children(node::AbstractNode, ::Val{n}) where {n} diff --git a/src/base.jl b/src/base.jl index a7d8f4c4..e39e2c90 100644 --- a/src/base.jl +++ b/src/base.jl @@ -174,13 +174,19 @@ end Reduce a flag function over a tree, returning `true` if the function returns `true` for any node. By using this instead of tree_mapreduce, we can take advantage of early exits. """ -function any(f::F, tree::AbstractNode) where {F<:Function} - if tree.degree == 0 - return @inline(f(tree))::Bool - elseif tree.degree == 1 - return @inline(f(tree))::Bool || any(f, tree.l) - else - return @inline(f(tree))::Bool || any(f, tree.l) || any(f, tree.r) +@generated function any(f::F, tree::AbstractNode{D}) where {F<:Function,D} + quote + deg = tree.degree + + deg == 0 && return @inline(f(tree)) + + return ( + @inline(f(tree)) || Base.Cartesian.@nif( + $D, i -> deg == i, i -> let cs = children(tree, Val(i)) + Base.Cartesian.@nany(i, j -> any(f, cs[j])) + end + ) + ) end end @@ -189,49 +195,49 @@ function Base.:(==)(a::AbstractExpressionNode, b::AbstractExpressionNode) end function Base.:(==)(a::N, b::N)::Bool where {N<:AbstractExpressionNode} if preserve_sharing(N) - return inner_is_equal_shared(a, b, Dict{UInt,Nothing}(), Dict{UInt,Nothing}()) + return inner_is_equal(a, b, (; a=Dict{UInt,Nothing}(), b=Dict{UInt,Nothing}())) else - return inner_is_equal(a, b) + return inner_is_equal(a, b, nothing) end end -function inner_is_equal(a, b) - (degree = a.degree) != b.degree && return false - if degree == 0 - return leaf_equal(a, b) - elseif degree == 1 - return branch_equal(a, b) && inner_is_equal(a.l, b.l) - else - return branch_equal(a, b) && inner_is_equal(a.l, b.l) && inner_is_equal(a.r, b.r) - end -end -function inner_is_equal_shared(a, b, id_map_a, id_map_b) - id_a = objectid(a) - id_b = objectid(b) - has_a = haskey(id_map_a, id_a) - has_b = haskey(id_map_b, id_b) - - if has_a && has_b - return true - elseif has_a ⊻ has_b - return false - end - - (degree = a.degree) != b.degree && return false +@generated function inner_is_equal( + a::AbstractNode{D}, b::AbstractNode{D}, id_maps::Union{Nothing,NamedTuple} +) where {D} + quote + ids = !isnothing(id_maps) ? (; a=objectid(a), b=objectid(b)) : nothing + + if !isnothing(id_maps) + has_a = haskey(id_maps.a, ids.a) + has_b = haskey(id_maps.b, ids.b) + if has_a && has_b + return true + elseif has_a ⊻ has_b + return false + end + end - result = if degree == 0 - leaf_equal(a, b) - elseif degree == 1 - branch_equal(a, b) && inner_is_equal_shared(a.l, b.l, id_map_a, id_map_b) - else - branch_equal(a, b) && - inner_is_equal_shared(a.l, b.l, id_map_a, id_map_b) && - inner_is_equal_shared(a.r, b.r, id_map_a, id_map_b) + deg = a.degree + result = if deg != b.degree + false + elseif deg == 0 + leaf_equal(a, b) + else + ( + branch_equal(a, b) && Base.Cartesian.@nif( + $D, + i -> deg == i, + i -> let cs_a = children(a, Val(i)), cs_b = children(b, Val(i)) + Base.Cartesian.@nall(i, j -> inner_is_equal(cs_a[j], cs_b[j], id_maps)) + end + ) + ) + end + if !isnothing(ids) + id_maps.a[ids.a] = nothing + id_maps.b[ids.b] = nothing + end + return result end - - id_map_a[id_a] = nothing - id_map_b[id_b] = nothing - - return result end @inline function branch_equal(a::AbstractExpressionNode, b::AbstractExpressionNode) @@ -240,7 +246,8 @@ end @inline function leaf_equal( a::AbstractExpressionNode{T1}, b::AbstractExpressionNode{T2} ) where {T1,T2} - (constant = a.constant) != b.constant && return false + constant = a.constant + constant != b.constant && return false if constant return a.val::T1 == b.val::T2 else @@ -521,11 +528,10 @@ using `convert(T1, tree.val)` at constant nodes. """ function convert( ::Type{N1}, tree::N2 -) where {T1,T2,D1,D2,N1<:AbstractExpressionNode{T1,D1},N2<:AbstractExpressionNode{T2,D2}} +) where {T1,T2,N1<:AbstractExpressionNode{T1},N2<:AbstractExpressionNode{T2}} if N1 === N2 return tree end - @assert max_degree(N1) == max_degree(N2) return tree_mapreduce( Base.Fix1(leaf_convert, N1), identity, @@ -535,11 +541,6 @@ function convert( ) # TODO: Need to allow user to overload this! end -function convert( - ::Type{N1}, tree::N2 -) where {T1,T2,D2,N1<:AbstractExpressionNode{T1},N2<:AbstractExpressionNode{T2,D2}} - return convert(with_max_degree(N1, Val(D2)), tree) -end function convert( ::Type{N1}, tree::N2 ) where {T2,N1<:AbstractExpressionNode,N2<:AbstractExpressionNode{T2}} diff --git a/test/runtests.jl b/test/runtests.jl index c24811a3..70de3cf8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using TestItemRunner # Check if SR_ENZYME_TEST is set in env test_name = split(get(ENV, "SR_TEST", "main"), ",") -unknown_tests = filter(Base.Fix2(∉, ["enzyme", "jet", "main"]), test_name) +unknown_tests = filter(Base.Fix2(∉, ["enzyme", "jet", "main", "narity"]), test_name) if !isempty(unknown_tests) error("Unknown test names: $unknown_tests") @@ -49,3 +49,7 @@ if "main" in test_name include("unittest.jl") @run_package_tests end +if "narity" in test_name + include("test_n_arity_nodes.jl") + @run_package_tests filter = ti -> (:narity in ti.tags) +end diff --git a/test/test_extra_node_fields.jl b/test/test_extra_node_fields.jl index 0c7fef36..996cefe4 100644 --- a/test/test_extra_node_fields.jl +++ b/test/test_extra_node_fields.jl @@ -11,7 +11,7 @@ mutable struct FrozenNode{T,D} <: AbstractExpressionNode{T,D} frozen::Bool # Extra field! feature::UInt16 op::UInt8 - children::NTuple{D,Base.RefValue{FrozenNode{T,D}}} + children::NTuple{D,FrozenNode{T,D}} function FrozenNode{_T,_D}() where {_T,_D} n = new{_T,_D}() diff --git a/test/test_n_arity_nodes.jl b/test/test_n_arity_nodes.jl new file mode 100644 index 00000000..500af513 --- /dev/null +++ b/test/test_n_arity_nodes.jl @@ -0,0 +1,632 @@ +@testitem "N-ary Node Construction and Properties" tags = [:narity] begin + using DynamicExpressions + using Test + + # Define some simple operators for structure, not evaluation here + my_unary_op(x) = x # Placeholder + my_binary_op(x, y) = x # Placeholder + my_ternary_op(x, y, z) = x # Placeholder + + # Corrected OperatorEnum constructor: + operators = OperatorEnum(((my_unary_op,), (my_binary_op,), (my_ternary_op,))) + + # Arity 1 (Unary) in a Node{T,3} type (max_degree is 3) + n_una_leaf = Node{Float64,3}(; feature=1) + n_una = Node{Float64,3}(; op=1, children=(n_una_leaf,)) # op=1 for my_unary_op + @test n_una.degree == 1 + @test n_una.op == 1 + @test DynamicExpressions.NodeModule.max_degree(n_una) == 3 + @test DynamicExpressions.NodeModule.max_degree(typeof(n_una)) == 3 + @test n_una.children[1] === n_una_leaf + # Test poison (node refers to itself for unused children slots) + @test n_una.children[2] === n_una + @test n_una.children[3] === n_una + + # Arity 2 (Binary) in a Node{T,3} type + n_bin_leaf1 = Node{Float64,3}(; feature=1) + n_bin_leaf2 = Node{Float64,3}(; val=2.0) + n_bin = Node{Float64,3}(; op=1, children=(n_bin_leaf1, n_bin_leaf2)) # op=1 for my_binary_op + @test n_bin.degree == 2 + @test n_bin.op == 1 + @test DynamicExpressions.NodeModule.max_degree(n_bin) == 3 + @test n_bin.children[1] === n_bin_leaf1 + @test n_bin.children[2] === n_bin_leaf2 + @test n_bin.children[3] === n_bin # Poison + @test DynamicExpressions.NodeModule.children(n_bin, Val(2)) == + (n_bin_leaf1, n_bin_leaf2) + # .l and .r should NOT be used for Node{T,3} as @make_accessors is for Node{T,2} + @test_throws ErrorException n_bin.l # getfield Node: no field l + + # Arity 3 (Ternary) in a Node{T,3} type + n_ter_leaf1 = Node{Float64,3}(; feature=1) + n_ter_leaf2 = Node{Float64,3}(; feature=2) + n_ter_leaf3 = Node{Float64,3}(; val=0.5) + n_ter = Node{Float64,3}(; op=1, children=(n_ter_leaf1, n_ter_leaf2, n_ter_leaf3)) # op=1 for my_ternary_op + @test n_ter.degree == 3 + @test n_ter.op == 1 + @test DynamicExpressions.NodeModule.max_degree(n_ter) == 3 + @test n_ter.children[1] === n_ter_leaf1 + @test n_ter.children[2] === n_ter_leaf2 + @test n_ter.children[3] === n_ter_leaf3 + @test DynamicExpressions.NodeModule.children(n_ter, Val(3)) == + (n_ter_leaf1, n_ter_leaf2, n_ter_leaf3) + + # Test .l and .r accessors ONLY for Node{T,2} + n2_leaf1_for_l_r = Node{Float64,2}(; feature=1) + n2_leaf2_for_l_r = Node{Float64,2}(; val=2.0) + # Need an operator enum where binary op is index 1 + ops_for_d2_accessors = OperatorEnum(((), (my_binary_op,))) + n2_bin_for_l_r = Node{Float64,2}(; op=1, children=(n2_leaf1_for_l_r, n2_leaf2_for_l_r)) + @test n2_bin_for_l_r.l === n2_leaf1_for_l_r + @test n2_bin_for_l_r.r === n2_leaf2_for_l_r + n2_new_leaf_for_l_r = Node{Float64,2}(; feature=3) + n2_bin_for_l_r.l = n2_new_leaf_for_l_r + @test n2_bin_for_l_r.children[1] === n2_new_leaf_for_l_r + + # Test default D=2 for Node{T} + n_default_D_leaf1 = Node{Float64}(; feature=1) # This is Node{Float64,2} + @test DynamicExpressions.NodeModule.max_degree(typeof(n_default_D_leaf1)) == 2 + + # Test promoting node types + n_f32_d3_promo = Node{Float32,3}(; val=1.0f0) + n_f64_d3_promo = Node{Float64,3}(; val=2.0) + promoted_nodes = promote(n_f32_d3_promo, n_f64_d3_promo) + @test promoted_nodes[1] isa Node{Float64,3} + @test promoted_nodes[2] isa Node{Float64,3} + + # Test with_max_degree + @test DynamicExpressions.NodeModule.with_max_degree(Node{Float64,2}, Val(3)) == + Node{Float64,3} + @test DynamicExpressions.NodeModule.with_max_degree(Node{Float64,3}, Val(2)) == + Node{Float64,2} + # Node{Float64} is UnionAll, Node{Float64,2} after constructor. + # with_max_degree(Node{Float64} where T, Val(D)) is used by convert. + @test DynamicExpressions.NodeModule.with_max_degree(Node{Float64}, Val(4)) == + Node{Float64,4} + + # Test defines_eltype (internal, but used in node_factory) + @test DynamicExpressions.NodeModule.defines_eltype(Node{Float64,2}) == true + @test DynamicExpressions.NodeModule.defines_eltype(Node) == false # Node is UnionAll +end + +@testitem "N-ary OperatorEnum Structure" tags = [:narity] begin + using DynamicExpressions + using Test + + my_unary_op(x) = x + my_binary_op(x, y) = x + my_ternary_op(x, y, z) = x + + operators_unary_only = OperatorEnum(((my_unary_op,),)) + @test length(operators_unary_only) == 1 + @test operators_unary_only.unaops == (my_unary_op,) + @test operators_unary_only.binops == () + @test operators_unary_only[1] == (my_unary_op,) + + operators_binary_only = OperatorEnum(((), (my_binary_op,))) # Empty tuple for unary + @test length(operators_binary_only) == 2 + @test operators_binary_only.unaops == () + @test operators_binary_only.binops == (my_binary_op,) + @test operators_binary_only[2] == (my_binary_op,) + + operators_full = OperatorEnum(((my_unary_op,), (my_binary_op,), (my_ternary_op,))) + @test length(operators_full) == 3 + @test operators_full.unaops == (my_unary_op,) + @test operators_full.binops == (my_binary_op,) + @test operators_full[1] == (my_unary_op,) + @test operators_full[2] == (my_binary_op,) + @test operators_full[3] == (my_ternary_op,) + + @test DynamicExpressions.EvaluateModule.get_nuna(typeof(operators_full)) == 1 + @test DynamicExpressions.EvaluateModule.get_nbin(typeof(operators_full)) == 1 + @test DynamicExpressions.EvaluateModule.get_nuna(typeof(operators_unary_only)) == 1 + @test DynamicExpressions.EvaluateModule.get_nbin(typeof(operators_unary_only)) == 0 + @test DynamicExpressions.EvaluateModule.get_nuna(typeof(operators_binary_only)) == 0 # Correct + @test DynamicExpressions.EvaluateModule.get_nbin(typeof(operators_binary_only)) == 1 +end + +@testitem "N-ary Evaluation (targeting dispatch_degn_eval)" tags = [:narity] begin + using DynamicExpressions + using Test + using Random + + my_eval_unary_op(x) = sin(x) + my_eval_binary_op(x, y) = x^2 + y + my_eval_ternary_op(x, y, z) = x * y - z + + # Operators for Node{Float64,3} which will force dispatch_degn_eval + operators_d3 = OperatorEnum(( + (my_eval_unary_op,), (my_eval_binary_op,), (my_eval_ternary_op,) + )) + + x1 = Node{Float64,3}(; feature=1) + x2 = Node{Float64,3}(; feature=2) + x3 = Node{Float64,3}(; feature=3) + c1 = Node{Float64,3}(; val=0.5) + + X = randn(MersenneTwister(0), Float64, 3, 10) + + # Scenario 1: A binary operator in a Node{Float64,3} type instance. + # This should hit `dispatch_degn_eval` because max_degree(typeof(tree_bin_in_d3)) == 3. + tree_bin_in_d3 = Node{Float64,3}(; op=1, children=(x1, c1)) # my_eval_binary_op(x1, 0.5) + expected_bin_in_d3 = my_eval_binary_op.(X[1, :], 0.5) + # Crucial assumption: is_constant and its dependencies (all/any in base.jl) must be D-arity aware + # for this test to pass without erroring before reaching dispatch_degn_eval. + # If they are not, this test *will* fail, pointing to that issue in the base library code. + output_bin_in_d3, flag_bin_in_d3 = eval_tree_array(tree_bin_in_d3, X, operators_d3) + @test flag_bin_in_d3 + @test output_bin_in_d3 ≈ expected_bin_in_d3 + + # Scenario 2: A ternary operator in a Node{Float64,3} type instance. + # This should also hit `dispatch_degn_eval`. + tree_ter_in_d3 = Node{Float64,3}(; op=1, children=(x1, x2, c1)) # my_eval_ternary_op(x1, x2, 0.5) + expected_ter_in_d3 = my_eval_ternary_op.(X[1, :], X[2, :], 0.5) + output_ter_in_d3, flag_ter_in_d3 = eval_tree_array(tree_ter_in_d3, X, operators_d3) + @test flag_ter_in_d3 + @test output_ter_in_d3 ≈ expected_ter_in_d3 + + # Test nested with different arities, ensuring children are also Node{Float64,3} + unary_child = Node{Float64,3}(; op=1, children=(x1,)) # my_eval_unary_op(x1) + binary_child = Node{Float64,3}(; op=1, children=(x3, c1)) # my_eval_binary_op(x3, 0.5) + tree_nested = Node{Float64,3}(; op=1, children=(unary_child, x2, binary_child)) # my_eval_ternary_op(...) + expected_nested = + my_eval_ternary_op.( + my_eval_unary_op.(X[1, :]), X[2, :], my_eval_binary_op.(X[3, :], 0.5) + ) + output_nested, flag_nested = eval_tree_array(tree_nested, X, operators_d3) + @test flag_nested + @test output_nested ≈ expected_nested + + # Test with type promotion in eval_tree_array (target node type max_degree inferred from input tree) + tree_f32_c1 = Node{Float32,3}(; feature=1) + tree_f32_c2 = Node{Float32,3}(; feature=2) + tree_f32_c3 = Node{Float32,3}(; val=0.5f0) + tree_f32 = Node{Float32,3}(; op=1, children=(tree_f32_c1, tree_f32_c2, tree_f32_c3)) # Ternary op + X_f64 = randn(MersenneTwister(1), Float64, 2, 5) # Only 2 features needed for this tree's variable nodes if used + + output_promoted, flag_promoted = eval_tree_array(tree_f32, X_f64, operators_d3) # operators_d3 has Float64 ops + @test flag_promoted + @test eltype(output_promoted) == Float64 + expected_promoted = my_eval_ternary_op.(X_f64[1, :], X_f64[2, :], 0.5) + @test output_promoted ≈ expected_promoted +end + +@testitem "N-ary Constant Evaluation (targeting inner_dispatch_degn_eval_constant)" tags = [ + :narity +] begin + using DynamicExpressions + using Test + + my_c_unary_op(x) = sin(x) + my_c_binary_op(x, y) = x^2 + y + my_c_ternary_op(x, y, z) = x * y - z + + operators = OperatorEnum(((my_c_unary_op,), (my_c_binary_op,), (my_c_ternary_op,))) + + c1 = Node{Float64,3}(; val=0.5) + c2 = Node{Float64,3}(; val=1.5) + c3 = Node{Float64,3}(; val=2.5) + X_dummy = zeros(Float64, 1, 1) # eval_tree_array needs X + + # Test structure: op_ter(op_una(c1), c2, op_bin(c1,c3)) + # This ensures recursive calls to dispatch_constant_tree and inner_dispatch_degn_eval_constant + const_una_child = Node{Float64,3}(; op=1, children=(c1,)) # my_c_unary_op(0.5) + const_bin_child = Node{Float64,3}(; op=1, children=(c1, c3)) # my_c_binary_op(0.5, 2.5) + tree_const_nested = Node{Float64,3}(; + op=1, children=(const_una_child, c2, const_bin_child) + ) + + expected_val_nested = my_c_ternary_op(my_c_unary_op(0.5), 1.5, my_c_binary_op(0.5, 2.5)) + + output_const_nested, flag_const_nested = eval_tree_array( + tree_const_nested, X_dummy, operators + ) + @test flag_const_nested + @test all(output_const_nested .≈ expected_val_nested) + + const_eval_res_nested = DynamicExpressions.EvaluateModule.dispatch_constant_tree( + tree_const_nested, operators + ) + @test const_eval_res_nested.ok + @test const_eval_res_nested.x ≈ expected_val_nested +end + +@testitem "N-ary ExpressionAlgebra (targeting apply_operator, @declare_expression_operator)" tags = [ + :narity +] begin + using DynamicExpressions + using Test + using Random + + # `clamp` is one of the default 3-arity ops handled by @declare_expression_operator + operators_clamp = OperatorEnum(((), (), (clamp,))) # arity 1 (empty), 2 (empty), 3 + DynamicExpressions.@extend_operators operators_clamp + + ex_x1 = Expression(Node{Float64,3}(; feature=1); operators=operators_clamp) + ex_val_low = Expression(Node{Float64,3}(; val=0.0); operators=operators_clamp) + ex_val_high = Expression(Node{Float64,3}(; val=1.0); operators=operators_clamp) + + expr_clamp3 = clamp(ex_x1, ex_val_low, ex_val_high) # Uses @declare_expression_operator + @test expr_clamp3.tree.degree == 3 + @test expr_clamp3.tree.op == 1 # Index of clamp in ternary list + X = [-0.5, 0.5, 1.5]' + expected_clamp3 = clamp.(X[1, :], 0.0, 1.0) + # Test evaluation of the Expression object + output_clamp3, flag_clamp3 = eval_tree_array(expr_clamp3.tree, X, operators_clamp) + @test flag_clamp3 + @test output_clamp3 ≈ expected_clamp3 + + # Test chaining for `+` (another default N-ary) + operators_plus_chain = OperatorEnum(((), (+,), (+,))) # Binary plus, Ternary plus + DynamicExpressions.@extend_operators operators_plus_chain + ex_p_x1 = Expression(Node{Float64,3}(; feature=1); operators=operators_plus_chain) + ex_p_x2 = Expression(Node{Float64,3}(; feature=2); operators=operators_plus_chain) + ex_p_x3 = Expression(Node{Float64,3}(; feature=3); operators=operators_plus_chain) + + # x1 + x2 + x3 should use ternary plus + expr_plus_ter = ex_p_x1 + ex_p_x2 + ex_p_x3 + @test expr_plus_ter.tree.degree == 3 + @test expr_plus_ter.tree.op == 1 # Index of ternary + + + # x1 + x2 (constant) should use binary plus + expr_plus_bin_const = ex_p_x1 + 0.5 + @test expr_plus_bin_const.tree.degree == 2 + @test expr_plus_bin_const.tree.op == 1 # Index of binary + +end + +@testitem "N-ary String Representation (targeting Strings.jl changes)" tags = [:narity] begin + using DynamicExpressions + using Test + + my_str_unary_op(x) = x + my_str_binary_op(x, y) = x + my_str_ternary_op(x, y, z) = x + + operators = OperatorEnum(( + (my_str_unary_op,), (my_str_binary_op,), (my_str_ternary_op,) + )) + DynamicExpressions.@extend_operators operators # For Expression creation + + x1 = Node{Float64,3}(; feature=1) + x2 = Node{Float64,3}(; feature=2) + x3 = Node{Float64,3}(; feature=3) + + # Wrap in Expression to use its string_tree method + tree_unary_expr = Expression(Node{Float64,3}(; op=1, children=(x1,)); operators) + @test string_tree(tree_unary_expr) == "my_str_unary_op(x1)" + + tree_binary_expr = Expression(Node{Float64,3}(; op=1, children=(x1, x2)); operators) + @test string_tree(tree_binary_expr) == "my_str_binary_op(x1, x2)" + + tree_ternary_expr = Expression( + Node{Float64,3}(; op=1, children=(x1, x2, x3)); operators + ) + @test string_tree(tree_ternary_expr) == "my_str_ternary_op(x1, x2, x3)" + + # Default naming for unknown operators (passed as nothing to string_tree) + # These are raw nodes, so string_tree is called directly on them + tree_unknown_unary = Node{Float64,3}(; op=1, children=(x1,)) + @test string_tree(tree_unknown_unary, nothing) == "unary_operator[1](x1)" # Assuming default op names + tree_unknown_ternary = Node{Float64,3}(; op=1, children=(x1, x2, x3)) + @test string_tree(tree_unknown_ternary, nothing) == "operator_deg3[1](x1, x2, x3)" +end + +@testitem "N-ary tree_mapreduce and base.jl convert" tags = [:narity] begin + using DynamicExpressions + using Test + + my_tmr_unary_op(x) = x + my_tmr_binary_op(x, y) = x + my_tmr_ternary_op(x, y, z) = x + operators_tmr = OperatorEnum(( + (my_tmr_unary_op,), (my_tmr_binary_op,), (my_tmr_ternary_op,) + )) + + x1_tmr = Node{Float64,3}(; feature=1) + x2_tmr = Node{Float64,3}(; feature=2) + x3_tmr = Node{Float64,3}(; feature=3) + c1_tmr = Node{Float64,3}(; val=0.5) + + unary_child_tmr = Node{Float64,3}(; op=1, children=(x1_tmr,)) + binary_child_tmr = Node{Float64,3}(; op=1, children=(x3_tmr, c1_tmr)) + tree_tmr = Node{Float64,3}(; op=1, children=(unary_child_tmr, x2_tmr, binary_child_tmr)) + + num_nodes = tree_mapreduce(_ -> 1, (p, c...) -> p + sum(c), tree_tmr, Int) + @test num_nodes == 7 + + tree_f32_tmr = convert(Node{Float32,3}, tree_tmr) # Converts Node{Float64,3} -> Node{Float32,3} + @test typeof(tree_f32_tmr) == Node{Float32,3} + @test typeof(tree_f32_tmr.children[1].children[1]) == Node{Float32,3} # Grandchild (x1_tmr) + @test tree_f32_tmr.children[3].children[2].val ≈ Float32(0.5) # Grandchild (c1_tmr) + + # Test the convert variant: convert(Node{T1,D1_implicit}, node_of_type_N2{T2,D2}) + # It should become Node{T1,D2} + tree_f64_d3_for_convert = tree_tmr # This is Node{Float64,3} + # Convert to Node{Float32} (which implies Node{Float32,2} as target MAX_DEGREE initially) + # but then with_max_degree(Node{Float32}, Val(3)) is used. + converted_tree_f32_d3 = convert(Node{Float32}, tree_f64_d3_for_convert) + @test typeof(converted_tree_f32_d3) == Node{Float32,3} +end + +@testitem "LoopVectorizationExt with N-ary (degn_eval)" tags = [:narity] begin + using DynamicExpressions + using Test + using Random + + my_lv_unary_op(x) = sin(x) + my_lv_binary_op(x, y) = x^2 + y + my_lv_ternary_op(x, y, z) = x * y - z + + let operators_for_lv = OperatorEnum(( + (my_lv_unary_op,), (my_lv_binary_op,), (my_lv_ternary_op,) + )) + if DynamicExpressions.ExtensionInterfaceModule._is_loopvectorization_loaded(0) + x1_lv = Node{Float64,3}(; feature=1) + x2_lv = Node{Float64,3}(; feature=2) + x3_lv = Node{Float64,3}(; feature=3) + + tree_lv_ternary = Node{Float64,3}(; op=1, children=(x1_lv, x2_lv, x3_lv)) + X_lv = randn(MersenneTwister(3), Float64, 3, 100) + expected_lv_ternary = my_lv_ternary_op.(X_lv[1, :], X_lv[2, :], X_lv[3, :]) + + output_lv_ternary_turbo, flag_turbo = eval_tree_array( + tree_lv_ternary, X_lv, operators_for_lv; turbo=true + ) + @test flag_turbo + @test output_lv_ternary_turbo ≈ expected_lv_ternary + + output_lv_ternary_noturbo, flag_noturbo = eval_tree_array( + tree_lv_ternary, X_lv, operators_for_lv; turbo=false + ) + @test flag_noturbo + @test output_lv_ternary_noturbo ≈ expected_lv_ternary + else + @warn "LoopVectorization not loaded or extension not triggered, skipping LoopVectorizationExt N-ary test." + end + end +end + +@testitem "SymbolicUtilsExt convert for N-ary" tags = [:narity] begin + using DynamicExpressions + using Test + + SU_EXT_LOADED = + Base.get_extension(DynamicExpressions, :DynamicExpressionsSymbolicUtilsExt) !== + nothing + + if SU_EXT_LOADED + DynamicExpressionsSymbolicUtilsExt = Base.get_extension( + DynamicExpressions, :DynamicExpressionsSymbolicUtilsExt + ) + SymbolicUtils = DynamicExpressionsSymbolicUtilsExt.SymbolicUtils + + my_su_unary_op(x) = sin(x) # Needs to be ::Number for SU usually + my_su_binary_op(x, y) = x * x + y + # For ternary, must be careful how SU handles it. Let's use a registered one for safety if possible. + # If not, SU might expand x*y-z into Term(-, [Term(*,...),...]) + # The `convert` diff handles SymbolicUtils.arguments(ex), which are direct arguments to a symbolic function. + @eval MySUTernaryOp(x, y, z) = x * y - z # Dummy for registration + SymbolicUtils.เด็กชาย(MySUTernaryOp) # Make it known to SU + + operators_su = OperatorEnum(( + (my_su_unary_op,), (my_su_binary_op,), (MySUTernaryOp,) + )) + + SymbolicUtils.@syms x_sym y_sym z_sym + + # Unary: args length 1 + expr_su_unary = my_su_unary_op(x_sym) + node_su_unary = convert( + Node{Float64,3}, + expr_su_unary, + operators_su; + variable_names=["x_sym", "y_sym", "z_sym"], + ) + @test node_su_unary.degree == 1 && + node_su_unary.op == 1 && + node_su_unary.children[1].feature == 1 + + # Binary: args length 2 + expr_su_binary = my_su_binary_op(x_sym, y_sym) + node_su_binary = convert( + Node{Float64,3}, + expr_su_binary, + operators_su; + variable_names=["x_sym", "y_sym", "z_sym"], + ) + @test node_su_binary.degree == 2 && + node_su_binary.op == 1 && + node_su_binary.children[1].feature == 1 && + node_su_binary.children[2].feature == 2 + + # Ternary: args length 3. The diff's `else { (only(args),) }` path will be taken. + # `only(args)` will error because `args` (from `SymbolicUtils.arguments`) has 3 elements. + expr_su_ternary = MySUTernaryOp(x_sym, y_sym, z_sym) + # This tests the code as written in the diff: + @test_throws ArgumentError convert( + Node{Float64,3}, + expr_su_ternary, + operators_su; + variable_names=["x_sym", "y_sym", "z_sym"], + ) + else + @warn "SymbolicUtils extension not loaded, skipping SymbolicUtilsExt N-ary test." + end +end + +@testitem "ParametricExpression with N-ary Node" tags = [:narity] begin + using DynamicExpressions + using Test + using Random + + my_p_unary_op(x) = sin(x) + my_p_binary_op(x, y) = x^2 + y + my_p_ternary_op(x, y, z) = x * y - z + + operators_param = OperatorEnum(( + (my_p_unary_op,), (my_p_binary_op,), (my_p_ternary_op,) + )) + DynamicExpressions.@extend_operators operators_param + + pn_x1 = ParametricNode{Float64,3}(; feature=1) + pn_x2 = ParametricNode{Float64,3}(; feature=2) + pn_p1 = ParametricNode{Float64,3}() + pn_p1.degree = UInt8(0) + pn_p1.constant = false + pn_p1.is_parameter = true + pn_p1.parameter = UInt16(1) + + tree_parametric_ter = ParametricNode{Float64,3}(; op=1, children=(pn_p1, pn_x1, pn_x2)) + + ex_param_ter = ParametricExpression( + tree_parametric_ter; + operators=operators_param, + variable_names=["x1", "x2"], # x1 is feature 1, x2 is feature 2 + parameters=reshape([0.5, 1.5], 1, 2), + parameter_names=["p1"], + ) + + @test DynamicExpressions.ExpressionModule.max_degree(ex_param_ter) == 3 + # node_type of ParametricExpression{T,N,D} is N, which is ParametricNode{T,D_node_type} + # For this ex_param_ter, N is ParametricNode{Float64,3} + @test DynamicExpressions.ExpressionModule.node_type(ex_param_ter) == + ParametricNode{Float64,3} + + X_p = randn(MersenneTwister(4), Float64, 2, 10) # 2 features: x1, x2 + classes_p = rand(MersenneTwister(5), 1:2, 10) + + expected_p = [ + my_p_ternary_op( + ex_param_ter.metadata.parameters[1, classes_p[i]], X_p[1, i], X_p[2, i] + ) for i in 1:10 + ] + output_p, flag_p = eval_tree_array(ex_param_ter, X_p, classes_p, operators_param) + @test flag_p + @test output_p ≈ expected_p + + node_from_pex = convert(Node, ex_param_ter) + @test typeof(node_from_pex) == Node{Float64,3} # D is from ParametricNode + @test node_from_pex.degree == 3 && node_from_pex.op == 1 + @test node_from_pex.children[1].feature == 1 # p1 (num_params=1, so parameter 1 becomes feature 1) + @test node_from_pex.children[2].feature == 2 # x1 (orig feat 1 becomes feat 1+1=2) + @test node_from_pex.children[3].feature == 3 # x2 (orig feat 2 becomes feat 2+1=3) +end + +@testitem "ReadOnlyNode with N-ary Node" tags = [:narity] begin + using DynamicExpressions + using Test + + my_ro_unary_op(x) = x + my_ro_binary_op(x, y) = x + my_ro_ternary_op(x, y, z) = x + + operators_ro = OperatorEnum(( + (my_ro_unary_op,), (my_ro_binary_op,), (my_ro_ternary_op,) + )) + DynamicExpressions.@extend_operators operators_ro + + x1_ro = Node{Float64,3}(; feature=1) + x2_ro = Node{Float64,3}(; feature=2) + x3_ro = Node{Float64,3}(; feature=3) + tree_ro_ter = Node{Float64,3}(; op=1, children=(x1_ro, x2_ro, x3_ro)) + + expr_ro = Expression(tree_ro_ter; operators=operators_ro) + readonly_tree = DynamicExpressions.get_tree(expr_ro) + + @test readonly_tree isa DynamicExpressions.ReadOnlyNodeModule.AbstractReadOnlyNode + inner_node_ro = DynamicExpressions.ReadOnlyNodeModule.inner(readonly_tree) + @test DynamicExpressions.NodeModule.max_degree(inner_node_ro) == 3 # D of inner node + @test readonly_tree.degree == 3 # Forwarded from inner node + @test readonly_tree.op == 1 # Forwarded + + ro_children = DynamicExpressions.NodeModule.children(readonly_tree, Val(3)) + @test length(ro_children) == 3 + @test ro_children[1] isa DynamicExpressions.ReadOnlyNodeModule.AbstractReadOnlyNode + @test ro_children[1].feature == 1 # Forwarded + @test ro_children[2].feature == 2 + @test ro_children[3].feature == 3 + + # .l and .r access on ReadOnlyNode wrapping Node{T,3}. + # This should error because inner node Node{T,3} doesn't have .l/.r fields (only properties for Node{T,2}). + @test_throws FieldError readonly_tree.l + @test_throws FieldError readonly_tree.r +end + +@testitem "Expression.jl default_node_type for N-ary" tags = [:narity] begin + using DynamicExpressions + using Test + + # Default node type for Expression{Float64} (which implies Node{T, DEFAULT_MAX_DEGREE=2} as its node_type parameter) + # max_degree(Expression{Float64}) will be max_degree(Node{Float64,2}) = 2. + # So, default_node_type(Expression{Float64}) becomes Node{Float64, 2}. + DefaultNodeForExprT = DynamicExpressions.ExpressionModule.default_node_type( + Expression{Float64} + ) + @test DefaultNodeForExprT == Node{Float64,2} + + # If Expression is explicitly parameterized with Node{Float64,3} + # max_degree(Expression{Float64, Node{Float64,3}}) will be max_degree(Node{Float64,3}) = 3. + # So, default_node_type(...) becomes Node{Float64,3}. + dt_expr_node3 = DynamicExpressions.ExpressionModule.default_node_type( + Expression{Float64,Node{Float64,3}} + ) + @test dt_expr_node3 == Node{Float64,3} + + # Test a custom expression type that influences default_node_type through its own max_degree + # This struct itself determines a max_degree for expressions of its type + struct MyCustomExprOverallArity{T,N<:AbstractExpressionNode{T},OVERALL_ARITY_PARAM} <: + AbstractExpression{T,N} + tree::N + end + DynamicExpressions.ExpressionModule.has_node_type( + ::Type{<:MyCustomExprOverallArity{T,N,ARITY}} + ) where {T,N,ARITY} = true + DynamicExpressions.ExpressionModule.node_type( + ::Type{<:MyCustomExprOverallArity{T,N,ARITY}} + ) where {T,N,ARITY} = N + DynamicExpressions.NodeModule.max_degree( + ::Type{<:MyCustomExprOverallArity{T,N,ARITY}} + ) where {T,N,ARITY} = ARITY # Expression type itself has a max_degree + + # default_node_type(MyCustomExprOverallArity{Float32, Node{Float32,2}, 4}) + # T = Float32. max_degree of this expression TYPE is 4. + # So, default_node_type should be Node{Float32, 4}. + dt_custom_overall_arity = DynamicExpressions.ExpressionModule.default_node_type( + MyCustomExprOverallArity{Float32,Node{Float32,2},4} + ) + @test dt_custom_overall_arity == Node{Float32,4} + + # A custom expression that defaults to max_degree(Node) because has_node_type is false. + struct MySimpleExprNoNodeParam{T} <: AbstractExpression{T,Nothing} end + DynamicExpressions.ExpressionModule.has_node_type(::Type{<:MySimpleExprNoNodeParam}) = + false + dt_my_simple_expr = DynamicExpressions.ExpressionModule.default_node_type( + MySimpleExprNoNodeParam{Float64} + ) + @test dt_my_simple_expr == Node{Float64,2} # max_degree(Node) is 2 +end + +@testitem "NodeUtils.jl NodeIndex for N-ary" tags = [:narity] begin + using DynamicExpressions + using Test + using DynamicExpressions.NodeUtilsModule: index_constant_nodes, NodeIndex + + my_idx_unary(x) = x + my_idx_binary(x, y) = x + my_idx_ternary(x, y, z) = x + operators_idx = OperatorEnum(((my_idx_unary,), (my_idx_binary,), (my_idx_ternary,))) + + c1_idx = Node{Float64,3}(; val=1.0) + f1_idx = Node{Float64,3}(; feature=1) + c2_idx = Node{Float64,3}(; val=2.0) + tree_idx = Node{Float64,3}(; op=1, children=(c1_idx, f1_idx, c2_idx)) # my_idx_ternary_op(1.0, x1, 2.0) + + idx_tree = index_constant_nodes(tree_idx) # Should produce NodeIndex{UInt16,3} + + @test idx_tree isa NodeIndex{UInt16,3} + @test DynamicExpressions.NodeModule.max_degree(typeof(idx_tree)) == 3 + @test idx_tree.degree == 3 + @test idx_tree.children[1].degree == 0 && idx_tree.children[1].val == UInt16(1) # Constant 1.0 is 1st const + @test idx_tree.children[2].degree == 0 && idx_tree.children[2].val == UInt16(0) # Feature node + @test idx_tree.children[3].degree == 0 && idx_tree.children[3].val == UInt16(2) # Constant 2.0 is 2nd const +end From d5a69b74d512e3e673ca1edd344b3d207ae1cd3a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 11 May 2025 04:09:17 +0100 Subject: [PATCH 29/74] fix: various issues with n-arity parametric node --- src/ParametricExpression.jl | 17 +- test/test_n_arity_nodes.jl | 327 +++++++++++++----------------------- 2 files changed, 129 insertions(+), 215 deletions(-) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 305c2810..d8039ed3 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -313,18 +313,23 @@ end function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T} num_params = UInt16(size(ex.metadata.parameters, 1)) + tree = get_tree(ex) + _NT = typeof(tree) + D = max_degree(_NT) + NT = with_max_degree(with_type_parameters(Node, T), Val(D)) + return tree_mapreduce( leaf -> if leaf.constant - Node(; val=leaf.val) + NT(; val=leaf.val) elseif leaf.is_parameter - Node(T; feature=leaf.parameter) + NT(T; feature=leaf.parameter) else - Node(T; feature=leaf.feature + num_params) + NT(T; feature=leaf.feature + num_params) end, branch -> branch.op, - (op, children...) -> Node(; op, children), - get_tree(ex), - Node{T}, + (op, children...) -> NT(; op, children), + tree, + NT, ) end function CRC.rrule(::typeof(convert), ::Type{Node}, ex::ParametricExpression{T}) where {T} diff --git a/test/test_n_arity_nodes.jl b/test/test_n_arity_nodes.jl index 500af513..583814dc 100644 --- a/test/test_n_arity_nodes.jl +++ b/test/test_n_arity_nodes.jl @@ -2,30 +2,30 @@ using DynamicExpressions using Test - # Define some simple operators for structure, not evaluation here - my_unary_op(x) = x # Placeholder - my_binary_op(x, y) = x # Placeholder - my_ternary_op(x, y, z) = x # Placeholder + my_unary_op(x) = x + my_binary_op(x, y) = x + my_ternary_op(x, y, z) = x - # Corrected OperatorEnum constructor: operators = OperatorEnum(((my_unary_op,), (my_binary_op,), (my_ternary_op,))) # Arity 1 (Unary) in a Node{T,3} type (max_degree is 3) n_una_leaf = Node{Float64,3}(; feature=1) - n_una = Node{Float64,3}(; op=1, children=(n_una_leaf,)) # op=1 for my_unary_op + n_una = Node{Float64,3}(; op=1, children=(n_una_leaf,)) @test n_una.degree == 1 @test n_una.op == 1 @test DynamicExpressions.NodeModule.max_degree(n_una) == 3 @test DynamicExpressions.NodeModule.max_degree(typeof(n_una)) == 3 @test n_una.children[1] === n_una_leaf # Test poison (node refers to itself for unused children slots) - @test n_una.children[2] === n_una - @test n_una.children[3] === n_una + @test n_una.children[2] === n_una # Poison value is the node itself + @test n_una.children[3] === n_una # Poison value is the node itself + # Test .l accessor (should work due to @make_accessors Node) + @test n_una.l === n_una_leaf # Arity 2 (Binary) in a Node{T,3} type n_bin_leaf1 = Node{Float64,3}(; feature=1) n_bin_leaf2 = Node{Float64,3}(; val=2.0) - n_bin = Node{Float64,3}(; op=1, children=(n_bin_leaf1, n_bin_leaf2)) # op=1 for my_binary_op + n_bin = Node{Float64,3}(; op=1, children=(n_bin_leaf1, n_bin_leaf2)) @test n_bin.degree == 2 @test n_bin.op == 1 @test DynamicExpressions.NodeModule.max_degree(n_bin) == 3 @@ -34,14 +34,15 @@ @test n_bin.children[3] === n_bin # Poison @test DynamicExpressions.NodeModule.children(n_bin, Val(2)) == (n_bin_leaf1, n_bin_leaf2) - # .l and .r should NOT be used for Node{T,3} as @make_accessors is for Node{T,2} - @test_throws ErrorException n_bin.l # getfield Node: no field l + # .l and .r should work for Node{T,3} due to general @make_accessors Node + @test n_bin.l === n_bin_leaf1 + @test n_bin.r === n_bin_leaf2 # Arity 3 (Ternary) in a Node{T,3} type n_ter_leaf1 = Node{Float64,3}(; feature=1) n_ter_leaf2 = Node{Float64,3}(; feature=2) n_ter_leaf3 = Node{Float64,3}(; val=0.5) - n_ter = Node{Float64,3}(; op=1, children=(n_ter_leaf1, n_ter_leaf2, n_ter_leaf3)) # op=1 for my_ternary_op + n_ter = Node{Float64,3}(; op=1, children=(n_ter_leaf1, n_ter_leaf2, n_ter_leaf3)) @test n_ter.degree == 3 @test n_ter.op == 1 @test DynamicExpressions.NodeModule.max_degree(n_ter) == 3 @@ -50,43 +51,38 @@ @test n_ter.children[3] === n_ter_leaf3 @test DynamicExpressions.NodeModule.children(n_ter, Val(3)) == (n_ter_leaf1, n_ter_leaf2, n_ter_leaf3) + @test n_ter.l === n_ter_leaf1 + @test n_ter.r === n_ter_leaf2 - # Test .l and .r accessors ONLY for Node{T,2} + # Test .l and .r accessors explicitly for Node{T,2} as per diff's specific @make_accessors Node{T,2} n2_leaf1_for_l_r = Node{Float64,2}(; feature=1) n2_leaf2_for_l_r = Node{Float64,2}(; val=2.0) - # Need an operator enum where binary op is index 1 ops_for_d2_accessors = OperatorEnum(((), (my_binary_op,))) n2_bin_for_l_r = Node{Float64,2}(; op=1, children=(n2_leaf1_for_l_r, n2_leaf2_for_l_r)) @test n2_bin_for_l_r.l === n2_leaf1_for_l_r @test n2_bin_for_l_r.r === n2_leaf2_for_l_r n2_new_leaf_for_l_r = Node{Float64,2}(; feature=3) - n2_bin_for_l_r.l = n2_new_leaf_for_l_r + n2_bin_for_l_r.l = n2_new_leaf_for_l_r # Uses setproperty! @test n2_bin_for_l_r.children[1] === n2_new_leaf_for_l_r - # Test default D=2 for Node{T} - n_default_D_leaf1 = Node{Float64}(; feature=1) # This is Node{Float64,2} + n_default_D_leaf1 = Node{Float64}(; feature=1) @test DynamicExpressions.NodeModule.max_degree(typeof(n_default_D_leaf1)) == 2 - # Test promoting node types n_f32_d3_promo = Node{Float32,3}(; val=1.0f0) n_f64_d3_promo = Node{Float64,3}(; val=2.0) promoted_nodes = promote(n_f32_d3_promo, n_f64_d3_promo) @test promoted_nodes[1] isa Node{Float64,3} @test promoted_nodes[2] isa Node{Float64,3} - # Test with_max_degree @test DynamicExpressions.NodeModule.with_max_degree(Node{Float64,2}, Val(3)) == Node{Float64,3} @test DynamicExpressions.NodeModule.with_max_degree(Node{Float64,3}, Val(2)) == Node{Float64,2} - # Node{Float64} is UnionAll, Node{Float64,2} after constructor. - # with_max_degree(Node{Float64} where T, Val(D)) is used by convert. @test DynamicExpressions.NodeModule.with_max_degree(Node{Float64}, Val(4)) == Node{Float64,4} - # Test defines_eltype (internal, but used in node_factory) @test DynamicExpressions.NodeModule.defines_eltype(Node{Float64,2}) == true - @test DynamicExpressions.NodeModule.defines_eltype(Node) == false # Node is UnionAll + @test DynamicExpressions.NodeModule.defines_eltype(Node) == false end @testitem "N-ary OperatorEnum Structure" tags = [:narity] begin @@ -103,7 +99,7 @@ end @test operators_unary_only.binops == () @test operators_unary_only[1] == (my_unary_op,) - operators_binary_only = OperatorEnum(((), (my_binary_op,))) # Empty tuple for unary + operators_binary_only = OperatorEnum(((), (my_binary_op,))) @test length(operators_binary_only) == 2 @test operators_binary_only.unaops == () @test operators_binary_only.binops == (my_binary_op,) @@ -121,7 +117,7 @@ end @test DynamicExpressions.EvaluateModule.get_nbin(typeof(operators_full)) == 1 @test DynamicExpressions.EvaluateModule.get_nuna(typeof(operators_unary_only)) == 1 @test DynamicExpressions.EvaluateModule.get_nbin(typeof(operators_unary_only)) == 0 - @test DynamicExpressions.EvaluateModule.get_nuna(typeof(operators_binary_only)) == 0 # Correct + @test DynamicExpressions.EvaluateModule.get_nuna(typeof(operators_binary_only)) == 0 @test DynamicExpressions.EvaluateModule.get_nbin(typeof(operators_binary_only)) == 1 end @@ -134,7 +130,6 @@ end my_eval_binary_op(x, y) = x^2 + y my_eval_ternary_op(x, y, z) = x * y - z - # Operators for Node{Float64,3} which will force dispatch_degn_eval operators_d3 = OperatorEnum(( (my_eval_unary_op,), (my_eval_binary_op,), (my_eval_ternary_op,) )) @@ -142,33 +137,25 @@ end x1 = Node{Float64,3}(; feature=1) x2 = Node{Float64,3}(; feature=2) x3 = Node{Float64,3}(; feature=3) - c1 = Node{Float64,3}(; val=0.5) + c1_node = Node{Float64,3}(; val=0.5) # Renamed to avoid conflict X = randn(MersenneTwister(0), Float64, 3, 10) - # Scenario 1: A binary operator in a Node{Float64,3} type instance. - # This should hit `dispatch_degn_eval` because max_degree(typeof(tree_bin_in_d3)) == 3. - tree_bin_in_d3 = Node{Float64,3}(; op=1, children=(x1, c1)) # my_eval_binary_op(x1, 0.5) + tree_bin_in_d3 = Node{Float64,3}(; op=1, children=(x1, c1_node)) expected_bin_in_d3 = my_eval_binary_op.(X[1, :], 0.5) - # Crucial assumption: is_constant and its dependencies (all/any in base.jl) must be D-arity aware - # for this test to pass without erroring before reaching dispatch_degn_eval. - # If they are not, this test *will* fail, pointing to that issue in the base library code. output_bin_in_d3, flag_bin_in_d3 = eval_tree_array(tree_bin_in_d3, X, operators_d3) @test flag_bin_in_d3 @test output_bin_in_d3 ≈ expected_bin_in_d3 - # Scenario 2: A ternary operator in a Node{Float64,3} type instance. - # This should also hit `dispatch_degn_eval`. - tree_ter_in_d3 = Node{Float64,3}(; op=1, children=(x1, x2, c1)) # my_eval_ternary_op(x1, x2, 0.5) + tree_ter_in_d3 = Node{Float64,3}(; op=1, children=(x1, x2, c1_node)) expected_ter_in_d3 = my_eval_ternary_op.(X[1, :], X[2, :], 0.5) output_ter_in_d3, flag_ter_in_d3 = eval_tree_array(tree_ter_in_d3, X, operators_d3) @test flag_ter_in_d3 @test output_ter_in_d3 ≈ expected_ter_in_d3 - # Test nested with different arities, ensuring children are also Node{Float64,3} - unary_child = Node{Float64,3}(; op=1, children=(x1,)) # my_eval_unary_op(x1) - binary_child = Node{Float64,3}(; op=1, children=(x3, c1)) # my_eval_binary_op(x3, 0.5) - tree_nested = Node{Float64,3}(; op=1, children=(unary_child, x2, binary_child)) # my_eval_ternary_op(...) + unary_child = Node{Float64,3}(; op=1, children=(x1,)) + binary_child_for_nest = Node{Float64,3}(; op=1, children=(x3, c1_node)) + tree_nested = Node{Float64,3}(; op=1, children=(unary_child, x2, binary_child_for_nest)) expected_nested = my_eval_ternary_op.( my_eval_unary_op.(X[1, :]), X[2, :], my_eval_binary_op.(X[3, :], 0.5) @@ -177,18 +164,26 @@ end @test flag_nested @test output_nested ≈ expected_nested - # Test with type promotion in eval_tree_array (target node type max_degree inferred from input tree) tree_f32_c1 = Node{Float32,3}(; feature=1) tree_f32_c2 = Node{Float32,3}(; feature=2) tree_f32_c3 = Node{Float32,3}(; val=0.5f0) - tree_f32 = Node{Float32,3}(; op=1, children=(tree_f32_c1, tree_f32_c2, tree_f32_c3)) # Ternary op - X_f64 = randn(MersenneTwister(1), Float64, 2, 5) # Only 2 features needed for this tree's variable nodes if used + tree_f32 = Node{Float32,3}(; op=1, children=(tree_f32_c1, tree_f32_c2, tree_f32_c3)) + X_f64 = randn(MersenneTwister(1), Float64, 2, 5) - output_promoted, flag_promoted = eval_tree_array(tree_f32, X_f64, operators_d3) # operators_d3 has Float64 ops + output_promoted, flag_promoted = eval_tree_array(tree_f32, X_f64, operators_d3) @test flag_promoted @test eltype(output_promoted) == Float64 expected_promoted = my_eval_ternary_op.(X_f64[1, :], X_f64[2, :], 0.5) @test output_promoted ≈ expected_promoted + + operators_d2 = OperatorEnum(((), (my_eval_binary_op,))) + x1_d2 = Node{Float64,2}(; feature=1) + c1_d2_node = Node{Float64,2}(; val=0.5) # Renamed + tree_binary_d2 = Node{Float64,2}(; op=1, children=(x1_d2, c1_d2_node)) + expected_binary_d2 = my_eval_binary_op.(X[1, :], 0.5) + output_binary_d2, flag_binary_d2 = eval_tree_array(tree_binary_d2, X, operators_d2) + @test flag_binary_d2 + @test output_binary_d2 ≈ expected_binary_d2 end @testitem "N-ary Constant Evaluation (targeting inner_dispatch_degn_eval_constant)" tags = [ @@ -203,19 +198,16 @@ end operators = OperatorEnum(((my_c_unary_op,), (my_c_binary_op,), (my_c_ternary_op,))) - c1 = Node{Float64,3}(; val=0.5) - c2 = Node{Float64,3}(; val=1.5) - c3 = Node{Float64,3}(; val=2.5) - X_dummy = zeros(Float64, 1, 1) # eval_tree_array needs X + c1_const = Node{Float64,3}(; val=0.5) # Renamed + c2_const = Node{Float64,3}(; val=1.5) # Renamed + c3_const = Node{Float64,3}(; val=2.5) # Renamed + X_dummy = zeros(Float64, 1, 1) - # Test structure: op_ter(op_una(c1), c2, op_bin(c1,c3)) - # This ensures recursive calls to dispatch_constant_tree and inner_dispatch_degn_eval_constant - const_una_child = Node{Float64,3}(; op=1, children=(c1,)) # my_c_unary_op(0.5) - const_bin_child = Node{Float64,3}(; op=1, children=(c1, c3)) # my_c_binary_op(0.5, 2.5) + const_una_child = Node{Float64,3}(; op=1, children=(c1_const,)) + const_bin_child = Node{Float64,3}(; op=1, children=(c1_const, c3_const)) tree_const_nested = Node{Float64,3}(; - op=1, children=(const_una_child, c2, const_bin_child) + op=1, children=(const_una_child, c2_const, const_bin_child) ) - expected_val_nested = my_c_ternary_op(my_c_unary_op(0.5), 1.5, my_c_binary_op(0.5, 2.5)) output_const_nested, flag_const_nested = eval_tree_array( @@ -238,40 +230,35 @@ end using Test using Random - # `clamp` is one of the default 3-arity ops handled by @declare_expression_operator - operators_clamp = OperatorEnum(((), (), (clamp,))) # arity 1 (empty), 2 (empty), 3 + operators_clamp = OperatorEnum(((), (), (clamp,))) DynamicExpressions.@extend_operators operators_clamp ex_x1 = Expression(Node{Float64,3}(; feature=1); operators=operators_clamp) ex_val_low = Expression(Node{Float64,3}(; val=0.0); operators=operators_clamp) ex_val_high = Expression(Node{Float64,3}(; val=1.0); operators=operators_clamp) - expr_clamp3 = clamp(ex_x1, ex_val_low, ex_val_high) # Uses @declare_expression_operator + expr_clamp3 = clamp(ex_x1, ex_val_low, ex_val_high) @test expr_clamp3.tree.degree == 3 - @test expr_clamp3.tree.op == 1 # Index of clamp in ternary list - X = [-0.5, 0.5, 1.5]' - expected_clamp3 = clamp.(X[1, :], 0.0, 1.0) - # Test evaluation of the Expression object - output_clamp3, flag_clamp3 = eval_tree_array(expr_clamp3.tree, X, operators_clamp) + @test expr_clamp3.tree.op == 1 + X_clamp = [-0.5, 0.5, 1.5]' + expected_clamp3 = clamp.(X_clamp[1, :], 0.0, 1.0) + output_clamp3, flag_clamp3 = eval_tree_array(expr_clamp3.tree, X_clamp, operators_clamp) @test flag_clamp3 @test output_clamp3 ≈ expected_clamp3 - # Test chaining for `+` (another default N-ary) - operators_plus_chain = OperatorEnum(((), (+,), (+,))) # Binary plus, Ternary plus + operators_plus_chain = OperatorEnum(((), (+,), (+,))) DynamicExpressions.@extend_operators operators_plus_chain ex_p_x1 = Expression(Node{Float64,3}(; feature=1); operators=operators_plus_chain) ex_p_x2 = Expression(Node{Float64,3}(; feature=2); operators=operators_plus_chain) ex_p_x3 = Expression(Node{Float64,3}(; feature=3); operators=operators_plus_chain) - # x1 + x2 + x3 should use ternary plus expr_plus_ter = ex_p_x1 + ex_p_x2 + ex_p_x3 @test expr_plus_ter.tree.degree == 3 - @test expr_plus_ter.tree.op == 1 # Index of ternary + + @test expr_plus_ter.tree.op == 1 - # x1 + x2 (constant) should use binary plus expr_plus_bin_const = ex_p_x1 + 0.5 @test expr_plus_bin_const.tree.degree == 2 - @test expr_plus_bin_const.tree.op == 1 # Index of binary + + @test expr_plus_bin_const.tree.op == 1 end @testitem "N-ary String Representation (targeting Strings.jl changes)" tags = [:narity] begin @@ -285,29 +272,28 @@ end operators = OperatorEnum(( (my_str_unary_op,), (my_str_binary_op,), (my_str_ternary_op,) )) - DynamicExpressions.@extend_operators operators # For Expression creation + DynamicExpressions.@extend_operators operators - x1 = Node{Float64,3}(; feature=1) - x2 = Node{Float64,3}(; feature=2) - x3 = Node{Float64,3}(; feature=3) + x1_str = Node{Float64,3}(; feature=1) # Renamed + x2_str = Node{Float64,3}(; feature=2) # Renamed + x3_str = Node{Float64,3}(; feature=3) # Renamed - # Wrap in Expression to use its string_tree method - tree_unary_expr = Expression(Node{Float64,3}(; op=1, children=(x1,)); operators) + tree_unary_expr = Expression(Node{Float64,3}(; op=1, children=(x1_str,)); operators) @test string_tree(tree_unary_expr) == "my_str_unary_op(x1)" - tree_binary_expr = Expression(Node{Float64,3}(; op=1, children=(x1, x2)); operators) + tree_binary_expr = Expression( + Node{Float64,3}(; op=1, children=(x1_str, x2_str)); operators + ) @test string_tree(tree_binary_expr) == "my_str_binary_op(x1, x2)" tree_ternary_expr = Expression( - Node{Float64,3}(; op=1, children=(x1, x2, x3)); operators + Node{Float64,3}(; op=1, children=(x1_str, x2_str, x3_str)); operators ) @test string_tree(tree_ternary_expr) == "my_str_ternary_op(x1, x2, x3)" - # Default naming for unknown operators (passed as nothing to string_tree) - # These are raw nodes, so string_tree is called directly on them - tree_unknown_unary = Node{Float64,3}(; op=1, children=(x1,)) - @test string_tree(tree_unknown_unary, nothing) == "unary_operator[1](x1)" # Assuming default op names - tree_unknown_ternary = Node{Float64,3}(; op=1, children=(x1, x2, x3)) + tree_unknown_unary = Node{Float64,3}(; op=1, children=(x1_str,)) + @test string_tree(tree_unknown_unary, nothing) == "unary_operator[1](x1)" + tree_unknown_ternary = Node{Float64,3}(; op=1, children=(x1_str, x2_str, x3_str)) @test string_tree(tree_unknown_ternary, nothing) == "operator_deg3[1](x1, x2, x3)" end @@ -325,33 +311,38 @@ end x1_tmr = Node{Float64,3}(; feature=1) x2_tmr = Node{Float64,3}(; feature=2) x3_tmr = Node{Float64,3}(; feature=3) - c1_tmr = Node{Float64,3}(; val=0.5) + c1_tmr_node = Node{Float64,3}(; val=0.5) # Renamed unary_child_tmr = Node{Float64,3}(; op=1, children=(x1_tmr,)) - binary_child_tmr = Node{Float64,3}(; op=1, children=(x3_tmr, c1_tmr)) + binary_child_tmr = Node{Float64,3}(; op=1, children=(x3_tmr, c1_tmr_node)) tree_tmr = Node{Float64,3}(; op=1, children=(unary_child_tmr, x2_tmr, binary_child_tmr)) num_nodes = tree_mapreduce(_ -> 1, (p, c...) -> p + sum(c), tree_tmr, Int) @test num_nodes == 7 - tree_f32_tmr = convert(Node{Float32,3}, tree_tmr) # Converts Node{Float64,3} -> Node{Float32,3} - @test typeof(tree_f32_tmr) == Node{Float32,3} - @test typeof(tree_f32_tmr.children[1].children[1]) == Node{Float32,3} # Grandchild (x1_tmr) - @test tree_f32_tmr.children[3].children[2].val ≈ Float32(0.5) # Grandchild (c1_tmr) - - # Test the convert variant: convert(Node{T1,D1_implicit}, node_of_type_N2{T2,D2}) - # It should become Node{T1,D2} - tree_f64_d3_for_convert = tree_tmr # This is Node{Float64,3} - # Convert to Node{Float32} (which implies Node{Float32,2} as target MAX_DEGREE initially) - # but then with_max_degree(Node{Float32}, Val(3)) is used. - converted_tree_f32_d3 = convert(Node{Float32}, tree_f64_d3_for_convert) - @test typeof(converted_tree_f32_d3) == Node{Float32,3} + tree_f32_d3_target = convert(Node{Float32,3}, tree_tmr) + @test typeof(tree_f32_d3_target) == Node{Float32,3} + @test typeof(tree_f32_d3_target.children[1].children[1]) == Node{Float32,3} + @test tree_f32_d3_target.children[3].children[2].val ≈ Float32(0.5) + + # Test convert(Node{T1}, node_of_type_N2{T2,D2}) + # This specific call needs Node{T1, D_source} as target for it to work without error with current convert. + tree_f64_d3_for_convert = tree_tmr + converted_tree_f32_d3_explicit_D = convert(Node{Float32,3}, tree_f64_d3_for_convert) + @test typeof(converted_tree_f32_d3_explicit_D) == Node{Float32,3} end @testitem "LoopVectorizationExt with N-ary (degn_eval)" tags = [:narity] begin using DynamicExpressions using Test using Random + try + using LoopVectorization + catch e + @warn "LoopVectorization not installed, skipping LoopVectorizationExt N-ary test." + # To satisfy the rest of the test structure if LV is not available + @eval const LoopVectorization = Nothing + end my_lv_unary_op(x) = sin(x) my_lv_binary_op(x, y) = x^2 + y @@ -360,7 +351,8 @@ end let operators_for_lv = OperatorEnum(( (my_lv_unary_op,), (my_lv_binary_op,), (my_lv_ternary_op,) )) - if DynamicExpressions.ExtensionInterfaceModule._is_loopvectorization_loaded(0) + if LoopVectorization !== Nothing && + DynamicExpressions.ExtensionInterfaceModule._is_loopvectorization_loaded(0) x1_lv = Node{Float64,3}(; feature=1) x2_lv = Node{Float64,3}(; feature=2) x3_lv = Node{Float64,3}(; feature=3) @@ -386,74 +378,6 @@ end end end -@testitem "SymbolicUtilsExt convert for N-ary" tags = [:narity] begin - using DynamicExpressions - using Test - - SU_EXT_LOADED = - Base.get_extension(DynamicExpressions, :DynamicExpressionsSymbolicUtilsExt) !== - nothing - - if SU_EXT_LOADED - DynamicExpressionsSymbolicUtilsExt = Base.get_extension( - DynamicExpressions, :DynamicExpressionsSymbolicUtilsExt - ) - SymbolicUtils = DynamicExpressionsSymbolicUtilsExt.SymbolicUtils - - my_su_unary_op(x) = sin(x) # Needs to be ::Number for SU usually - my_su_binary_op(x, y) = x * x + y - # For ternary, must be careful how SU handles it. Let's use a registered one for safety if possible. - # If not, SU might expand x*y-z into Term(-, [Term(*,...),...]) - # The `convert` diff handles SymbolicUtils.arguments(ex), which are direct arguments to a symbolic function. - @eval MySUTernaryOp(x, y, z) = x * y - z # Dummy for registration - SymbolicUtils.เด็กชาย(MySUTernaryOp) # Make it known to SU - - operators_su = OperatorEnum(( - (my_su_unary_op,), (my_su_binary_op,), (MySUTernaryOp,) - )) - - SymbolicUtils.@syms x_sym y_sym z_sym - - # Unary: args length 1 - expr_su_unary = my_su_unary_op(x_sym) - node_su_unary = convert( - Node{Float64,3}, - expr_su_unary, - operators_su; - variable_names=["x_sym", "y_sym", "z_sym"], - ) - @test node_su_unary.degree == 1 && - node_su_unary.op == 1 && - node_su_unary.children[1].feature == 1 - - # Binary: args length 2 - expr_su_binary = my_su_binary_op(x_sym, y_sym) - node_su_binary = convert( - Node{Float64,3}, - expr_su_binary, - operators_su; - variable_names=["x_sym", "y_sym", "z_sym"], - ) - @test node_su_binary.degree == 2 && - node_su_binary.op == 1 && - node_su_binary.children[1].feature == 1 && - node_su_binary.children[2].feature == 2 - - # Ternary: args length 3. The diff's `else { (only(args),) }` path will be taken. - # `only(args)` will error because `args` (from `SymbolicUtils.arguments`) has 3 elements. - expr_su_ternary = MySUTernaryOp(x_sym, y_sym, z_sym) - # This tests the code as written in the diff: - @test_throws ArgumentError convert( - Node{Float64,3}, - expr_su_ternary, - operators_su; - variable_names=["x_sym", "y_sym", "z_sym"], - ) - else - @warn "SymbolicUtils extension not loaded, skipping SymbolicUtilsExt N-ary test." - end -end - @testitem "ParametricExpression with N-ary Node" tags = [:narity] begin using DynamicExpressions using Test @@ -481,18 +405,16 @@ end ex_param_ter = ParametricExpression( tree_parametric_ter; operators=operators_param, - variable_names=["x1", "x2"], # x1 is feature 1, x2 is feature 2 + variable_names=["x1", "x2"], parameters=reshape([0.5, 1.5], 1, 2), parameter_names=["p1"], ) @test DynamicExpressions.ExpressionModule.max_degree(ex_param_ter) == 3 - # node_type of ParametricExpression{T,N,D} is N, which is ParametricNode{T,D_node_type} - # For this ex_param_ter, N is ParametricNode{Float64,3} @test DynamicExpressions.ExpressionModule.node_type(ex_param_ter) == ParametricNode{Float64,3} - X_p = randn(MersenneTwister(4), Float64, 2, 10) # 2 features: x1, x2 + X_p = randn(MersenneTwister(4), Float64, 2, 10) classes_p = rand(MersenneTwister(5), 1:2, 10) expected_p = [ @@ -505,11 +427,11 @@ end @test output_p ≈ expected_p node_from_pex = convert(Node, ex_param_ter) - @test typeof(node_from_pex) == Node{Float64,3} # D is from ParametricNode + @test typeof(node_from_pex) == Node{Float64,3} @test node_from_pex.degree == 3 && node_from_pex.op == 1 - @test node_from_pex.children[1].feature == 1 # p1 (num_params=1, so parameter 1 becomes feature 1) - @test node_from_pex.children[2].feature == 2 # x1 (orig feat 1 becomes feat 1+1=2) - @test node_from_pex.children[3].feature == 3 # x2 (orig feat 2 becomes feat 2+1=3) + @test node_from_pex.children[1].feature == 1 + @test node_from_pex.children[2].feature == 2 + @test node_from_pex.children[3].feature == 3 end @testitem "ReadOnlyNode with N-ary Node" tags = [:narity] begin @@ -535,75 +457,62 @@ end @test readonly_tree isa DynamicExpressions.ReadOnlyNodeModule.AbstractReadOnlyNode inner_node_ro = DynamicExpressions.ReadOnlyNodeModule.inner(readonly_tree) - @test DynamicExpressions.NodeModule.max_degree(inner_node_ro) == 3 # D of inner node - @test readonly_tree.degree == 3 # Forwarded from inner node - @test readonly_tree.op == 1 # Forwarded + @test DynamicExpressions.NodeModule.max_degree(inner_node_ro) == 3 + @test readonly_tree.degree == 3 + @test readonly_tree.op == 1 ro_children = DynamicExpressions.NodeModule.children(readonly_tree, Val(3)) @test length(ro_children) == 3 @test ro_children[1] isa DynamicExpressions.ReadOnlyNodeModule.AbstractReadOnlyNode - @test ro_children[1].feature == 1 # Forwarded + @test ro_children[1].feature == 1 @test ro_children[2].feature == 2 @test ro_children[3].feature == 3 - # .l and .r access on ReadOnlyNode wrapping Node{T,3}. - # This should error because inner node Node{T,3} doesn't have .l/.r fields (only properties for Node{T,2}). - @test_throws FieldError readonly_tree.l - @test_throws FieldError readonly_tree.r + @test readonly_tree.l.feature == 1 + @test readonly_tree.r.feature == 2 end @testitem "Expression.jl default_node_type for N-ary" tags = [:narity] begin using DynamicExpressions using Test - # Default node type for Expression{Float64} (which implies Node{T, DEFAULT_MAX_DEGREE=2} as its node_type parameter) - # max_degree(Expression{Float64}) will be max_degree(Node{Float64,2}) = 2. - # So, default_node_type(Expression{Float64}) becomes Node{Float64, 2}. DefaultNodeForExprT = DynamicExpressions.ExpressionModule.default_node_type( Expression{Float64} ) @test DefaultNodeForExprT == Node{Float64,2} - # If Expression is explicitly parameterized with Node{Float64,3} - # max_degree(Expression{Float64, Node{Float64,3}}) will be max_degree(Node{Float64,3}) = 3. - # So, default_node_type(...) becomes Node{Float64,3}. dt_expr_node3 = DynamicExpressions.ExpressionModule.default_node_type( Expression{Float64,Node{Float64,3}} ) @test dt_expr_node3 == Node{Float64,3} - # Test a custom expression type that influences default_node_type through its own max_degree - # This struct itself determines a max_degree for expressions of its type - struct MyCustomExprOverallArity{T,N<:AbstractExpressionNode{T},OVERALL_ARITY_PARAM} <: - AbstractExpression{T,N} - tree::N + abstract type AbstractMyCustomExpr{T,N<:AbstractExpressionNode{T}} <: + AbstractExpression{T,N} end + struct MyCustomExprWithArity{T,N_NODE<:AbstractExpressionNode{T,NODE_D},EXPR_ARITY} <: + AbstractMyCustomExpr{T,N_NODE} where {NODE_D} + tree::N_NODE end - DynamicExpressions.ExpressionModule.has_node_type( - ::Type{<:MyCustomExprOverallArity{T,N,ARITY}} - ) where {T,N,ARITY} = true + DynamicExpressions.ExpressionModule.has_node_type(::Type{<:MyCustomExprWithArity}) = + true DynamicExpressions.ExpressionModule.node_type( - ::Type{<:MyCustomExprOverallArity{T,N,ARITY}} - ) where {T,N,ARITY} = N + ::Type{<:MyCustomExprWithArity{T,N_NODE,EXPR_ARITY}} + ) where {T,N_NODE,EXPR_ARITY} = N_NODE DynamicExpressions.NodeModule.max_degree( - ::Type{<:MyCustomExprOverallArity{T,N,ARITY}} - ) where {T,N,ARITY} = ARITY # Expression type itself has a max_degree + ::Type{<:MyCustomExprWithArity{T,N_NODE,EXPR_ARITY}} + ) where {T,N_NODE,EXPR_ARITY} = EXPR_ARITY - # default_node_type(MyCustomExprOverallArity{Float32, Node{Float32,2}, 4}) - # T = Float32. max_degree of this expression TYPE is 4. - # So, default_node_type should be Node{Float32, 4}. dt_custom_overall_arity = DynamicExpressions.ExpressionModule.default_node_type( - MyCustomExprOverallArity{Float32,Node{Float32,2},4} + MyCustomExprWithArity{Float32,Node{Float32,2},4} ) @test dt_custom_overall_arity == Node{Float32,4} - # A custom expression that defaults to max_degree(Node) because has_node_type is false. struct MySimpleExprNoNodeParam{T} <: AbstractExpression{T,Nothing} end DynamicExpressions.ExpressionModule.has_node_type(::Type{<:MySimpleExprNoNodeParam}) = false dt_my_simple_expr = DynamicExpressions.ExpressionModule.default_node_type( MySimpleExprNoNodeParam{Float64} ) - @test dt_my_simple_expr == Node{Float64,2} # max_degree(Node) is 2 + @test dt_my_simple_expr == Node{Float64,2} end @testitem "NodeUtils.jl NodeIndex for N-ary" tags = [:narity] begin @@ -619,14 +528,14 @@ end c1_idx = Node{Float64,3}(; val=1.0) f1_idx = Node{Float64,3}(; feature=1) c2_idx = Node{Float64,3}(; val=2.0) - tree_idx = Node{Float64,3}(; op=1, children=(c1_idx, f1_idx, c2_idx)) # my_idx_ternary_op(1.0, x1, 2.0) + tree_idx = Node{Float64,3}(; op=1, children=(c1_idx, f1_idx, c2_idx)) - idx_tree = index_constant_nodes(tree_idx) # Should produce NodeIndex{UInt16,3} + idx_tree = index_constant_nodes(tree_idx) @test idx_tree isa NodeIndex{UInt16,3} @test DynamicExpressions.NodeModule.max_degree(typeof(idx_tree)) == 3 @test idx_tree.degree == 3 - @test idx_tree.children[1].degree == 0 && idx_tree.children[1].val == UInt16(1) # Constant 1.0 is 1st const - @test idx_tree.children[2].degree == 0 && idx_tree.children[2].val == UInt16(0) # Feature node - @test idx_tree.children[3].degree == 0 && idx_tree.children[3].val == UInt16(2) # Constant 2.0 is 2nd const + @test idx_tree.children[1].degree == 0 && idx_tree.children[1].val == UInt16(1) + @test idx_tree.children[2].degree == 0 && idx_tree.children[2].val == UInt16(0) + @test idx_tree.children[3].degree == 0 && idx_tree.children[3].val == UInt16(2) end From 90fe1776846f5da62ff19aee813168413ff8148e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 11 May 2025 04:23:14 +0100 Subject: [PATCH 30/74] fix: fix degree in `NodeIndex` --- src/NodeUtils.jl | 1 + test/test_n_arity_nodes.jl | 46 ++------------------------------------ 2 files changed, 3 insertions(+), 44 deletions(-) diff --git a/src/NodeUtils.jl b/src/NodeUtils.jl index 35bfda13..71a915fa 100644 --- a/src/NodeUtils.jl +++ b/src/NodeUtils.jl @@ -158,6 +158,7 @@ mutable struct NodeIndex{T,D} <: AbstractNode{D} node = NodeIndex(_T, Val(_D)) poison = get_poison(node) children = (child, childs...) + node.degree = _D2 + 1 node.children = ntuple(i -> i <= _D2 + 1 ? children[i] : poison, Val(_D)) return node end diff --git a/test/test_n_arity_nodes.jl b/test/test_n_arity_nodes.jl index 583814dc..7adf3d40 100644 --- a/test/test_n_arity_nodes.jl +++ b/test/test_n_arity_nodes.jl @@ -453,7 +453,7 @@ end tree_ro_ter = Node{Float64,3}(; op=1, children=(x1_ro, x2_ro, x3_ro)) expr_ro = Expression(tree_ro_ter; operators=operators_ro) - readonly_tree = DynamicExpressions.get_tree(expr_ro) + readonly_tree = DynamicExpressions.ReadOnlyNode(DynamicExpressions.get_tree(expr_ro)) @test readonly_tree isa DynamicExpressions.ReadOnlyNodeModule.AbstractReadOnlyNode inner_node_ro = DynamicExpressions.ReadOnlyNodeModule.inner(readonly_tree) @@ -472,49 +472,6 @@ end @test readonly_tree.r.feature == 2 end -@testitem "Expression.jl default_node_type for N-ary" tags = [:narity] begin - using DynamicExpressions - using Test - - DefaultNodeForExprT = DynamicExpressions.ExpressionModule.default_node_type( - Expression{Float64} - ) - @test DefaultNodeForExprT == Node{Float64,2} - - dt_expr_node3 = DynamicExpressions.ExpressionModule.default_node_type( - Expression{Float64,Node{Float64,3}} - ) - @test dt_expr_node3 == Node{Float64,3} - - abstract type AbstractMyCustomExpr{T,N<:AbstractExpressionNode{T}} <: - AbstractExpression{T,N} end - struct MyCustomExprWithArity{T,N_NODE<:AbstractExpressionNode{T,NODE_D},EXPR_ARITY} <: - AbstractMyCustomExpr{T,N_NODE} where {NODE_D} - tree::N_NODE - end - DynamicExpressions.ExpressionModule.has_node_type(::Type{<:MyCustomExprWithArity}) = - true - DynamicExpressions.ExpressionModule.node_type( - ::Type{<:MyCustomExprWithArity{T,N_NODE,EXPR_ARITY}} - ) where {T,N_NODE,EXPR_ARITY} = N_NODE - DynamicExpressions.NodeModule.max_degree( - ::Type{<:MyCustomExprWithArity{T,N_NODE,EXPR_ARITY}} - ) where {T,N_NODE,EXPR_ARITY} = EXPR_ARITY - - dt_custom_overall_arity = DynamicExpressions.ExpressionModule.default_node_type( - MyCustomExprWithArity{Float32,Node{Float32,2},4} - ) - @test dt_custom_overall_arity == Node{Float32,4} - - struct MySimpleExprNoNodeParam{T} <: AbstractExpression{T,Nothing} end - DynamicExpressions.ExpressionModule.has_node_type(::Type{<:MySimpleExprNoNodeParam}) = - false - dt_my_simple_expr = DynamicExpressions.ExpressionModule.default_node_type( - MySimpleExprNoNodeParam{Float64} - ) - @test dt_my_simple_expr == Node{Float64,2} -end - @testitem "NodeUtils.jl NodeIndex for N-ary" tags = [:narity] begin using DynamicExpressions using Test @@ -534,6 +491,7 @@ end @test idx_tree isa NodeIndex{UInt16,3} @test DynamicExpressions.NodeModule.max_degree(typeof(idx_tree)) == 3 + @test tree_idx.degree == 3 @test idx_tree.degree == 3 @test idx_tree.children[1].degree == 0 && idx_tree.children[1].val == UInt16(1) @test idx_tree.children[2].degree == 0 && idx_tree.children[2].val == UInt16(0) From b4a5ba8c9af9c60a820c3f0a3a46162985bd524b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 11 May 2025 06:15:58 +0100 Subject: [PATCH 31/74] fix: `set_node!` should set to `children` --- src/Node.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index ba78fbd7..1e0d25ed 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -398,10 +398,7 @@ function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNod end else tree.op = new_tree.op - tree.l = new_tree.l - if new_tree.degree == 2 - tree.r = new_tree.r - end + tree.children = new_tree.children end return nothing end From 3905fc8a197e3d8eac92077395a10f81b511f0e8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 11 May 2025 06:32:48 +0100 Subject: [PATCH 32/74] fix: node preallocation for n-arity nodes --- src/NodePreallocation.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/NodePreallocation.jl b/src/NodePreallocation.jl index ccce372d..fb2a0d85 100644 --- a/src/NodePreallocation.jl +++ b/src/NodePreallocation.jl @@ -6,7 +6,8 @@ using ..NodeModule: tree_mapreduce, leaf_copy, branch_copy, - set_node! + set_node!, + get_poison """ allocate_container(prototype::AbstractExpressionNode, n=nothing) @@ -56,13 +57,11 @@ end # COV_EXCL_STOP function branch_copy_into!( dest::N, src::N, children::Vararg{N,M} -) where {N<:AbstractExpressionNode,M} +) where {T,D,N<:AbstractExpressionNode{T,D},M} dest.degree = M dest.op = src.op - dest.l = children[1] - if M == 2 - dest.r = children[2] - end + poison = get_poison(dest) + dest.children = ntuple(i -> i <= M ? children[i] : poison, D) return dest end From 59c0878e079e37fde341f3d5cdeb3037c175af79 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 11 May 2025 06:46:36 +0100 Subject: [PATCH 33/74] feat: complete node interface for n-arity --- src/Interfaces.jl | 39 +++++++++++++------------------------ src/Node.jl | 2 +- test/test_node_interface.jl | 24 +++++++++++++++++++++++ 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/src/Interfaces.jl b/src/Interfaces.jl index 267496c4..d542dc41 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -12,6 +12,7 @@ using ..NodeModule: constructorof, default_allocator, with_type_parameters, + children, leaf_copy, leaf_convert, leaf_hash, @@ -248,8 +249,8 @@ function _check_eltype(tree::AbstractExpressionNode{T}) where {T} end function _check_with_type_parameters(tree::AbstractExpressionNode{T}) where {T} N = typeof(tree) - NT = with_type_parameters(Base.typename(N).wrapper, eltype(tree)) - return NT == typeof(tree) + Nf16 = with_type_parameters(N, Float16) + return Nf16 <: AbstractExpressionNode{Float16} end function _check_default_allocator(tree::AbstractExpressionNode) N = Base.typename(typeof(tree)).wrapper @@ -299,35 +300,21 @@ function _check_leaf_equal(tree::AbstractExpressionNode) return leaf_equal(tree, copy(tree)) end function _check_branch_copy(tree::AbstractExpressionNode) - if tree.degree == 0 - return true - elseif tree.degree == 1 - return branch_copy(tree, tree.l) isa typeof(tree) - else - return branch_copy(tree, tree.l, tree.r) isa typeof(tree) - end + tree.degree == 0 && return true + return branch_copy(tree, children(tree, Val(tree.degree))...) isa typeof(tree) end function _check_branch_copy_into!(tree::AbstractExpressionNode{T}) where {T} - if tree.degree == 0 - return true - end + tree.degree == 0 && return true new_branch = constructorof(typeof(tree))(; val=zero(T)) - if tree.degree == 1 - ret = branch_copy_into!(new_branch, tree, copy(tree.l)) - return new_branch == tree && ret === new_branch - else - ret = branch_copy_into!(new_branch, tree, copy(tree.l), copy(tree.r)) - return new_branch == tree && ret === new_branch - end + ret = branch_copy_into!( + new_branch, tree, map(copy, children(tree, Val(tree.degree)))... + ) + return new_branch == tree && ret === new_branch end function _check_branch_convert(tree::AbstractExpressionNode) - if tree.degree == 0 - return true - elseif tree.degree == 1 - return branch_convert(typeof(tree), tree, tree.l) isa typeof(tree) - else - return branch_convert(typeof(tree), tree, tree.l, tree.r) isa typeof(tree) - end + tree.degree == 0 && return true + return branch_convert(typeof(tree), tree, children(tree, Val(tree.degree))...) isa + typeof(tree) end function _check_branch_hash(tree::AbstractExpressionNode) tree.degree == 0 && return true diff --git a/src/Node.jl b/src/Node.jl index 1e0d25ed..30476ace 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -225,7 +225,7 @@ end @inline children(node::AbstractNode) = node.children @inline function children(node::AbstractNode, ::Val{n}) where {n} cs = children(node) - return ntuple(i -> cs[i], Val(n)) + return ntuple(i -> cs[i], Val(Int(n))) end ################################################################################ diff --git a/test/test_node_interface.jl b/test/test_node_interface.jl index 1ee01688..c4f37529 100644 --- a/test/test_node_interface.jl +++ b/test/test_node_interface.jl @@ -33,3 +33,27 @@ ], ) end + +@testitem "Node interface on n-arity nodes" begin + using DynamicExpressions + using DynamicExpressions: NodeInterface + using Interfaces: Interfaces + + for D in (3, 4, 5) + x = [Node{Float64,D}(; feature=i) for i in 1:3] + operator_tuple = ((sin, cos, exp), (+, *, /, -), (fma, clamp), (max, min), ()) + operators = OperatorEnum(operator_tuple[1:D]) + DynamicExpressions.OperatorEnumConstructionModule.empty_all_globals!() + let tree = Node{Float64,D}(; op=2, children=(x[1], x[2])) # * + if D > 2 + fma_idx = 1 + tree = Node{Float64,D}(; op=fma_idx, children=(tree, x[1], x[2])) # fma + end + if D > 3 + idx_max = 1 + tree = Node{Float64,D}(; op=idx_max, children=(tree, x[1], x[2], x[3])) # max + end + @test Interfaces.test(NodeInterface, Node, tree) + end + end +end From b78097a3df6330a6819cf2e61877b18469fb3bc8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 11 May 2025 06:50:56 +0100 Subject: [PATCH 34/74] feat: add `children` to required interface --- src/Interfaces.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/Interfaces.jl b/src/Interfaces.jl index d542dc41..f4f57b4a 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -226,6 +226,13 @@ function _check_create_node(tree::AbstractExpressionNode) NT = with_type_parameters(N, Float16) return NT() isa NT end +function _check_children(tree::AbstractExpressionNode{T,D}) where {T,D} + tree.degree == 0 && return true + return children(tree) isa Tuple{typeof(tree),Vararg{typeof(tree)}} && + children(tree, Val(D)) isa Tuple && + length(children(tree, Val(D))) == D && + length(children(tree, Val(1))) == 1 +end function _check_copy(tree::AbstractExpressionNode) return copy(tree) isa typeof(tree) end @@ -360,6 +367,7 @@ end ni_components = ( mandatory = ( create_node = "creates a new instance of the node type" => _check_create_node, + children = "returns the children of the node" => _check_children, copy = "returns a copy of the tree" => _check_copy, hash = "returns the hash of the tree" => _check_hash, any = "checks if any element of the tree satisfies a condition" => _check_any, From 008dfbc818302bf5a1e5da558732ceee5fad612e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 11 May 2025 19:49:00 +0100 Subject: [PATCH 35/74] feat: better interface for children --- src/Evaluate.jl | 6 ++-- src/Interfaces.jl | 20 ++++++------ src/Node.jl | 63 +++++++++++++++++++++++--------------- src/NodePreallocation.jl | 5 ++- src/NodeUtils.jl | 5 ++- src/ReadOnlyNode.jl | 6 ++-- src/base.jl | 13 +++++--- test/test_n_arity_nodes.jl | 6 ++-- 8 files changed, 70 insertions(+), 54 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index b96184f1..069e339e 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -3,7 +3,7 @@ module EvaluateModule using DispatchDoctor: @stable, @unstable import ..NodeModule: - AbstractExpressionNode, constructorof, max_degree, children, with_type_parameters + AbstractExpressionNode, constructorof, max_degree, get_children, with_type_parameters import ..StringsModule: string_tree import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum import ..UtilsModule: fill_similar, counttuple, ResultOk @@ -343,7 +343,7 @@ end ) where {T,degree,OPS} nops = length(OPS.types[degree].types) return quote - cs = children(tree, Val($degree)) + cs = get_children(tree, Val($degree)) Base.Cartesian.@nexprs( $degree, i -> begin @@ -727,7 +727,7 @@ end ) where {T,degree,OPS} nops = length(OPS.types[degree].types) get_inputs = quote - cs = children(tree, Val($degree)) + cs = get_children(tree, Val($degree)) Base.Cartesian.@nexprs( $degree, i -> begin diff --git a/src/Interfaces.jl b/src/Interfaces.jl index f4f57b4a..5842ec34 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -12,7 +12,7 @@ using ..NodeModule: constructorof, default_allocator, with_type_parameters, - children, + get_children, leaf_copy, leaf_convert, leaf_hash, @@ -226,12 +226,12 @@ function _check_create_node(tree::AbstractExpressionNode) NT = with_type_parameters(N, Float16) return NT() isa NT end -function _check_children(tree::AbstractExpressionNode{T,D}) where {T,D} +function _check_get_children(tree::AbstractExpressionNode{T,D}) where {T,D} tree.degree == 0 && return true - return children(tree) isa Tuple{typeof(tree),Vararg{typeof(tree)}} && - children(tree, Val(D)) isa Tuple && - length(children(tree, Val(D))) == D && - length(children(tree, Val(1))) == 1 + return get_children(tree) isa Tuple{typeof(tree),Vararg{typeof(tree)}} && + get_children(tree, Val(D)) isa Tuple && + length(get_children(tree, Val(D))) == D && + length(get_children(tree, Val(1))) == 1 end function _check_copy(tree::AbstractExpressionNode) return copy(tree) isa typeof(tree) @@ -308,19 +308,19 @@ function _check_leaf_equal(tree::AbstractExpressionNode) end function _check_branch_copy(tree::AbstractExpressionNode) tree.degree == 0 && return true - return branch_copy(tree, children(tree, Val(tree.degree))...) isa typeof(tree) + return branch_copy(tree, get_children(tree, Val(tree.degree))...) isa typeof(tree) end function _check_branch_copy_into!(tree::AbstractExpressionNode{T}) where {T} tree.degree == 0 && return true new_branch = constructorof(typeof(tree))(; val=zero(T)) ret = branch_copy_into!( - new_branch, tree, map(copy, children(tree, Val(tree.degree)))... + new_branch, tree, map(copy, get_children(tree, Val(tree.degree)))... ) return new_branch == tree && ret === new_branch end function _check_branch_convert(tree::AbstractExpressionNode) tree.degree == 0 && return true - return branch_convert(typeof(tree), tree, children(tree, Val(tree.degree))...) isa + return branch_convert(typeof(tree), tree, get_children(tree, Val(tree.degree))...) isa typeof(tree) end function _check_branch_hash(tree::AbstractExpressionNode) @@ -367,7 +367,7 @@ end ni_components = ( mandatory = ( create_node = "creates a new instance of the node type" => _check_create_node, - children = "returns the children of the node" => _check_children, + get_children = "returns the children of the node" => _check_get_children, copy = "returns a copy of the tree" => _check_copy, hash = "returns the hash of the tree" => _check_hash, any = "checks if any element of the tree satisfies a condition" => _check_any, diff --git a/src/Node.jl b/src/Node.jl index 30476ace..99dd0d31 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -176,14 +176,42 @@ function get_poison(n::AbstractNode) return n end +@inline function get_children(node::AbstractNode) + return getfield(node, :children) +end +@inline function get_children(node::AbstractNode, ::Val{n}) where {n} + cs = get_children(node) + return ntuple(i -> cs[i], Val(Int(n))) +end +@inline function get_child(n::AbstractNode{D}, i::Int) where {D} + return get_children(n)[i] +end +@inline function set_child!(n::AbstractNode{D}, child::AbstractNode{D}, i::Int) where {D} + set_children!(n, Base.setindex(get_children(n), child, i)) + return child +end +@inline function set_children!(n::AbstractNode{D}, children::NTuple{D2,AbstractNode{D}}) where {D,D2} + if D === D2 + n.children = children + else + poison = get_poison(n) + # We insert poison at the end of the tuple so that + # errors will appear loudly if accessed. + # This poison should be efficient to insert. So + # for simplicity, we can just use poison == n, which + # will trigger infinite recursion errors if accessed. + n.children = ntuple(i -> i <= D2 ? children[i] : poison, Val(D)) + end +end + macro make_accessors(node_type) esc(quote @inline function Base.getproperty(n::$node_type, k::Symbol) if k == :l # TODO: Should a depwarn be raised here? Or too slow? - return getfield(n, :children)[1] + return $(get_child)(n, 1) elseif k == :r - return getfield(n, :children)[2] + return $(get_child)(n, 2) else return getfield(n, k) end @@ -191,19 +219,13 @@ macro make_accessors(node_type) @inline function Base.setproperty!(n::$node_type, k::Symbol, v) if k == :l if isdefined(n, :children) - old = getfield(n, :children) - setfield!(n, :children, (v, old[2])) - v + $(set_child!)(n, v, 1) else - poison = $(get_poison)(n) - setfield!(n, :children, (v, poison)) + $(set_children!)(n, (v,)) v end elseif k == :r - # TODO: Remove this assert once we know that this is safe - old = getfield(n, :children) - setfield!(n, :children, (old[1], v)) - v + $(set_child!)(n, v, 2) else T = fieldtype(typeof(n), k) if v isa T @@ -222,12 +244,6 @@ end @make_accessors GraphNode # TODO: Disable the `.l` accessors eventually, once the codebase is fully generic -@inline children(node::AbstractNode) = node.children -@inline function children(node::AbstractNode, ::Val{n}) where {n} - cs = children(node) - return ntuple(i -> cs[i], Val(Int(n))) -end - ################################################################################ #! format: on @@ -273,11 +289,11 @@ include("base.jl") @inline function (::Type{N})( ::Type{T1}=Undefined; val=nothing, feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, allocator::F=default_allocator, ) where {T1,N<:AbstractExpressionNode{T} where T,F} - _children = if l !== nothing && r === nothing - @assert children === nothing + _children = if !isnothing(l) && isnothing(r) + @assert isnothing(children) (l,) - elseif l !== nothing && r !== nothing - @assert children === nothing + elseif !isnothing(l) && !isnothing(r) + @assert isnothing(children) (l, r) else children @@ -328,8 +344,7 @@ end n = allocator(N, T) n.degree = D2 n.op = op - poison = get_poison(n) - n.children = ntuple(i -> i <= D2 ? convert(NT, children[i]) : poison, Val(max_degree(N))) + set_children!(n, children) return n end @@ -398,7 +413,7 @@ function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNod end else tree.op = new_tree.op - tree.children = new_tree.children + set_children!(tree, get_children(new_tree)) end return nothing end diff --git a/src/NodePreallocation.jl b/src/NodePreallocation.jl index fb2a0d85..207773a7 100644 --- a/src/NodePreallocation.jl +++ b/src/NodePreallocation.jl @@ -7,7 +7,7 @@ using ..NodeModule: leaf_copy, branch_copy, set_node!, - get_poison + set_children! """ allocate_container(prototype::AbstractExpressionNode, n=nothing) @@ -60,8 +60,7 @@ function branch_copy_into!( ) where {T,D,N<:AbstractExpressionNode{T,D},M} dest.degree = M dest.op = src.op - poison = get_poison(dest) - dest.children = ntuple(i -> i <= M ? children[i] : poison, D) + set_children!(dest, children) return dest end diff --git a/src/NodeUtils.jl b/src/NodeUtils.jl index 71a915fa..1bba822e 100644 --- a/src/NodeUtils.jl +++ b/src/NodeUtils.jl @@ -6,7 +6,7 @@ import ..NodeModule: Node, preserve_sharing, constructorof, - get_poison, + set_children!, copy_node, count_nodes, tree_mapreduce, @@ -156,10 +156,9 @@ mutable struct NodeIndex{T,D} <: AbstractNode{D} ::Type{_T}, ::Val{_D}, child::NodeIndex{_T,_D}, childs::Vararg{NodeIndex{_T,_D},_D2} ) where {_T,_D,_D2} node = NodeIndex(_T, Val(_D)) - poison = get_poison(node) children = (child, childs...) node.degree = _D2 + 1 - node.children = ntuple(i -> i <= _D2 + 1 ? children[i] : poison, Val(_D)) + set_children!(node, children) return node end end diff --git a/src/ReadOnlyNode.jl b/src/ReadOnlyNode.jl index 98edf13c..8b45612a 100644 --- a/src/ReadOnlyNode.jl +++ b/src/ReadOnlyNode.jl @@ -3,7 +3,7 @@ module ReadOnlyNodeModule using DispatchDoctor: @unstable using ..NodeModule: AbstractExpressionNode, Node -import ..NodeModule: default_allocator, with_type_parameters, constructorof, children +import ..NodeModule: default_allocator, with_type_parameters, constructorof, get_children abstract type AbstractReadOnlyNode{T,D,N<:AbstractExpressionNode{T,D},IS_REF} <: AbstractExpressionNode{T,D} end @@ -38,8 +38,8 @@ Base.getindex(n::AbstractReadOnlyNode{T,D,N,true} where {T,D,N}) = n return out end end -@inline function children(node::AbstractReadOnlyNode, ::Val{n}) where {n} - return map(ReadOnlyNode, children(inner(node), Val(n))) +@inline function get_children(node::AbstractReadOnlyNode) + return map(ReadOnlyNode, get_children(inner(node))) end function Base.setproperty!(::AbstractReadOnlyNode, ::Symbol, v) return error("Cannot set properties on a ReadOnlyNode") diff --git a/src/base.jl b/src/base.jl index e39e2c90..9b3dae31 100644 --- a/src/base.jl +++ b/src/base.jl @@ -137,7 +137,7 @@ end Base.Cartesian.@nif( $D, i -> i == d, - i -> let cs = children(tree, Val(i)) + i -> let cs = get_children(tree, Val(i)) Base.Cartesian.@ncall( i, mapreducer.op, @@ -182,7 +182,7 @@ By using this instead of tree_mapreduce, we can take advantage of early exits. return ( @inline(f(tree)) || Base.Cartesian.@nif( - $D, i -> deg == i, i -> let cs = children(tree, Val(i)) + $D, i -> deg == i, i -> let cs = get_children(tree, Val(i)) Base.Cartesian.@nany(i, j -> any(f, cs[j])) end ) @@ -226,9 +226,12 @@ end branch_equal(a, b) && Base.Cartesian.@nif( $D, i -> deg == i, - i -> let cs_a = children(a, Val(i)), cs_b = children(b, Val(i)) - Base.Cartesian.@nall(i, j -> inner_is_equal(cs_a[j], cs_b[j], id_maps)) - end + i -> + let cs_a = get_children(a, Val(i)), cs_b = get_children(b, Val(i)) + Base.Cartesian.@nall( + i, j -> inner_is_equal(cs_a[j], cs_b[j], id_maps) + ) + end ) ) end diff --git a/test/test_n_arity_nodes.jl b/test/test_n_arity_nodes.jl index 7adf3d40..36b28c70 100644 --- a/test/test_n_arity_nodes.jl +++ b/test/test_n_arity_nodes.jl @@ -32,7 +32,7 @@ @test n_bin.children[1] === n_bin_leaf1 @test n_bin.children[2] === n_bin_leaf2 @test n_bin.children[3] === n_bin # Poison - @test DynamicExpressions.NodeModule.children(n_bin, Val(2)) == + @test DynamicExpressions.NodeModule.get_children(n_bin, Val(2)) == (n_bin_leaf1, n_bin_leaf2) # .l and .r should work for Node{T,3} due to general @make_accessors Node @test n_bin.l === n_bin_leaf1 @@ -49,7 +49,7 @@ @test n_ter.children[1] === n_ter_leaf1 @test n_ter.children[2] === n_ter_leaf2 @test n_ter.children[3] === n_ter_leaf3 - @test DynamicExpressions.NodeModule.children(n_ter, Val(3)) == + @test DynamicExpressions.NodeModule.get_children(n_ter, Val(3)) == (n_ter_leaf1, n_ter_leaf2, n_ter_leaf3) @test n_ter.l === n_ter_leaf1 @test n_ter.r === n_ter_leaf2 @@ -461,7 +461,7 @@ end @test readonly_tree.degree == 3 @test readonly_tree.op == 1 - ro_children = DynamicExpressions.NodeModule.children(readonly_tree, Val(3)) + ro_children = DynamicExpressions.NodeModule.get_children(readonly_tree, Val(3)) @test length(ro_children) == 3 @test ro_children[1] isa DynamicExpressions.ReadOnlyNodeModule.AbstractReadOnlyNode @test ro_children[1].feature == 1 From 97abbd0199524f85383adec29247b7cc2e3d79d9 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 11 May 2025 20:23:50 +0100 Subject: [PATCH 36/74] feat: make differentiable eval work for n-arity --- src/Evaluate.jl | 84 ++++++++++++++++++++++++++----------------------- 1 file changed, 45 insertions(+), 39 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 069e339e..90d456bd 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -781,53 +781,59 @@ function differentiable_eval_tree_array( end @generated function _differentiable_eval_tree_array( + tree::AbstractExpressionNode{T1,D}, cX::AbstractMatrix{T}, operators::OperatorEnum +)::ResultOk where {T<:Number,T1,D} + quote + tree.degree == 0 && return deg0_diff_eval(tree, cX, operators) + op_idx = tree.op + deg = tree.degree + Base.Cartesian.@nif( + $D, + i -> i == deg, + i -> dispatch_degn_diff_eval(tree, cX, op_idx, Val(i), operators) + ) + end +end + + +function deg0_diff_eval( tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T}, operators::OperatorEnum )::ResultOk where {T<:Number,T1} - nuna = get_nuna(operators) - nbin = get_nbin(operators) - quote - if tree.degree == 0 - if tree.constant - ResultOk(fill_similar(one(T), cX, axes(cX, 2)) .* tree.val, true) - else - ResultOk(cX[tree.feature, :], true) - end - elseif tree.degree == 1 - op_idx = tree.op - Base.Cartesian.@nif( - $nuna, - i -> i == op_idx, - i -> deg1_diff_eval(tree, cX, operators.unaops[i], operators) - ) - else - op_idx = tree.op - Base.Cartesian.@nif( - $nbin, - i -> i == op_idx, - i -> deg2_diff_eval(tree, cX, operators.binops[i], operators) - ) - end + if tree.constant + ResultOk(fill_similar(one(T), cX, axes(cX, 2)) .* tree.val, true) + else + ResultOk(cX[tree.feature, :], true) end end -function deg1_diff_eval( - tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum -)::ResultOk where {T<:Number,F,T1} - left = _differentiable_eval_tree_array(tree.l, cX, operators) - !left.ok && return left - out = op.(left.x) +function degn_diff_eval(cumulators::C, op::F) where {C<:Tuple,F} + out = op.(cumulators...) return ResultOk(out, all(isfinite, out)) end -function deg2_diff_eval( - tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum -)::ResultOk where {T<:Number,F,T1} - left = _differentiable_eval_tree_array(tree.l, cX, operators) - !left.ok && return left - right = _differentiable_eval_tree_array(tree.r, cX, operators) - !right.ok && return right - out = op.(left.x, right.x) - return ResultOk(out, all(isfinite, out)) +@generated function dispatch_degn_diff_eval( + tree::AbstractExpressionNode{T1,D}, + cX::AbstractMatrix{T}, + op_idx::Integer, + ::Val{degree}, + operators::OperatorEnum{OPS} +) where {T<:Number,T1,D,degree,OPS} + nops = length(OPS.types[degree].types) + quote + cs = get_children(tree, Val($degree)) + Base.Cartesian.@nexprs($degree, i -> begin + cumulator_i = let result = _differentiable_eval_tree_array(cs[i], cX, operators) + !result.ok && return result + result.x + end + end) + cumulators = Base.Cartesian.@ntuple($degree, i -> cumulator_i) + Base.Cartesian.@nif( + $nops, + i -> i == op_idx, + i -> degn_diff_eval(cumulators, operators[$degree][i]) + ) + end end """ From 067734ac2b387b2def187e4a7e8728a756d95804 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 11 May 2025 20:40:02 +0100 Subject: [PATCH 37/74] feat: make generic eval allow n-arity nodes --- src/Evaluate.jl | 141 ++++++++++++++++++++++++++---------------------- 1 file changed, 77 insertions(+), 64 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 90d456bd..7a48f261 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -795,7 +795,6 @@ end end end - function deg0_diff_eval( tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T}, operators::OperatorEnum )::ResultOk where {T<:Number,T1} @@ -816,22 +815,24 @@ end cX::AbstractMatrix{T}, op_idx::Integer, ::Val{degree}, - operators::OperatorEnum{OPS} + operators::OperatorEnum{OPS}, ) where {T<:Number,T1,D,degree,OPS} nops = length(OPS.types[degree].types) quote cs = get_children(tree, Val($degree)) - Base.Cartesian.@nexprs($degree, i -> begin - cumulator_i = let result = _differentiable_eval_tree_array(cs[i], cX, operators) - !result.ok && return result - result.x + Base.Cartesian.@nexprs( + $degree, + i -> begin + cumulator_i = + let result = _differentiable_eval_tree_array(cs[i], cX, operators) + !result.ok && return result + result.x + end end - end) + ) cumulators = Base.Cartesian.@ntuple($degree, i -> cumulator_i) Base.Cartesian.@nif( - $nops, - i -> i == op_idx, - i -> degn_diff_eval(cumulators, operators[$degree][i]) + $nops, i -> i == op_idx, i -> degn_diff_eval(cumulators, operators[$degree][i]) ) end end @@ -910,77 +911,89 @@ function eval(current_node) end end -@unstable function _eval_tree_array_generic( - tree::AbstractExpressionNode{T1}, +@generated function _eval_tree_array_generic( + tree::AbstractExpressionNode{T1,D}, cX::AbstractArray{T2,N}, operators::GenericOperatorEnum, ::Val{throw_errors}, -) where {T1,T2,N,throw_errors} - if tree.degree == 0 - if tree.constant - if N == 1 - return (tree.val::T1), true - else - return fill(tree.val::T1, size(cX)[2:N]), true - end +) where {T1,D,T2,N,throw_errors} + quote + tree.degree == 0 && return deg0_eval_generic(tree, cX) + op_idx = tree.op + deg = tree.degree + Base.Cartesian.@nif( + $D, + i -> i == deg, + i -> dispatch_degn_eval_generic( + tree, cX, op_idx, Val(i), operators, Val(throw_errors) + ) + ) + end +end + +@unstable function deg0_eval_generic( + tree::AbstractExpressionNode{T1}, cX::AbstractArray{T2,N} +) where {T1,T2,N} + if tree.constant + if N == 1 + return (tree.val::T1), true else - if N == 1 - return (cX[tree.feature]), true - else - return copy(selectdim(cX, 1, tree.feature)), true - end + return fill(tree.val::T1, size(cX)[2:N]), true end - elseif tree.degree == 1 - return deg1_eval_generic( - tree, cX, operators.unaops[tree.op], operators, Val(throw_errors) - ) else - return deg2_eval_generic( - tree, cX, operators.binops[tree.op], operators, Val(throw_errors) - ) + if N == 1 + return (cX[tree.feature]), true + else + return copy(selectdim(cX, 1, tree.feature)), true + end end end -@unstable function deg1_eval_generic( - tree::AbstractExpressionNode{T1}, - cX::AbstractArray{T2,N}, - op::F, - operators::GenericOperatorEnum, - ::Val{throw_errors}, -) where {F,T1,T2,N,throw_errors} - left, complete = _eval_tree_array_generic(tree.l, cX, operators, Val(throw_errors)) - !throw_errors && !complete && return nothing, false - !throw_errors && - !hasmethod(op, N == 1 ? Tuple{typeof(left)} : Tuple{eltype(left)}) && - return nothing, false +@unstable function degn_eval_generic( + cumulators::C, op::F, ::Val{N}, ::Val{throw_errors} +) where {C<:Tuple,F,N,throw_errors} + if !throw_errors + input_type = N == 1 ? C : Tuple{map(eltype, cumulators)...} + !hasmethod(op, input_type) && return nothing, false + end if N == 1 - return op(left), true + return op(cumulators...), true else - return op.(left), true + return op.(cumulators...), true end end -@unstable function deg2_eval_generic( +@generated function dispatch_degn_eval_generic( tree::AbstractExpressionNode{T1}, cX::AbstractArray{T2,N}, - op::F, - operators::GenericOperatorEnum, + op_idx::Integer, + ::Val{degree}, + operators::GenericOperatorEnum{OPS}, ::Val{throw_errors}, -) where {F,T1,T2,N,throw_errors} - left, complete = _eval_tree_array_generic(tree.l, cX, operators, Val(throw_errors)) - !throw_errors && !complete && return nothing, false - right, complete = _eval_tree_array_generic(tree.r, cX, operators, Val(throw_errors)) - !throw_errors && !complete && return nothing, false - !throw_errors && - !hasmethod( - op, - N == 1 ? Tuple{typeof(left),typeof(right)} : Tuple{eltype(left),eltype(right)}, - ) && - return nothing, false - if N == 1 - return op(left, right), true - else - return op.(left, right), true +) where {T1,T2,N,degree,throw_errors,OPS} + nops = length(OPS.types[degree].types) + quote + cs = get_children(tree, Val($degree)) + Base.Cartesian.@nexprs( + $degree, + i -> begin + cumulator_i = + let (x, complete) = _eval_tree_array_generic( + cs[i], cX, operators, Val(throw_errors) + ) + !throw_errors && !complete && return nothing, false + x + end + end + ) + cumulators = Base.Cartesian.@ntuple($degree, i -> cumulator_i) + Base.Cartesian.@nif( + $nops, + i -> i == op_idx, + i -> degn_eval_generic( + cumulators, operators[$degree][i], Val(N), Val(throw_errors) + ) + ) end end From 7b04a1530e97d112500557dab0dc86de19d7619f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 11 May 2025 20:50:16 +0100 Subject: [PATCH 38/74] refactor: remove reference stuff from read only nodes --- src/ReadOnlyNode.jl | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/src/ReadOnlyNode.jl b/src/ReadOnlyNode.jl index 8b45612a..6d4999f2 100644 --- a/src/ReadOnlyNode.jl +++ b/src/ReadOnlyNode.jl @@ -2,38 +2,24 @@ module ReadOnlyNodeModule using DispatchDoctor: @unstable -using ..NodeModule: AbstractExpressionNode, Node +using ..NodeModule: AbstractExpressionNode, Node, max_degree import ..NodeModule: default_allocator, with_type_parameters, constructorof, get_children -abstract type AbstractReadOnlyNode{T,D,N<:AbstractExpressionNode{T,D},IS_REF} <: +abstract type AbstractReadOnlyNode{T,D,N<:AbstractExpressionNode{T,D}} <: AbstractExpressionNode{T,D} end """A type of expression node that prevents writing to the inner node""" -struct ReadOnlyNode{T,D,N,IS_REF} <: AbstractReadOnlyNode{T,D,N,IS_REF} +struct ReadOnlyNode{T,D,N} <: AbstractReadOnlyNode{T,D,N} _inner::N - function ReadOnlyNode( - n::N, ::Val{IS_REF} - ) where {T,D,N<:AbstractExpressionNode{T,D},IS_REF} - return new{T,D,N,IS_REF}(n) - end - function ReadOnlyNode(n::N) where {T,D,N<:AbstractExpressionNode{T,D}} - return ReadOnlyNode(n, Val(false)) - end - function ReadOnlyNode(n::AbstractReadOnlyNode) - return n - end - function ReadOnlyNode(n::Ref{<:AbstractExpressionNode}) - return ReadOnlyNode(n[], Val(true)) - end + ReadOnlyNode(n::N) where {T,N<:AbstractExpressionNode{T}} = new{T,max_degree(N),N}(n) end @inline inner(n::AbstractReadOnlyNode) = getfield(n, :_inner) @unstable constructorof(::Type{<:ReadOnlyNode}) = ReadOnlyNode -Base.getindex(n::AbstractReadOnlyNode{T,D,N,true} where {T,D,N}) = n @inline function Base.getproperty(n::AbstractReadOnlyNode, s::Symbol) out = getproperty(inner(n), s) - if out isa Union{AbstractExpressionNode,Ref{<:AbstractExpressionNode}} - return ReadOnlyNode(out) + if out isa AbstractExpressionNode + return constructorof(typeof(n))(out) else return out end @@ -44,7 +30,7 @@ end function Base.setproperty!(::AbstractReadOnlyNode, ::Symbol, v) return error("Cannot set properties on a ReadOnlyNode") end -Base.propertynames(n::AbstractReadOnlyNode) = propertynames(getfield(n, :_inner)) -Base.copy(n::AbstractReadOnlyNode) = ReadOnlyNode(copy(getfield(n, :_inner))) +Base.propertynames(n::AbstractReadOnlyNode) = propertynames(inner(n)) +Base.copy(n::AbstractReadOnlyNode) = ReadOnlyNode(copy(inner(n))) end From 83902e3d847ae01db66f9a9a28a32a1bdcf34e2e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 11 May 2025 23:39:56 +0100 Subject: [PATCH 39/74] feat: make diff compatibility with n-arity --- ext/DynamicExpressionsZygoteExt.jl | 25 +---- src/EvaluateDerivative.jl | 157 ++++++++++++--------------- src/ExtensionInterface.jl | 11 +- src/Node.jl | 11 +- test/test_zygote_gradient_wrapper.jl | 21 ---- 5 files changed, 80 insertions(+), 145 deletions(-) diff --git a/ext/DynamicExpressionsZygoteExt.jl b/ext/DynamicExpressionsZygoteExt.jl index 5654c27e..37d86926 100644 --- a/ext/DynamicExpressionsZygoteExt.jl +++ b/ext/DynamicExpressionsZygoteExt.jl @@ -3,30 +3,17 @@ module DynamicExpressionsZygoteExt using Zygote: gradient import DynamicExpressions.ExtensionInterfaceModule: _zygote_gradient, ZygoteGradient -function _zygote_gradient(op::F, ::Val{1}) where {F} - return ZygoteGradient{F,1,1}(op) -end -function _zygote_gradient(op::F, ::Val{2}, ::Val{side}=Val(nothing)) where {F,side} - # side should be either nothing (for both), 1, or 2 - @assert side === nothing || side in (1, 2) - return ZygoteGradient{F,2,side}(op) +function _zygote_gradient(op::F, ::Val{degree}) where {F,degree} + return ZygoteGradient{F,degree}(op) end -function (g::ZygoteGradient{F,1,1})(x) where {F} +function (g::ZygoteGradient{F,1})(x) where {F} out = only(gradient(g.op, x)) return out === nothing ? zero(x) : out end -function (g::ZygoteGradient{F,2,nothing})(x, y) where {F} - (∂x, ∂y) = gradient(g.op, x, y) - return (∂x === nothing ? zero(x) : ∂x, ∂y === nothing ? zero(y) : ∂y) -end -function (g::ZygoteGradient{F,2,1})(x, y) where {F} - ∂x = only(gradient(Base.Fix2(g.op, y), x)) - return ∂x === nothing ? zero(x) : ∂x -end -function (g::ZygoteGradient{F,2,2})(x, y) where {F} - ∂y = only(gradient(Base.Fix1(g.op, x), y)) - return ∂y === nothing ? zero(y) : ∂y +function (g::ZygoteGradient{F,degree})(args::Vararg{Any,degree}) where {F,degree} + partials = gradient(g.op, args...) + return ntuple(i -> @something(partials[i], zero(args[i])), Val(degree)) end end diff --git a/src/EvaluateDerivative.jl b/src/EvaluateDerivative.jl index 1c4070f0..2ef1dadb 100644 --- a/src/EvaluateDerivative.jl +++ b/src/EvaluateDerivative.jl @@ -1,6 +1,6 @@ module EvaluateDerivativeModule -import ..NodeModule: AbstractExpressionNode, constructorof +import ..NodeModule: AbstractExpressionNode, constructorof, get_children import ..OperatorEnumModule: OperatorEnum import ..UtilsModule: fill_similar, ResultOk2 import ..ValueInterfaceModule: is_valid_array @@ -66,54 +66,18 @@ function eval_diff_tree_array( end @generated function _eval_diff_tree_array( - tree::AbstractExpressionNode{T}, + tree::AbstractExpressionNode{T,D}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer, -)::ResultOk2 where {T<:Number} - nuna = get_nuna(operators) - nbin = get_nbin(operators) - deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN - quote - diff_deg1_eval(tree, cX, operators.unaops[op_idx], operators, direction) - end - else - quote - Base.Cartesian.@nif( - $nuna, - i -> i == op_idx, - i -> - diff_deg1_eval(tree, cX, operators.unaops[i], operators, direction) - ) - end - end - deg2_branch = if nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN - quote - diff_deg2_eval(tree, cX, operators.binops[op_idx], operators, direction) - end - else - quote - Base.Cartesian.@nif( - $nbin, - i -> i == op_idx, - i -> - diff_deg2_eval(tree, cX, operators.binops[i], operators, direction) - ) - end - end +)::ResultOk2 where {T<:Number,D} quote - result = if tree.degree == 0 - diff_deg0_eval(tree, cX, direction) - elseif tree.degree == 1 - op_idx = tree.op - $deg1_branch - else - op_idx = tree.op - $deg2_branch - end - !result.ok && return result - return ResultOk2( - result.x, result.dx, is_valid_array(result.x) && is_valid_array(result.dx) + deg = tree.degree + deg == 0 && return diff_deg0_eval(tree, cX, direction) + Base.Cartesian.@nif( + $D, + i -> i == deg, + i -> dispatch_diff_degn_eval(tree, cX, Val(i), operators, direction) ) end end @@ -130,58 +94,71 @@ function diff_deg0_eval( return ResultOk2(const_part, derivative_part, true) end -function diff_deg1_eval( - tree::AbstractExpressionNode{T}, - cX::AbstractMatrix{T}, - op::F, - operators::OperatorEnum, - direction::Integer, -) where {T<:Number,F} - result = _eval_diff_tree_array(tree.l, cX, operators, direction) - !result.ok && return result - - # TODO - add type assertions to get better speed: - cumulator = result.x - dcumulator = result.dx - diff_op = _zygote_gradient(op, Val(1)) - @inbounds @simd for j in eachindex(cumulator) - x = op(cumulator[j])::T - dx = diff_op(cumulator[j])::T * dcumulator[j] - - cumulator[j] = x - dcumulator[j] = dx +@generated function diff_degn_eval( + x_cumulators::NTuple{N}, dx_cumulators::NTuple{N}, op::F, direction::Integer +) where {N,F} + quote + Base.Cartesian.@nexprs($N, i -> begin + x_cumulator_i = x_cumulators[i] + dx_cumulator_i = dx_cumulators[i] + end) + diff_op = _zygote_gradient(op, Val(N)) + @inbounds @simd for j in eachindex(x_cumulator_1) + x = Base.Cartesian.@ncall($N, op, i -> x_cumulator_i[j]) + Base.Cartesian.@ntuple($N, i -> grad_i) = Base.Cartesian.@ncall( + $N, diff_op, i -> x_cumulator_i[j] + ) + dx = Base.Cartesian.@ncall($N, +, i -> grad_i * dx_cumulator_i[j]) + x_cumulator_1[j] = x + dx_cumulator_1[j] = dx + end + return ResultOk2(x_cumulator_1, dx_cumulator_1, true) end - return result end -function diff_deg2_eval( - tree::AbstractExpressionNode{T}, +@generated function dispatch_diff_degn_eval( + tree::AbstractExpressionNode{T,D}, cX::AbstractMatrix{T}, - op::F, - operators::OperatorEnum, + ::Val{degree}, + operators::OperatorEnum{OPS}, direction::Integer, -) where {T<:Number,F} - result_l = _eval_diff_tree_array(tree.l, cX, operators, direction) - !result_l.ok && return result_l - result_r = _eval_diff_tree_array(tree.r, cX, operators, direction) - !result_r.ok && return result_r - - ar_l = result_l.x - d_ar_l = result_l.dx - ar_r = result_r.x - d_ar_r = result_r.dx - diff_op = _zygote_gradient(op, Val(2)) - - @inbounds @simd for j in eachindex(ar_l) - x = op(ar_l[j], ar_r[j])::T - - first, second = diff_op(ar_l[j], ar_r[j])::Tuple{T,T} - dx = first * d_ar_l[j] + second * d_ar_r[j] +) where {T<:Number,D,degree,OPS} + nops = length(OPS.types[degree].types) + + setup = quote + cs = get_children(tree, Val($degree)) + Base.Cartesian.@nexprs( + $degree, + i -> begin + result_i = _eval_diff_tree_array(cs[i], cX, operators, direction) + !result_i.ok && return result_i + end + ) + x_cumulators = Base.Cartesian.@ntuple($degree, i -> result_i.x) + dx_cumulators = Base.Cartesian.@ntuple($degree, i -> result_i.dx) + op_idx = tree.op + end - ar_l[j] = x - d_ar_l[j] = dx + if nops > OPERATOR_LIMIT_BEFORE_SLOWDOWN + quote + $setup + diff_degn_eval( + x_cumulators, dx_cumulators, operators[$degree][op_idx], direction + ) + end + else + quote + $setup + Base.Cartesian.@nif( + $nops, + i -> i == op_idx, + i -> diff_degn_eval( + x_cumulators, dx_cumulators, operators[$degree][i], direction + ) + ) + end end - return result_l + # TODO: Need to add the case for many operators end """ diff --git a/src/ExtensionInterface.jl b/src/ExtensionInterface.jl index 1628683d..6eed12b0 100644 --- a/src/ExtensionInterface.jl +++ b/src/ExtensionInterface.jl @@ -7,19 +7,12 @@ function symbolic_to_node(args...; kws...) return error("Please load the `SymbolicUtils` package to use `symbolic_to_node`.") end -struct ZygoteGradient{F,degree,arg} <: Function +struct ZygoteGradient{F,degree} <: Function op::F end -function Base.show(io::IO, g::ZygoteGradient{F,degree,arg}) where {F,degree,arg} +function Base.show(io::IO, g::ZygoteGradient{F,degree}) where {F,degree} print(io, "∂") - if degree == 2 - if arg == 1 - print(io, "₁") - elseif arg == 2 - print(io, "₂") - end - end print(io, g.op) return nothing end diff --git a/src/Node.jl b/src/Node.jl index 99dd0d31..fa9b7d81 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -6,15 +6,15 @@ import ..OperatorEnumModule: AbstractOperatorEnum import ..UtilsModule: deprecate_varmap, Undefined const DEFAULT_NODE_TYPE = Float32 +const DEFAULT_MAX_DEGREE = 2 """ AbstractNode{D} Abstract type for D-arity trees. Must have the following fields: -- `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1, - then `l` needs to be defined as the left child. If 2, - then `r` also needs to be defined as the right child. +- `degree::UInt8`: Degree of the node. This should be a value + between 0 and `DEFAULT_MAX_DEGREE`. - `children`: A collection of D references to children nodes. # Deprecated fields @@ -25,7 +25,7 @@ Abstract type for D-arity trees. Must have the following fields: Don't use `nothing` to represent an undefined value as it will incur a large performance penalty. - `r::AbstractNode{D}`: Right child of the current node. Should only - be defined if `degree == 2`. + be defined if `degree >= 2`. """ abstract type AbstractNode{D} end @@ -82,7 +82,7 @@ for N in (:Node, :GraphNode) ## Constructors: ################# $N{_T,_D}() where {_T,_D} = new{_T,_D::Int}() - $N{_T}() where {_T} = $N{_T,2}() + $N{_T}() where {_T} = $N{_T,DEFAULT_MAX_DEGREE}() # TODO: Test with this disabled to spot any unintended uses end end @@ -250,7 +250,6 @@ end Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T Base.eltype(::AbstractExpressionNode{T}) where {T} = T -const DEFAULT_MAX_DEGREE = 2 max_degree(::Type{<:AbstractNode}) = DEFAULT_MAX_DEGREE max_degree(::Type{<:AbstractNode{D}}) where {D} = D max_degree(node::AbstractNode) = max_degree(typeof(node)) diff --git a/test/test_zygote_gradient_wrapper.jl b/test/test_zygote_gradient_wrapper.jl index 7eed34bb..38f899aa 100644 --- a/test/test_zygote_gradient_wrapper.jl +++ b/test/test_zygote_gradient_wrapper.jl @@ -11,23 +11,8 @@ g(x, y) = x * y @test repr(_zygote_gradient(g, Val(2))) == "∂g" - # Test binary gradient (first partial) - @test repr(_zygote_gradient(g, Val(2), Val(1))) == "∂₁g" - - # Test binary gradient (second partial) - @test repr(_zygote_gradient(g, Val(2), Val(2))) == "∂₂g" - # Test with standard operators @test repr(_zygote_gradient(+, Val(2))) == "∂+" - @test repr(_zygote_gradient(*, Val(2), Val(1))) == "∂₁*" - @test repr(_zygote_gradient(*, Val(2), Val(2))) == "∂₂*" - - first_partial = _zygote_gradient(log, Val(2), Val(1)) - nested = _zygote_gradient(first_partial, Val(1)) - @test repr(nested) == "∂∂₁log" - - # Also should work with text/plain - @test repr("text/plain", nested) == "∂∂₁log" end @testitem "ZygoteGradient evaluation" begin @@ -45,10 +30,4 @@ end # Test binary gradient (both partials) g(x, y) = x * y @test (_zygote_gradient(g, Val(2)))(x, y) == (3.0, 2.0) - - # Test binary gradient (first partial) - @test (_zygote_gradient(g, Val(2), Val(1)))(x, y) == 3.0 - - # Test second partial - @test (_zygote_gradient(g, Val(2), Val(2)))(x, y) == 2.0 end From 3845539f50d6dca2ca837c1562eaba609b85810a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 11 May 2025 23:42:45 +0100 Subject: [PATCH 40/74] refactor: simplify eval code --- src/Evaluate.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 7a48f261..6c93141b 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -288,8 +288,7 @@ function _eval_tree_array( op_idx = tree.op return dispatch_deg2_eval(tree, cX, op_idx, operators, eval_options) else - op_idx = tree.op - return dispatch_degn_eval(tree, cX, op_idx, operators, eval_options) + return dispatch_degn_eval(tree, cX, operators, eval_options) end end @@ -336,7 +335,6 @@ end @generated function inner_dispatch_degn_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, - op_idx::Integer, ::Val{degree}, operators::OperatorEnum{OPS}, eval_options::EvalOptions, @@ -352,6 +350,7 @@ end @return_on_nonfinite_array(eval_options, result_i.x) end ) + op_idx = tree.op cumulators = Base.Cartesian.@ntuple($degree, i -> result_i.x) Base.Cartesian.@nif( $nops, @@ -363,7 +362,6 @@ end @generated function dispatch_degn_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, - op_idx::Integer, operators::OperatorEnum, eval_options::EvalOptions, ) where {T} @@ -374,8 +372,7 @@ end return Base.Cartesian.@nif( $D, d -> d == degree, - d -> - inner_dispatch_degn_eval(tree, cX, op_idx, Val(d), operators, eval_options) + d -> inner_dispatch_degn_eval(tree, cX, Val(d), operators, eval_options) ) end end From 5f977abf649e44a8844161332ae22669ff999096 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 12 May 2025 00:05:44 +0100 Subject: [PATCH 41/74] feat: make grad compatible with n-arity --- ext/DynamicExpressionsZygoteExt.jl | 5 +- src/EvaluateDerivative.jl | 189 +++++++++++---------------- test/test_zygote_gradient_wrapper.jl | 2 +- 3 files changed, 77 insertions(+), 119 deletions(-) diff --git a/ext/DynamicExpressionsZygoteExt.jl b/ext/DynamicExpressionsZygoteExt.jl index 37d86926..f42a3f89 100644 --- a/ext/DynamicExpressionsZygoteExt.jl +++ b/ext/DynamicExpressionsZygoteExt.jl @@ -7,10 +7,7 @@ function _zygote_gradient(op::F, ::Val{degree}) where {F,degree} return ZygoteGradient{F,degree}(op) end -function (g::ZygoteGradient{F,1})(x) where {F} - out = only(gradient(g.op, x)) - return out === nothing ? zero(x) : out -end +# All this does is remove `nothing`, so that we get type stability function (g::ZygoteGradient{F,degree})(args::Vararg{Any,degree}) where {F,degree} partials = gradient(g.op, args...) return ntuple(i -> @something(partials[i], zero(args[i])), Val(degree)) diff --git a/src/EvaluateDerivative.jl b/src/EvaluateDerivative.jl index 2ef1dadb..91470946 100644 --- a/src/EvaluateDerivative.jl +++ b/src/EvaluateDerivative.jl @@ -254,55 +254,97 @@ function eval_grad_tree_array( end @generated function _eval_grad_tree_array( - tree::AbstractExpressionNode{T}, + tree::AbstractExpressionNode{T,D}, n_gradients, - index_tree::Union{NodeIndex,Nothing}, + index_tree::Union{NodeIndex{<:Any,D},Nothing}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{mode}, -)::ResultOk2 where {T<:Number,mode} - nuna = get_nuna(operators) - nbin = get_nbin(operators) - deg1_branch_skeleton = quote - grad_deg1_eval( - tree, n_gradients, index_tree, cX, operators.unaops[i], operators, Val(mode) +)::ResultOk2 where {T<:Number,D,mode} + quote + deg = tree.degree + deg == 0 && return grad_deg0_eval(tree, n_gradients, index_tree, cX, Val(mode)) + Base.Cartesian.@nif( + $D, + i -> i == deg, + i -> dispatch_grad_degn_eval( + tree, n_gradients, index_tree, cX, Val(i), operators, Val(mode) + ) ) end - deg2_branch_skeleton = quote - grad_deg2_eval( - tree, n_gradients, index_tree, cX, operators.binops[i], operators, Val(mode) +end + +@generated function dispatch_grad_degn_eval( + tree::AbstractExpressionNode{T}, + n_gradients, + index_tree::Union{NodeIndex,Nothing}, + cX::AbstractMatrix{T}, + ::Val{degree}, + operators::OperatorEnum{OPS}, + ::Val{mode}, +) where {T<:Number,degree,OPS,mode} + setup = quote + cs = get_children(tree, Val($degree)) + index_cs = + isnothing(index_tree) ? index_tree : get_children(index_tree, Val($degree)) + Base.Cartesian.@nexprs( + $degree, + i -> begin + result_i = eval_grad_tree_array( + cs[i], + n_gradients, + isnothing(index_cs) ? index_cs : index_cs[i], + cX, + operators, + Val(mode), + ) + !result_i.ok && return result_i + end ) + x_cumulators = Base.Cartesian.@ntuple($degree, i -> result_i.x) + d_cumulators = Base.Cartesian.@ntuple($degree, i -> result_i.dx) + op_idx = tree.op end - deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN - quote - i = tree.op - $deg1_branch_skeleton - end - else - quote - op_idx = tree.op - Base.Cartesian.@nif($nuna, i -> i == op_idx, i -> $deg1_branch_skeleton) - end - end - deg2_branch = if nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN + nops = length(OPS.types[degree].types) + if nops > OPERATOR_LIMIT_BEFORE_SLOWDOWN quote - i = tree.op - $deg2_branch_skeleton + $setup + grad_degn_eval(x_cumulators, d_cumulators, operators[$degree][op_idx]) end else quote - op_idx = tree.op - Base.Cartesian.@nif($nbin, i -> i == op_idx, i -> $deg2_branch_skeleton) + $setup + Base.Cartesian.@nif( + $nops, + i -> i == op_idx, + i -> grad_degn_eval(x_cumulators, d_cumulators, operators[$degree][i]) + ) end end +end + +@generated function grad_degn_eval( + x_cumulators::NTuple{N}, d_cumulators::NTuple{N}, op::F +) where {N,F} quote - if tree.degree == 0 - grad_deg0_eval(tree, n_gradients, index_tree, cX, Val(mode)) - elseif tree.degree == 1 - $deg1_branch - else - $deg2_branch + Base.Cartesian.@nexprs($N, i -> begin + x_cumulator_i = x_cumulators[i] + d_cumulator_i = d_cumulators[i] + end) + diff_op = _zygote_gradient(op, Val($N)) + @inbounds @simd for j in eachindex(x_cumulator_1) + x = Base.Cartesian.@ncall($N, op, i -> x_cumulator_i[j]) + Base.Cartesian.@ntuple($N, i -> grad_i) = Base.Cartesian.@ncall( + $N, diff_op, i -> x_cumulator_i[j] + ) + x_cumulator_1[j] = x + for k in axes(d_cumulator_1, 1) + d_cumulator_1[k, j] = Base.Cartesian.@ncall( + $N, +, i -> grad_i * d_cumulator_i[k, j] + ) + end end + return ResultOk2(x_cumulator_1, d_cumulator_1, true) end end @@ -344,85 +386,4 @@ function grad_deg0_eval( return ResultOk2(const_part, derivative_part, true) end -function grad_deg1_eval( - tree::AbstractExpressionNode{T}, - n_gradients, - index_tree::Union{NodeIndex,Nothing}, - cX::AbstractMatrix{T}, - op::F, - operators::OperatorEnum, - ::Val{mode}, -)::ResultOk2 where {T<:Number,F,mode} - result = eval_grad_tree_array( - tree.l, - n_gradients, - index_tree === nothing ? index_tree : index_tree.l, - cX, - operators, - Val(mode), - ) - !result.ok && return result - - cumulator = result.x - dcumulator = result.dx - diff_op = _zygote_gradient(op, Val(1)) - @inbounds @simd for j in axes(dcumulator, 2) - x = op(cumulator[j])::T - dx = diff_op(cumulator[j])::T - - cumulator[j] = x - for k in axes(dcumulator, 1) - dcumulator[k, j] = dx * dcumulator[k, j] - end - end - return result -end - -function grad_deg2_eval( - tree::AbstractExpressionNode{T}, - n_gradients, - index_tree::Union{NodeIndex,Nothing}, - cX::AbstractMatrix{T}, - op::F, - operators::OperatorEnum, - ::Val{mode}, -)::ResultOk2 where {T<:Number,F,mode} - result_l = eval_grad_tree_array( - tree.l, - n_gradients, - index_tree === nothing ? index_tree : index_tree.l, - cX, - operators, - Val(mode), - ) - !result_l.ok && return result_l - result_r = eval_grad_tree_array( - tree.r, - n_gradients, - index_tree === nothing ? index_tree : index_tree.r, - cX, - operators, - Val(mode), - ) - !result_r.ok && return result_r - - cumulator_l = result_l.x - dcumulator_l = result_l.dx - cumulator_r = result_r.x - dcumulator_r = result_r.dx - diff_op = _zygote_gradient(op, Val(2)) - @inbounds @simd for j in axes(dcumulator_l, 2) - c1 = cumulator_l[j] - c2 = cumulator_r[j] - x = op(c1, c2)::T - dx1, dx2 = diff_op(c1, c2)::Tuple{T,T} - cumulator_l[j] = x - for k in axes(dcumulator_l, 1) - dcumulator_l[k, j] = dx1 * dcumulator_l[k, j] + dx2 * dcumulator_r[k, j] - end - end - - return result_l -end - end diff --git a/test/test_zygote_gradient_wrapper.jl b/test/test_zygote_gradient_wrapper.jl index 38f899aa..44db43b2 100644 --- a/test/test_zygote_gradient_wrapper.jl +++ b/test/test_zygote_gradient_wrapper.jl @@ -25,7 +25,7 @@ end # Test unary gradient f(x) = x^2 - @test (_zygote_gradient(f, Val(1)))(x) == 4.0 + @test (_zygote_gradient(f, Val(1)))(x) == (4.0,) # Test binary gradient (both partials) g(x, y) = x * y From 7b51c060f5b3b6e87c46287a64d1ac9fb079f183 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 12 May 2025 00:22:37 +0100 Subject: [PATCH 42/74] feat: n-arity compat with simplification --- src/Evaluate.jl | 3 +++ src/Simplify.jl | 21 +++++++++------------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 6c93141b..e6ba96b3 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -332,6 +332,7 @@ function deg0_eval( end end +# This basically forms an if statement over the operators for the degree. @generated function inner_dispatch_degn_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, @@ -359,6 +360,8 @@ end ) end end + +# This forms an if statement over the degree of a given node. @generated function dispatch_degn_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, diff --git a/src/Simplify.jl b/src/Simplify.jl index cf2592a2..ef17b33f 100644 --- a/src/Simplify.jl +++ b/src/Simplify.jl @@ -5,8 +5,7 @@ import ..NodeUtilsModule: tree_mapreduce, is_node_constant import ..OperatorEnumModule: AbstractOperatorEnum import ..ValueInterfaceModule: is_valid -_una_op_kernel(f::F, l::T) where {F,T} = f(l) -_bin_op_kernel(f::F, l::T, r::T) where {F,T} = f(l, r) +_op_kernel(f::F, l::T, ls::T...) where {F,T} = f(l, ls...) is_commutative(::typeof(*)) = true is_commutative(::typeof(+)) = true @@ -17,8 +16,8 @@ is_subtraction(_) = false combine_operators(tree::AbstractExpressionNode, ::AbstractOperatorEnum) = tree # This is only defined for `Node` as it is not possible for, e.g., -# `GraphNode`. -function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T} +# `GraphNode`, and n-arity nodes. +function combine_operators(tree::Node{T,2}, operators::AbstractOperatorEnum) where {T} # NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before. # ((const + var) + const) => (const + var) # ((const * var) * const) => (const * var) @@ -51,10 +50,10 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where if below.degree == 2 && below.op == op if is_node_constant(below.l) tree = below - tree.l.val = _bin_op_kernel(operators.binops[op], tree.l.val, topconstant) + tree.l.val = _op_kernel(operators.binops[op], tree.l.val, topconstant) elseif is_node_constant(below.r) tree = below - tree.r.val = _bin_op_kernel(operators.binops[op], tree.r.val, topconstant) + tree.r.val = _op_kernel(operators.binops[op], tree.r.val, topconstant) end end end @@ -106,15 +105,13 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where return tree end -function combine_children!(operators, p::N, c::N...) where {T,N<:AbstractExpressionNode{T}} +function combine_children!( + operators, p::N, c::Vararg{N,degree} +) where {T,N<:AbstractExpressionNode{T},degree} all(is_node_constant, c) || return p vals = map(n -> n.val, c) all(is_valid, vals) || return p - out = if length(c) == 1 - _una_op_kernel(operators.unaops[p.op], vals...) - else - _bin_op_kernel(operators.binops[p.op], vals...) - end + out = _op_kernel(operators[degree][p.op], vals...) is_valid(out) || return p new_node = constructorof(N)(T; val=convert(T, out)) set_node!(p, new_node) From 4f31a85519ce9287c5bd9509b23f29eb74a8f67b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 12 May 2025 11:35:06 +0100 Subject: [PATCH 43/74] feat: n-arity compat with bumper --- ext/DynamicExpressionsBumperExt.jl | 81 +++++++------------ ext/DynamicExpressionsLoopVectorizationExt.jl | 21 ++--- src/ExtensionInterface.jl | 3 +- 3 files changed, 39 insertions(+), 66 deletions(-) diff --git a/ext/DynamicExpressionsBumperExt.jl b/ext/DynamicExpressionsBumperExt.jl index 6e99927b..0b745a85 100644 --- a/ext/DynamicExpressionsBumperExt.jl +++ b/ext/DynamicExpressionsBumperExt.jl @@ -5,8 +5,7 @@ using DynamicExpressions: OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array, EvalOptions using DynamicExpressions.UtilsModule: ResultOk, counttuple -import DynamicExpressions.ExtensionInterfaceModule: - bumper_eval_tree_array, bumper_kern1!, bumper_kern2! +import DynamicExpressions.ExtensionInterfaceModule: bumper_eval_tree_array, bumper_kern! function bumper_eval_tree_array( tree::AbstractExpressionNode{T}, @@ -37,8 +36,7 @@ function bumper_eval_tree_array( branch_node -> branch_node, # In the evaluation kernel, we combine the branch nodes # with the arrays created by the leaf nodes: - ((args::Vararg{Any,M}) where {M}) -> - dispatch_kerns!(operators, args..., eval_options), + KernelDispatcher(operators, eval_options), tree; break_sharing=Val(true), ) @@ -49,63 +47,44 @@ function bumper_eval_tree_array( return (result, all_ok[]) end -function dispatch_kerns!( - operators, branch_node, cumulator, eval_options::EvalOptions{<:Any,true,early_exit} -) where {early_exit} - cumulator.ok || return cumulator - - out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, eval_options) - return ResultOk(out, early_exit ? is_valid_array(out) : true) -end -function dispatch_kerns!( - operators, - branch_node, - cumulator1, - cumulator2, - eval_options::EvalOptions{<:Any,true,early_exit}, -) where {early_exit} - cumulator1.ok || return cumulator1 - cumulator2.ok || return cumulator2 - - out = dispatch_kern2!( - operators.binops, branch_node.op, cumulator1.x, cumulator2.x, eval_options - ) - return ResultOk(out, early_exit ? is_valid_array(out) : true) +struct KernelDispatcher{O<:OperatorEnum,E<:EvalOptions{<:Any,true,<:Any}} <: Function + operators::O + eval_options::E end -@generated function dispatch_kern1!(unaops, op_idx, cumulator, eval_options::EvalOptions) - nuna = counttuple(unaops) +@generated function (kd::KernelDispatcher{<:Any,<:EvalOptions{<:Any,true,early_exit}})( + branch_node, inputs::Vararg{Any,degree} +) where {degree,early_exit} quote - Base.@nif( - $nuna, - i -> i == op_idx, - i -> let op = unaops[i] - return bumper_kern1!(op, cumulator, eval_options) - end, - ) + Base.Cartesian.@nexprs($degree, i -> inputs[i].ok || return inputs[i]) + cumulators = Base.Cartesian.@ntuple($degree, i -> inputs[i].x) + out = dispatch_kerns!(kd.operators, branch_node, cumulators, kd.eval_options) + return ResultOk(out, early_exit ? is_valid_array(out) : true) end end -@generated function dispatch_kern2!( - binops, op_idx, cumulator1, cumulator2, eval_options::EvalOptions -) - nbin = counttuple(binops) +@generated function dispatch_kerns!( + operators::OperatorEnum{OPS}, + branch_node, + cumulators::Tuple{Vararg{Any,degree}}, + eval_options::EvalOptions, +) where {OPS,degree} + nops = length(OPS.types[degree].types) quote - Base.@nif( - $nbin, + op_idx = branch_node.op + Base.Cartesian.@nif( + $nops, i -> i == op_idx, - i -> let op = binops[i] - return bumper_kern2!(op, cumulator1, cumulator2, eval_options) - end, + i -> bumper_kern!(operators[$degree][i], cumulators, eval_options) ) end end -function bumper_kern1!(op::F, cumulator, ::EvalOptions{false,true}) where {F} - @. cumulator = op(cumulator) - return cumulator -end -function bumper_kern2!(op::F, cumulator1, cumulator2, ::EvalOptions{false,true}) where {F} - @. cumulator1 = op(cumulator1, cumulator2) - return cumulator1 + +function bumper_kern!( + op::F, cumulators::Tuple{Vararg{Any,degree}}, ::EvalOptions{false,true,early_exit} +) where {F,degree,early_exit} + cumulator_1 = first(cumulators) + @. cumulator_1 = op(cumulators...) + return cumulator_1 end end diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index f65fcf4b..4d2f7c33 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -15,7 +15,7 @@ import DynamicExpressions.EvaluateModule: deg2_l0_eval, deg2_r0_eval import DynamicExpressions.ExtensionInterfaceModule: - _is_loopvectorization_loaded, bumper_kern1!, bumper_kern2! + _is_loopvectorization_loaded, bumper_kern! _is_loopvectorization_loaded(::Int) = true @@ -208,18 +208,13 @@ function deg2_r0_eval( end end -## Interface with Bumper.jl -function bumper_kern1!( - op::F, cumulator, ::EvalOptions{true,true,early_exit} -) where {F,early_exit} - @turbo @. cumulator = op(cumulator) - return cumulator -end -function bumper_kern2!( - op::F, cumulator1, cumulator2, ::EvalOptions{true,true,early_exit} -) where {F,early_exit} - @turbo @. cumulator1 = op(cumulator1, cumulator2) - return cumulator1 +# Interface with Bumper.jl +function bumper_kern!( + op::F, cumulators::Tuple{Vararg{Any,degree}}, ::EvalOptions{true,true,early_exit} +) where {F,degree,early_exit} + cumulator_1 = first(cumulators) + @turbo @. cumulator_1 = op(cumulators...) + return cumulator_1 end end diff --git a/src/ExtensionInterface.jl b/src/ExtensionInterface.jl index 6eed12b0..b62a5a5f 100644 --- a/src/ExtensionInterface.jl +++ b/src/ExtensionInterface.jl @@ -25,8 +25,7 @@ end function bumper_eval_tree_array(args...) return error("Please load the Bumper.jl package to use this feature.") end -function bumper_kern1! end -function bumper_kern2! end +function bumper_kern! end _is_loopvectorization_loaded(_) = false From f9d21c6bdd8c066d2fa6d51123a4acfd218b953a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 12 May 2025 11:43:06 +0100 Subject: [PATCH 44/74] refactor: avoid `NTuple` typing --- src/Node.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Node.jl b/src/Node.jl index fa9b7d81..61cf9189 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -190,7 +190,7 @@ end set_children!(n, Base.setindex(get_children(n), child, i)) return child end -@inline function set_children!(n::AbstractNode{D}, children::NTuple{D2,AbstractNode{D}}) where {D,D2} +@inline function set_children!(n::AbstractNode{D}, children::Tuple{Vararg{AbstractNode{D},D2}}) where {D,D2} if D === D2 n.children = children else From 69cd3727cd1a08b2eb1595fd57706b9a820e56f6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 12 May 2025 11:54:18 +0100 Subject: [PATCH 45/74] docs: tweak docstring --- src/ExtensionInterface.jl | 2 +- src/Node.jl | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/ExtensionInterface.jl b/src/ExtensionInterface.jl index b62a5a5f..ca42430f 100644 --- a/src/ExtensionInterface.jl +++ b/src/ExtensionInterface.jl @@ -27,6 +27,6 @@ function bumper_eval_tree_array(args...) end function bumper_kern! end -_is_loopvectorization_loaded(_) = false +_is_loopvectorization_loaded(_) = false # COV_EXCL_LINE end diff --git a/src/Node.jl b/src/Node.jl index 61cf9189..08bca8df 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -11,11 +11,14 @@ const DEFAULT_MAX_DEGREE = 2 """ AbstractNode{D} -Abstract type for D-arity trees. Must have the following fields: +Abstract type for trees which can have up to `D` children per node. +Must have the following fields: - `degree::UInt8`: Degree of the node. This should be a value - between 0 and `DEFAULT_MAX_DEGREE`. -- `children`: A collection of D references to children nodes. + between 0 and `D`, inclusive. +- `children`: A collection of up to `D` children nodes. The number + of children which are _active_ is given by the `degree` field, + but for type stability reasons, you can have inactive children. # Deprecated fields @@ -250,12 +253,14 @@ end Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T Base.eltype(::AbstractExpressionNode{T}) where {T} = T +# COV_EXCL_START max_degree(::Type{<:AbstractNode}) = DEFAULT_MAX_DEGREE max_degree(::Type{<:AbstractNode{D}}) where {D} = D max_degree(node::AbstractNode) = max_degree(typeof(node)) has_max_degree(::Type{<:AbstractNode}) = false has_max_degree(::Type{<:AbstractNode{D}}) where {D} = true +# COV_EXCL_STOP @unstable function constructorof(::Type{N}) where {N<:Node} return Node{T,max_degree(N)} where {T} @@ -358,8 +363,8 @@ end eltype(N) end end -defines_eltype(::Type{<:AbstractExpressionNode}) = false -defines_eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = true +defines_eltype(::Type{<:AbstractExpressionNode}) = false # COV_EXCL_LINE +defines_eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = true # COV_EXCL_LINE #! format: on function (::Type{N})( From 422b5c51035efb7d182191db60f2495f9987d4fd Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 12 May 2025 15:10:42 +0100 Subject: [PATCH 46/74] feat: parsing for D-degree nodes --- src/Parse.jl | 66 ++++++++++++++++++++++++---------------------------- src/base.jl | 2 -- 2 files changed, 30 insertions(+), 38 deletions(-) diff --git a/src/Parse.jl b/src/Parse.jl index 10d121db..255afe57 100644 --- a/src/Parse.jl +++ b/src/Parse.jl @@ -282,50 +282,46 @@ end evaluate_on::Union{Nothing,AbstractVector}; kws..., )::N where {F<:Function,N<:AbstractExpressionNode,E<:AbstractExpression} - if length(args) == 2 && func ∈ operators.unaops - # Regular unary operator - op = findfirst(==(func), operators.unaops)::Int + degree = length(args) - 1 + if degree <= length(operators.ops) && func ∈ operators[degree] + op_idx = findfirst(==(func), operators[degree]) return N(; - op=op::Int, - l=_parse_expression( - args[2], operators, variable_names, N, E, evaluate_on; kws... - ), - ) - elseif length(args) == 3 && func ∈ operators.binops - # Regular binary operator - op = findfirst(==(func), operators.binops)::Int - return N(; - op=op::Int, - l=_parse_expression( - args[2], operators, variable_names, N, E, evaluate_on; kws... - ), - r=_parse_expression( - args[3], operators, variable_names, N, E, evaluate_on; kws... + op=op_idx::Int, + children=map( + arg -> _parse_expression( + arg, operators, variable_names, N, E, evaluate_on; kws... + ), + (args[2:end]...,), ), ) - elseif length(args) > 3 && func in (+, -, *) && func ∈ operators.binops - # Either + or - but used with more than two arguments - op = findfirst(==(func), operators.binops)::Int + elseif degree > 2 && func ∈ (+, -, *) && func ∈ operators[2] + op_idx = findfirst(==(func), operators[2])::Int inner = N(; - op=op::Int, - l=_parse_expression( - args[2], operators, variable_names, N, E, evaluate_on; kws... - ), - r=_parse_expression( - args[3], operators, variable_names, N, E, evaluate_on; kws... + op=op_idx::Int, + children=( + _parse_expression( + args[2], operators, variable_names, N, E, evaluate_on; kws... + ), + _parse_expression( + args[3], operators, variable_names, N, E, evaluate_on; kws... + ), ), ) for arg in args[4:end] inner = N(; - op=op::Int, - l=inner, - r=_parse_expression( - arg, operators, variable_names, N, E, evaluate_on; kws... + op=op_idx::Int, + children=( + inner, + _parse_expression( + arg, operators, variable_names, N, E, evaluate_on; kws... + ), ), ) end return inner - elseif evaluate_on !== nothing && func in evaluate_on + end + + if evaluate_on !== nothing && func in evaluate_on # External function func( map( @@ -337,10 +333,8 @@ end ) else matching_s = let - s = if length(args) == 2 - "`" * string(operators.unaops) * "`" - elseif length(args) == 3 - "`" * string(operators.binops) * "`" + s = if degree <= length(operators.ops) + join(('`', operators[degree], '`')) else "" end diff --git a/src/base.jl b/src/base.jl index 9b3dae31..e86ccf59 100644 --- a/src/base.jl +++ b/src/base.jl @@ -177,9 +177,7 @@ By using this instead of tree_mapreduce, we can take advantage of early exits. @generated function any(f::F, tree::AbstractNode{D}) where {F<:Function,D} quote deg = tree.degree - deg == 0 && return @inline(f(tree)) - return ( @inline(f(tree)) || Base.Cartesian.@nif( $D, i -> deg == i, i -> let cs = get_children(tree, Val(i)) From fb248de3151918ed7d38d5a2676bc470e65d335f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 12 May 2025 15:12:13 +0100 Subject: [PATCH 47/74] test: integrate Supposition testing --- test/Project.toml | 1 + test/supposition_utils.jl | 66 ++++++++++++++++++++++++++++ test/test_supposition_consistency.jl | 44 +++++++++++++++++++ test/unittest.jl | 1 + 4 files changed, 112 insertions(+) create mode 100644 test/supposition_utils.jl create mode 100644 test/test_supposition_consistency.jl diff --git a/test/Project.toml b/test/Project.toml index 952e0370..c1e6c23e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,6 +17,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Supposition = "5a0628fe-1738-4658-9b6d-0b7605a9755b" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/supposition_utils.jl b/test/supposition_utils.jl new file mode 100644 index 00000000..fd002301 --- /dev/null +++ b/test/supposition_utils.jl @@ -0,0 +1,66 @@ +# supposition_utils.jl +# +# Helper that builds a Supposition generator returning fully-random +# DynamicExpressions.Expression objects whose node type is Node{T,D}. +# D is inferred from `operators`. + +module SuppositionUtils + +using Supposition: Data +using DynamicExpressions: Node, Expression, OperatorEnum +using DynamicExpressions.OperatorEnumConstructionModule: empty_all_globals! +empty_all_globals!() + +function make_expression_generator( + ::Type{T}; + num_features::Int=5, + operators::OperatorEnum=OperatorEnum(((abs, cos), (+, -, *, /))), + max_layers::Int=3, +) where {T} + D = length(operators.ops) + + val_gen = Data.Floats{T}(; nans=false, infs=false) + val_node_gen = map(v -> Node{T,D}(; val=v), val_gen) + + feature_gen = Data.SampledFrom(1:num_features) + feature_node_gen = map(i -> Node{T,D}(; feature=i), feature_gen) + + leaf_gen = val_node_gen | feature_node_gen + + wrapper_funcs = ntuple( + degree -> let op_list = operators[degree] + op_gen = Data.SampledFrom(1:length(op_list)) + + child -> map( + (op_idx, args...) -> Node{T,D}(; op=op_idx, children=args), + op_gen, + ntuple(_ -> child, degree)..., + ) + end, + Val(D), + ) + expr_wrap(child) = foldl(|, (w(child) for w in wrapper_funcs)) + tree_gen = Data.Recursive(leaf_gen, expr_wrap; max_layers) + return map( + t -> Expression(t; operators, variable_names=["x$i" for i in 1:num_features]), + tree_gen, + ) +end + +# inside module SuppositionUtils +function make_input_matrix_generator( + ::Type{T}=Float64; n_features::Int=5, min_batch::Int=1, max_batch::Int=16 +) where {T} + elem_gen = Data.Floats{T}(; nans=false, infs=false) + batch_gen = Data.Integers(min_batch, max_batch) + + Data.bind(batch_gen) do bs + vec_len = n_features * bs + vec_gen = Data.Vectors(elem_gen; min_size=vec_len, max_size=vec_len) + Data.map(v -> reshape(v, n_features, bs), vec_gen) + end +end + +end + +using .SuppositionUtils: make_input_matrix_generator, make_expression_generator diff --git a/test/test_supposition_consistency.jl b/test/test_supposition_consistency.jl new file mode 100644 index 00000000..377f45c3 --- /dev/null +++ b/test/test_supposition_consistency.jl @@ -0,0 +1,44 @@ +@testitem "Supposition round-trip consistency" begin + using Test + using Random + + using Supposition + using Supposition: @check, Data + using DynamicExpressions + using DynamicExpressions: + string_tree, parse_expression, eval_tree_array, Node, get_operators, get_tree + + # bring the generator into scope + include("supposition_utils.jl") + + n_features = 5 + max_layers = 20 + T = Float64 + operators = OperatorEnum(((abs, cos, exp), (+, -, *, /), (fma, clamp, +, max))) + + expr_gen = make_expression_generator( + T; num_features=n_features, max_layers=max_layers, operators=operators + ) + + @check function roundtrip_string(ex=expr_gen) + tree_str = string_tree(ex) + ex_parsed = parse_expression( + Meta.parse(tree_str); + operators=get_operators(ex), + variable_names=["x$i" for i in 1:n_features], + node_type=Node{Float64,3}, + ) + return ex == ex_parsed + end + + input_gen = make_input_matrix_generator(T; n_features) + @check max_examples = 1024 function eval_against_string(ex=expr_gen, X=input_gen) + expression_result, ok = eval_tree_array(ex, X) + !ok && return true # If the expression is not valid, we can't test it + tree_str = string_tree(ex) + f_sym = gensym("f") + f = eval(Meta.parse("(x1, x2, x3, x4, x5) -> ($tree_str)")) + true_result = Float64[Base.invokelatest(f, x...) for x in eachcol(X)] + return expression_result ≈ true_result + end +end diff --git a/test/unittest.jl b/test/unittest.jl index 42ae11bb..42e8a286 100644 --- a/test/unittest.jl +++ b/test/unittest.jl @@ -133,3 +133,4 @@ include("test_expression_math.jl") include("test_structured_expression.jl") include("test_readonlynode.jl") include("test_zygote_gradient_wrapper.jl") +include("test_supposition_consistency.jl") From e3da69bcc44022fa7b6c75d0933166e1bc29753c Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 12 May 2025 22:53:13 +0100 Subject: [PATCH 48/74] refactor: better name for poison node --- src/Node.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index 08bca8df..a72a077a 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -169,7 +169,14 @@ when constructing or setting properties. """ GraphNode -function get_poison(n::AbstractNode) +""" + poison_node(n::AbstractNode) + +Return a placeholder for unused child slots, ensuring type stability. +Accessing this node should trigger some kind of noticable error +(e.g., default returns itself, which causes infinite recursion). +""" +function poison_node(n::AbstractNode) # We don't want to use `nothing` because the type instability # hits memory hard. # Setting itself as the right child is the best thing, @@ -197,7 +204,7 @@ end if D === D2 n.children = children else - poison = get_poison(n) + poison = poison_node(n) # We insert poison at the end of the tuple so that # errors will appear loudly if accessed. # This poison should be efficient to insert. So From 4f2c117a10632d60c5f1297e3edf0f8ce5fd02fd Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 12 May 2025 23:26:51 +0100 Subject: [PATCH 49/74] refactor: fix type instability in parametric expression converter --- src/ParametricExpression.jl | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index d8039ed3..21790064 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -311,6 +311,24 @@ function extract_gradient( return vcat(d_constants, d_params) # Same shape as `get_scalar_constants` end +struct BranchConverter{NT<:Node} <: Function end +struct LeafConverter{NT<:Node} <: Function + num_params::UInt16 +end +function (bc::BranchConverter{NT})( + branch::ParametricNode, children::Vararg{Any,M} +) where {NT,M} + return NT(; branch.op, children) +end +function (lc::LeafConverter{NT})(leaf::ParametricNode) where {NT} + if leaf.constant + return NT(; val=leaf.val) + elseif leaf.is_parameter + return NT(; feature=leaf.parameter) + else + return NT(; feature=leaf.feature + lc.num_params) + end +end function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T} num_params = UInt16(size(ex.metadata.parameters, 1)) tree = get_tree(ex) @@ -319,17 +337,7 @@ function Base.convert(::Type{Node}, ex::ParametricExpression{T}) where {T} NT = with_max_degree(with_type_parameters(Node, T), Val(D)) return tree_mapreduce( - leaf -> if leaf.constant - NT(; val=leaf.val) - elseif leaf.is_parameter - NT(T; feature=leaf.parameter) - else - NT(T; feature=leaf.feature + num_params) - end, - branch -> branch.op, - (op, children...) -> NT(; op, children), - tree, - NT, + LeafConverter{NT}(num_params), branch -> branch.op, BranchConverter{NT}(), tree, NT ) end function CRC.rrule(::typeof(convert), ::Type{Node}, ex::ParametricExpression{T}) where {T} From b309f52378224f1b31516087b380055d3cc9fc3e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 13 May 2025 01:00:16 +0100 Subject: [PATCH 50/74] refactor: reduce new type instabilities in generic eval --- src/Evaluate.jl | 74 ++++++++++++++++++++++++------------- src/EvaluateDerivative.jl | 23 +++++++----- src/ParametricExpression.jl | 6 +-- test/test_n_arity_nodes.jl | 20 +++++++--- 4 files changed, 78 insertions(+), 45 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index e6ba96b3..2ac8f33f 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -5,7 +5,7 @@ using DispatchDoctor: @stable, @unstable import ..NodeModule: AbstractExpressionNode, constructorof, max_degree, get_children, with_type_parameters import ..StringsModule: string_tree -import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum +import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum import ..UtilsModule: fill_similar, counttuple, ResultOk import ..NodeUtilsModule: is_constant import ..ExtensionInterfaceModule: bumper_eval_tree_array, _is_loopvectorization_loaded @@ -252,13 +252,11 @@ end # These are marked unstable due to issues discussed on # https://github.com/JuliaLang/julia/issues/55147 -@unstable function get_nuna(::Type{<:OperatorEnum{OPS}}) where {OPS} - ts = OPS.types - return isempty(ts) ? 0 : counttuple(ts[1]) -end -@unstable function get_nbin(::Type{<:OperatorEnum{OPS}}) where {OPS} - ts = OPS.types - return length(ts) == 1 ? 0 : counttuple(ts[2]) +@unstable function get_nops( + ::Type{O}, ::Val{degree} +) where {OPS,O<:Union{OperatorEnum{OPS},GenericOperatorEnum{OPS}},degree} + max_degree = counttuple(OPS) + return degree > max_degree ? 0 : counttuple(OPS.types[degree]) end function _eval_tree_array( @@ -332,15 +330,30 @@ function deg0_eval( end end +# This is used for type stability, since Julia will fail inference +# when the operator list is empty, even if that node type never appears +@inline function get_op( + operators::AbstractOperatorEnum, ::Val{degree}, ::Val{i} +) where {degree,i} + ops = operators[degree] + if isempty(ops) + error( + lazy"Invalid access: a node has degree $degree, but no operators were passed for this degree.", + ) + else + return ops[i] + end +end + # This basically forms an if statement over the operators for the degree. @generated function inner_dispatch_degn_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, ::Val{degree}, - operators::OperatorEnum{OPS}, + operators::O, eval_options::EvalOptions, -) where {T,degree,OPS} - nops = length(OPS.types[degree].types) +) where {T,degree,O<:OperatorEnum} + nops = get_nops(O, Val(degree)) return quote cs = get_children(tree, Val($degree)) Base.Cartesian.@nexprs( @@ -356,7 +369,9 @@ end Base.Cartesian.@nif( $nops, i -> i == op_idx, - i -> degn_eval(cumulators, operators[$degree][i], eval_options), + i -> degn_eval( + cumulators, get_op(operators, Val($degree), Val(i)), eval_options + ), ) end end @@ -386,7 +401,7 @@ end operators::OperatorEnum, eval_options::EvalOptions, ) where {T} - nbin = get_nbin(operators) + nbin = get_nops(operators, Val(2)) long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote @@ -440,7 +455,7 @@ end operators::OperatorEnum, eval_options::EvalOptions, ) where {T} - nuna = get_nuna(operators) + nuna = get_nops(operators, Val(1)) long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote @@ -723,9 +738,9 @@ end # Now that we have the degree, we can get the operator @generated function inner_dispatch_degn_eval_constant( - tree::AbstractExpressionNode{T}, ::Val{degree}, operators::OperatorEnum{OPS} -) where {T,degree,OPS} - nops = length(OPS.types[degree].types) + tree::AbstractExpressionNode{T}, ::Val{degree}, operators::OperatorEnum +) where {T,degree} + nops = get_nops(operators, Val(degree)) get_inputs = quote cs = get_children(tree, Val($degree)) Base.Cartesian.@nexprs( @@ -752,7 +767,9 @@ end Base.Cartesian.@nif( $nops, i -> i == op_idx, - i -> degn_eval_constant(inputs, operators[$degree][i])::ResultOk{T} + i -> degn_eval_constant( + inputs, get_op(operators, Val($degree), Val(i)) + )::ResultOk{T} ) end end @@ -815,9 +832,9 @@ end cX::AbstractMatrix{T}, op_idx::Integer, ::Val{degree}, - operators::OperatorEnum{OPS}, -) where {T<:Number,T1,D,degree,OPS} - nops = length(OPS.types[degree].types) + operators::OperatorEnum, +) where {T<:Number,T1,D,degree} + nops = get_nops(operators, Val(degree)) quote cs = get_children(tree, Val($degree)) Base.Cartesian.@nexprs( @@ -832,7 +849,9 @@ end ) cumulators = Base.Cartesian.@ntuple($degree, i -> cumulator_i) Base.Cartesian.@nif( - $nops, i -> i == op_idx, i -> degn_diff_eval(cumulators, operators[$degree][i]) + $nops, + i -> i == op_idx, + i -> degn_diff_eval(cumulators, get_op(operators, Val($degree), Val(i))) ) end end @@ -968,10 +987,10 @@ end cX::AbstractArray{T2,N}, op_idx::Integer, ::Val{degree}, - operators::GenericOperatorEnum{OPS}, + operators::GenericOperatorEnum, ::Val{throw_errors}, -) where {T1,T2,N,degree,throw_errors,OPS} - nops = length(OPS.types[degree].types) +) where {T1,T2,N,degree,throw_errors} + nops = get_nops(operators, Val(degree)) quote cs = get_children(tree, Val($degree)) Base.Cartesian.@nexprs( @@ -991,7 +1010,10 @@ end $nops, i -> i == op_idx, i -> degn_eval_generic( - cumulators, operators[$degree][i], Val(N), Val(throw_errors) + cumulators, + get_op(operators, Val($degree), Val(i)), + Val(N), + Val(throw_errors), ) ) end diff --git a/src/EvaluateDerivative.jl b/src/EvaluateDerivative.jl index 91470946..fc983164 100644 --- a/src/EvaluateDerivative.jl +++ b/src/EvaluateDerivative.jl @@ -6,7 +6,7 @@ import ..UtilsModule: fill_similar, ResultOk2 import ..ValueInterfaceModule: is_valid_array import ..NodeUtilsModule: count_constant_nodes, index_constant_nodes, NodeIndex import ..EvaluateModule: - deg0_eval, get_nuna, get_nbin, OPERATOR_LIMIT_BEFORE_SLOWDOWN, EvalOptions + deg0_eval, get_op, get_nops, OPERATOR_LIMIT_BEFORE_SLOWDOWN, EvalOptions import ..ExtensionInterfaceModule: _zygote_gradient """ @@ -120,10 +120,10 @@ end tree::AbstractExpressionNode{T,D}, cX::AbstractMatrix{T}, ::Val{degree}, - operators::OperatorEnum{OPS}, + operators::OperatorEnum, direction::Integer, -) where {T<:Number,D,degree,OPS} - nops = length(OPS.types[degree].types) +) where {T<:Number,D,degree} + nops = get_nops(operators, Val(degree)) setup = quote cs = get_children(tree, Val($degree)) @@ -153,7 +153,10 @@ end $nops, i -> i == op_idx, i -> diff_degn_eval( - x_cumulators, dx_cumulators, operators[$degree][i], direction + x_cumulators, + dx_cumulators, + get_op(operators, Val($degree), Val(i)), + direction, ) ) end @@ -280,9 +283,9 @@ end index_tree::Union{NodeIndex,Nothing}, cX::AbstractMatrix{T}, ::Val{degree}, - operators::OperatorEnum{OPS}, + operators::OperatorEnum, ::Val{mode}, -) where {T<:Number,degree,OPS,mode} +) where {T<:Number,degree,mode} setup = quote cs = get_children(tree, Val($degree)) index_cs = @@ -305,7 +308,7 @@ end d_cumulators = Base.Cartesian.@ntuple($degree, i -> result_i.dx) op_idx = tree.op end - nops = length(OPS.types[degree].types) + nops = get_nops(operators, Val(degree)) if nops > OPERATOR_LIMIT_BEFORE_SLOWDOWN quote $setup @@ -317,7 +320,9 @@ end Base.Cartesian.@nif( $nops, i -> i == op_idx, - i -> grad_degn_eval(x_cumulators, d_cumulators, operators[$degree][i]) + i -> grad_degn_eval( + x_cumulators, d_cumulators, get_op(operators, Val($degree), Val(i)) + ) ) end end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 21790064..bea449c2 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -315,10 +315,8 @@ struct BranchConverter{NT<:Node} <: Function end struct LeafConverter{NT<:Node} <: Function num_params::UInt16 end -function (bc::BranchConverter{NT})( - branch::ParametricNode, children::Vararg{Any,M} -) where {NT,M} - return NT(; branch.op, children) +function (bc::BranchConverter{NT})(op::Integer, children::Vararg{Any,M}) where {NT,M} + return NT(; op, children) end function (lc::LeafConverter{NT})(leaf::ParametricNode) where {NT} if leaf.constant diff --git a/test/test_n_arity_nodes.jl b/test/test_n_arity_nodes.jl index 36b28c70..d3ddcdcb 100644 --- a/test/test_n_arity_nodes.jl +++ b/test/test_n_arity_nodes.jl @@ -113,12 +113,20 @@ end @test operators_full[2] == (my_binary_op,) @test operators_full[3] == (my_ternary_op,) - @test DynamicExpressions.EvaluateModule.get_nuna(typeof(operators_full)) == 1 - @test DynamicExpressions.EvaluateModule.get_nbin(typeof(operators_full)) == 1 - @test DynamicExpressions.EvaluateModule.get_nuna(typeof(operators_unary_only)) == 1 - @test DynamicExpressions.EvaluateModule.get_nbin(typeof(operators_unary_only)) == 0 - @test DynamicExpressions.EvaluateModule.get_nuna(typeof(operators_binary_only)) == 0 - @test DynamicExpressions.EvaluateModule.get_nbin(typeof(operators_binary_only)) == 1 + @test DynamicExpressions.EvaluateModule.get_nops(typeof(operators_full), Val(1)) == 1 + @test DynamicExpressions.EvaluateModule.get_nops(typeof(operators_full), Val(2)) == 1 + @test DynamicExpressions.EvaluateModule.get_nops( + typeof(operators_unary_only), Val(1) + ) == 1 + @test DynamicExpressions.EvaluateModule.get_nops( + typeof(operators_unary_only), Val(2) + ) == 0 + @test DynamicExpressions.EvaluateModule.get_nops( + typeof(operators_binary_only), Val(1) + ) == 0 + @test DynamicExpressions.EvaluateModule.get_nops( + typeof(operators_binary_only), Val(2) + ) == 1 end @testitem "N-ary Evaluation (targeting dispatch_degn_eval)" tags = [:narity] begin From b66b1782b958519374810c5cdd5242a5af3121e0 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 13 May 2025 02:22:13 +0100 Subject: [PATCH 51/74] refactor: eliminate other type instabilities --- src/Evaluate.jl | 25 +++++++++++++++++++++---- src/Node.jl | 11 +++++++---- src/Utils.jl | 2 ++ 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 2ac8f33f..4833e60e 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -255,8 +255,7 @@ end @unstable function get_nops( ::Type{O}, ::Val{degree} ) where {OPS,O<:Union{OperatorEnum{OPS},GenericOperatorEnum{OPS}},degree} - max_degree = counttuple(OPS) - return degree > max_degree ? 0 : counttuple(OPS.types[degree]) + return degree > counttuple(OPS) ? 0 : counttuple(OPS.types[degree]) end function _eval_tree_array( @@ -345,8 +344,26 @@ end end end +# TODO: Hack to fix type instability in some branches that can't be inferred. +# It does this using the other branches, which _can_ be inferred. +function _get_return_type(tree, cX, operators, eval_options) + # public Julia API version of `Core.Compiler.return_type(_eval_tree_array, typeof((tree, cX, operators, eval_options)))` + return eltype([_eval_tree_array(tree, cX, operators, eval_options) for _ in 1:0]) +end + # This basically forms an if statement over the operators for the degree. -@generated function inner_dispatch_degn_eval( +function inner_dispatch_degn_eval( + tree::AbstractExpressionNode{T}, + cX::AbstractMatrix{T}, + ::Val{degree}, + operators::OperatorEnum, + eval_options::EvalOptions, +) where {T,degree} + return _inner_dispatch_degn_eval( + tree, cX, Val(degree), operators, eval_options + )::(_get_return_type(tree, cX, operators, eval_options)) +end +@generated function _inner_dispatch_degn_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, ::Val{degree}, @@ -371,7 +388,7 @@ end i -> i == op_idx, i -> degn_eval( cumulators, get_op(operators, Val($degree), Val(i)), eval_options - ), + ) ) end end diff --git a/src/Node.jl b/src/Node.jl index a72a077a..5995935b 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -2,7 +2,6 @@ module NodeModule using DispatchDoctor: @unstable -import ..OperatorEnumModule: AbstractOperatorEnum import ..UtilsModule: deprecate_varmap, Undefined const DEFAULT_NODE_TYPE = Float32 @@ -260,14 +259,18 @@ end Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T Base.eltype(::AbstractExpressionNode{T}) where {T} = T +#! format: off # COV_EXCL_START -max_degree(::Type{<:AbstractNode}) = DEFAULT_MAX_DEGREE -max_degree(::Type{<:AbstractNode{D}}) where {D} = D -max_degree(node::AbstractNode) = max_degree(typeof(node)) +# These are marked unstable due to issues discussed on +# https://github.com/JuliaLang/julia/issues/55147 +@unstable max_degree(::Type{<:AbstractNode}) = DEFAULT_MAX_DEGREE +@unstable max_degree(::Type{<:AbstractNode{D}}) where {D} = D +@unstable max_degree(node::AbstractNode) = max_degree(typeof(node)) has_max_degree(::Type{<:AbstractNode}) = false has_max_degree(::Type{<:AbstractNode{D}}) where {D} = true # COV_EXCL_STOP +#! format: on @unstable function constructorof(::Type{N}) where {N<:Node} return Node{T,max_degree(N)} where {T} diff --git a/src/Utils.jl b/src/Utils.jl index 6db149ff..be7fdf72 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -28,6 +28,8 @@ function deprecate_varmap(variable_names, varMap, func_name) return variable_names end +# These are marked unstable due to issues discussed on +# https://github.com/JuliaLang/julia/issues/55147 @unstable counttuple(::Type{<:NTuple{N,Any}}) where {N} = N """ From 6e78b279a68617b257e0b247f4f2cdc01ec074da Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 13 May 2025 13:50:59 +0100 Subject: [PATCH 52/74] fix: attempt type stability fix for union of `Type{Union{}}` with other type --- src/Evaluate.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 4833e60e..3e178604 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -346,7 +346,7 @@ end # TODO: Hack to fix type instability in some branches that can't be inferred. # It does this using the other branches, which _can_ be inferred. -function _get_return_type(tree, cX, operators, eval_options) +@unstable function _get_return_type(tree, cX, operators, eval_options) # public Julia API version of `Core.Compiler.return_type(_eval_tree_array, typeof((tree, cX, operators, eval_options)))` return eltype([_eval_tree_array(tree, cX, operators, eval_options) for _ in 1:0]) end From bae706393c62c645752e2188e80ed691acf21a90 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 13 May 2025 14:24:47 +0100 Subject: [PATCH 53/74] test: move test to main file --- test/runtests.jl | 4 ---- test/unittest.jl | 1 + 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 70de3cf8..3fe9f6a8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -49,7 +49,3 @@ if "main" in test_name include("unittest.jl") @run_package_tests end -if "narity" in test_name - include("test_n_arity_nodes.jl") - @run_package_tests filter = ti -> (:narity in ti.tags) -end diff --git a/test/unittest.jl b/test/unittest.jl index 42e8a286..1cead521 100644 --- a/test/unittest.jl +++ b/test/unittest.jl @@ -134,3 +134,4 @@ include("test_structured_expression.jl") include("test_readonlynode.jl") include("test_zygote_gradient_wrapper.jl") include("test_supposition_consistency.jl") +include("test_n_arity_nodes.jl") From 1d6119246428f1f53fca28843b0f28a764ac4813 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 13 May 2025 14:30:25 +0100 Subject: [PATCH 54/74] fix: `convert` should check degree --- src/base.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/base.jl b/src/base.jl index e86ccf59..19fab57c 100644 --- a/src/base.jl +++ b/src/base.jl @@ -529,10 +529,15 @@ using `convert(T1, tree.val)` at constant nodes. """ function convert( ::Type{N1}, tree::N2 -) where {T1,T2,N1<:AbstractExpressionNode{T1},N2<:AbstractExpressionNode{T2}} +) where {T1,T2,D1,D2,N1<:AbstractExpressionNode{T1,D1},N2<:AbstractExpressionNode{T2,D2}} if N1 === N2 return tree end + if D1 !== D2 + throw( + ArgumentError("Cannot convert $N2 to $N1 because they have different degrees.") + ) + end return tree_mapreduce( Base.Fix1(leaf_convert, N1), identity, @@ -542,6 +547,11 @@ function convert( ) # TODO: Need to allow user to overload this! end +function convert( + ::Type{N1}, tree::N2 +) where {T1,T2,D,N1<:AbstractExpressionNode{T1},N2<:AbstractExpressionNode{T2,D}} + return convert(with_max_degree(N1, Val(D)), tree) +end function convert( ::Type{N1}, tree::N2 ) where {T2,N1<:AbstractExpressionNode,N2<:AbstractExpressionNode{T2}} From 24b4ed006784beb7f5083f8fd7655fd41f63fd40 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 13 May 2025 14:34:06 +0100 Subject: [PATCH 55/74] test: make `FrozenNode` also overload `with_max_degree` --- test/test_extra_node_fields.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_extra_node_fields.jl b/test/test_extra_node_fields.jl index 996cefe4..d74135a4 100644 --- a/test/test_extra_node_fields.jl +++ b/test/test_extra_node_fields.jl @@ -27,6 +27,9 @@ function DynamicExpressions.with_type_parameters( ) where {T,N<:FrozenNode} return FrozenNode{T,max_degree(N)} end +function DynamicExpressions.with_max_degree(::Type{N}, ::Val{D}) where {T,N<:FrozenNode{T}} + return FrozenNode{T,D} +end function DynamicExpressions.leaf_copy(t::FrozenNode{T}) where {T} out = if t.constant constructorof(typeof(t))(; val=t.val) From befee815e9027a01805736e98c21575b466085d1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 13 May 2025 14:46:26 +0100 Subject: [PATCH 56/74] feat: add `with_max_degree` to required interface --- src/Interfaces.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/Interfaces.jl b/src/Interfaces.jl index 5842ec34..e5e846a5 100644 --- a/src/Interfaces.jl +++ b/src/Interfaces.jl @@ -12,6 +12,8 @@ using ..NodeModule: constructorof, default_allocator, with_type_parameters, + with_max_degree, + max_degree, get_children, leaf_copy, leaf_convert, @@ -259,6 +261,12 @@ function _check_with_type_parameters(tree::AbstractExpressionNode{T}) where {T} Nf16 = with_type_parameters(N, Float16) return Nf16 <: AbstractExpressionNode{Float16} end +function _check_with_max_degree(tree::AbstractExpressionNode) + N = typeof(tree) + new_D = max_degree(N) + 1 + N2 = with_max_degree(N, Val(new_D)) + return N2 <: AbstractExpressionNode && max_degree(N2) == new_D +end function _check_default_allocator(tree::AbstractExpressionNode) N = Base.typename(typeof(tree)).wrapper return default_allocator(N, Float64) isa with_type_parameters(N, Float64) @@ -376,6 +384,7 @@ ni_components = ( constructorof = "gets the constructor function for a node type" => _check_constructorof, eltype = "gets the element type of the node" => _check_eltype, with_type_parameters = "applies type parameters to the node type" => _check_with_type_parameters, + with_max_degree = "changes the maximum degree of a node type" => _check_with_max_degree, default_allocator = "gets the default allocator for the node type" => _check_default_allocator, set_node! = "sets the node's value" => _check_set_node!, count_nodes = "counts the number of nodes in the tree" => _check_count_nodes, From 0668a582c180f60628330de810c6a54bf9281cb6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 13 May 2025 19:04:39 +0100 Subject: [PATCH 57/74] fix: missing import --- src/DynamicExpressions.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 051be69a..c1087516 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -52,6 +52,7 @@ import .NodeModule: with_type_parameters, preserve_sharing, max_degree, + with_max_degree, leaf_copy, branch_copy, leaf_hash, From a960c278e1d79388c6ed49a5acd3902526c46491 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 13 May 2025 19:23:18 +0100 Subject: [PATCH 58/74] test: move preferences to test/Project.toml --- test/LocalPreferences.toml | 3 --- test/Project.toml | 4 ++++ 2 files changed, 4 insertions(+), 3 deletions(-) delete mode 100644 test/LocalPreferences.toml diff --git a/test/LocalPreferences.toml b/test/LocalPreferences.toml deleted file mode 100644 index d5e044f3..00000000 --- a/test/LocalPreferences.toml +++ /dev/null @@ -1,3 +0,0 @@ -[DynamicExpressions] -instability_check = "error" -instability_check_codegen = "min" diff --git a/test/Project.toml b/test/Project.toml index c1e6c23e..d33b1b78 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -27,3 +27,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Aqua = "0.7" + +[preferences.DynamicExpressions] +dispatch_doctor_mode = "error" +dispatch_doctor_codegen_level = "min" From 7978292eefe9ed18d498b690559e05957830c7b1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 14 May 2025 10:58:19 +0100 Subject: [PATCH 59/74] refactor: additional tricks to try to improve type stability --- src/base.jl | 6 +++--- test/Project.toml | 2 +- test/test_extra_node_fields.jl | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/base.jl b/src/base.jl index 19fab57c..39ab306a 100644 --- a/src/base.jl +++ b/src/base.jl @@ -121,8 +121,8 @@ struct TreeMapreducer{ end @generated function call_mapreducer( - mapreducer::TreeMapreducer{D,ID}, tree::AbstractNode -) where {D,ID} + mapreducer::TreeMapreducer{D,ID,F1,F2,G,H}, tree::AbstractNode +) where {D,ID,F1,F2,G,H} quote key = ID <: Dict ? objectid(tree) : nothing if ID <: Dict && haskey(mapreducer.id_map, key) @@ -353,7 +353,7 @@ end Collect all nodes in a tree into a flat array in depth-first order. """ function collect(tree::AbstractNode; break_sharing::Val{BS}=Val(false)) where {BS} - return filter(Returns(true), tree; break_sharing=Val(BS)) + return filter(_ -> true, tree; break_sharing=Val(BS)) end Base.IteratorSize(::Type{<:AbstractNode}) = Base.HasLength() diff --git a/test/Project.toml b/test/Project.toml index d33b1b78..ac94070f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -29,5 +29,5 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Aqua = "0.7" [preferences.DynamicExpressions] -dispatch_doctor_mode = "error" dispatch_doctor_codegen_level = "min" +dispatch_doctor_mode = "error" diff --git a/test/test_extra_node_fields.jl b/test/test_extra_node_fields.jl index d74135a4..dc191fec 100644 --- a/test/test_extra_node_fields.jl +++ b/test/test_extra_node_fields.jl @@ -27,7 +27,9 @@ function DynamicExpressions.with_type_parameters( ) where {T,N<:FrozenNode} return FrozenNode{T,max_degree(N)} end -function DynamicExpressions.with_max_degree(::Type{N}, ::Val{D}) where {T,N<:FrozenNode{T}} +function DynamicExpressions.with_max_degree( + ::Type{N}, ::Val{D} +) where {T,N<:FrozenNode{T},D} return FrozenNode{T,D} end function DynamicExpressions.leaf_copy(t::FrozenNode{T}) where {T} From a623e8f189cd4dcac7ee4d8c8fb1d0b82ba96a32 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 14 May 2025 11:38:18 +0100 Subject: [PATCH 60/74] test: fix imaginary type instability from Interfaces.jl This was from `tree` being treated as an iterator! --- src/base.jl | 16 +++++++++------- test/test_node_interface.jl | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/base.jl b/src/base.jl index 39ab306a..20cba23a 100644 --- a/src/base.jl +++ b/src/base.jl @@ -121,8 +121,8 @@ struct TreeMapreducer{ end @generated function call_mapreducer( - mapreducer::TreeMapreducer{D,ID,F1,F2,G,H}, tree::AbstractNode -) where {D,ID,F1,F2,G,H} + mapreducer::TreeMapreducer{D,ID}, tree::AbstractNode +) where {D,ID} quote key = ID <: Dict ? objectid(tree) : nothing if ID <: Dict && haskey(mapreducer.id_map, key) @@ -369,11 +369,13 @@ function map( result_type::Type{RT}=Nothing; break_sharing::Val{BS}=Val(false), ) where {F<:Function,RT,BS} - if RT == Nothing - return map(f, collect(tree; break_sharing=Val(BS))) - else - return filter_map(Returns(true), f, tree, result_type; break_sharing=Val(BS)) - end + return _map(f, tree, result_type, Val(BS)) +end +function _map(f::F, tree::AbstractNode, ::Type{Nothing}, ::Val{BS}) where {F<:Function,BS} + return map(f, collect(tree; break_sharing=Val(BS))) +end +function _map(f::F, tree::AbstractNode, ::Type{RT}, ::Val{BS}) where {F<:Function,RT,BS} + return filter_map(Returns(true), f, tree, RT; break_sharing=Val(BS)) end """ diff --git a/test/test_node_interface.jl b/test/test_node_interface.jl index c4f37529..5a023fcb 100644 --- a/test/test_node_interface.jl +++ b/test/test_node_interface.jl @@ -53,7 +53,7 @@ end idx_max = 1 tree = Node{Float64,D}(; op=idx_max, children=(tree, x[1], x[2], x[3])) # max end - @test Interfaces.test(NodeInterface, Node, tree) + @test Interfaces.test(NodeInterface, Node, [tree]) end end end From 5472e8aefa68a7944d57966b52c7870cbf39e44e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 14 May 2025 12:10:32 +0100 Subject: [PATCH 61/74] test: fix dd seting --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 3fe9f6a8..24054a93 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,7 +18,7 @@ end if "jet" in test_name @safetestset "JET" begin using Preferences - set_preferences!("DynamicExpressions", "instability_check" => "disable"; force=true) + set_preferences!("DynamicExpressions", "dispatch_doctor_mode" => "disable"; force=true) using JET using DynamicExpressions struct MyIgnoredModule From 2bf20fa3a6d41209f71fc5a9b39e969d1e0ad48d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 15 May 2025 20:49:43 +0100 Subject: [PATCH 62/74] feat: `set_children!` to work with vector --- src/Node.jl | 5 +++-- src/ReadOnlyNode.jl | 19 ++++++++++--------- test/runtests.jl | 4 +++- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index 5995935b..fb0cfd62 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -199,7 +199,8 @@ end set_children!(n, Base.setindex(get_children(n), child, i)) return child end -@inline function set_children!(n::AbstractNode{D}, children::Tuple{Vararg{AbstractNode{D},D2}}) where {D,D2} +@inline function set_children!(n::AbstractNode{D}, children::Union{Tuple,AbstractVector{<:AbstractNode{D}}}) where {D} + D2 = length(children) if D === D2 n.children = children else @@ -209,7 +210,7 @@ end # This poison should be efficient to insert. So # for simplicity, we can just use poison == n, which # will trigger infinite recursion errors if accessed. - n.children = ntuple(i -> i <= D2 ? children[i] : poison, Val(D)) + n.children = ntuple(i -> i <= D2 ? @inbounds(children[i]) : poison, Val(D)) end end diff --git a/src/ReadOnlyNode.jl b/src/ReadOnlyNode.jl index 6d4999f2..1d830222 100644 --- a/src/ReadOnlyNode.jl +++ b/src/ReadOnlyNode.jl @@ -8,14 +8,7 @@ import ..NodeModule: default_allocator, with_type_parameters, constructorof, get abstract type AbstractReadOnlyNode{T,D,N<:AbstractExpressionNode{T,D}} <: AbstractExpressionNode{T,D} end -"""A type of expression node that prevents writing to the inner node""" -struct ReadOnlyNode{T,D,N} <: AbstractReadOnlyNode{T,D,N} - _inner::N - - ReadOnlyNode(n::N) where {T,N<:AbstractExpressionNode{T}} = new{T,max_degree(N),N}(n) -end @inline inner(n::AbstractReadOnlyNode) = getfield(n, :_inner) -@unstable constructorof(::Type{<:ReadOnlyNode}) = ReadOnlyNode @inline function Base.getproperty(n::AbstractReadOnlyNode, s::Symbol) out = getproperty(inner(n), s) if out isa AbstractExpressionNode @@ -25,12 +18,20 @@ end end end @inline function get_children(node::AbstractReadOnlyNode) - return map(ReadOnlyNode, get_children(inner(node))) + return map(constructorof(typeof(node)), get_children(inner(node))) end function Base.setproperty!(::AbstractReadOnlyNode, ::Symbol, v) return error("Cannot set properties on a ReadOnlyNode") end Base.propertynames(n::AbstractReadOnlyNode) = propertynames(inner(n)) -Base.copy(n::AbstractReadOnlyNode) = ReadOnlyNode(copy(inner(n))) +Base.copy(n::AbstractReadOnlyNode) = constructorof(typeof(n))(copy(inner(n))) + +"""A type of expression node that prevents writing to the inner node""" +struct ReadOnlyNode{T,D,N} <: AbstractReadOnlyNode{T,D,N} + _inner::N + + ReadOnlyNode(n::N) where {T,N<:AbstractExpressionNode{T}} = new{T,max_degree(N),N}(n) +end +@unstable constructorof(::Type{<:ReadOnlyNode}) = ReadOnlyNode end diff --git a/test/runtests.jl b/test/runtests.jl index 24054a93..0d9b3b05 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,7 +18,9 @@ end if "jet" in test_name @safetestset "JET" begin using Preferences - set_preferences!("DynamicExpressions", "dispatch_doctor_mode" => "disable"; force=true) + set_preferences!( + "DynamicExpressions", "dispatch_doctor_mode" => "disable"; force=true + ) using JET using DynamicExpressions struct MyIgnoredModule From 8495d809d301dbeaf87850ab8835a5f296335e27 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 30 May 2025 23:15:11 +0100 Subject: [PATCH 63/74] test: remove Enzyme as required part of test suite --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index ac94070f..61ff11c1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,7 +5,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" From 0af17009d9bbbb1be67ddff08c2c5e67a5335c87 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 31 May 2025 00:32:27 +0100 Subject: [PATCH 64/74] feat: fix children set in ParametricNode --- src/ParametricExpression.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index bea449c2..372b364a 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -15,6 +15,8 @@ import ..NodeModule: with_max_degree, max_degree, preserve_sharing, + get_children, + set_children!, leaf_copy, leaf_convert, leaf_hash, @@ -144,8 +146,8 @@ function leaf_copy(t::ParametricNode{T}) where {T} end end function set_node!(tree::ParametricNode, new_tree::ParametricNode) - tree.degree = new_tree.degree - if new_tree.degree == 0 + tree.degree = (deg = new_tree.degree) + if deg == 0 if new_tree.constant tree.constant = true tree.val = new_tree.val @@ -160,10 +162,7 @@ function set_node!(tree::ParametricNode, new_tree::ParametricNode) end else tree.op = new_tree.op - tree.l = new_tree.l - if new_tree.degree == 2 - tree.r = new_tree.r - end + set_children!(tree, get_children(new_tree)) end return nothing end From 7d2a542baa24c1f4ccaf257d8799d24ca0b3d971 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 31 May 2025 00:33:32 +0100 Subject: [PATCH 65/74] test: incorporate turbo and bumper in supposition test --- test/test_supposition_consistency.jl | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/test/test_supposition_consistency.jl b/test/test_supposition_consistency.jl index 377f45c3..db0d07fa 100644 --- a/test/test_supposition_consistency.jl +++ b/test/test_supposition_consistency.jl @@ -7,6 +7,8 @@ using DynamicExpressions using DynamicExpressions: string_tree, parse_expression, eval_tree_array, Node, get_operators, get_tree + using LoopVectorization: LoopVectorization + using Bumper: Bumper # bring the generator into scope include("supposition_utils.jl") @@ -20,7 +22,7 @@ T; num_features=n_features, max_layers=max_layers, operators=operators ) - @check function roundtrip_string(ex=expr_gen) + result = @check function roundtrip_string(ex=expr_gen) tree_str = string_tree(ex) ex_parsed = parse_expression( Meta.parse(tree_str); @@ -30,15 +32,25 @@ ) return ex == ex_parsed end + @test something(result.result) isa Supposition.Pass input_gen = make_input_matrix_generator(T; n_features) - @check max_examples = 1024 function eval_against_string(ex=expr_gen, X=input_gen) - expression_result, ok = eval_tree_array(ex, X) - !ok && return true # If the expression is not valid, we can't test it + args_gen = map( + (ex, X, turbo, bumper) -> (; ex, X, turbo, bumper), + expr_gen, + input_gen, + Data.Booleans(), + Data.Booleans(), + ) + # We only consider expressions that don't have NaN/Inf/etc. + clean_args_gen = filter(args -> eval_tree_array(args.ex, args.X)[2], args_gen) + result2 = @check max_examples = 1000 function eval_against_string(args=clean_args_gen) + (; ex, X, turbo, bumper) = args + expression_result, ok = eval_tree_array(ex, X; turbo, bumper) tree_str = string_tree(ex) - f_sym = gensym("f") f = eval(Meta.parse("(x1, x2, x3, x4, x5) -> ($tree_str)")) true_result = Float64[Base.invokelatest(f, x...) for x in eachcol(X)] - return expression_result ≈ true_result + return ok && expression_result ≈ true_result end + @test something(result2.result) isa Supposition.Pass end From 358d21f3156e6c6421329b2f4372c01293834193 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 31 May 2025 01:39:47 +0100 Subject: [PATCH 66/74] test: refactor supposition test --- test/test_supposition_consistency.jl | 115 +++++++++++++++++++++------ 1 file changed, 91 insertions(+), 24 deletions(-) diff --git a/test/test_supposition_consistency.jl b/test/test_supposition_consistency.jl index db0d07fa..344b7d4b 100644 --- a/test/test_supposition_consistency.jl +++ b/test/test_supposition_consistency.jl @@ -10,47 +10,114 @@ using LoopVectorization: LoopVectorization using Bumper: Bumper - # bring the generator into scope + # Bring the generator into scope include("supposition_utils.jl") - n_features = 5 - max_layers = 20 - T = Float64 - operators = OperatorEnum(((abs, cos, exp), (+, -, *, /), (fma, clamp, +, max))) + # Test configuration constants + const N_FEATURES = 5 + const MAX_LAYERS = 20 + const NUMERIC_TYPE = Float64 + const OPERATORS = OperatorEnum(((abs, cos, exp), (+, -, *, /), (fma, clamp, +, max))) + const VARIABLE_NAMES = ["x$i" for i in 1:N_FEATURES] + # Create expression generator expr_gen = make_expression_generator( - T; num_features=n_features, max_layers=max_layers, operators=operators + NUMERIC_TYPE; num_features=N_FEATURES, max_layers=MAX_LAYERS, operators=OPERATORS ) + # Test 1: Round-trip string parsing consistency result = @check function roundtrip_string(ex=expr_gen) tree_str = string_tree(ex) ex_parsed = parse_expression( Meta.parse(tree_str); operators=get_operators(ex), - variable_names=["x$i" for i in 1:n_features], + variable_names=VARIABLE_NAMES, node_type=Node{Float64,3}, ) return ex == ex_parsed end @test something(result.result) isa Supposition.Pass - input_gen = make_input_matrix_generator(T; n_features) - args_gen = map( - (ex, X, turbo, bumper) -> (; ex, X, turbo, bumper), - expr_gen, - input_gen, - Data.Booleans(), - Data.Booleans(), - ) - # We only consider expressions that don't have NaN/Inf/etc. - clean_args_gen = filter(args -> eval_tree_array(args.ex, args.X)[2], args_gen) - result2 = @check max_examples = 1000 function eval_against_string(args=clean_args_gen) - (; ex, X, turbo, bumper) = args - expression_result, ok = eval_tree_array(ex, X; turbo, bumper) + # Test 2: Evaluation consistency against string representation + input_gen = make_input_matrix_generator(NUMERIC_TYPE; n_features=N_FEATURES) + + # Helper function to create clean argument generators + function clean_args_gen_maker(default_turbo) + args_gen = map( + (ex, X, turbo, bumper) -> let + result, ok = eval_tree_array(ex, X; turbo, bumper) + (; ex, X, turbo, bumper, result, ok) + end, + expr_gen, + input_gen, + map(_ -> default_turbo, Data.Booleans()), + Data.Booleans(), + ) + # We only consider expressions that don't have NaN/Inf/etc. + return filter(args -> args.ok, args_gen) + end + + # Helper function to create turbo evaluation function + function create_turbo_function(tree_str) + turbo_expr = "(x1, x2, x3, x4, x5) -> let y = deepcopy(x1); @turbo(@.(y = ($tree_str))); y; end" + return eval(Meta.parse(turbo_expr)) + end + + # Helper function to create regular evaluation function + function create_regular_function(tree_str) + regular_expr = "(x1, x2, x3, x4, x5) -> ($tree_str)" + return eval(Meta.parse(regular_expr)) + end + + # Helper function to evaluate with turbo + function evaluate_with_turbo(f, X) + return Base.invokelatest(f, X[1, :], X[2, :], X[3, :], X[4, :], X[5, :]) + end + + # Helper function to evaluate without turbo + function evaluate_without_turbo(f, X) + return Float64[Base.invokelatest(f, x...) for x in eachcol(X)] + end + + # Helper function to evaluate expression against its string representation + function _eval_against_string((; ex, X, turbo, bumper, result)) tree_str = string_tree(ex) - f = eval(Meta.parse("(x1, x2, x3, x4, x5) -> ($tree_str)")) - true_result = Float64[Base.invokelatest(f, x...) for x in eachcol(X)] - return ok && expression_result ≈ true_result + true_result = if turbo + # Turbo changes the operators, so we need to use a different function + f = create_turbo_function(tree_str) + evaluate_with_turbo(f, X) + else + f = create_regular_function(tree_str) + evaluate_without_turbo(f, X) + end + + return result ≈ true_result end - @test something(result2.result) isa Supposition.Pass + + # Test evaluation consistency without turbo + no_turbo_args_gen = clean_args_gen_maker(false) + result2_noturbo = @check max_examples = 2000 function eval_against_string( + args=no_turbo_args_gen + ) + return _eval_against_string(args) + end + @test something(result2_noturbo.result) isa Supposition.Pass + + # TODO: We need to run this test manually, as there are too many + # examples where turbo evaluation is slightly different. + # # Test evaluation consistency with turbo (fewer examples due to performance) + # turbo_args_gen = clean_args_gen_maker(true) + # counter = Ref(0) + # result2_turbo = @check max_examples = 50 function eval_against_string( + # args=turbo_args_gen + # ) + # c = (counter[] += 1) + # if c > 50 + # # Supposition seems to not listen to max_examples sometimes + # return true + # else + # return _eval_against_string(args) + # end + # end + # @test something(result2_turbo.result) isa Supposition.Pass end From 7a69a7bab0674f856baed353c008d26c3044d17d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 31 May 2025 23:21:50 +0100 Subject: [PATCH 67/74] fix: fix JET identified missing method --- src/ReadOnlyNode.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/ReadOnlyNode.jl b/src/ReadOnlyNode.jl index 1d830222..7ab4cc23 100644 --- a/src/ReadOnlyNode.jl +++ b/src/ReadOnlyNode.jl @@ -33,5 +33,13 @@ struct ReadOnlyNode{T,D,N} <: AbstractReadOnlyNode{T,D,N} ReadOnlyNode(n::N) where {T,N<:AbstractExpressionNode{T}} = new{T,max_degree(N),N}(n) end @unstable constructorof(::Type{<:ReadOnlyNode}) = ReadOnlyNode +# TODO: Should this provide the degree? Or is it fine, since it always infers it from the inner node type? + +function with_type_parameters(::Type{N}, ::Type{T}) where {N<:ReadOnlyNode,T} + return ReadOnlyNode{T,max_degree(N),with_type_parameters(inner_node_type(N), T)} +end + +@inline inner_node_type(::Type{<:(ReadOnlyNode{T,D,N} where {T,D})}) where {N} = N +@inline inner_node_type(::Type{<:ReadOnlyNode}) = Node end From 76899991fd8760179f6db2872f47b8bdabdbc497 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 31 May 2025 23:59:35 +0100 Subject: [PATCH 68/74] chore: bump major version with n-ary nodes --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3c24580c..f5d15f5e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DynamicExpressions" uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b" authors = ["MilesCranmer "] -version = "1.10.1" +version = "2.0.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 0514c90f6882a8a3a6cb2cf793aa0485e06b85bc Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 Jun 2025 00:13:48 +0100 Subject: [PATCH 69/74] refactor: more generic `set_node!` implementation --- src/Node.jl | 24 ++++++++++++++++-------- src/ParametricExpression.jl | 30 ++++++++++++------------------ 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/src/Node.jl b/src/Node.jl index fb0cfd62..dad32cca 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -420,17 +420,25 @@ Set every field of `tree` equal to the corresponding field of `new_tree`. function set_node!(tree::AbstractExpressionNode, new_tree::AbstractExpressionNode) tree.degree = new_tree.degree if new_tree.degree == 0 - tree.constant = new_tree.constant - if new_tree.constant - tree.val = new_tree.val::eltype(new_tree) - else - tree.feature = new_tree.feature - end + set_leaf!(tree, new_tree) else - tree.op = new_tree.op - set_children!(tree, get_children(new_tree)) + set_branch!(tree, new_tree) end return nothing end +function set_leaf!(tree::AbstractExpressionNode, new_leaf::AbstractExpressionNode) + tree.constant = new_leaf.constant + if new_leaf.constant + tree.val = new_leaf.val::eltype(new_leaf) + else + tree.feature = new_leaf.feature + end + return nothing +end +function set_branch!(tree::AbstractExpressionNode, new_branch::AbstractExpressionNode) + tree.op = new_branch.op + set_children!(tree, get_children(new_branch)) + return nothing +end end diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 372b364a..a6c38e00 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -21,7 +21,7 @@ import ..NodeModule: leaf_convert, leaf_hash, leaf_equal, - set_node!, + set_leaf!, @make_accessors import ..NodePreallocationModule: copy_into!, allocate_container import ..NodeUtilsModule: @@ -145,24 +145,18 @@ function leaf_copy(t::ParametricNode{T}) where {T} return n end end -function set_node!(tree::ParametricNode, new_tree::ParametricNode) - tree.degree = (deg = new_tree.degree) - if deg == 0 - if new_tree.constant - tree.constant = true - tree.val = new_tree.val - elseif !new_tree.is_parameter - tree.constant = false - tree.is_parameter = false - tree.feature = new_tree.feature - else - tree.constant = false - tree.is_parameter = true - tree.parameter = new_tree.parameter - end +function set_leaf!(tree::ParametricNode, new_leaf::ParametricNode) + if new_leaf.constant + tree.constant = true + tree.val = new_leaf.val + elseif !new_leaf.is_parameter + tree.constant = false + tree.is_parameter = false + tree.feature = new_leaf.feature else - tree.op = new_tree.op - set_children!(tree, get_children(new_tree)) + tree.constant = false + tree.is_parameter = true + tree.parameter = new_leaf.parameter end return nothing end From f16af837762405fcdef6ebe933a6c8aa360dda88 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 Jun 2025 00:48:01 +0100 Subject: [PATCH 70/74] refactor: use generic getters and setters --- ext/DynamicExpressionsLoopVectorizationExt.jl | 64 +++++----- ext/DynamicExpressionsSymbolicUtilsExt.jl | 11 +- src/DynamicExpressions.jl | 4 + src/Evaluate.jl | 109 ++++++++++-------- src/Simplify.jl | 96 +++++++-------- 5 files changed, 154 insertions(+), 130 deletions(-) diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index 4d2f7c33..7b41f767 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -2,6 +2,7 @@ module DynamicExpressionsLoopVectorizationExt using LoopVectorization: @turbo using DynamicExpressions: AbstractExpressionNode +using DynamicExpressions.NodeModule: get_child using DynamicExpressions.UtilsModule: ResultOk using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions, get_array, get_feature_array, get_filled_array @@ -39,9 +40,10 @@ function deg1_l2_ll0_lr0_eval( op_l::F2, eval_options::EvalOptions{true}, ) where {T<:Number,F,F2} - if tree.l.l.constant && tree.l.r.constant - val_ll = tree.l.l.val - val_lr = tree.l.r.val + if get_child(get_child(tree, 1), 1).constant && + get_child(get_child(tree, 1), 2).constant + val_ll = get_child(get_child(tree, 1), 1).val + val_lr = get_child(get_child(tree, 1), 2).val @return_on_nonfinite_val(eval_options, val_ll, cX) @return_on_nonfinite_val(eval_options, val_lr, cX) x_l = op_l(val_ll, val_lr)::T @@ -49,10 +51,10 @@ function deg1_l2_ll0_lr0_eval( x = op(x_l)::T @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true) - elseif tree.l.l.constant - val_ll = tree.l.l.val + elseif get_child(get_child(tree, 1), 1).constant + val_ll = get_child(get_child(tree, 1), 1).val @return_on_nonfinite_val(eval_options, val_ll, cX) - feature_lr = tree.l.r.feature + feature_lr = get_child(get_child(tree, 1), 2).feature cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) @turbo for j in axes(cX, 2) x_l = op_l(val_ll, cX[feature_lr, j]) @@ -60,9 +62,9 @@ function deg1_l2_ll0_lr0_eval( cumulator[j] = x end return ResultOk(cumulator, true) - elseif tree.l.r.constant - feature_ll = tree.l.l.feature - val_lr = tree.l.r.val + elseif get_child(get_child(tree, 1), 2).constant + feature_ll = get_child(get_child(tree, 1), 1).feature + val_lr = get_child(get_child(tree, 1), 2).val @return_on_nonfinite_val(eval_options, val_lr, cX) cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) @turbo for j in axes(cX, 2) @@ -72,8 +74,8 @@ function deg1_l2_ll0_lr0_eval( end return ResultOk(cumulator, true) else - feature_ll = tree.l.l.feature - feature_lr = tree.l.r.feature + feature_ll = get_child(get_child(tree, 1), 1).feature + feature_lr = get_child(get_child(tree, 1), 2).feature cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) @turbo for j in axes(cX, 2) x_l = op_l(cX[feature_ll, j], cX[feature_lr, j]) @@ -91,8 +93,8 @@ function deg1_l1_ll0_eval( op_l::F2, eval_options::EvalOptions{true}, ) where {T<:Number,F,F2} - if tree.l.l.constant - val_ll = tree.l.l.val + if get_child(get_child(tree, 1), 1).constant + val_ll = get_child(get_child(tree, 1), 1).val @return_on_nonfinite_val(eval_options, val_ll, cX) x_l = op_l(val_ll)::T @return_on_nonfinite_val(eval_options, x_l, cX) @@ -100,7 +102,7 @@ function deg1_l1_ll0_eval( @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true) else - feature_ll = tree.l.l.feature + feature_ll = get_child(get_child(tree, 1), 1).feature cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) @turbo for j in axes(cX, 2) x_l = op_l(cX[feature_ll, j]) @@ -117,28 +119,28 @@ function deg2_l0_r0_eval( op::F, eval_options::EvalOptions{true}, ) where {T<:Number,F} - if tree.l.constant && tree.r.constant - val_l = tree.l.val + if get_child(tree, 1).constant && get_child(tree, 2).constant + val_l = get_child(tree, 1).val @return_on_nonfinite_val(eval_options, val_l, cX) - val_r = tree.r.val + val_r = get_child(tree, 2).val @return_on_nonfinite_val(eval_options, val_r, cX) x = op(val_l, val_r)::T @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true) - elseif tree.l.constant + elseif get_child(tree, 1).constant cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) - val_l = tree.l.val + val_l = get_child(tree, 1).val @return_on_nonfinite_val(eval_options, val_l, cX) - feature_r = tree.r.feature + feature_r = get_child(tree, 2).feature @turbo for j in axes(cX, 2) x = op(val_l, cX[feature_r, j]) cumulator[j] = x end return ResultOk(cumulator, true) - elseif tree.r.constant + elseif get_child(tree, 2).constant cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) - feature_l = tree.l.feature - val_r = tree.r.val + feature_l = get_child(tree, 1).feature + val_r = get_child(tree, 2).val @return_on_nonfinite_val(eval_options, val_r, cX) @turbo for j in axes(cX, 2) x = op(cX[feature_l, j], val_r) @@ -147,8 +149,8 @@ function deg2_l0_r0_eval( return ResultOk(cumulator, true) else cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) - feature_l = tree.l.feature - feature_r = tree.r.feature + feature_l = get_child(tree, 1).feature + feature_r = get_child(tree, 2).feature @turbo for j in axes(cX, 2) x = op(cX[feature_l, j], cX[feature_r, j]) cumulator[j] = x @@ -165,8 +167,8 @@ function deg2_l0_eval( op::F, eval_options::EvalOptions{true}, ) where {T<:Number,F} - if tree.l.constant - val = tree.l.val + if get_child(tree, 1).constant + val = get_child(tree, 1).val @return_on_nonfinite_val(eval_options, val, cX) @turbo for j in eachindex(cumulator) x = op(val, cumulator[j]) @@ -174,7 +176,7 @@ function deg2_l0_eval( end return ResultOk(cumulator, true) else - feature = tree.l.feature + feature = get_child(tree, 1).feature @turbo for j in eachindex(cumulator) x = op(cX[feature, j], cumulator[j]) cumulator[j] = x @@ -190,8 +192,8 @@ function deg2_r0_eval( op::F, eval_options::EvalOptions{true}, ) where {T<:Number,F} - if tree.r.constant - val = tree.r.val + if get_child(tree, 2).constant + val = get_child(tree, 2).val @return_on_nonfinite_val(eval_options, val, cX) @turbo for j in eachindex(cumulator) x = op(cumulator[j], val) @@ -199,7 +201,7 @@ function deg2_r0_eval( end return ResultOk(cumulator, true) else - feature = tree.r.feature + feature = get_child(tree, 2).feature @turbo for j in eachindex(cumulator) x = op(cumulator[j], cX[feature, j]) cumulator[j] = x diff --git a/ext/DynamicExpressionsSymbolicUtilsExt.jl b/ext/DynamicExpressionsSymbolicUtilsExt.jl index e455eda7..b3423e9f 100644 --- a/ext/DynamicExpressionsSymbolicUtilsExt.jl +++ b/ext/DynamicExpressionsSymbolicUtilsExt.jl @@ -3,7 +3,7 @@ module DynamicExpressionsSymbolicUtilsExt using DynamicExpressions: AbstractExpression, get_tree, get_operators, get_variable_names, default_node_type using DynamicExpressions.NodeModule: - AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE + AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE, get_child using DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum using DynamicExpressions.UtilsModule: deprecate_varmap @@ -39,7 +39,7 @@ end subs_bad(x) = is_valid(x) ? x : Inf function parse_tree_to_eqs( - tree::AbstractExpressionNode{T}, + tree::AbstractExpressionNode{T,2}, operators::AbstractOperatorEnum, index_functions::Bool=false, ) where {T} @@ -50,7 +50,8 @@ function parse_tree_to_eqs( end # Collect the next children # TODO: Type instability! - children = tree.degree == 2 ? (tree.l, tree.r) : (tree.l,) + children = + tree.degree == 2 ? (get_child(tree, 1), get_child(tree, 2)) : (get_child(tree, 1),) # Get the operation op = tree.degree == 2 ? operators.binops[tree.op] : operators.unaops[tree.op] # Create an N tuple of Numbers for each argument @@ -219,13 +220,13 @@ will generate a symbolic equation in SymbolicUtils.jl format. (CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84). """ function node_to_symbolic( - tree::AbstractExpressionNode, + tree::AbstractExpressionNode{T,2}, operators::AbstractOperatorEnum; variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, index_functions::Bool=false, # Deprecated: varMap=nothing, -) +) where {T} variable_names = deprecate_varmap(variable_names, varMap, :node_to_symbolic) expr = subs_bad(parse_tree_to_eqs(tree, operators, index_functions)) # Check for NaN and Inf diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index c1087516..8931a172 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -41,6 +41,10 @@ import .ValueInterfaceModule: AbstractExpressionNode, GraphNode, Node, + get_child, + set_child!, + get_children, + set_children!, copy_node, set_node!, tree_mapreduce, diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 3e178604..6cf01375 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -3,7 +3,12 @@ module EvaluateModule using DispatchDoctor: @stable, @unstable import ..NodeModule: - AbstractExpressionNode, constructorof, max_degree, get_children, with_type_parameters + AbstractExpressionNode, + constructorof, + max_degree, + get_children, + get_child, + with_type_parameters import ..StringsModule: string_tree import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum import ..UtilsModule: fill_similar, counttuple, ResultOk @@ -422,10 +427,10 @@ end long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote - result_l = _eval_tree_array(tree.l, cX, operators, eval_options) + result_l = _eval_tree_array(get_child(tree, 1), cX, operators, eval_options) !result_l.ok && return result_l @return_on_nonfinite_array(eval_options, result_l.x) - result_r = _eval_tree_array(tree.r, cX, operators, eval_options) + result_r = _eval_tree_array(get_child(tree, 2), cX, operators, eval_options) !result_r.ok && return result_r @return_on_nonfinite_array(eval_options, result_r.x) # op(x, y), for any x or y @@ -437,25 +442,27 @@ end $nbin, i -> i == op_idx, i -> let op = operators.binops[i] - if tree.l.degree == 0 && tree.r.degree == 0 + if get_child(tree, 1).degree == 0 && get_child(tree, 2).degree == 0 deg2_l0_r0_eval(tree, cX, op, eval_options) - elseif tree.r.degree == 0 - result_l = _eval_tree_array(tree.l, cX, operators, eval_options) + elseif get_child(tree, 2).degree == 0 + result_l = _eval_tree_array(get_child(tree, 1), cX, operators, eval_options) !result_l.ok && return result_l @return_on_nonfinite_array(eval_options, result_l.x) # op(x, y), where y is a constant or variable but x is not. deg2_r0_eval(tree, result_l.x, cX, op, eval_options) - elseif tree.l.degree == 0 - result_r = _eval_tree_array(tree.r, cX, operators, eval_options) + elseif get_child(tree, 1).degree == 0 + result_r = _eval_tree_array(get_child(tree, 2), cX, operators, eval_options) !result_r.ok && return result_r @return_on_nonfinite_array(eval_options, result_r.x) # op(x, y), where x is a constant or variable but y is not. deg2_l0_eval(tree, result_r.x, cX, op, eval_options) else - result_l = _eval_tree_array(tree.l, cX, operators, eval_options) + result_l = _eval_tree_array(get_child(tree, 1), cX, operators, eval_options) !result_l.ok && return result_l @return_on_nonfinite_array(eval_options, result_l.x) - result_r = _eval_tree_array(tree.r, cX, operators, eval_options) + result_r = _eval_tree_array( + get_child(tree, 2), cX, operators, eval_options + ) !result_r.ok && return result_r @return_on_nonfinite_array(eval_options, result_r.x) # op(x, y), for any x or y @@ -476,7 +483,7 @@ end long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote - result = _eval_tree_array(tree.l, cX, operators, eval_options) + result = _eval_tree_array(get_child(tree, 1), cX, operators, eval_options) !result.ok && return result @return_on_nonfinite_array(eval_options, result.x) deg1_eval(result.x, operators.unaops[op_idx], eval_options) @@ -489,21 +496,24 @@ end $nuna, i -> i == op_idx, i -> let op = operators.unaops[i] - if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0 + if get_child(tree, 1).degree == 2 && + get_child(get_child(tree, 1), 1).degree == 0 && + get_child(get_child(tree, 1), 2).degree == 0 # op(op2(x, y)), where x, y, z are constants or variables. - l_op_idx = tree.l.op + l_op_idx = get_child(tree, 1).op dispatch_deg1_l2_ll0_lr0_eval( tree, cX, op, l_op_idx, operators.binops, eval_options ) - elseif tree.l.degree == 1 && tree.l.l.degree == 0 + elseif get_child(tree, 1).degree == 1 && + get_child(get_child(tree, 1), 1).degree == 0 # op(op2(x)), where x is a constant or variable. - l_op_idx = tree.l.op + l_op_idx = get_child(tree, 1).op dispatch_deg1_l1_ll0_eval( tree, cX, op, l_op_idx, operators.unaops, eval_options ) else # op(x), for any x. - result = _eval_tree_array(tree.l, cX, operators, eval_options) + result = _eval_tree_array(get_child(tree, 1), cX, operators, eval_options) !result.ok && return result @return_on_nonfinite_array(eval_options, result.x) deg1_eval(result.x, op, eval_options) @@ -560,9 +570,10 @@ function deg1_l2_ll0_lr0_eval( op_l::F2, eval_options::EvalOptions{false,false}, ) where {T,F,F2} - if tree.l.l.constant && tree.l.r.constant - val_ll = tree.l.l.val - val_lr = tree.l.r.val + if get_child(get_child(tree, 1), 1).constant && + get_child(get_child(tree, 1), 2).constant + val_ll = get_child(get_child(tree, 1), 1).val + val_lr = get_child(get_child(tree, 1), 2).val @return_on_nonfinite_val(eval_options, val_ll, cX) @return_on_nonfinite_val(eval_options, val_lr, cX) x_l = op_l(val_ll, val_lr)::T @@ -570,10 +581,10 @@ function deg1_l2_ll0_lr0_eval( x = op(x_l)::T @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true) - elseif tree.l.l.constant - val_ll = tree.l.l.val + elseif get_child(get_child(tree, 1), 1).constant + val_ll = get_child(get_child(tree, 1), 1).val @return_on_nonfinite_val(eval_options, val_ll, cX) - feature_lr = tree.l.r.feature + feature_lr = get_child(get_child(tree, 1), 2).feature cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) @inbounds @simd for j in axes(cX, 2) x_l = op_l(val_ll, cX[feature_lr, j])::T @@ -581,9 +592,9 @@ function deg1_l2_ll0_lr0_eval( cumulator[j] = x end return ResultOk(cumulator, true) - elseif tree.l.r.constant - feature_ll = tree.l.l.feature - val_lr = tree.l.r.val + elseif get_child(get_child(tree, 1), 2).constant + feature_ll = get_child(get_child(tree, 1), 1).feature + val_lr = get_child(get_child(tree, 1), 2).val @return_on_nonfinite_val(eval_options, val_lr, cX) cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) @inbounds @simd for j in axes(cX, 2) @@ -593,8 +604,8 @@ function deg1_l2_ll0_lr0_eval( end return ResultOk(cumulator, true) else - feature_ll = tree.l.l.feature - feature_lr = tree.l.r.feature + feature_ll = get_child(get_child(tree, 1), 1).feature + feature_lr = get_child(get_child(tree, 1), 2).feature cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) @inbounds @simd for j in axes(cX, 2) x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])::T @@ -613,8 +624,8 @@ function deg1_l1_ll0_eval( op_l::F2, eval_options::EvalOptions{false,false}, ) where {T,F,F2} - if tree.l.l.constant - val_ll = tree.l.l.val + if get_child(get_child(tree, 1), 1).constant + val_ll = get_child(get_child(tree, 1), 1).val @return_on_nonfinite_val(eval_options, val_ll, cX) x_l = op_l(val_ll)::T @return_on_nonfinite_val(eval_options, x_l, cX) @@ -622,7 +633,7 @@ function deg1_l1_ll0_eval( @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true) else - feature_ll = tree.l.l.feature + feature_ll = get_child(get_child(tree, 1), 1).feature cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) @inbounds @simd for j in axes(cX, 2) x_l = op_l(cX[feature_ll, j])::T @@ -633,35 +644,35 @@ function deg1_l1_ll0_eval( end end -# op(x, y) for x and y variable/constant +# op(x, y), for x, y either constants or variables. function deg2_l0_r0_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, eval_options::EvalOptions{false,false}, ) where {T,F} - if tree.l.constant && tree.r.constant - val_l = tree.l.val + if get_child(tree, 1).constant && get_child(tree, 2).constant + val_l = get_child(tree, 1).val @return_on_nonfinite_val(eval_options, val_l, cX) - val_r = tree.r.val + val_r = get_child(tree, 2).val @return_on_nonfinite_val(eval_options, val_r, cX) x = op(val_l, val_r)::T @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true) - elseif tree.l.constant + elseif get_child(tree, 1).constant cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) - val_l = tree.l.val + val_l = get_child(tree, 1).val @return_on_nonfinite_val(eval_options, val_l, cX) - feature_r = tree.r.feature + feature_r = get_child(tree, 2).feature @inbounds @simd for j in axes(cX, 2) x = op(val_l, cX[feature_r, j])::T cumulator[j] = x end return ResultOk(cumulator, true) - elseif tree.r.constant + elseif get_child(tree, 2).constant cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) - feature_l = tree.l.feature - val_r = tree.r.val + feature_l = get_child(tree, 1).feature + val_r = get_child(tree, 2).val @return_on_nonfinite_val(eval_options, val_r, cX) @inbounds @simd for j in axes(cX, 2) x = op(cX[feature_l, j], val_r)::T @@ -670,8 +681,8 @@ function deg2_l0_r0_eval( return ResultOk(cumulator, true) else cumulator = get_array(eval_options.buffer, cX, axes(cX, 2)) - feature_l = tree.l.feature - feature_r = tree.r.feature + feature_l = get_child(tree, 1).feature + feature_r = get_child(tree, 2).feature @inbounds @simd for j in axes(cX, 2) x = op(cX[feature_l, j], cX[feature_r, j])::T cumulator[j] = x @@ -680,7 +691,7 @@ function deg2_l0_r0_eval( end end -# op(x, y) for x variable/constant, y arbitrary +# op(x, y), where y is a constant or variable but x is not. function deg2_l0_eval( tree::AbstractExpressionNode{T}, cumulator::AbstractVector{T}, @@ -688,8 +699,8 @@ function deg2_l0_eval( op::F, eval_options::EvalOptions{false,false}, ) where {T,F} - if tree.l.constant - val = tree.l.val + if get_child(tree, 1).constant + val = get_child(tree, 1).val @return_on_nonfinite_val(eval_options, val, cX) @inbounds @simd for j in eachindex(cumulator) x = op(val, cumulator[j])::T @@ -697,7 +708,7 @@ function deg2_l0_eval( end return ResultOk(cumulator, true) else - feature = tree.l.feature + feature = get_child(tree, 1).feature @inbounds @simd for j in eachindex(cumulator) x = op(cX[feature, j], cumulator[j])::T cumulator[j] = x @@ -714,8 +725,8 @@ function deg2_r0_eval( op::F, eval_options::EvalOptions{false,false}, ) where {T,F} - if tree.r.constant - val = tree.r.val + if get_child(tree, 2).constant + val = get_child(tree, 2).val @return_on_nonfinite_val(eval_options, val, cX) @inbounds @simd for j in eachindex(cumulator) x = op(cumulator[j], val)::T @@ -723,7 +734,7 @@ function deg2_r0_eval( end return ResultOk(cumulator, true) else - feature = tree.r.feature + feature = get_child(tree, 2).feature @inbounds @simd for j in eachindex(cumulator) x = op(cumulator[j], cX[feature, j])::T cumulator[j] = x diff --git a/src/Simplify.jl b/src/Simplify.jl index ef17b33f..842bafbb 100644 --- a/src/Simplify.jl +++ b/src/Simplify.jl @@ -1,6 +1,7 @@ module SimplifyModule -import ..NodeModule: AbstractExpressionNode, constructorof, Node, copy_node, set_node! +import ..NodeModule: + AbstractExpressionNode, constructorof, Node, copy_node, set_node!, set_child!, get_child import ..NodeUtilsModule: tree_mapreduce, is_node_constant import ..OperatorEnumModule: AbstractOperatorEnum import ..ValueInterfaceModule: is_valid @@ -27,33 +28,38 @@ function combine_operators(tree::Node{T,2}, operators::AbstractOperatorEnum) whe if tree.degree == 0 return tree elseif tree.degree == 1 - tree.l = combine_operators(tree.l, operators) + set_child!(tree, combine_operators(get_child(tree, 1), operators), 1) elseif tree.degree == 2 - tree.l = combine_operators(tree.l, operators) - tree.r = combine_operators(tree.r, operators) + set_child!(tree, combine_operators(get_child(tree, 1), operators), 1) + set_child!(tree, combine_operators(get_child(tree, 2), operators), 2) end top_level_constant = - tree.degree == 2 && (is_node_constant(tree.l) || is_node_constant(tree.r)) + tree.degree == 2 && + (is_node_constant(get_child(tree, 1)) || is_node_constant(get_child(tree, 2))) if tree.degree == 2 && is_commutative(operators.binops[tree.op]) && top_level_constant # TODO: Does this break SymbolicRegression.jl due to the different names of operators? op = tree.op # Put the constant in r. Need to assume var in left for simplification assumption. - if is_node_constant(tree.l) - tmp = tree.r - tree.r = tree.l - tree.l = tmp + if is_node_constant(get_child(tree, 1)) + tmp = get_child(tree, 2) + set_child!(tree, get_child(tree, 1), 2) + set_child!(tree, tmp, 1) end - topconstant = tree.r.val + topconstant = get_child(tree, 2).val # Simplify down first - below = tree.l + below = get_child(tree, 1) if below.degree == 2 && below.op == op - if is_node_constant(below.l) + if is_node_constant(get_child(below, 1)) tree = below - tree.l.val = _op_kernel(operators.binops[op], tree.l.val, topconstant) - elseif is_node_constant(below.r) + get_child(tree, 1).val = _op_kernel( + operators.binops[op], get_child(tree, 1).val, topconstant + ) + elseif is_node_constant(get_child(below, 2)) tree = below - tree.r.val = _op_kernel(operators.binops[op], tree.r.val, topconstant) + get_child(tree, 2).val = _op_kernel( + operators.binops[op], get_child(tree, 2).val, topconstant + ) end end end @@ -62,42 +68,42 @@ function combine_operators(tree::Node{T,2}, operators::AbstractOperatorEnum) whe # Currently just simplifies subtraction. (can't assume both plus and sub are operators) # Not commutative, so use different op. - if is_node_constant(tree.l) - if tree.r.degree == 2 && tree.op == tree.r.op - if is_node_constant(tree.r.l) + if is_node_constant(get_child(tree, 1)) + if get_child(tree, 2).degree == 2 && tree.op == get_child(tree, 2).op + if is_node_constant(get_child(get_child(tree, 2), 1)) #(const - (const - var)) => (var - const) - l = tree.l - r = tree.r - simplified_const = (r.l.val - l.val) #neg(sub(l.val, r.l.val)) - tree.l = tree.r.r - tree.r = l - tree.r.val = simplified_const - elseif is_node_constant(tree.r.r) + l = get_child(tree, 1) + r = get_child(tree, 2) + simplified_const = (get_child(r, 1).val - l.val) #neg(sub(l.val, r.l.val)) + set_child!(tree, get_child(get_child(tree, 2), 2), 1) + set_child!(tree, l, 2) + get_child(tree, 2).val = simplified_const + elseif is_node_constant(get_child(get_child(tree, 2), 2)) #(const - (var - const)) => (const - var) - l = tree.l - r = tree.r - simplified_const = l.val + r.r.val #plus(l.val, r.r.val) - tree.r = tree.r.l - tree.l.val = simplified_const + l = get_child(tree, 1) + r = get_child(tree, 2) + simplified_const = l.val + get_child(r, 2).val #plus(l.val, r.r.val) + set_child!(tree, get_child(get_child(tree, 2), 1), 2) + get_child(tree, 1).val = simplified_const end end - else #tree.r is a constant - if tree.l.degree == 2 && tree.op == tree.l.op - if is_node_constant(tree.l.l) + else #get_child(tree, 2) is a constant + if get_child(tree, 1).degree == 2 && tree.op == get_child(tree, 1).op + if is_node_constant(get_child(get_child(tree, 1), 1)) #((const - var) - const) => (const - var) - l = tree.l - r = tree.r - simplified_const = l.l.val - r.val#sub(l.l.val, r.val) - tree.r = tree.l.r - tree.l = r - tree.l.val = simplified_const - elseif is_node_constant(tree.l.r) + l = get_child(tree, 1) + r = get_child(tree, 2) + simplified_const = get_child(l, 1).val - r.val#sub(l.l.val, r.val) + set_child!(tree, get_child(get_child(tree, 1), 2), 2) + set_child!(tree, r, 1) + get_child(tree, 1).val = simplified_const + elseif is_node_constant(get_child(get_child(tree, 1), 2)) #((var - const) - const) => (var - const) - l = tree.l - r = tree.r - simplified_const = r.val + l.r.val #plus(r.val, l.r.val) - tree.l = tree.l.l - tree.r.val = simplified_const + l = get_child(tree, 1) + r = get_child(tree, 2) + simplified_const = r.val + get_child(l, 2).val #plus(r.val, l.r.val) + set_child!(tree, get_child(get_child(tree, 1), 1), 1) + get_child(tree, 2).val = simplified_const end end end From 981e3b022d7c587f2c107729364e661c5c29550a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 Jun 2025 01:39:20 +0100 Subject: [PATCH 71/74] docs: describe generic getters and setters --- docs/src/api.md | 38 +++++++++++++++++++++++++++++++++++++- src/Node.jl | 32 ++++++++++++++++++++++++++++++++ src/Simplify.jl | 8 ++++---- 3 files changed, 73 insertions(+), 5 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 6bf37b73..248ababf 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -75,6 +75,42 @@ You can create a copy of a node with `copy_node`: copy_node ``` +## Generic Node Accessors + +For working with nodes of arbitrary arity: + +```@docs +get_child +set_child! +get_children +set_children! +``` + +Examples: + +```julia +# Define operators including ternary +my_ternary(x, y, z) = x + y * z +operators = OperatorEnum(((sin,), (+, *), (my_ternary,))) # (unary, binary, ternary) + +tree = Node{Float64,3}(; op=1, children=(Node{Float64,3}(; val=1.0), Node{Float64,3}(; val=2.0))) +new_child = Node{Float64,3}(; val=3.0) + +left_child = get_child(tree, 1) +right_child = get_child(tree, 2) + +set_child!(tree, new_child, 1) + +children = get_children(tree) +left, right = get_children(tree, Val(2)) # type stable + +# Transform to ternary operation +child1, child2, child3 = Node{Float64,3}(; val=4.0), Node{Float64,3}(; val=5.0), Node{Float64,3}(; val=6.0) +set_children!(tree, (child1, child2, child3)) +tree.op = 1 # my_ternary +tree.degree = 3 +``` + ## Graph Nodes You can describe an equation as a *graph* rather than a tree @@ -109,7 +145,7 @@ This means that we only need to change it once to have changes propagate across the expression: ```julia -julia> y.r.val *= 0.9 +julia> get_child(y, 2).val *= 0.9 1.35 julia> z diff --git a/src/Node.jl b/src/Node.jl index dad32cca..f7d2517f 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -117,6 +117,9 @@ nodes, you can evaluate or print a given expression. object. Only defined if `degree >= 1` - `children::NTuple{D,Node{T,D}}`: Children of the node. Only defined up to `degree` +For accessing and modifying children, use [`get_child`](@ref), [`set_child!`](@ref), +[`get_children`](@ref), and [`set_children!`](@ref). + # Constructors @@ -188,17 +191,46 @@ end @inline function get_children(node::AbstractNode) return getfield(node, :children) end + +""" + get_children(node::AbstractNode) + get_children(node::AbstractNode, ::Val{n}) + +Return the children tuple of a node. The first form returns the complete children tuple as stored. +The second form returns a tuple of exactly `n` children, useful for type stability when the +number of children needed is known at compile time. +""" @inline function get_children(node::AbstractNode, ::Val{n}) where {n} cs = get_children(node) return ntuple(i -> cs[i], Val(Int(n))) end + +""" + get_child(node::AbstractNode, i::Integer) + +Return the `i`-th child of a node (1-indexed). +""" @inline function get_child(n::AbstractNode{D}, i::Int) where {D} return get_children(n)[i] end + +""" + set_child!(node::AbstractNode, child::AbstractNode, i::Integer) + +Replace the `i`-th child of a node (1-indexed) with the given child node. +Returns the new child. Updates the children tuple in-place. +""" @inline function set_child!(n::AbstractNode{D}, child::AbstractNode{D}, i::Int) where {D} set_children!(n, Base.setindex(get_children(n), child, i)) return child end + +""" + set_children!(node::AbstractNode, children::Tuple) + +Replace all children of a node with the given tuple. If fewer children are +provided than the node's maximum degree, remaining slots are filled with poison nodes. +""" @inline function set_children!(n::AbstractNode{D}, children::Union{Tuple,AbstractVector{<:AbstractNode{D}}}) where {D} D2 = length(children) if D === D2 diff --git a/src/Simplify.jl b/src/Simplify.jl index 842bafbb..5ef6649c 100644 --- a/src/Simplify.jl +++ b/src/Simplify.jl @@ -74,7 +74,7 @@ function combine_operators(tree::Node{T,2}, operators::AbstractOperatorEnum) whe #(const - (const - var)) => (var - const) l = get_child(tree, 1) r = get_child(tree, 2) - simplified_const = (get_child(r, 1).val - l.val) #neg(sub(l.val, r.l.val)) + simplified_const = (get_child(r, 1).val - l.val) #neg(sub(l.val, get_child(r, 1).val)) set_child!(tree, get_child(get_child(tree, 2), 2), 1) set_child!(tree, l, 2) get_child(tree, 2).val = simplified_const @@ -82,7 +82,7 @@ function combine_operators(tree::Node{T,2}, operators::AbstractOperatorEnum) whe #(const - (var - const)) => (const - var) l = get_child(tree, 1) r = get_child(tree, 2) - simplified_const = l.val + get_child(r, 2).val #plus(l.val, r.r.val) + simplified_const = l.val + get_child(r, 2).val #plus(l.val, get_child(r, 2).val) set_child!(tree, get_child(get_child(tree, 2), 1), 2) get_child(tree, 1).val = simplified_const end @@ -93,7 +93,7 @@ function combine_operators(tree::Node{T,2}, operators::AbstractOperatorEnum) whe #((const - var) - const) => (const - var) l = get_child(tree, 1) r = get_child(tree, 2) - simplified_const = get_child(l, 1).val - r.val#sub(l.l.val, r.val) + simplified_const = get_child(l, 1).val - r.val#sub(get_child(l, 1).val, r.val) set_child!(tree, get_child(get_child(tree, 1), 2), 2) set_child!(tree, r, 1) get_child(tree, 1).val = simplified_const @@ -101,7 +101,7 @@ function combine_operators(tree::Node{T,2}, operators::AbstractOperatorEnum) whe #((var - const) - const) => (var - const) l = get_child(tree, 1) r = get_child(tree, 2) - simplified_const = r.val + get_child(l, 2).val #plus(r.val, l.r.val) + simplified_const = r.val + get_child(l, 2).val #plus(r.val, get_child(l, 2).val) set_child!(tree, get_child(get_child(tree, 1), 1), 1) get_child(tree, 2).val = simplified_const end From 84f81211b3264232f28ba291e2d89f2986d1acf4 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 Jun 2025 06:07:46 +0100 Subject: [PATCH 72/74] docs: describe changelog --- CHANGELOG.md | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..659fcc00 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,38 @@ +# CHANGELOG + +## 2.0.0 + +- Nodes can now have arbitrary numbers of children, not just binary trees. This was designed in such a way to have identical performance to the previous version. Essentially, the `Node{T,D}` type is a wrapper around a tuple of `D` children. + - Note that `Node{T}` automatically converts to `Node{T,2}` in many contexts, for backwards compatibility. +- Node type signature changed from `Node{T}` to `Node{T,D}`. Usually `D` is 2. +- Direct property access `.l` and `.r` is automatically forwarded to `get_child(tree, 1)` and `get_child(tree, 2)`, + but it is recommended to use the generic accessors instead. +- Similarly, `.l = child` should be replaced with `set_child!(tree, child, 1)` and similar for `.r`. +- All internal code migrated to use generic accessors. + +### Backwards Compatibility + +- Existing `.l` and `.r` access continues to work without warnings + +### Breaking Changes + +The only breaking change is if: + +1. You have any types that are subtyped to `<:AbstractExpressionNode{T}` or `<:AbstractNode{T}`. These should now be subtyped to `<:AbstractExpressionNode{T,2}` or `<:AbstractNode{T,2}`. You may also allow a `D` parameter in case you want to support higher-arity trees. +2. You assume a tree has type, e.g., `=== Node{T}`, rather than `<: Node{T}`. So any methods dispatched to `::Type{Node{T}}` will also break. (To be safe you should always use a form `<: Node{T}` in case of future type changes - in any library.) +3. You assume `tree.degree <= 2` in conditional logic, and your code interacts with a tree that is _not_ a binary tree. For example, the following pattern was common before this change: + + ```julia + if tree.degree == 0 + #= operations on leaf node =# + elseif tree.degree == 1 + #= operations on unary node =# + else + # BAD: ASSUMED TO BE BINARY + #= operations on binary node, using `.l` and `.r` only =# + end + ``` + + This will obviously break if you pass a tree that is not binary, such as `tree::Node{T,3}`. + - To fix this, you can use the `get_children` function to get the children of the tree as a tuple of `D` children, and then index up to `tree.degree`. + - Inside DynamicExpressions, we commonly use `Base.Cartesian.@nif` to generate code for different degrees, to avoid any unstable types. From caca4878575a3b53c455938c24536feba0d6eb0e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 Jun 2025 06:08:14 +0100 Subject: [PATCH 73/74] docs: fix signatures --- docs/src/api.md | 2 +- docs/src/utils.md | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 248ababf..edd91cd4 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -60,7 +60,7 @@ When using these node constructors, types will automatically be promoted. You can convert the type of a node using `convert`: ```@docs -convert(::Type{AbstractExpressionNode{T1}}, tree::AbstractExpressionNode{T2}) where {T1, T2} +convert(::Type{N1}, tree::N2) where {T1,T2,D1,D2,N1<:AbstractExpressionNode{T1,D1},N2<:AbstractExpressionNode{T2,D2}} ``` You can set a `tree` (in-place) with `set_node!`: diff --git a/docs/src/utils.md b/docs/src/utils.md index 5ae581ee..74ef7e9d 100644 --- a/docs/src/utils.md +++ b/docs/src/utils.md @@ -15,7 +15,6 @@ mapreduce(f::F, op::G, tree::AbstractNode; return_type, f_on_shared, break_shari any(f::F, tree::AbstractNode) where {F<:Function} all(f::F, tree::AbstractNode) where {F<:Function} map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing; break_sharing::Val=Val(false)) where {F<:Function,RT} -convert(::Type{<:AbstractExpressionNode{T1}}, n::AbstractExpressionNode{T2}) where {T1,T2} hash(tree::AbstractExpressionNode{T}, h::UInt; break_sharing::Val=Val(false)) where {T} ``` From acbef0f77ae37211d4eced764a4b28b6011e6650 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 Jun 2025 06:25:10 +0100 Subject: [PATCH 74/74] docs: fix example with base operations --- test/test_base_2.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_base_2.jl b/test/test_base_2.jl index e0ab645b..61a0c5a8 100644 --- a/test/test_base_2.jl +++ b/test/test_base_2.jl @@ -9,10 +9,10 @@ First, let's create a node to reference `feature=1` of our dataset: =# - using DynamicExpressions, Random + using DynamicExpressions, Random, Test x = Node{Float64}(; feature=1) - @test x isa Node{Float64} + @test x isa Node{Float64,2} # We can also create values, using `val`: const_1 = Node{Float64}(; val=1.0)