Skip to content

Add support for setting constants by name #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ using Reexport
has_operators,
has_constants,
get_constants,
set_constants
set_constants,
get_named_constants,
set_named_constants!
@reexport import .OperatorEnumModule: AbstractOperatorEnum
@reexport import .OperatorEnumConstructionModule:
OperatorEnum, GenericOperatorEnum, @extend_operators
Expand Down
17 changes: 11 additions & 6 deletions src/Equation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ nodes, you can evaluate or print a given expression.
argument to the binary operator.
"""
mutable struct Node{T}
name::Symbol # A unique identifier for each node
degree::Int # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
constant::Bool # false if variable
val::Union{T,Nothing} # If is a constant, this stores the actual value
Expand All @@ -43,22 +44,26 @@ mutable struct Node{T}
op::Int # If operator, this is the index of the operator in operators.binary_operators, or operators.unary_operators
l::Node{T} # Left child node. Only defined for degree=1 or degree=2.
r::Node{T} # Right child node. Only defined for degree=2.

#################
## Constructors:
#################
Node(d::Int, c::Bool, v::_T) where {_T} = new{_T}(d, c, v)
Node(::Type{_T}, d::Int, c::Bool, v::_T) where {_T} = new{_T}(d, c, v)
Node(::Type{_T}, d::Int, c::Bool, v::Nothing, f::Int) where {_T} = new{_T}(d, c, v, f)
Node(d::Int, c::Bool, v::_T) where {_T} = new{_T}(gensym("Constant"), d, c, v)
function Node(::Type{_T}, d::Int, c::Bool, v::_T) where {_T}
return new{_T}(gensym("Constant"), d, c, v)
end
function Node(::Type{_T}, d::Int, c::Bool, v::Nothing, f::Int) where {_T}
return new{_T}(gensym("Feature"), d, c, v, f)
end
function Node(d::Int, c::Bool, v::Nothing, f::Int, o::Int, l::Node{_T}) where {_T}
return new{_T}(d, c, v, f, o, l)
return new{_T}(gensym("Unary"), d, c, v, f, o, l)
end
function Node(
d::Int, c::Bool, v::Nothing, f::Int, o::Int, l::Node{_T}, r::Node{_T}
) where {_T}
return new{_T}(d, c, v, f, o, l, r)
return new{_T}(gensym("Binary"), d, c, v, f, o, l, r)
end
end

################################################################################

"""
Expand Down
42 changes: 42 additions & 0 deletions src/EquationUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,48 @@ function set_constants(tree::Node{T}, constants::AbstractVector{T}) where {T}
end
end

# Get all the constants from a tree named
function get_named_constants(tree::Node{T}) where {T}
vals = Tuple{Symbol,Number}[]
_get_named_constants!(vals, tree)
return NamedTuple(vals)
end

function _get_named_constants!(
vals::Vector{Tuple{Symbol,<:Number}}, tree::Node{T}
) where {T}
if tree.degree == 0
if tree.constant
push!(vals, (tree.name, tree.val))
end
elseif tree.degree == 1
_get_named_constants!(vals, tree.l)
else
_get_named_constants!(vals, tree.l)
_get_named_constants!(vals, tree.r)
end
return nothing
end

# Set all the constants inside a tree
function set_named_constants!(tree::Node{T}, constants::C) where {T,C}
return _set_named_constants!(tree, constants)
end

function _set_named_constants!(tree::Node{T}, vals::C) where {T,C}
if tree.degree == 0
if tree.constant && (tree.name ∈ keys(vals))
tree.val = getfield(vals, tree.name)
end
elseif tree.degree == 1
_set_named_constants!(tree.l, vals)
else
_set_named_constants!(tree.l, vals)
_set_named_constants!(tree.r, vals)
end
return nothing
end

## Assign index to nodes of a tree
# This will mirror a Node struct, rather
# than adding a new attribute to Node.
Expand Down
19 changes: 19 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,26 @@ x1, x2, x3 = Node("x1"), Node("x2"), Node("x3")
tree = Node(; val=0.0)
set_constants(tree, [1.0])
@test repr(tree) == "1.0"

tree = Node(; val=0.0)
vals = get_named_constants(tree)
new_vals = map(zip(keys(vals), [1.0])) do (k, new)
(k, new)
end
set_named_constants!(tree, NamedTuple(new_vals))
@test repr(tree) == "1.0"

tree = x1 + Node(; val=0.0) - sin(x2 - Node(; val=0.5))
@test get_constants(tree) == [0.0, 0.5]
set_constants(tree, [1.0, 2.0])
@test repr(tree) == "((x1 + 1.0) - sin(x2 - 2.0))"

constant_node = Node(; val=0.0)
tree = x1 + constant_node - sin(x2 - constant_node)
vals = get_named_constants(tree)
@test length(vals) == 1
new_vals = map(zip(keys(vals), [1.0])) do (k, new)
(k, new)
end
set_named_constants!(tree, NamedTuple(new_vals))
@test repr(tree) == "((x1 + 1.0) - sin(x2 - 1.0))"