Skip to content

Commit 80dfccb

Browse files
committed
fix: Support SDEs defined via Brownian with jumps
1 parent fe0f674 commit 80dfccb

File tree

3 files changed

+326
-20
lines changed

3 files changed

+326
-20
lines changed

lib/ModelingToolkitBase/src/problems/jumpproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
has_vrjs = any(x -> x isa VariableRateJump, jumps(sys))
1111
has_eqs = !isempty(equations(sys))
12-
has_noise = get_noise_eqs(sys) !== nothing
12+
has_noise = get_noise_eqs(sys) !== nothing || !isempty(brownians(sys))
1313

1414
if (has_vrjs || has_eqs)
1515
if has_eqs && has_noise

lib/ModelingToolkitBase/src/systems/systems.jl

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ function _mtkcompile(sys::AbstractSystem; kwargs...)
125125
sys = noise_to_brownians(sys; names = :αₘₜₖ)
126126
end
127127
if !isempty(jumps(sys))
128+
# For systems with jumps, skip full structural simplification to preserve
129+
# variables that only appear in jumps. But if brownians are present,
130+
# we still need to extract them into noise_eqs for SDEProblem construction.
131+
if !isempty(brownians(sys))
132+
return extract_brownians_to_noise_eqs(sys)
133+
end
128134
return sys
129135
end
130136
if isempty(equations(sys)) && !is_time_dependent(sys) && !_iszero(cost(sys))
@@ -524,15 +530,20 @@ function add_array_observed!(obseqs::Vector{Equation})
524530
return
525531
end
526532

527-
function simplify_sde_system(sys::AbstractSystem; kwargs...)
528-
brown_vars = brownians(sys)
529-
@set! sys.brownians = SymbolicT[]
530-
sys = __mtkcompile(sys; kwargs...)
533+
"""
534+
_brownians_to_noise_eqs(eqs::Vector{Equation}, brown_vars::Vector)
531535
532-
new_eqs = copy(equations(sys))
536+
Extract brownian coefficients from equations and return (new_eqs, noise_eqs).
537+
The brownian terms are removed from the equations and collected into a noise matrix.
538+
This is a helper function used by both `extract_brownians_to_noise_eqs` and
539+
`simplify_sde_system`.
540+
"""
541+
function _brownians_to_noise_eqs(eqs::Vector{Equation}, brown_vars::Vector)
542+
new_eqs = copy(eqs)
533543
Is = Int[]
534544
Js = Int[]
535545
vals = SymbolicT[]
546+
536547
for (i, eq) in enumerate(new_eqs)
537548
resid = eq.rhs
538549
for (j, bvar) in enumerate(brown_vars)
@@ -557,24 +568,48 @@ function simplify_sde_system(sys::AbstractSystem; kwargs...)
557568
end
558569

559570
g = Matrix(sparse(Is, Js, vals, length(new_eqs), length(brown_vars)))
560-
@set! sys.eqs = new_eqs
571+
572+
# Determine noise type (scalar, diagonal, or general)
561573
# Fix for https://github.com/SciML/ModelingToolkit.jl/issues/2490
562-
if size(g, 2) == 1
563-
# If there's only one brownian variable referenced across all the equations,
564-
# we get a Nx1 matrix of noise equations, which is a special case known as scalar noise
565-
noise_eqs = reshape(g[:, 1], (:, 1))
566-
is_scalar_noise = true
574+
noise_eqs = if size(g, 2) == 1
575+
# Scalar noise: Nx1 matrix
576+
reshape(g[:, 1], (:, 1))
567577
elseif __num_isdiag_noise(g)
568-
# If each column of the noise matrix has either 0 or 1 non-zero entry, then this is "diagonal noise".
569-
# In this case, the solver just takes a vector column of equations and it interprets that to
570-
# mean that each noise process is independent
571-
noise_eqs = __get_num_diag_noise(g)
572-
is_scalar_noise = false
578+
# Diagonal noise: each column has 0 or 1 non-zero entry
579+
__get_num_diag_noise(g)
573580
else
574-
noise_eqs = g
575-
is_scalar_noise = false
581+
g
576582
end
577583

