Skip to content

Conversation

@AstitvaAggarwal
Copy link
Member

@AstitvaAggarwal AstitvaAggarwal commented Jan 14, 2026

Clear TODO left:

  1. Write Tests calls.
  2. Add Docs.
  3. struct to store results for ease of analysis.

Weak SDEPINN Solver

The current implementation of SDEPINN is a Physics-Informed Neural Network (PINN) solver for scalar SDEs, solving Fokker–Planck (Kolmogorov forward) PDEs instead of sampling individual trajectories.

1. Problem

We consider a 1D Itô SDE:

$$ dX_t = f(X_t, t), dt + g(X_t, t), dW_t, $$

with scalar state $(X_t \in \mathbb{R})$, general drift (f) and diffusion (g), and initial condition $(X_{t_0} = x_0)$.

The solver learns the probability density (p(x,t)) satisfying the Fokker–Planck PDE:

$$ \frac{\partial p(x,t)}{\partial t} = -\frac{\partial}{\partial x}\Big[f(x,t) p(x,t)\Big] + \frac{1}{2} \frac{\partial^2}{\partial x^2} \Big[g(x,t)^2 p(x,t)\Big]. $$


2. Domain and boundary conditions

Spatial domain: $(x \in [x_0, x_\text{end}])$ - The user must decide this based on rough range of expected SDE solution.
Initial condition approximated as a narrow PDF:

$$ p(x, t_0) \approx \mathcal{N}(x_0, \sigma_{bc}^2) $$

Boundary conditions:

  1. Absorbing:
    $p(x_0, t) = p(x_{end}, t) = 0 $

  2. Reflecting (zero flux at chosen spatial domain's boundaries) :
    $J(x,t) = f(x,t)p(x,t) - \frac{1}{2} \frac{\partial}{\partial x}\big[g(x,t)^2 p(x,t)\big] = 0$

Currently (for below results) only reflective BCs were enabled - in practice, one can choose the consistent BC loss terms.


3. Loss formulation

The total loss consists of :

  1. PDE residual (enforces Fokker–Planck equation):

$$ \mathcal{L}_\text{PDE} = \sum_{(x,t) \in \text{grid}} \Bigg| \frac{\partial p_\theta(x,t)}{\partial t} + \frac{\partial}{\partial x}[f(x,t)p_\theta(x,t)] - \frac{1}{2} \frac{\partial^2}{\partial x^2}[g(x,t)^2 p_\theta(x,t)] \Bigg|^2 $$

  1. Boundary condition loss:

$$ \mathcal{L}_\text{BC} = \sum_{x \in {x_0, x_\text{end}}} |p_\theta(x,t)|^2 + |J_\theta(x,t)|^2 $$

  1. Normalization loss (probability mass = 1):

$$ \mathcal{L}_\text{norm} = \sum_{t \in \text{grid}} \Big| \int_{x_0}^{x_\text{end}} p_\theta(x,t), dx - 1 \Big|^2 $$

  1. Initial condition loss:

$$ \mathcal{L}_\text{IC} = \sum_{x \in \text{IC points}} |p_\theta(x,t_0) - p_\text{IC}(x)|^2 $$

The total loss minimized during training:

$$ \mathcal{L}_\text{total} = \mathcal{L}_\text{PDE} + \mathcal{L}_\text{IC} + \mathcal{L}_\text{BC} + \lambda*\text{norm} \mathcal{L}_\text{norm} $$

Currently, IC loss is implemented pointwise; normalization enforces integrated probability mass.


4. Potential improvements

  • Multidimensional SDEs: currently only scalar (X_t); generalizing requires multi-dimensional PDEs with vector X.
  • Boundary condition handling: give user choice with absorbing &/or reflecting BC.
  • Batching & autodiff: ensure consistency for large grids.
  • Parameter estimation mode: ?

TL;DR

SDEPINN solves scalar SDEs at the density level via PINNs, enforcing PDE, BCs, ICs, and normalization.

Current results :

  1. GBM SDE (resulting state follows a LogNormal Distribution)
image
  1. OU Process (resulting state follows a Normal distribution)
image

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

Add any other context about the problem here.

@AstitvaAggarwal
Copy link
Member Author

AstitvaAggarwal commented Jan 14, 2026

Messy PR aside, the CI needs work...

@avik-pal
Copy link
Member

  Compiling Tuple{typeof(Cubature.integrands), Cubature.IntegrandData{IntegralsCubatureExt.var"#3#10"{ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:12, ShapedAxis((12, 1))), bias = ViewAxis(13:24, Shaped1DAxis((12,))))), layer_2 = ViewAxis(25:37, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12))), bias = ViewAxis(13:13, Shaped1DAxis((1,))))))}}}, SciMLBase.BatchIntegralFunction{false, SciMLBase.FullSpecialize, Integrals.var"#26#28"{typeof(Integrals.t2ujac), Vector{Float64}, Vector{Float64}, NeuralPDE.var"#integrand#124"{NeuralPDE.var"#262#263"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:cord, Symbol("##θ#226"), :phi, :derivative, :integral, :u, :p), NeuralPDE.var"#_RGF_ModTag", NeuralPDE.var"#_RGF_ModTag", (0x9688f044, 0xc7770165, 0xa8db61b5, 0x159219f0, 0x326e88bb), Expr}, NeuralPDE.var"#12#13", NeuralPDE.var"#304#311"{NeuralPDE.var"#304#305#312"{typeof(NeuralPDE.numeric_derivative)}, Dict{Symbol, Int64}, Dict{Symbol, Int64}, NeuralPDE.QuadratureTraining{Float64, Integrals.CubatureJLh}}, typeof(NeuralPDE.numeric_derivative), NeuralPDE.Phi{LuxCore.StatefulLuxLayerImpl.StatefulLuxLayer{Val{true}, Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Nothing}, DataType, MLDataDevices.CPUDevice{Missing}}}, Nothing}}}, Bool, Bool, Bool}: UndefVarError: `spvals` not defined

(p̂(u0[i], t₀) .- Distributions.pdf(distrib[i], u0[i]) ~ P(0) for i in 1:length(u0))
end

# # inside optimization loss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this all for? It seems like it would be difficult for this to ever sample?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually testing with different forms of loss, this approach did sample well for GBM but not for OU.

Comment on lines 140 to 144
# # I_ic = solve(IntegralProblem(f_ic, x_0, x_end, θ), norm_loss_alg,
# # HCubatureJL(),
# # reltol = 1e-8, abstol = 1e-8, maxiters = 10)[1]
# # return abs(I_ic) # I_ic loss AUC = 0?
# return sum(abs2, ftest_icloss(phi, θ)) # I_ic pointwise
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this?

Copy link
Member Author

@AstitvaAggarwal AstitvaAggarwal Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am yet to clear this part (was tidying up some other parts of the code first). But here i was manually taking the quadrature of the diff between the expected & IC pdf.

Comment on lines +164 to +179
# absorbing Bcs
if absorbing_bc
@info "absorbing BCS used"

bcs = vcat(bcs, [p̂(x_0, T) ~ P(0),
p̂(x_end, T) ~ P(0)]...)
end

# reflecting Bcs
if reflective_bc
@info "reflecting BCS used"

bcs = vcat(bcs, [J(x_0, T) ~ P(0),
J(x_end, T) ~ P(0)
]...)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's infinite domain though?

Copy link
Member Author

@AstitvaAggarwal AstitvaAggarwal Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure how we can enforce that properly? (so i though we could just enforcing on the user chosen truncated domain)

@ChrisRackauckas
Copy link
Member

Make the commented dead code into instead options that can be switched in the algorithm, or delete them if they are never useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants