Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
29e123e
Too Easy.
AstitvaAggarwal Apr 13, 2025
6bf8c49
remove Integrals from test deps
AstitvaAggarwal Apr 13, 2025
45c1e39
tests. (new is better in more noise)
AstitvaAggarwal Apr 13, 2025
83a642d
minor change
AstitvaAggarwal Apr 13, 2025
5cfc6e7
.
AstitvaAggarwal Apr 13, 2025
d8f602b
.
AstitvaAggarwal Apr 13, 2025
30fa615
Likelihood probabilites are not driven to 0.
AstitvaAggarwal Apr 14, 2025
85b350c
.
AstitvaAggarwal Apr 14, 2025
3c59fde
more samples
AstitvaAggarwal Apr 14, 2025
f2aafdb
.
AstitvaAggarwal Apr 14, 2025
70e956d
fixed tests
AstitvaAggarwal Apr 14, 2025
41691eb
tests.
AstitvaAggarwal Apr 16, 2025
f21969c
.
AstitvaAggarwal Apr 16, 2025
a968cc8
std for new loss is parametric
AstitvaAggarwal Apr 26, 2025
c66772a
Changes to API
AstitvaAggarwal Apr 26, 2025
2fcf75e
tests.
AstitvaAggarwal Apr 27, 2025
cb85411
tests-2
AstitvaAggarwal Apr 27, 2025
2423224
tests-3
AstitvaAggarwal Apr 27, 2025
cc1348f
Update BPINN_tests.jl
AstitvaAggarwal Apr 27, 2025
4a1ca0e
BPINN_PDE loss corrected
AstitvaAggarwal Apr 27, 2025
107165e
NNODE improvements & L2Data!=additional_loss
AstitvaAggarwal May 3, 2025
a4d1fb7
spelling check
AstitvaAggarwal May 3, 2025
d47c19c
tests
AstitvaAggarwal May 3, 2025
0bc0ec1
tests-1
AstitvaAggarwal May 3, 2025
65c4b08
tests-3
AstitvaAggarwal May 5, 2025
8ed4e18
format
AstitvaAggarwal May 6, 2025
c242419
Update src/BPINN_ode.jl
AstitvaAggarwal May 6, 2025
95140dd
cubature over L2 instead of L1
AstitvaAggarwal May 6, 2025
6cb24c5
Merge branch 'sdepinn' of https://github.com/AstitvaAggarwal/NeuralPD…
AstitvaAggarwal May 6, 2025
f24df29
bpinn remains in non squared within logpdf?
AstitvaAggarwal May 6, 2025
e8dfd9a
changes from reviews.
AstitvaAggarwal May 16, 2025
0b7123a
docstrings, support preexisting tutorials
AstitvaAggarwal May 16, 2025
d594ac3
Update BPINN_tests.jl
AstitvaAggarwal May 16, 2025
e31d4e2
Update BPINN_ode.jl
AstitvaAggarwal May 16, 2025
b681693
Merge branch 'SciML:master' into sdepinn
AstitvaAggarwal May 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Distributions = "0.25.107"
DocStringExtensions = "0.9.3"
DomainSets = "0.7"
ExplicitImports = "1.10.1"
FastGaussQuadrature = "1.0.2"
Flux = "0.14.22"
ForwardDiff = "0.10.36"
Functors = "0.4.12, 0.5"
Expand All @@ -92,6 +93,7 @@ Optimization = "4"
OptimizationOptimJL = "0.4"
OptimizationOptimisers = "0.3"
OrdinaryDiffEq = "6.87"
PolyChaos = "0.2.11"
Printf = "1.10"
QuasiMonteCarlo = "0.3.2"
Random = "1"
Expand All @@ -116,6 +118,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -126,10 +129,11 @@ LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PolyChaos = "8d666b04-775d-5f6e-b778-5ac7c70f65a3"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "Flux", "Hwloc", "InteractiveUtils", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "ReTestItems", "StochasticDiffEq", "TensorBoardLogger", "Test"]
test = ["Aqua", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "FastGaussQuadrature", "Flux", "Hwloc", "InteractiveUtils", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "PolyChaos", "ReTestItems", "StochasticDiffEq", "TensorBoardLogger", "Test"]
6 changes: 3 additions & 3 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
phystd = [0.05], phynewstd = (ode_params)->[0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (; n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (; Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric),
Expand Down Expand Up @@ -86,7 +86,7 @@ Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, El
param <: Union{Nothing, Vector{<:Distribution}}
l2std::Vector{Float64}
phystd::Vector{Float64}
phynewstd::Vector{Float64}
phynewstd::Function
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
physdt::Float64
MCMCkwargs <: NamedTuple
Expand All @@ -103,7 +103,7 @@ end

function BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 1000,
priorsNNw = (0.0, 2.0), param = nothing, l2std = [0.05], phystd = [0.05],
phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
phynewstd = (ode_params) -> [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCkwargs = (n_leapfrog = 30,), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Expand Down
2 changes: 1 addition & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ using AdvancedHMC: AdvancedHMC, DiagEuclideanMetric, HMC, HMCDA, Hamiltonian,
using Distributions: Distributions, Distribution, MvNormal, Normal, dim, logpdf
using LogDensityProblems: LogDensityProblems
using MCMCChains: MCMCChains, Chains, sample
using MonteCarloMeasurements: Particles, pmean
using MonteCarloMeasurements: Particles

import LuxCore: initialparameters, initialstates, parameterlength

Expand Down
44 changes: 26 additions & 18 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
priors <: Vector{<:Distribution}
phystd::Vector{Float64}
phynewstd::Vector{Float64}
phynewstd::Function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specialize?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not sure how we can specialize functions...
(Im keeping a function for std in BPINNs as selecting the right std can be tricky and usually depends on the problem)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

::F will

l2std::Vector{Float64}
autodiff::Bool
physdt::Float64
Expand Down Expand Up @@ -76,30 +76,35 @@ suggested extra loss function for ODE solver case
ltd.extraparams ≤ 0 && return false # XXX: type-stability?

f = ltd.prob.f
t = ltd.dataset[end]
u1 = ltd.dataset[2]
= ltd.dataset[1]
t = ltd.dataset[end - 1]
= ltd.dataset[1:(end - 2)]
quadrature_weights = ltd.dataset[end]

nnsol = ode_dfdx(ltd, t, θ[1:(length(θ) - ltd.extraparams)], ltd.autodiff)

ode_params = ltd.extraparams == 1 ? θ[((length(θ) - ltd.extraparams) + 1)] :
θ[((length(θ) - ltd.extraparams) + 1):length(θ)]
phynewstd = ltd.phynewstd(ode_params)

physsol = if length(ltd.prob.u0) == 1
[f(û[i], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)]
[f(û[1][i], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)]
else
[f([û[i], u1[i]], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)]
[f([û[j][i] for j in 1:length(û)], ode_params, tᵢ)
for (i, tᵢ) in enumerate(t)]
end
# form of NN output matrix output dim x n
deri_physsol = reduce(hcat, physsol)
T = promote_type(eltype(deri_physsol), eltype(nnsol))

physlogprob = T(0)
loss_vals = nnsol .- deri_physsol
# for BPINNS Quadrature is NOT applied on timewise logpdfs, it isnt being driven to zero.
# Gridtraining/trapezoidal rule quadrature_weights is dt.*ones(T, length(t))
for i in 1:length(ltd.prob.u0)
physlogprob += logpdf(
MvNormal(deri_physsol[i, :],
Diagonal(abs2.(T(ltd.phynewstd[i]) .* ones(T, length(nnsol[i, :]))))),
nnsol[i, :]
MvNormal(loss_vals[i, :] .* quadrature_weights,
Diagonal(abs2.(T(phynewstd[i]) .* ones(T, length(nnsol[i, :]))))),
zeros(length(t))
)
end
return physlogprob
Expand All @@ -112,7 +117,7 @@ L2 loss loglikelihood(needed for ODE parameter estimation).
(ltd.dataset isa Vector{Nothing} || ltd.extraparams == 0) && return 0

# matrix(each row corresponds to vector u's rows)
nn = ltd(ltd.dataset[end], θ[1:(length(θ) - ltd.extraparams)])
nn = ltd(ltd.dataset[end - 1], θ[1:(length(θ) - ltd.extraparams)])
T = eltype(nn)

L2logprob = zero(T)
Expand Down Expand Up @@ -150,24 +155,26 @@ end
function getlogpdf(strategy::GridTraining, ltd::LogTargetDensity, f, autodiff::Bool,
tspan, ode_params, θ)
ts = collect(eltype(strategy.dx), tspan[1]:(strategy.dx):tspan[2])
t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end])
t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end - 1])
return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params))
end

