diff --git a/Project.toml b/Project.toml index 8d6483227e..c2aff6ac2b 100644 --- a/Project.toml +++ b/Project.toml @@ -39,6 +39,7 @@ SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StateSelection = "64909d44-ed92-46a8-bbd9-f047dfbdc84b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +SymbolicCompilerPasses = "3384d301-0fbe-4b40-9ae0-b0e68bedb069" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" @@ -108,7 +109,8 @@ StaticArrays = "1.9.14" StochasticDelayDiffEq = "1.11" StochasticDiffEq = "6.82.0" SymbolicIndexingInterface = "0.3.39" -SymbolicUtils = "4.11" +SymbolicCompilerPasses = "0.1.0" +SymbolicUtils = "4.13" Symbolics = "7" UnPack = "0.1, 1.0" julia = "1.9" diff --git a/lib/ModelingToolkitBase/src/systems/codegen_utils.jl b/lib/ModelingToolkitBase/src/systems/codegen_utils.jl index a59e3efbbc..623c1c926d 100644 --- a/lib/ModelingToolkitBase/src/systems/codegen_utils.jl +++ b/lib/ModelingToolkitBase/src/systems/codegen_utils.jl @@ -252,7 +252,8 @@ function build_function_wrapper( wrap_delays = is_dde(sys), histfn = DDE_HISTORY_FUN, histfn_symbolic = histfn, wrap_code = identity, add_observed = true, filter_observed = Returns(true), create_bindings = false, output_type = nothing, mkarray = nothing, - wrap_mtkparameters = true, extra_assignments = Assignment[], cse = true, kwargs... + wrap_mtkparameters = true, extra_assignments = Assignment[], cse = true, + optimize = nothing, kwargs... ) isscalar = !(expr isa AbstractArray || symbolic_type(expr) == ArraySymbolic()) # filter observed equations @@ -375,9 +376,14 @@ function build_function_wrapper( if wrap_code isa Tuple && symbolic_type(expr) == ScalarSymbolic() wrap_code = wrap_code[1] end - return build_function(expr, args...; wrap_code, similarto, cse, kwargs...) + + optimize = resolve_optimize_option(optimize) + return build_function(expr, args...; wrap_code, similarto, cse, optimize, kwargs...) end +resolve_optimize_option(x) = x +resolve_optimize_option(::Nothing) = nothing + """ $(TYPEDEF) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 850a45a5e8..3cf41c8441 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -75,6 +75,10 @@ import Symbolics: rename, get_variables!, _solve, hessian_sparsity, scalarize, hasderiv import ModelingToolkitBase as MTKBase +import SymbolicCompilerPasses as SCP +import SymbolicCompilerPasses: MATMUL_ADD_RULE, LDIV_RULE, HVNCAT_STATIC_RULE, + TRIU_RULE, TRIL_RULE, NORMALIZE_RULE + import DiffEqBase: @add_kwonly @reexport using Symbolics @reexport using UnPack diff --git a/src/systems/codegen.jl b/src/systems/codegen.jl index a70ec0f558..14223590a4 100644 --- a/src/systems/codegen.jl +++ b/src/systems/codegen.jl @@ -575,3 +575,42 @@ function get_semiquadratic_W_sparsity( SparseMatrixCSC{Bool, Int64}((!iszero).(mm)) return (!_iszero).(jac) .| M_sparsity end + +const SCP_BASIC = [ + MATMUL_ADD_RULE, + TRIU_RULE, + TRIL_RULE, + NORMALIZE_RULE, + LDIV_RULE, +] + +const SCP_AGGRESSIVE = [ + SCP_BASIC; + HVNCAT_STATIC_RULE; +] + +const SCP_OPTIONS = Dict( + :basic => SCP_BASIC, + :aggressive => SCP_AGGRESSIVE, + :none => nothing +) + +function MTKBase.resolve_optimize_option(o::Bool) + return MTKBase.resolve_optimize_option(o ? SCP_BASIC : nothing) +end + +function MTKBase.resolve_optimize_option(o::Symbol) + rules = get(SCP_OPTIONS, o, nothing) + return MTKBase.resolve_optimize_option(rules) +end + +function MTKBase.resolve_optimize_option(o::Int) + if o == 0 + return MTKBase.resolve_optimize_option(false) + elseif o == 1 + return MTKBase.resolve_optimize_option(:basic) + elseif o == 2 + return MTKBase.resolve_optimize_option(:aggressive) + end + throw(ArgumentError("Invalid optimize option integer: $o")) +end diff --git a/test/structural_transformation/index_reduction.jl b/test/structural_transformation/index_reduction.jl index 41b0290e92..e65e07f052 100644 --- a/test/structural_transformation/index_reduction.jl +++ b/test/structural_transformation/index_reduction.jl @@ -82,4 +82,5 @@ let sol = solve(prob, Rodas5P()) @test SciMLBase.successful_retcode(sol) @test sol[x^2 + y^2][end] < 1.1 + @test_throws ArgumentError ODEProblem(sys, [x => 1, y => 0, D(x) => 0.0, g => 1], (0.0, 10.0), guesses = [λ => 0.0], optimize = 7) end diff --git a/test/structural_transformation/tearing.jl b/test/structural_transformation/tearing.jl index 725bfa325f..3554d45fb9 100644 --- a/test/structural_transformation/tearing.jl +++ b/test/structural_transformation/tearing.jl @@ -219,6 +219,8 @@ prob_complex = ODEProblem(sys, u0, (0, 1.0)) sol = solve(prob_complex, Tsit5()) @test all(sol[mass.v] .== 1) +@test_throws ArgumentError ODEProblem(sys, u0, (0, 1.0), optimize = 7) + using ModelingToolkitStandardLibrary.Electrical using ModelingToolkitStandardLibrary.Blocks: Constant