Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ jobs:
- "QA"
- "ODEBPINN"
- "PDEBPINN"
- "NNSDE"
- "NNSDE1"
- "NNSDE2"
- "NNPDE1"
- "NNPDE2"
- "AdaptiveLoss"
Expand Down
215 changes: 215 additions & 0 deletions src/NN_SDE_weaksolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@

@concrete struct SDEPINN
chain <: AbstractLuxLayer
optimalg
norm_loss_alg
initial_parameters

# domain + discretization
x_0::Float64
x_end::Float64
Nt::Int
dx::Float64

# IC & normalization
σ_var_bc::Float64
λ_ic::Float64
λ_norm::Float64
distrib::Distributions.Distribution

# solver options
strategy <: Union{Nothing,AbstractTrainingStrategy}
autodiff::Bool
batch::Bool
param_estim::Bool

# For postprocessing - solution handling
# xview::AbstractArray
# tview::AbstractArray
# phi::Phi

dataset <: Union{Nothing,Vector,Vector{<:Vector}}
additional_loss <: Union{Nothing,Function}
kwargs
end

function SDEPINN(;
chain,
optimalg=nothing,
norm_loss_alg=nothing,
initial_parameters=nothing,
x_0,
x_end,
Nt=50,
dx=0.01,
σ_var_bc=0.05,
λ_ic=1.0,
λ_norm=1.0,
distrib=Normal(0.5, 0.01),
strategy=nothing,
autodiff=true,
batch=false,
param_estim=false,
dataset=nothing,
additional_loss=nothing,
kwargs...
)
return SDEPINN(
chain,
optimalg,
norm_loss_alg,
initial_parameters,
x_0,
x_end,
Nt,
dx,
σ_var_bc,
λ_ic,
λ_norm,
distrib,
strategy,
autodiff,
batch,
param_estim,
dataset,
additional_loss,
kwargs
)
end

function SciMLBase.__solve(
prob::SciMLBase.AbstractSDEProblem,
alg::SDEPINN,
args...;
dt=nothing,
abtol=1.0f-6,
reltol=1.0f-3,
saveat=nothing,
tstops=nothing,
maxiters=200,
verbose=false,
kwargs...,
)
(; u0, tspan, f, g, p) = prob
P = eltype(u0)
t₀, t₁ = tspan

absorbing_bc = false
reflective_bc = true

(; x_0, x_end, Nt, dx, σ_var_bc, λ_ic, λ_norm,
distrib, optimalg, norm_loss_alg, initial_parameters, chain) = alg

dt = (t₁ - t₀) / Nt
ts = collect(t₀:dt:t₁)

# Define FP PDE
@parameters X, T
@variables p̂(..)
Dx = Differential(X)
Dxx = Differential(X)^2
Dt = Differential(T)

J(x, T) = prob.f(x, p, T) * p̂(x, T) -
P(0.5) * Dx((prob.g(x, p, T))^2 * p̂(x, T))

# IC symbolic equation form
f_icloss = if u0 isa Number
(p̂(u0, t₀) - Distributions.pdf(distrib, u0) ~ P(0),)
else
(p̂(u0[i], t₀) .- Distributions.pdf(distrib[i], u0[i]) ~ P(0) for i in 1:length(u0))
end

eq = Dt(p̂(X, T)) ~ -Dx(f(X, p, T) * p̂(X, T)) +
P(0.5) * Dxx((g(X, p, T))^2 * p̂(X, T))

# if we try to use p=0 and normalization it works
# however if we increase the x domainby too much on any side:
# The Normalization PDF mass although "conserved" inside domain
# can be forced to spread in different regions.

bcs = [
# No probability enters or leaves the domain
# Total mass is conserved
# Matches an SDE on a truncated but reflecting domain BC

# IC LOSS (it's getting amplified by the number of training points.)
f_icloss...
]

# absorbing Bcs
if absorbing_bc
@info "absorbing BCS used"