function getlogpdf(strategy::StochasticTraining, ltd::LogTargetDensity,
f, autodiff::Bool, tspan, ode_params, θ)
T = promote_type(eltype(tspan[1]), eltype(tspan[2]))
samples = (tspan[2] - tspan[1]) .* rand(T, strategy.points) .+ tspan[1]
t = ltd.dataset isa Vector{Nothing} ? samples : vcat(samples, ltd.dataset[end])
t = ltd.dataset isa Vector{Nothing} ? samples : vcat(samples, ltd.dataset[end - 1])
return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params))
end

function getlogpdf(strategy::QuadratureTraining, ltd::LogTargetDensity, f, autodiff::Bool,
tspan, ode_params, θ)
# integrand is shape of NN output
integrand(t::Number, θ) = innerdiff(ltd, f, autodiff, [t], θ, ode_params)
intprob = IntegralProblem(
integrand, (tspan[1], tspan[2]), θ; nout = length(ltd.prob.u0))
sol = solve(intprob, QuadGKJL(); strategy.abstol, strategy.reltol)
# sum over losses for all NN outputs
return sum(sol.u)
end

Expand All @@ -185,7 +192,7 @@ function getlogpdf(strategy::WeightedIntervalTraining, ltd::LogTargetDensity, f,
append!(ts, temp_data)
end

t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end])
t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end - 1])
return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params))
end