584+
return new_eqs, noise_eqs
585+
end
586+
587+
"""
588+
extract_brownians_to_noise_eqs(sys::AbstractSystem)
589+
590+
Extract brownian variables from equations and convert them to a noise_eqs matrix,
591+
without performing structural simplification. This is used for systems with both
592+
jumps and brownians, where full simplification could eliminate variables that
593+
only appear in jumps.
594+
"""
595+
function extract_brownians_to_noise_eqs(sys::AbstractSystem)
596+
brown_vars = brownians(sys)
597+
new_eqs, noise_eqs = _brownians_to_noise_eqs(equations(sys), brown_vars)
598+
599+
@set! sys.eqs = new_eqs
600+
@set! sys.noise_eqs = noise_eqs
601+
@set! sys.brownians = SymbolicT[]
602+
603+
return sys
604+
end
605+
606+
function simplify_sde_system(sys::AbstractSystem; kwargs...)
607+
brown_vars = brownians(sys)
608+
@set! sys.brownians = SymbolicT[]
609+
sys = __mtkcompile(sys; kwargs...)
610+
611+
new_eqs, noise_eqs = _brownians_to_noise_eqs(equations(sys), brown_vars)
612+
578613
dummy_sub = Dict{SymbolicT, SymbolicT}()
579614
for eq in new_eqs
580615
isdiffeq(eq) || continue

lib/ModelingToolkitBase/test/jumpsystem.jl

