Skip to content

Commit 152ded4

Browse files
Merge pull request #937 from AstitvaAggarwal/sdepinn
New loss
2 parents 393b42e + b681693 commit 152ded4

File tree

9 files changed

+319
-126
lines changed

9 files changed

+319
-126
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ Distributions = "0.25.107"
6767
DocStringExtensions = "0.9.3"
6868
DomainSets = "0.7"
6969
ExplicitImports = "1.10.1"
70+
FastGaussQuadrature = "1.0.2"
7071
Flux = "0.14.22"
7172
ForwardDiff = "0.10.36"
7273
Functors = "0.4.12, 0.5"
@@ -116,6 +117,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
116117
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
117118
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
118119
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
120+
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
119121
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
120122
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
121123
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@@ -132,4 +134,4 @@ TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
132134
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
133135

134136
[targets]
135-
test = ["Aqua", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "Flux", "Hwloc", "InteractiveUtils", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "ReTestItems", "StochasticDiffEq", "TensorBoardLogger", "Test"]
137+
test = ["Aqua", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "FastGaussQuadrature", "Flux", "Hwloc", "InteractiveUtils", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "ReTestItems", "StochasticDiffEq", "TensorBoardLogger", "Test"]

src/BPINN_ode.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
"""
44
BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 2000,
55
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
6-
phystd = [0.05], phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
6+
phystd = [0.05], phynewstd = (ode_params)->[0.05], dataset = [], physdt = 1 / 20.0,
77
MCMCargs = (; n_leapfrog=30), nchains = 1, init_params = nothing,
88
Adaptorkwargs = (; Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
99
Metric = DiagEuclideanMetric),
10-
Integratorkwargs = (Integrator = Leapfrog,), autodiff = false,
11-
progress = false, verbose = false)
10+
Integratorkwargs = (Integrator = Leapfrog,), autodiff = false, estim_collocate = false, progress = false, verbose = false)
1211
1312
Algorithm for solving ordinary differential equations using a Bayesian neural network. This
1413
is a specialization of the physics-informed neural network which is used as a solver for a
@@ -43,7 +42,7 @@ sol = solve(prob, Tsit5(); saveat = 0.05)
4342
u = sol.u[1:100]
4443
time = sol.t[1:100]
4544
x̂ = u .+ (u .* 0.2) .* randn(size(u))
46-
dataset = [x̂, time]
45+
dataset = [x̂, time, 0.05 .* ones(length(time))]
4746
4847
chainlux = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6, 1))
4948
@@ -86,8 +85,8 @@ Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, El
8685
param <: Union{Nothing, Vector{<:Distribution}}
8786
l2std::Vector{Float64}
8887
phystd::Vector{Float64}
89-
phynewstd::Vector{Float64}
90-
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
88+
phynewstd
89+
dataset <: Union{Vector, Vector{<:Vector{<:AbstractFloat}}}
9190
physdt::Float64
9291
MCMCkwargs <: NamedTuple
9392
nchains::Int
@@ -103,7 +102,7 @@ end
103102

104103
function BNNODE(chain, kernel = HMC; strategy = nothing, draw_samples = 1000,
105104
priorsNNw = (0.0, 2.0), param = nothing, l2std = [0.05], phystd = [0.05],
106-
phynewstd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
105+
phynewstd = (ode_params) -> [0.05], dataset = [], physdt = 1 / 20.0,
107106
MCMCkwargs = (n_leapfrog = 30,), nchains = 1, init_params = nothing,
108107
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
109108
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),

src/NeuralPDE.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ using AdvancedHMC: AdvancedHMC, DiagEuclideanMetric, HMC, HMCDA, Hamiltonian,
4949
using Distributions: Distributions, Distribution, MvNormal, Normal, dim, logpdf
5050
using LogDensityProblems: LogDensityProblems
5151
using MCMCChains: MCMCChains, Chains, sample
52-
using MonteCarloMeasurements: Particles, pmean
52+
using MonteCarloMeasurements: Particles
5353

5454
import LuxCore: initialparameters, initialstates, parameterlength
5555

src/advancedHMC_MCMC.jl

Lines changed: 59 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
prob <: SciMLBase.ODEProblem
44
smodel <: StatefulLuxLayer
55
strategy <: AbstractTrainingStrategy
6-
dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}
6+
dataset <: Union{Vector, Vector{<:Vector{<:AbstractFloat}}}
77
priors <: Vector{<:Distribution}
88
phystd::Vector{Float64}
9-
phynewstd::Vector{Float64}
9+
phynewstd::Function
1010
l2std::Vector{Float64}
1111
autodiff::Bool
1212
physdt::Float64
@@ -74,32 +74,37 @@ suggested extra loss function for ODE solver case
7474
"""
7575
@views function L2loss2(ltd::LogTargetDensity, θ)
7676
ltd.extraparams 0 && return false # XXX: type-stability?
77-
77+
u0 = ltd.prob.u0
7878
f = ltd.prob.f
79-
t = ltd.dataset[end]
80-
u1 = ltd.dataset[2]
81-
= ltd.dataset[1]
79+
t = ltd.dataset[end - 1]
80+
= ltd.dataset[1:(end - 2)]
81+
quadrature_weights = ltd.dataset[end]
8282