Expand Down Expand Up @@ -264,7 +271,7 @@ end
"""
ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0f0,
l2std = [0.05], phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0),
l2std = [0.05], phystd = [0.05], phynewstd = (ode_params)->[0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Expand Down Expand Up @@ -337,7 +344,7 @@ Incase you are only solving the Equations for solution, do not provide dataset
~2/3 of draw samples)
* `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset
* `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System
* `phynewstd`: standard deviation of new loss func term
* `phynewstd`: Function in ode_params that gives the standard deviation of the new loss function terms.
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of
BPINN are Normal Distributions by default.
* `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems.
Expand Down Expand Up @@ -368,7 +375,8 @@ Incase you are only solving the Equations for solution, do not provide dataset
function ahmc_bayesian_pinn_ode(
prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0, l2std = [0.05],
phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
phystd = [0.05], phynewstd = (ode_params) -> [0.05],
priorsNNw = (0.0, 2.0), param = [], nchains = 1,
autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Expand All @@ -381,8 +389,8 @@ function ahmc_bayesian_pinn_ode(
strategy = strategy == GridTraining ? strategy(physdt) : strategy

if dataset != [nothing] &&
(length(dataset) < 2 || !(dataset isa Vector{<:Vector{<:AbstractFloat}}))
error("Invalid dataset. dataset would be timeseries (x̂,t) where type: Vector{Vector{AbstractFloat}")
(length(dataset) < 3 || !(dataset isa Vector{<:Vector{<:AbstractFloat}}))
error("Invalid dataset. dataset would be timeseries (x̂,t,W) where type: Vector{Vector{AbstractFloat}")
end

if dataset != [nothing] && param == []
Expand Down
70 changes: 65 additions & 5 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,17 @@ Networks 9, no. 5 (1998): 987-1000.
strategy <: Union{Nothing, AbstractTrainingStrategy}
param_estim
additional_loss <: Union{Nothing, Function}
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
estim_collocate::Bool
kwargs
end

function NNODE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false,
batch = true, param_estim = false, additional_loss = nothing, kwargs...)
batch = true, param_estim = false, additional_loss = nothing,
dataset = [nothing], estim_collocate = false, kwargs...)
chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain))
return NNODE(chain, opt, init_params, autodiff, batch,
strategy, param_estim, additional_loss, kwargs)
strategy, param_estim, additional_loss, dataset, estim_collocate, kwargs)
end

"""
Expand Down Expand Up @@ -263,6 +266,44 @@ function generate_loss(::QuasiRandomTraining, phi, f, autodiff::Bool, tspan)
spaces only. Use StochasticTraining instead.")
end