bcs = vcat(bcs, [p̂(x_0, T) ~ P(0),
p̂(x_end, T) ~ P(0)]...)
end

# reflecting Bcs
if reflective_bc
@info "reflecting BCS used"

bcs = vcat(bcs, [J(x_0, T) ~ P(0),
J(x_end, T) ~ P(0)
]...)
end
Comment on lines +140 to +155
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's infinite domain though?

Copy link
Member Author

@AstitvaAggarwal AstitvaAggarwal Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure how we can enforce that properly? (so i though we could just enforcing on the user chosen truncated domain)


domains = [X ∈ (x_0, x_end), T ∈ (t₀, t₁)]

# Additional losses
# Handle normloss and ICloss for vector NN outputs !! -> will need to adjst x0, x_end, u0 handling for this also !!

σ_var_bc = 0.05 # must be narrow, dirac deltra function centering. (smaller this is, we drop NN from a taller point to learn)
function norm_loss(phi, θ)
loss = P(0)
for t in ts
# define integrand as a function of x only (t fixed)
# perform ∫ f(x) dx over [x_0, x_end]
phi_normloss(x, θ) = u0 isa Number ? first(phi([x, t], θ)) : phi([x, t], θ)
I_est = solve(IntegralProblem(phi_normloss, x_0, x_end, θ), norm_loss_alg,
reltol=1e-8, abstol=1e-8, maxiters=10)[1]
loss += abs2(I_est - P(1))
end
return loss
end

function combined_additional(phi, θ, _)
λ_norm * norm_loss(phi, θ)
end

# Discretization - GridTraining only
discretization = PhysicsInformedNN(
chain,
GridTraining([dx, dt]);
init_params=initial_parameters,
additional_loss=combined_additional
)

@named pdesys = PDESystem(eq, bcs, domains, [X, T], [p̂(X, T)])
opt_prob = discretize(pdesys, discretization)
phi = discretization.phi

sym = NeuralPDE.symbolic_discretize(pdesys, discretization)
pde_losses = sym.loss_functions.pde_loss_functions
bc_losses = sym.loss_functions.bc_loss_functions

cb = function (p, l)
(!verbose) && return false
println("loss = ", l)
println("pde = ", map(f -> f(p.u), pde_losses))
println("bc = ", map(f -> f(p.u), bc_losses))
println("norm = ", norm_loss(phi, p.u))
return false
end

res = Optimization.solve(
opt_prob,
optimalg;
callback=cb,
maxiters=maxiters,
kwargs...
)

# postprocessing?
return res, phi
end
2 changes: 2 additions & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,13 @@ include("PDE_BPINN.jl")

include("dgm.jl")
include("NN_SDE_solve.jl")
include("NN_SDE_weaksolve.jl")

export PINOODE
export NNODE, NNDAE
export BNNODE, ahmc_bayesian_pinn_ode, ahmc_bayesian_pinn_pde
export NNSDE
export SDEPINN
export PhysicsInformedNN, discretize
export BPINNsolution, BayesianPINN
export DeepGalerkin
Expand Down
136 changes: 136 additions & 0 deletions test/NN_SDE_weaksolve_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
@testitem "OU process" tags = [:nnsde2] begin
using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL, Optimisers
using OrdinaryDiffEq, Random, Distributions, Integrals, Cubature
using OptimizationOptimJL: BFGS
Random.seed!(100)

α = -1
β = 1
u0 = 0.5
t0 = 0.0
f(u, p, t) = α * u
g(u, p, t) = β
tspan = (0.0, 1.0)
prob = SDEProblem(f, g, u0, tspan)

# Neural network
inn = 20
chain = Lux.Chain(Dense(2, inn, Lux.tanh),
Dense(inn, inn, Lux.tanh),
Dense(inn, 1, Lux.logcosh
)) |> f64

# problem setting
dx = 0.01
x_0 = -4.0
x_end = 4.0
σ_var_bc = 0.05

