diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 07dc53980..660c8e6c3 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -30,7 +30,8 @@ jobs: - "QA" - "ODEBPINN" - "PDEBPINN" - - "NNSDE" + - "NNSDE1" + - "NNSDE2" - "NNPDE1" - "NNPDE2" - "AdaptiveLoss" diff --git a/src/NN_SDE_weaksolve.jl b/src/NN_SDE_weaksolve.jl new file mode 100644 index 000000000..e8f864bc1 --- /dev/null +++ b/src/NN_SDE_weaksolve.jl @@ -0,0 +1,215 @@ + +@concrete struct SDEPINN + chain <: AbstractLuxLayer + optimalg + norm_loss_alg + initial_parameters + + # domain + discretization + x_0::Float64 + x_end::Float64 + Nt::Int + dx::Float64 + + # IC & normalization + σ_var_bc::Float64 + λ_ic::Float64 + λ_norm::Float64 + distrib::Distributions.Distribution + + # solver options + strategy <: Union{Nothing,AbstractTrainingStrategy} + autodiff::Bool + batch::Bool + param_estim::Bool + + # For postprocessing - solution handling + # xview::AbstractArray + # tview::AbstractArray + # phi::Phi + + dataset <: Union{Nothing,Vector,Vector{<:Vector}} + additional_loss <: Union{Nothing,Function} + kwargs +end + +function SDEPINN(; + chain, + optimalg=nothing, + norm_loss_alg=nothing, + initial_parameters=nothing, + x_0, + x_end, + Nt=50, + dx=0.01, + σ_var_bc=0.05, + λ_ic=1.0, + λ_norm=1.0, + distrib=Normal(0.5, 0.01), + strategy=nothing, + autodiff=true, + batch=false, + param_estim=false, + dataset=nothing, + additional_loss=nothing, + kwargs... +) + return SDEPINN( + chain, + optimalg, + norm_loss_alg, + initial_parameters, + x_0, + x_end, + Nt, + dx, + σ_var_bc, + λ_ic, + λ_norm, + distrib, + strategy, + autodiff, + batch, + param_estim, + dataset, + additional_loss, + kwargs + ) +end + +function SciMLBase.__solve( + prob::SciMLBase.AbstractSDEProblem, + alg::SDEPINN, + args...; + dt=nothing, + abtol=1.0f-6, + reltol=1.0f-3, + saveat=nothing, + tstops=nothing, + maxiters=200, + verbose=false, + kwargs..., +) + (; u0, tspan, f, g, p) = prob + P = eltype(u0) + t₀, t₁ = tspan + + absorbing_bc = false + reflective_bc = true + + (; x_0, x_end, Nt, dx, σ_var_bc, λ_ic, λ_norm, + distrib, optimalg, norm_loss_alg, initial_parameters, chain) = alg + + dt = (t₁ - t₀) / Nt + ts = collect(t₀:dt:t₁) + + # Define FP PDE + @parameters X, T + @variables p̂(..) + Dx = Differential(X) + Dxx = Differential(X)^2 + Dt = Differential(T) + + J(x, T) = prob.f(x, p, T) * p̂(x, T) - + P(0.5) * Dx((prob.g(x, p, T))^2 * p̂(x, T)) + + # IC symbolic equation form + f_icloss = if u0 isa Number + (p̂(u0, t₀) - Distributions.pdf(distrib, u0) ~ P(0),) + else + (p̂(u0[i], t₀) .- Distributions.pdf(distrib[i], u0[i]) ~ P(0) for i in 1:length(u0)) + end + + eq = Dt(p̂(X, T)) ~ -Dx(f(X, p, T) * p̂(X, T)) + + P(0.5) * Dxx((g(X, p, T))^2 * p̂(X, T)) + + # if we try to use p=0 and normalization it works + # however if we increase the x domainby too much on any side: + # The Normalization PDF mass although "conserved" inside domain + # can be forced to spread in different regions. + + bcs = [ + # No probability enters or leaves the domain + # Total mass is conserved + # Matches an SDE on a truncated but reflecting domain BC + + # IC LOSS (it's getting amplified by the number of training points.) + f_icloss... + ] + + # absorbing Bcs + if absorbing_bc + @info "absorbing BCS used" + + bcs = vcat(bcs, [p̂(x_0, T) ~ P(0), + p̂(x_end, T) ~ P(0)]...) + end + + # reflecting Bcs + if reflective_bc + @info "reflecting BCS used" + + bcs = vcat(bcs, [J(x_0, T) ~ P(0), + J(x_end, T) ~ P(0) + ]...) + end + + domains = [X ∈ (x_0, x_end), T ∈ (t₀, t₁)] + + # Additional losses + # Handle normloss and ICloss for vector NN outputs !! -> will need to adjst x0, x_end, u0 handling for this also !! + + σ_var_bc = 0.05 # must be narrow, dirac deltra function centering. (smaller this is, we drop NN from a taller point to learn) + function norm_loss(phi, θ) + loss = P(0) + for t in ts + # define integrand as a function of x only (t fixed) + # perform ∫ f(x) dx over [x_0, x_end] + phi_normloss(x, θ) = u0 isa Number ? first(phi([x, t], θ)) : phi([x, t], θ) + I_est = solve(IntegralProblem(phi_normloss, x_0, x_end, θ), norm_loss_alg, + reltol=1e-8, abstol=1e-8, maxiters=10)[1] + loss += abs2(I_est - P(1)) + end + return loss + end + + function combined_additional(phi, θ, _) + λ_norm * norm_loss(phi, θ) + end + + # Discretization - GridTraining only + discretization = PhysicsInformedNN( + chain, + GridTraining([dx, dt]); + init_params=initial_parameters, + additional_loss=combined_additional + ) + + @named pdesys = PDESystem(eq, bcs, domains, [X, T], [p̂(X, T)]) + opt_prob = discretize(pdesys, discretization) + phi = discretization.phi + + sym = NeuralPDE.symbolic_discretize(pdesys, discretization) + pde_losses = sym.loss_functions.pde_loss_functions + bc_losses = sym.loss_functions.bc_loss_functions + + cb = function (p, l) + (!verbose) && return false + println("loss = ", l) + println("pde = ", map(f -> f(p.u), pde_losses)) + println("bc = ", map(f -> f(p.u), bc_losses)) + println("norm = ", norm_loss(phi, p.u)) + return false + end + + res = Optimization.solve( + opt_prob, + optimalg; + callback=cb, + maxiters=maxiters, + kwargs... + ) + + # postprocessing? + return res, phi +end diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index d5e92fbd7..014cbde06 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -91,11 +91,13 @@ include("PDE_BPINN.jl") include("dgm.jl") include("NN_SDE_solve.jl") +include("NN_SDE_weaksolve.jl") export PINOODE export NNODE, NNDAE export BNNODE, ahmc_bayesian_pinn_ode, ahmc_bayesian_pinn_pde export NNSDE +export SDEPINN export PhysicsInformedNN, discretize export BPINNsolution, BayesianPINN export DeepGalerkin diff --git a/test/NN_SDE_weaksolve_tests.jl b/test/NN_SDE_weaksolve_tests.jl new file mode 100644 index 000000000..d682edf3b --- /dev/null +++ b/test/NN_SDE_weaksolve_tests.jl @@ -0,0 +1,136 @@ +@testitem "OU process" tags = [:nnsde2] begin + using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL, Optimisers + using OrdinaryDiffEq, Random, Distributions, Integrals, Cubature + using OptimizationOptimJL: BFGS + Random.seed!(100) + + α = -1 + β = 1 + u0 = 0.5 + t0 = 0.0 + f(u, p, t) = α * u + g(u, p, t) = β + tspan = (0.0, 1.0) + prob = SDEProblem(f, g, u0, tspan) + + # Neural network + inn = 20 + chain = Lux.Chain(Dense(2, inn, Lux.tanh), + Dense(inn, inn, Lux.tanh), + Dense(inn, 1, Lux.logcosh + )) |> f64 + + # problem setting + dx = 0.01 + x_0 = -4.0 + x_end = 4.0 + σ_var_bc = 0.05 + + alg = SDEPINN( + chain=chain, + optimalg=BFGS(), + norm_loss_alg=HCubatureJL(), + x_0=x_0, + x_end=x_end, + distrib=Normal(u0, σ_var_bc) + ) + + sol_OU, phi = solve( + prob, + alg, + maxiters=500, + ) + + # OU analytic solution + σ² = 0.5 # stationary variance = 1/2 <- # $Var_{\infty} = \frac{\beta^2}{2|\alpha|}$ + analytic_sol_func(x, t) = pdf(Normal(u0 * exp(-t), sqrt(σ² * (1 - exp(-2t)))), x) # mean μ and variance σ^2 + xs = collect(x_0:dx:x_end) + + # test at 0.1, not 0.0 ∵ analytic sol goes to inf (dirac delta func) + ts = [0.1, 0.2, 0.4, 0.6, 0.8, 1.0] + + u_real = [[analytic_sol_func(x, t) for x in xs] for t in ts] + u_predict = [[first(phi([x, t], sol_OU.u)) for x in xs] for t in ts] # NeuralPDE predictions + + # MSE across all x. + diff = u_real .- u_predict + @test mean(vcat([abs2.(diff_i) for diff_i in diff]...)) < 0.01 + + # using Plots + # plotly() + # plots_got = [] + # for i in 1:length(ts) + # plot(xs, u_real[i], label="analytic t=$(ts[i])") + # push!(plots_got, plot!(xs, u_predict[i], label="predict t=$(ts[i])")) + # end + # plot(plots_got..., legend=:outerbottomright) +end + +@testitem "GBM SDE" tags = [:nnsde2] begin + using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL, Optimisers + using OrdinaryDiffEq, Random, Distributions, Integrals, Cubature + using OptimizationOptimJL: BFGS + Random.seed!(100) + + μ = 0.2 + σ = 0.3 + f(x, p, t) = μ * x + g(x, p, t) = σ * x + u0 = 1.0 + tspan = (0.0, 1.0) + prob = SDEProblem(f, g, u0, tspan) + + # Neural network + inn = 20 + chain = Lux.Chain(Dense(2, inn, Lux.tanh), + Dense(inn, inn, Lux.tanh), + Dense(inn, 1, Lux.logcosh + )) |> f64 + + # problem setting - (results depend on x's assumed range) + + dx = 0.01 + x_0 = 0.0 + x_end = 3.0 + σ_var_bc = 0.05 + alg = SDEPINN( + chain=chain, + optimalg=BFGS(), + norm_loss_alg=HCubatureJL(), + x_0=x_0, + x_end=x_end, + + # pdf(LogNormal(log(X₀), σ_var_bc), x) # initial PDF + # for gbm normal X0 disti also gives good results with absorbing_bc. + distrib=LogNormal(log(u0), σ_var_bc) + ) + + sol_GBM, phi = solve( + prob, + alg, + maxiters=500 + ) + + analytic_sol_func(x, t) = pdf(LogNormal(log(u0) + (μ - 0.5 * σ^2) * t, sqrt(t) * σ), x) + xs = collect(x_0:dx:x_end) + + # test at 0.1, not 0.0 ∵ analytic sol goes to inf (dirac delta func) + ts = [0.0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0] + + u_real = [[analytic_sol_func(x, t) for x in xs] for t in ts] + u_predict = [[first(phi([x, t], sol_GBM.u)) for x in xs] for t in ts] + + # MSE across all x. + diff = u_real .- u_predict + @test mean(vcat([abs2.(diff_i) for diff_i in diff]...)) < 0.01 + + # Compare with analytic GBM solution + # using Plots + # plotly() + # plots_got = [] + # for i in 1:length(ts) + # plot(xs, u_real[i], label="analytic t=$(ts[i])") + # push!(plots_got, plot!(xs, u_predict[i], label="predict t=$(ts[i])")) + # end + # plot(plots_got..., legend=:outerbottomright) +end \ No newline at end of file