"""
L2 loss (needed for ODE parameter estimation).
"""
function generate_L2lossData(dataset, phi, n_output)
dataset isa Vector{Nothing} && return 0
return (θ, _) -> sum(sum(abs2, phi(dataset[end - 1], θ)[i, :] .- dataset[i])
for i in 1:n_output)
end

"""
new loss
"""
function generate_L2loss2(f, autodiff, dataset, phi, n_output)
dataset isa Vector{Nothing} && return 0
t = dataset[end - 1]
û = dataset[1:(end - 2)]
quadrature_weights = dataset[end]

function L2loss2(θ, _)
nnsol = ode_dfdx(phi, t, θ, autodiff)
ode_params = θ.p

physsol = if n_output == 1
[f(û[1][i], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)]
else
[f([û[j][i] for j in 1:(length(dataset) - 2)], ode_params, tᵢ)
for (i, tᵢ) in enumerate(t)]
end
# form of NN output matrix output dim x n
deri_physsol = reduce(hcat, physsol)
loss_vals = nnsol .- deri_physsol

# Quadrature is applied on timewise losses
# Gridtraining/trapezoidal rule quadrature_weights is dt.*ones(T, length(t))
return sum(sum(abs2, loss_vals[i, :] .* quadrature_weights) for i in 1:n_output)
end
end

@concrete struct NNODEInterpolation
phi <: ODEPhi
θ
Expand Down Expand Up @@ -307,7 +348,8 @@ function SciMLBase.__solve(
)
(; u0, tspan, f, p) = prob
t0 = tspan[1]
(; param_estim, chain, opt, autodiff, init_params, batch, additional_loss) = alg
# add estim_collocate, dataset (or nothing) in NNODE
(; param_estim, estim_collocate, dataset, chain, opt, autodiff, init_params, batch, additional_loss, estim_collocate) = alg

phi, init_params = generate_phi_θ(chain, t0, u0, init_params)

Expand Down Expand Up @@ -336,12 +378,30 @@ function SciMLBase.__solve(

inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim)

(param_estim && additional_loss === nothing) &&
throw(ArgumentError("Please provide `additional_loss` in `NNODE` for parameter estimation (`param_estim` is true)."))
if dataset != [nothing] &&
(length(dataset) < 3 || !(dataset isa Vector{<:Vector{<:AbstractFloat}}))
error("Invalid dataset. dataset would be timeseries (x̂,t,W) where type: Vector{Vector{AbstractFloat}")
end

if dataset == [nothing] && param_estim
error("Dataset is Required for Parameter Estimation.")
elseif dataset == [nothing] && estim_collocate
error("Dataset Required for Parameter Estimation using new loss.")
end

n_output = length(u0)
L2lossData = generate_L2lossData(dataset, phi, n_output)
L2loss2 = generate_L2loss2(f, autodiff, dataset, phi, n_output)

# Creates OptimizationFunction Object from total_loss
function total_loss(θ, _)
L2_loss = inner_f(θ, phi)

if param_estim && estim_collocate
L2_loss = L2_loss + L2lossData(θ, phi) + L2loss2(θ, phi)
elseif param_estim
L2_loss = L2_loss + L2lossData(θ, phi)
end
if additional_loss !== nothing
L2_loss = L2_loss + additional_loss(phi, θ)
end
Expand Down
5 changes: 3 additions & 2 deletions test/BPINN_PDE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,9 @@ end
end
end

@parameters x, t, α
@variables u(..)
@parameters α
@variables x, t
@syms u(x, t)
Dt = Differential(t)
Dx = Differential(x)
Dx2 = Differential(x)^2
Expand Down
Loading
Loading