Lines changed: 272 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using ModelingToolkitBase, DiffEqBase, JumpProcesses, Test, LinearAlgebra
22
using SymbolicIndexingInterface, OrderedCollections
33
using Random, StableRNGs, NonlinearSolve
4-
using OrdinaryDiffEq
4+
using OrdinaryDiffEq, StochasticDiffEq, Statistics
55
using ModelingToolkitBase: t_nounits as t, D_nounits as D
66
using BenchmarkTools
77
using Symbolics: SymbolicT, unwrap
@@ -831,3 +831,274 @@ end
831831
@test jprob.discrete_jump_aggregation.save_positions == (true, true)
832832
end
833833
end
834+
835+
# Test that JumpProblem correctly detects brownians and creates SDEProblem
836+
# Issue: JumpProblem was only checking get_noise_eqs(sys), not brownians(sys)
837+
# Also tests that mtkcompile properly processes brownians for systems with jumps
838+
@testset "JumpProblem with brownians creates SDEProblem" begin
839+
# Test 1: System with brownians and a mass action jump
840+
@testset "Brownians + MassActionJump" begin
841+
@variables X(t) = 10.0
842+
@parameters k = 1.0
843+
@brownians B
844+
845+
# Equation with Brownian noise: dX = -k*X*dt + sqrt(k)*dB
846+
eqs = [D(X) ~ -k * X + sqrt(k) * B]
847+
848+
# A simple mass action jump: X -> 0 with rate k
849+
jump = MassActionJump(k, [X => 1], [X => -1])
850+
851+
# Build the system with @mtkcompile - this properly processes brownians
852+
@mtkcompile sys = System(eqs, t; jumps = [jump])
853+
854+
# After mtkcompile, brownians are converted to noise_eqs
855+
@test MT.get_noise_eqs(sys) !== nothing
856+
857+
# Create JumpProblem - should create SDEProblem
858+
op = [X => 10.0, k => 1.0]
859+
tspan = (0.0, 1.0)
860+
jprob = JumpProblem(sys, op, tspan; rng)
861+
862+
# The underlying problem should be SDEProblem, not ODEProblem
863+
@test jprob.prob isa SDEProblem
864+
865+
# Should be solvable without error
866+
sol = solve(jprob, SOSRI())
867+
@test SciMLBase.successful_retcode(sol)
868+
end
869+
870+
# Test 2: System with brownians and a constant rate jump
871+
@testset "Brownians + ConstantRateJump" begin
872+
@variables X(t) = 5.0
873+
@parameters k = 0.5
874+
@brownians B
875+
876+
eqs = [D(X) ~ k + 0.1 * B]
877+
crj = ConstantRateJump(k * X, [X ~ Pre(X) - 1])
878+
879+
@mtkcompile sys = System(eqs, t; jumps = [crj])
880+
881+
@test MT.get_noise_eqs(sys) !== nothing
882+
883+
op = [X => 5.0, k => 0.5]
884+
tspan = (0.0, 1.0)
885+
jprob = JumpProblem(sys, op, tspan; rng)
886+
887+
@test jprob.prob isa SDEProblem
888+
889+
sol = solve(jprob, SOSRI())
890+
@test SciMLBase.successful_retcode(sol)
891+
end
892+
893+
# Test 3: System with brownians and a variable rate jump
894+
@testset "Brownians + VariableRateJump" begin
895+
@variables X(t) = 5.0
896+
@parameters k = 0.5
897+
@brownians B
898+
899+
eqs = [D(X) ~ k + 0.1 * B]
900+
vrj = VariableRateJump(k * (1 + sin(t)), [X ~ Pre(X) + 1])
901+
902+
@mtkcompile sys = System(eqs, t; jumps = [vrj])
903+
904+
@test MT.get_noise_eqs(sys) !== nothing
905+
906+
op = [X => 5.0, k => 0.5]
907+
tspan = (0.0, 1.0)
908+
jprob = JumpProblem(sys, op, tspan; rng)
909+
910+
@test jprob.prob isa SDEProblem
911+
912+
sol = solve(jprob, SOSRI())
913+
@test SciMLBase.successful_retcode(sol)
914+
end
915+
916+
# Test 4: System with brownians and multiple jump types
917+
@testset "Brownians + mixed jump types" begin
918+
@variables X(t) = 10.0 Y(t) = 5.0
919+
@parameters k1 = 1.0 k2 = 0.5
920+
@brownians B
921+
922+
eqs = [D(X) ~ -k1 * X + 0.1 * B, D(Y) ~ k2]
923+
maj = MassActionJump(k1, [X => 1], [X => -1])
924+
crj = ConstantRateJump(k2 * Y, [Y ~ Pre(Y) - 1])
925+
926+
@mtkcompile sys = System(eqs, t; jumps = [maj, crj])
927+
928+
@test MT.get_noise_eqs(sys) !== nothing
929+
930+
op = [X => 10.0, Y => 5.0, k1 => 1.0, k2 => 0.5]
931+
tspan = (0.0, 1.0)
932+
jprob = JumpProblem(sys, op, tspan; rng)
933+
934+
@test jprob.prob isa SDEProblem
935+
936+
sol = solve(jprob, SOSRI())
937+
@test SciMLBase.successful_retcode(sol)
938+
end
939+
940+
# Test 5: Ensure systems WITHOUT brownians still work correctly
941+
# (i.e., VRJ-only systems should create ODEProblem, not SDEProblem)
942+
@testset "No brownians, VRJ only -> ODEProblem" begin
943+
@variables X(t) = 5.0
944+
@parameters k = 0.5
945+
946+
# No brownians, but has equations and variable rate jump
947+
eqs = [D(X) ~ k]
948+
vrj = VariableRateJump(k * (1 + sin(t)), [X ~ Pre(X) + 1])
949+
950+
@mtkcompile sys = System(eqs, t; jumps = [vrj])
951+
952+
@test isempty(MT.brownians(sys))
953+
@test MT.get_noise_eqs(sys) === nothing
954+
955+
op = [X => 5.0, k => 0.5]
956+
tspan = (0.0, 1.0)
957+
jprob = JumpProblem(sys, op, tspan; rng)
958+
959+
# Should be ODEProblem since there are no brownians
960+
@test jprob.prob isa ODEProblem
961+
962+
sol = solve(jprob, Tsit5())
963+
@test SciMLBase.successful_retcode(sol)
964+
end
965+
end
966+
967+
# Correctness tests: verify symbolic SDE+jump solutions match analytical/direct expectations
968+
@testset "Brownians + Jumps correctness" begin
969+
# Test 1: Pure diffusion + constant rate jump
970+
# dX = sig*dB, X(0) = 0, with jumps X → X + delta at rate lam
971+
# E[X(T)] = lam*delta*T (diffusion has zero mean)
972+
@testset "Diffusion + CRJ mean" begin
973+
@variables X(t) = 0.0
974+
@parameters sig = 0.3 lam = 2.0 delta = 1.0
975+
@brownians B
976+
977+
eqs = [D(X) ~ sig * B]
978+
crj = ConstantRateJump(lam, [X ~ Pre(X) + delta])
979+
980+
# Must pass all parameters explicitly since System doesn't auto-collect from jumps
981+
@mtkcompile sys = System(eqs, t, [X], [sig, lam, delta], [B]; jumps = [crj])
982+
983+
T = 2.0
984+
Nsims = 4000
985+
sig_val, lam_val, delta_val = 0.3, 2.0, 1.0
986+
E_X = lam_val * delta_val * T # = 4.0
987+
988+
# Create JumpProblem once, use seed parameter to vary randomness
989+
jprob = JumpProblem(sys, [X => 0.0, sig => sig_val, lam => lam_val, delta => delta_val],
990+
(0.0, T); rng, save_positions = (false, false))
991+
992+
seed = 1111
993+
Xfinal = zeros(Nsims)
994+
for i in 1:Nsims
995+
sol = solve(jprob, SOSRI(); save_everystep = false, seed)
996+
Xfinal[i] = sol[X, end]
997+
seed += 1
998+
end
999+
1000+
sample_mean = mean(Xfinal)
1001+
rel_error = abs(sample_mean - E_X) / E_X
1002+
@test rel_error < 0.05 # 5% relative error
1003+
end
1004+
1005+
# Test 2: Compare symbolic vs direct JumpProcesses construction
1006+
# Verifies that the symbolic system produces the same statistics as manual construction
1007+
@testset "Symbolic vs Direct JumpProcesses" begin
1008+
sig_val = 0.2
1009+
lam_val = 3.0
1010+
delta_val = 0.5
1011+
X0 = 1.0
1012+
T = 1.5
1013+
Nsims = 3000
1014+
1015+
# Build symbolically
1016+
@variables X(t) = X0
1017+
@parameters sig = sig_val lam = lam_val delta = delta_val
1018+
@brownians B
1019+
1020+
eqs = [D(X) ~ sig * B]
1021+
crj = ConstantRateJump(lam, [X ~ Pre(X) + delta])
1022+
1023+
# Must pass all parameters explicitly since System doesn't auto-collect from jumps
1024+
@mtkcompile sys = System(eqs, t, [X], [sig, lam, delta], [B]; jumps = [crj])
1025+
1026+
# Create JumpProblem once for symbolic version
1027+
jprob_sym = JumpProblem(sys, [X => X0, sig => sig_val, lam => lam_val, delta => delta_val],
1028+
(0.0, T); rng, save_positions = (false, false))
1029+
1030+
seed = 2222
1031+
Xfinal_sym = zeros(Nsims)
1032+
for i in 1:Nsims
1033+
sol = solve(jprob_sym, SOSRI(); save_everystep = false, seed)
1034+
Xfinal_sym[i] = sol[X, end]
1035+
seed += 1
1036+
end
1037+
1038+
# Build directly with JumpProcesses
1039+
f_direct(du, u, p, t) = (du[1] = 0.0)
1040+
g_direct(du, u, p, t) = (du[1] = sig_val)
1041+
sprob = SDEProblem(f_direct, g_direct, [X0], (0.0, T))
1042+
rate_direct(u, p, t) = lam_val
1043+
affect_direct!(integ) = (integ.u[1] += delta_val)
1044+
crj_direct = ConstantRateJump(rate_direct, affect_direct!)
1045+
1046+
jprob_direct = JumpProblem(sprob, Direct(), crj_direct; rng, save_positions = (false, false))
1047+
1048+
seed = 2222 # Use same seeds for comparison
1049+
Xfinal_direct = zeros(Nsims)
1050+
for i in 1:Nsims
1051+
sol = solve(jprob_direct, SOSRI(); save_everystep = false, seed)
1052+
Xfinal_direct[i] = sol[end][1]
1053+
seed += 1
1054+
end
1055+
1056+
# Expected mean: X0 + lam*delta*T = 1.0 + 3.0*0.5*1.5 = 3.25
1057+
E_X = X0 + lam_val * delta_val * T
1058+
1059+
mean_sym = mean(Xfinal_sym)
1060+
mean_direct = mean(Xfinal_direct)
1061+
1062+
# Both should match each other and the analytical value within 5%
1063+
@test abs(mean_sym - mean_direct) / E_X < 0.05
1064+
@test abs(mean_sym - E_X) / E_X < 0.05
1065+
@test abs(mean_direct - E_X) / E_X < 0.05
1066+
end
1067+
1068+
# Test 3: Drift + diffusion + MassActionJump (birth-death with noise)
1069+
# dX = (alph - bet*X)*dt + sig*dB
1070+
# Birth: ∅ → X at rate gam
1071+
# At steady state (long time), E[X] ≈ (alph + gam) / bet
1072+
@testset "Drift + diffusion + MAJ steady state" begin
1073+
@variables X(t) = 5.0
1074+
@parameters alph = 2.0 bet = 0.5 gam = 3.0 sig = 0.1
1075+
@brownians B
1076+
1077+
# ODE part drives toward alph/bet, MAJ adds gam births per unit time
1078+
eqs = [D(X) ~ alph - bet * X + sig * B]
1079+
birth = MassActionJump(gam, [0 => 1], [X => 1])
1080+
1081+
# Must pass all parameters explicitly since System doesn't auto-collect from jumps
1082+
@mtkcompile sys = System(eqs, t, [X], [alph, bet, gam, sig], [B]; jumps = [birth])
1083+
1084+
T = 20.0 # Long enough to reach steady state
1085+
Nsims = 2000
1086+
alph_val, bet_val, gam_val, sig_val = 2.0, 0.5, 3.0, 0.1
1087+
E_X_ss = (alph_val + gam_val) / bet_val # = 10
1088+
1089+
jprob = JumpProblem(sys, [X => 5.0, alph => alph_val, bet => bet_val, gam => gam_val, sig => sig_val],
1090+
(0.0, T); rng, save_positions = (false, false))
1091+
1092+
seed = 3333
1093+
Xfinal = zeros(Nsims)
1094+
for i in 1:Nsims
1095+
sol = solve(jprob, SOSRI(); save_everystep = false, seed)
1096+
Xfinal[i] = sol[X, end]
1097+
seed += 1
1098+
end
1099+
1100+
sample_mean = mean(Xfinal)
1101+
rel_error = abs(sample_mean - E_X_ss) / E_X_ss
1102+
@test rel_error < 0.05 # 5% relative error
1103+
end
1104+
end

0 commit comments

Comments
 (0)