Skip to content

Commit 5faf961

Browse files
Merge pull request #4191 from DhairyaLGandhi/dg/scp
feat: allow SymbolicCompilerPasses (SCP) passes during code generation
2 parents fe0f674 + eab24fd commit 5faf961

File tree

6 files changed

+57
-3
lines changed

6 files changed

+57
-3
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
3939
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4040
StateSelection = "64909d44-ed92-46a8-bbd9-f047dfbdc84b"
4141
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
42+
SymbolicCompilerPasses = "3384d301-0fbe-4b40-9ae0-b0e68bedb069"
4243
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
4344
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
4445
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
@@ -108,7 +109,8 @@ StaticArrays = "1.9.14"
108109
StochasticDelayDiffEq = "1.11"
109110
StochasticDiffEq = "6.82.0"
110111
SymbolicIndexingInterface = "0.3.39"
111-
SymbolicUtils = "4.11"
112+
SymbolicCompilerPasses = "0.1.0"
113+
SymbolicUtils = "4.13"
112114
Symbolics = "7"
113115
UnPack = "0.1, 1.0"
114116
julia = "1.9"

lib/ModelingToolkitBase/src/systems/codegen_utils.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,8 @@ function build_function_wrapper(
252252
wrap_delays = is_dde(sys), histfn = DDE_HISTORY_FUN, histfn_symbolic = histfn, wrap_code = identity,
253253
add_observed = true, filter_observed = Returns(true),
254254
create_bindings = false, output_type = nothing, mkarray = nothing,
255-
wrap_mtkparameters = true, extra_assignments = Assignment[], cse = true, kwargs...
255+
wrap_mtkparameters = true, extra_assignments = Assignment[], cse = true,
256+
optimize = nothing, kwargs...
256257
)
257258
isscalar = !(expr isa AbstractArray || symbolic_type(expr) == ArraySymbolic())
258259
# filter observed equations
@@ -375,9 +376,14 @@ function build_function_wrapper(
375376
if wrap_code isa Tuple && symbolic_type(expr) == ScalarSymbolic()
376377
wrap_code = wrap_code[1]
377378
end
378-
return build_function(expr, args...; wrap_code, similarto, cse, kwargs...)
379+
380+
optimize = resolve_optimize_option(optimize)
381+
return build_function(expr, args...; wrap_code, similarto, cse, optimize, kwargs...)
379382
end
380383

384+
resolve_optimize_option(x) = x
385+
resolve_optimize_option(::Nothing) = nothing
386+
381387
"""
382388
$(TYPEDEF)
383389

src/ModelingToolkit.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
7575
scalarize, hasderiv
7676
import ModelingToolkitBase as MTKBase
7777

78+
import SymbolicCompilerPasses as SCP
79+
import SymbolicCompilerPasses: MATMUL_ADD_RULE, LDIV_RULE, HVNCAT_STATIC_RULE,
80+
TRIU_RULE, TRIL_RULE, NORMALIZE_RULE
81+
7882
import DiffEqBase: @add_kwonly
7983
@reexport using Symbolics
8084
@reexport using UnPack

src/systems/codegen.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,3 +575,42 @@ function get_semiquadratic_W_sparsity(
575575
SparseMatrixCSC{Bool, Int64}((!iszero).(mm))
576576
return (!_iszero).(jac) .| M_sparsity
577577
end
578+
579+
const SCP_BASIC = [
580+
MATMUL_ADD_RULE,
581+
TRIU_RULE,
582+
TRIL_RULE,
583+
NORMALIZE_RULE,
584+
LDIV_RULE,
585+
]
586+
587+
const SCP_AGGRESSIVE = [
588+
SCP_BASIC;
589+
HVNCAT_STATIC_RULE;
590+
]
591+
592+
const SCP_OPTIONS = Dict(
593+
:basic => SCP_BASIC,
594+
:aggressive => SCP_AGGRESSIVE,
595+
:none => nothing
596+
)
597+
598+
function MTKBase.resolve_optimize_option(o::Bool)
599+
return MTKBase.resolve_optimize_option(o ? SCP_BASIC : nothing)
600+
end
601+
602+
function MTKBase.resolve_optimize_option(o::Symbol)
603+
rules = get(SCP_OPTIONS, o, nothing)
604+
return MTKBase.resolve_optimize_option(rules)
605+
end
606+
607+
function MTKBase.resolve_optimize_option(o::Int)
608+
if o == 0
609+
return MTKBase.resolve_optimize_option(false)
610+
elseif o == 1
611+
return MTKBase.resolve_optimize_option(:basic)
612+
elseif o == 2
613+
return MTKBase.resolve_optimize_option(:aggressive)
614+
end
615+
throw(ArgumentError("Invalid optimize option integer: $o"))
616+
end

test/structural_transformation/index_reduction.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,5 @@ let
8282
sol = solve(prob, Rodas5P())
8383
@test SciMLBase.successful_retcode(sol)
8484
@test sol[x^2 + y^2][end] < 1.1
85+
@test_throws ArgumentError ODEProblem(sys, [x => 1, y => 0, D(x) => 0.0, g => 1], (0.0, 10.0), guesses ==> 0.0], optimize = 7)
8586
end

test/structural_transformation/tearing.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ prob_complex = ODEProblem(sys, u0, (0, 1.0))
219219
sol = solve(prob_complex, Tsit5())
220220
@test all(sol[mass.v] .== 1)
221221

222+
@test_throws ArgumentError ODEProblem(sys, u0, (0, 1.0), optimize = 7)
223+
222224
using ModelingToolkitStandardLibrary.Electrical
223225
using ModelingToolkitStandardLibrary.Blocks: Constant
224226

0 commit comments

Comments
 (0)