alg = SDEPINN(
chain=chain,
optimalg=BFGS(),
norm_loss_alg=HCubatureJL(),
x_0=x_0,
x_end=x_end,
distrib=Normal(u0, σ_var_bc)
)

sol_OU, phi = solve(
prob,
alg,
maxiters=500,
)

# OU analytic solution
σ² = 0.5 # stationary variance = 1/2 <- # $Var_{\infty} = \frac{\beta^2}{2|\alpha|}$
analytic_sol_func(x, t) = pdf(Normal(u0 * exp(-t), sqrt(σ² * (1 - exp(-2t)))), x) # mean μ and variance σ^2
xs = collect(x_0:dx:x_end)

# test at 0.1, not 0.0 ∵ analytic sol goes to inf (dirac delta func)
ts = [0.1, 0.2, 0.4, 0.6, 0.8, 1.0]

u_real = [[analytic_sol_func(x, t) for x in xs] for t in ts]
u_predict = [[first(phi([x, t], sol_OU.u)) for x in xs] for t in ts] # NeuralPDE predictions

# MSE across all x.
diff = u_real .- u_predict
@test mean(vcat([abs2.(diff_i) for diff_i in diff]...)) < 0.01

# using Plots
# plotly()
# plots_got = []
# for i in 1:length(ts)
# plot(xs, u_real[i], label="analytic t=$(ts[i])")
# push!(plots_got, plot!(xs, u_predict[i], label="predict t=$(ts[i])"))
# end
# plot(plots_got..., legend=:outerbottomright)
end

@testitem "GBM SDE" tags = [:nnsde2] begin
using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL, Optimisers
using OrdinaryDiffEq, Random, Distributions, Integrals, Cubature
using OptimizationOptimJL: BFGS
Random.seed!(100)

μ = 0.2
σ = 0.3
f(x, p, t) = μ * x
g(x, p, t) = σ * x
u0 = 1.0
tspan = (0.0, 1.0)
prob = SDEProblem(f, g, u0, tspan)

# Neural network
inn = 20
chain = Lux.Chain(Dense(2, inn, Lux.tanh),
Dense(inn, inn, Lux.tanh),
Dense(inn, 1, Lux.logcosh
)) |> f64

# problem setting - (results depend on x's assumed range)

dx = 0.01
x_0 = 0.0
x_end = 3.0
σ_var_bc = 0.05
alg = SDEPINN(
chain=chain,
optimalg=BFGS(),
norm_loss_alg=HCubatureJL(),
x_0=x_0,
x_end=x_end,

# pdf(LogNormal(log(X₀), σ_var_bc), x) # initial PDF
# for gbm normal X0 disti also gives good results with absorbing_bc.
distrib=LogNormal(log(u0), σ_var_bc)
)

sol_GBM, phi = solve(
prob,
alg,
maxiters=500
)

analytic_sol_func(x, t) = pdf(LogNormal(log(u0) + (μ - 0.5 * σ^2) * t, sqrt(t) * σ), x)
xs = collect(x_0:dx:x_end)

# test at 0.1, not 0.0 ∵ analytic sol goes to inf (dirac delta func)
ts = [0.0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0]

u_real = [[analytic_sol_func(x, t) for x in xs] for t in ts]
u_predict = [[first(phi([x, t], sol_GBM.u)) for x in xs] for t in ts]

# MSE across all x.
diff = u_real .- u_predict
@test mean(vcat([abs2.(diff_i) for diff_i in diff]...)) < 0.01

# Compare with analytic GBM solution
# using Plots
# plotly()
# plots_got = []
# for i in 1:length(ts)
# plot(xs, u_real[i], label="analytic t=$(ts[i])")
# push!(plots_got, plot!(xs, u_predict[i], label="predict t=$(ts[i])"))
# end
# plot(plots_got..., legend=:outerbottomright)
end
Loading