Skip to content

Commit

Permalink
Support comparison operators (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Aug 13, 2024
1 parent 78561ec commit 5d375f2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/nonlinear_oracle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ Parse `expr` into `f` and return a `_Node`.
This function gets called recursively.
"""
function _Node(f::_Function, expr::Expr)
if Meta.isexpr(expr, :||) || Meta.isexpr(expr, :&&)
return _Node(f, expr.head, (_Node(f, arg) for arg in expr.args)...)
end
@assert Meta.isexpr(expr, :call)
# Performance optimization: most calls will be unary or binary
# operators. Therefore, we can specialize an if-statement to handle the
Expand Down Expand Up @@ -219,6 +222,10 @@ function _expr_to_symbolics(model::MOI.Nonlinear.Model, expr::_Node, p, x)
args = [_expr_to_symbolics(model, c, p, x) for c in expr.children]
if hasproperty(Base, expr.operation)
return getproperty(Base, expr.operation)(args...)
elseif expr.operation == :&&
return (&)(args...)
elseif expr.operation == :||
return (|)(args...)
end
# If the function isn't defined in Base, defer to the operator registry.
# We don't do this for all functions, because MOI uses NaNMath, which
Expand Down Expand Up @@ -518,6 +525,10 @@ function _to_expr(
MOI.VariableIndex(variable_order[node.index])
elseif node.type == MOI.Nonlinear.NODE_PARAMETER
data.parameters[node.index]
elseif node.type == MOI.Nonlinear.NODE_LOGIC
Expr(data.operators.logic_operators[node.index])
elseif node.type == MOI.Nonlinear.NODE_COMPARISON
Expr(:call, data.operators.comparison_operators[node.index])
else
@assert node.type == MOI.Nonlinear.NODE_VALUE
expr.values[node.index]
Expand Down
17 changes: 17 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,23 @@ function test_constant_subexpressions_expr()
return
end

function test_logic_comparison_expr()
if VERSION < v"1.7"
return # Symbolics doesn't support Base.ifelse in Julia v1.6
end
model = Model(Ipopt.Optimizer)
@variable(model, -1 <= x <= 1)
@objective(model, Max, ifelse(-0.5 <= x && x <= 0.5, 1 - x^2, 0))
set_attribute(
model,
MOI.AutomaticDifferentiationBackend(),
MathOptSymbolicAD.DefaultBackend(),
)
optimize!(model)
Test.@test termination_status(model) == LOCALLY_SOLVED
return
end

end # module

RunTests.runtests()

0 comments on commit 5d375f2

Please sign in to comment.