Skip to content
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StateSelection = "64909d44-ed92-46a8-bbd9-f047dfbdc84b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicCompilerPasses = "3384d301-0fbe-4b40-9ae0-b0e68bedb069"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Expand Down Expand Up @@ -108,7 +109,8 @@ StaticArrays = "1.9.14"
StochasticDelayDiffEq = "1.11"
StochasticDiffEq = "6.82.0"
SymbolicIndexingInterface = "0.3.39"
SymbolicUtils = "4.11"
SymbolicCompilerPasses = "0.1.0"
SymbolicUtils = "4.13"
Symbolics = "7"
UnPack = "0.1, 1.0"
julia = "1.9"
Expand Down
10 changes: 8 additions & 2 deletions lib/ModelingToolkitBase/src/systems/codegen_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ function build_function_wrapper(
wrap_delays = is_dde(sys), histfn = DDE_HISTORY_FUN, histfn_symbolic = histfn, wrap_code = identity,
add_observed = true, filter_observed = Returns(true),
create_bindings = false, output_type = nothing, mkarray = nothing,
wrap_mtkparameters = true, extra_assignments = Assignment[], cse = true, kwargs...
wrap_mtkparameters = true, extra_assignments = Assignment[], cse = true,
optimize = nothing, kwargs...
)
isscalar = !(expr isa AbstractArray || symbolic_type(expr) == ArraySymbolic())
# filter observed equations
Expand Down Expand Up @@ -375,9 +376,14 @@ function build_function_wrapper(
if wrap_code isa Tuple && symbolic_type(expr) == ScalarSymbolic()
wrap_code = wrap_code[1]
end
return build_function(expr, args...; wrap_code, similarto, cse, kwargs...)

optimize = resolve_optimize_option(optimize)
return build_function(expr, args...; wrap_code, similarto, cse, optimize, kwargs...)
end

resolve_optimize_option(x) = x
resolve_optimize_option(::Nothing) = nothing

"""
$(TYPEDEF)

Expand Down
4 changes: 4 additions & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
scalarize, hasderiv
import ModelingToolkitBase as MTKBase

import SymbolicCompilerPasses as SCP
import SymbolicCompilerPasses: MATMUL_ADD_RULE, LDIV_RULE, HVNCAT_STATIC_RULE,
TRIU_RULE, TRIL_RULE, NORMALIZE_RULE

import DiffEqBase: @add_kwonly
@reexport using Symbolics
@reexport using UnPack
Expand Down
39 changes: 39 additions & 0 deletions src/systems/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,42 @@ function get_semiquadratic_W_sparsity(
SparseMatrixCSC{Bool, Int64}((!iszero).(mm))
return (!_iszero).(jac) .| M_sparsity
end

const SCP_BASIC = [
MATMUL_ADD_RULE,
TRIU_RULE,
TRIL_RULE,
NORMALIZE_RULE,
LDIV_RULE,
]

const SCP_AGGRESSIVE = [
SCP_BASIC;
HVNCAT_STATIC_RULE;
]

const SCP_OPTIONS = Dict(
:basic => SCP_BASIC,
:aggressive => SCP_AGGRESSIVE,
:none => nothing
)

function MTKBase.resolve_optimize_option(o::Bool)
return MTKBase.resolve_optimize_option(o ? SCP_BASIC : nothing)
end

function MTKBase.resolve_optimize_option(o::Symbol)
rules = get(SCP_OPTIONS, o, nothing)
return MTKBase.resolve_optimize_option(rules)
end

function MTKBase.resolve_optimize_option(o::Int)
if o == 0
return MTKBase.resolve_optimize_option(false)
elseif o == 1
return MTKBase.resolve_optimize_option(:basic)
elseif o == 2
return MTKBase.resolve_optimize_option(:aggressive)
end
throw(ArgumentError("Invalid optimize option integer: $o"))
end
1 change: 1 addition & 0 deletions test/structural_transformation/index_reduction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ let
sol = solve(prob, Rodas5P())
@test SciMLBase.successful_retcode(sol)
@test sol[x^2 + y^2][end] < 1.1
@test_throws ArgumentError ODEProblem(sys, [x => 1, y => 0, D(x) => 0.0, g => 1], (0.0, 10.0), guesses ==> 0.0], optimize = 7)
end
2 changes: 2 additions & 0 deletions test/structural_transformation/tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ prob_complex = ODEProblem(sys, u0, (0, 1.0))
sol = solve(prob_complex, Tsit5())
@test all(sol[mass.v] .== 1)

@test_throws ArgumentError ODEProblem(sys, u0, (0, 1.0), optimize = 7)

using ModelingToolkitStandardLibrary.Electrical
using ModelingToolkitStandardLibrary.Blocks: Constant

Expand Down
Loading