8383
nnsol = ode_dfdx(ltd, t, θ[1:(length(θ) - ltd.extraparams)], ltd.autodiff)
8484

8585
ode_params = ltd.extraparams == 1 ? θ[((length(θ) - ltd.extraparams) + 1)] :
8686
θ[((length(θ) - ltd.extraparams) + 1):length(θ)]
87+
phynewstd = ltd.phynewstd(ode_params)
8788

88-
physsol = if length(ltd.prob.u0) == 1
89-
[f(û[i], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)]
89+
physsol = if length(u0) == 1
90+
[f(û[1][i], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)]
9091
else
91-
[f([û[i], u1[i]], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)]
92+
[f([û[j][i] for j in eachindex(u0)], ode_params, tᵢ)
93+
for (i, tᵢ) in enumerate(t)]
9294
end
9395
# form of NN output matrix output dim x n
9496
deri_physsol = reduce(hcat, physsol)
9597
T = promote_type(eltype(deri_physsol), eltype(nnsol))
9698

9799
physlogprob = T(0)
98-
for i in 1:length(ltd.prob.u0)
100+
# for BPINNS Quadrature is NOT applied on timewise logpdfs, it isnt being driven to zero.
101+
# Gridtraining/trapezoidal rule quadrature_weights is dt.*ones(T, length(t))
102+
# dims of phynewstd is same as u0 due to BNNODE being an out-of-place ODE solver.
103+
for i in eachindex(u0)
99104
physlogprob += logpdf(
100-
MvNormal(deri_physsol[i, :],
101-
Diagonal(abs2.(T(ltd.phynewstd[i]) .* ones(T, length(nnsol[i, :]))))),
102-
nnsol[i, :]
105+
MvNormal((nnsol[i, :] .- deri_physsol[i, :]) .* quadrature_weights,
106+
Diagonal(abs2.(T(phynewstd[i]) .* ones(T, length(t))))),
107+
zeros(length(t))
103108
)
104109
end
105110
return physlogprob
@@ -109,10 +114,10 @@ end
109114
L2 loss loglikelihood(needed for ODE parameter estimation).
110115
"""
111116
@views function L2LossData(ltd::LogTargetDensity, θ)
112-
(ltd.dataset isa Vector{Nothing} || ltd.extraparams == 0) && return 0
117+
(isempty(ltd.dataset) || ltd.extraparams == 0) && return 0
113118

114119
# matrix(each row corresponds to vector u's rows)
115-
nn = ltd(ltd.dataset[end], θ[1:(length(θ) - ltd.extraparams)])
120+
nn = ltd(ltd.dataset[end - 1], θ[1:(length(θ) - ltd.extraparams)])
116121
T = eltype(nn)
117122

118123
L2logprob = zero(T)
@@ -150,24 +155,26 @@ end
150155
function getlogpdf(strategy::GridTraining, ltd::LogTargetDensity, f, autodiff::Bool,
151156
tspan, ode_params, θ)
152157
ts = collect(eltype(strategy.dx), tspan[1]:(strategy.dx):tspan[2])
153-
t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end])
158+
t = isempty(ltd.dataset) ? ts : vcat(ts, ltd.dataset[end - 1])
154159
return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params))
155160
end
156161

157162
function getlogpdf(strategy::StochasticTraining, ltd::LogTargetDensity,
158163
f, autodiff::Bool, tspan, ode_params, θ)
159164
T = promote_type(eltype(tspan[1]), eltype(tspan[2]))
160165
samples = (tspan[2] - tspan[1]) .* rand(T, strategy.points) .+ tspan[1]
161-
t = ltd.dataset isa Vector{Nothing} ? samples : vcat(samples, ltd.dataset[end])
166+
t = isempty(ltd.dataset) ? samples : vcat(samples, ltd.dataset[end - 1])
162167
return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params))
163168
end
164169

165170
function getlogpdf(strategy::QuadratureTraining, ltd::LogTargetDensity, f, autodiff::Bool,
166171
tspan, ode_params, θ)
172+
# integrand is shape of NN output
167173
integrand(t::Number, θ) = innerdiff(ltd, f, autodiff, [t], θ, ode_params)
168174
intprob = IntegralProblem(
169175
integrand, (tspan[1], tspan[2]), θ; nout = length(ltd.prob.u0))
170176
sol = solve(intprob, QuadGKJL(); strategy.abstol, strategy.reltol)
177+
# sum over losses for all NN outputs
171178
return sum(sol.u)
172179
end
173180

@@ -185,7 +192,7 @@ function getlogpdf(strategy::WeightedIntervalTraining, ltd::LogTargetDensity, f,
185192
append!(ts, temp_data)
186193
end
187194

188-
t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end])
195+
t = isempty(ltd.dataset) ? ts : vcat(ts, ltd.dataset[end - 1])
189196
return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params))
190197
end
191198

@@ -202,23 +209,21 @@ MvNormal likelihood at each `ti` in time `t` for ODE collocation residue with NN
202209

203210
# this is a vector{vector{dx,dy}}(handle case single u(float passed))
204211
if length(out[:, 1]) == 1
205-
physsol = [f(out[:, i][1], ode_params, t[i]) for i in 1:length(out[1, :])]
212+
physsol = [f(out[:, i][1], ode_params, t[i]) for i in eachindex(t)]
206213
else
207-
physsol = [f(out[:, i], ode_params, t[i]) for i in 1:length(out[1, :])]
214+
physsol = [f(out[:, i], ode_params, t[i]) for i in eachindex(t)]
208215
end
209216
physsol = reduce(hcat, physsol)
210217

211218
nnsol = ode_dfdx(ltd, t, θ[1:(length(θ) - ltd.extraparams)], autodiff)
212-
213-
vals = nnsol .- physsol
214-
T = eltype(vals)
219+
T = eltype(nnsol)
215220

216221
# N dimensional vector if N outputs for NN(each row has logpdf of u[i] where u is vector
217222
# of dependant variables)
218223
return [logpdf(
219-
MvNormal(vals[i, :],
220-
Diagonal(abs2.(T(ltd.phystd[i]) .* ones(T, length(vals[i, :]))))),
221-
zeros(T, length(vals[i, :]))
224+
MvNormal((nnsol[i, :] .- physsol[i, :]),
225+
Diagonal(abs2.(T(ltd.phystd[i]) .* ones(T, length(t))))),
226+
zeros(T, length(t))
222227
) for i in 1:length(ltd.prob.u0)]
223228
end
224229

@@ -262,9 +267,9 @@ function kernelchoice(Kernel, MCMCkwargs)
262267
end
263268

264269
"""
265-
ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [nothing],
270+
ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [],
266271
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0f0,
267-
l2std = [0.05], phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0),
272+
l2std = [0.05], phystd = [0.05], phynewstd = (ode_params)->[0.05], priorsNNw = (0.0, 2.0),
268273
param = [], nchains = 1, autodiff = false, Kernel = HMC,
269274
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
270275
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
@@ -294,7 +299,7 @@ time = sol.t[1:100]
294299
295300
### dataset and BPINN create
296301
x̂ = collect(Float64, Array(u) + 0.05 * randn(size(u)))
297-
dataset = [x̂, time]
302+
dataset = [x̂, time, 0.05 .* ones(length(time))]
298303
299304
chain1 = Lux.Chain(Lux.Dense(1, 5, tanh), Lux.Dense(5, 5, tanh), Lux.Dense(5, 1)
300305
@@ -318,26 +323,32 @@ fh_mcmc_chain2, fhsamples2, fhstats2 = ahmc_bayesian_pinn_ode(prob, chain1,
318323
319324
## NOTES
320325
321-
Dataset is required for accurate Parameter estimation + solving equations
322-
Incase you are only solving the Equations for solution, do not provide dataset
326+
Dataset is required for accurate Parameter estimation in Inverse Problems.
327+
Incase you are only solving Non parametric ODE Equations for a solution, do not provide a dataset.
323328
324329
## Positional Arguments
325330
326-
* `prob`: DEProblem(out of place and the function signature should be f(u,p,t).
331+
* `prob`: ODEProblem(out of place and the function signature should be f(u,p,t).
327332
* `chain`: Lux Neural Netork which would be made the Bayesian PINN.
328333
329334
## Keyword Arguments
330335
331336
* `strategy`: The training strategy used to choose the points for the evaluations. By
332337
default GridTraining is used with given physdt discretization.
338+
* `dataset`: Is either an empty Vector or a nested Vector of the form `[x̂, t, W]` where `x̂` are dependant variable observations, `t` are time points and `W` are quadrature weights for domain.
339+
The dataset is used to compute the L2 loss against the data and also for the new loss function.
340+
For multiple dependant variables, there will be multiple vectors with the last two vectors in dataset still being for `t`, `W`.
341+
Is empty by default assuming a forward problem is being solved.
333342
* `init_params`: initial parameter values for BPINN (ideally for multiple chains different
334343
initializations preferred)
335344
* `nchains`: number of chains you want to sample
336345
* `draw_samples`: number of samples to be drawn in the MCMC algorithms (warmup samples are
337346
~2/3 of draw samples)
338347
* `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset
339348
* `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System
340-
* `phynewstd`: standard deviation of new loss func term
349+
* `phynewstd`: A function that gives the standard deviation of the new loss function at each iteration.
350+
It takes the ODE parameters as input and returns a vector of standard deviations.
351+
Is (ode_params) -> [0.05] by default.
341352
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of
342353
BPINN are Normal Distributions by default.
343354
* `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems.
@@ -357,6 +368,7 @@ Incase you are only solving the Equations for solution, do not provide dataset
357368
* `max_depth`: Maximum doubling tree depth (NUTS)
358369
* `Δ_max`: Maximum divergence during doubling tree (NUTS)
359370
Refer: https://turinglang.org/AdvancedHMC.jl/stable/
371+
* `estim_collocate`: A boolean value to indicate whether to use the new loss function or not. This is only relevant for ODE parameter estimation.
360372
* `progress`: controls whether to show the progress meter or not.
361373
* `verbose`: controls the verbosity. (Sample call args in AHMC)
362374
@@ -366,9 +378,10 @@ Incase you are only solving the Equations for solution, do not provide dataset
366378
releases.
367379
"""
368380
function ahmc_bayesian_pinn_ode(
369-
prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [nothing],
381+
prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [],
370382
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0, l2std = [0.05],
371-
phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
383+
phystd = [0.05], phynewstd = (ode_params) -> [0.05],
384+
priorsNNw = (0.0, 2.0), param = [], nchains = 1,
372385
autodiff = false, Kernel = HMC,
373386
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
374387
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
@@ -380,15 +393,15 @@ function ahmc_bayesian_pinn_ode(
380393

381394
strategy = strategy == GridTraining ? strategy(physdt) : strategy
382395

383-
if dataset != [nothing] &&
384-
(length(dataset) < 2 || !(dataset isa Vector{<:Vector{<:AbstractFloat}}))
385-
error("Invalid dataset. dataset would be timeseries (x̂,t) where type: Vector{Vector{AbstractFloat}")
396+
if !isempty(dataset) &&
397+
(length(dataset) < 3 || !(dataset isa Vector{<:Vector{<:AbstractFloat}}))
398+
error("Invalid dataset. The dataset would be a timeseries (x̂,t,W) with type: Vector{Vector{AbstractFloat}}")
386399
end
387400

388-
if dataset != [nothing] && param == []
389-
println("Dataset is only needed for Parameter Estimation + Forward Problem, not in only Forward Problem case.")
390-
elseif dataset == [nothing] && param != []
391-
error("Dataset Required for Parameter Estimation.")
401+
if !isempty(dataset) && isempty(param)
402+
println("Dataset is only needed for Inverse problems performing Parameter Estimation, not in only Forward Problem case.")
403+
elseif isempty(dataset) && !isempty(param)
404+
error("Dataset Required for Inverse problems performing Parameter Estimation.")
392405
end
393406

394407
initial_nnθ, chain, st = generate_ltd(chain, init_params)
@@ -461,7 +474,8 @@ function ahmc_bayesian_pinn_ode(
461474

462475
MCMC_alg = kernelchoice(Kernel, MCMCkwargs)
463476
Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator)
464-
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor;
477+
samples,
478+
stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor;
465479
progress = progress, verbose = verbose)
466480

467481
samplesc[i] = samples
@@ -479,7 +493,8 @@ function ahmc_bayesian_pinn_ode(
479493

480494
MCMC_alg = kernelchoice(Kernel, MCMCkwargs)
481495
Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator)
482-
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
496+
samples,
497+
stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
483498
adaptor; progress = progress, verbose = verbose)
484499

485500
if verbose

0 commit comments

Comments
 (0)