33 prob <: SciMLBase.ODEProblem
44 smodel <: StatefulLuxLayer
55 strategy <: AbstractTrainingStrategy
6- dataset <: Union{Vector{Nothing} , Vector{<:Vector{<:AbstractFloat}}}
6+ dataset <: Union{Vector, Vector{<:Vector{<:AbstractFloat}}}
77 priors <: Vector{<:Distribution}
88 phystd:: Vector{Float64}
9- phynewstd:: Vector{Float64}
9+ phynewstd:: Function
1010 l2std:: Vector{Float64}
1111 autodiff:: Bool
1212 physdt:: Float64
@@ -74,32 +74,37 @@ suggested extra loss function for ODE solver case
7474"""
7575@views function L2loss2 (ltd:: LogTargetDensity , θ)
7676 ltd. extraparams ≤ 0 && return false # XXX : type-stability?
77-
77+ u0 = ltd . prob . u0
7878 f = ltd. prob. f
79- t = ltd. dataset[end ]
80- u1 = ltd. dataset[2 ]
81- û = ltd. dataset[1 ]
79+ t = ltd. dataset[end - 1 ]
80+ û = ltd. dataset[1 : ( end - 2 ) ]
81+ quadrature_weights = ltd. dataset[end ]
8282
8383 nnsol = ode_dfdx (ltd, t, θ[1 : (length (θ) - ltd. extraparams)], ltd. autodiff)
8484
8585 ode_params = ltd. extraparams == 1 ? θ[((length (θ) - ltd. extraparams) + 1 )] :
8686 θ[((length (θ) - ltd. extraparams) + 1 ): length (θ)]
87+ phynewstd = ltd. phynewstd (ode_params)
8788
88- physsol = if length (ltd . prob . u0) == 1
89- [f (û[i], ode_params, tᵢ) for (i, tᵢ) in enumerate (t)]
89+ physsol = if length (u0) == 1
90+ [f (û[1 ][ i], ode_params, tᵢ) for (i, tᵢ) in enumerate (t)]
9091 else
91- [f ([û[i], u1[i]], ode_params, tᵢ) for (i, tᵢ) in enumerate (t)]
92+ [f ([û[j][i] for j in eachindex (u0)], ode_params, tᵢ)
93+ for (i, tᵢ) in enumerate (t)]
9294 end
9395 # form of NN output matrix output dim x n
9496 deri_physsol = reduce (hcat, physsol)
9597 T = promote_type (eltype (deri_physsol), eltype (nnsol))
9698
9799 physlogprob = T (0 )
98- for i in 1 : length (ltd. prob. u0)
100+ # for BPINNS Quadrature is NOT applied on timewise logpdfs, it isnt being driven to zero.
101+ # Gridtraining/trapezoidal rule quadrature_weights is dt.*ones(T, length(t))
102+ # dims of phynewstd is same as u0 due to BNNODE being an out-of-place ODE solver.
103+ for i in eachindex (u0)
99104 physlogprob += logpdf (
100- MvNormal (deri_physsol[i, :],
101- Diagonal (abs2 .(T (ltd . phynewstd[i]) .* ones (T, length (nnsol[i, :] ))))),
102- nnsol[i, :]
105+ MvNormal ((nnsol[i, :] .- deri_physsol[i, :]) .* quadrature_weights ,
106+ Diagonal (abs2 .(T (phynewstd[i]) .* ones (T, length (t ))))),
107+ zeros ( length (t))
103108 )
104109 end
105110 return physlogprob
@@ -109,10 +114,10 @@ end
109114L2 loss loglikelihood(needed for ODE parameter estimation).
110115"""
111116@views function L2LossData (ltd:: LogTargetDensity , θ)
112- (ltd. dataset isa Vector{Nothing} || ltd. extraparams == 0 ) && return 0
117+ (isempty ( ltd. dataset) || ltd. extraparams == 0 ) && return 0
113118
114119 # matrix(each row corresponds to vector u's rows)
115- nn = ltd (ltd. dataset[end ], θ[1 : (length (θ) - ltd. extraparams)])
120+ nn = ltd (ltd. dataset[end - 1 ], θ[1 : (length (θ) - ltd. extraparams)])
116121 T = eltype (nn)
117122
118123 L2logprob = zero (T)
@@ -150,24 +155,26 @@ end
150155function getlogpdf (strategy:: GridTraining , ltd:: LogTargetDensity , f, autodiff:: Bool ,
151156 tspan, ode_params, θ)
152157 ts = collect (eltype (strategy. dx), tspan[1 ]: (strategy. dx): tspan[2 ])
153- t = ltd. dataset isa Vector{Nothing} ? ts : vcat (ts, ltd. dataset[end ])
158+ t = isempty ( ltd. dataset) ? ts : vcat (ts, ltd. dataset[end - 1 ])
154159 return sum (innerdiff (ltd, f, autodiff, t, θ, ode_params))
155160end
156161
157162function getlogpdf (strategy:: StochasticTraining , ltd:: LogTargetDensity ,
158163 f, autodiff:: Bool , tspan, ode_params, θ)
159164 T = promote_type (eltype (tspan[1 ]), eltype (tspan[2 ]))
160165 samples = (tspan[2 ] - tspan[1 ]) .* rand (T, strategy. points) .+ tspan[1 ]
161- t = ltd. dataset isa Vector{Nothing} ? samples : vcat (samples, ltd. dataset[end ])
166+ t = isempty ( ltd. dataset) ? samples : vcat (samples, ltd. dataset[end - 1 ])
162167 return sum (innerdiff (ltd, f, autodiff, t, θ, ode_params))
163168end
164169
165170function getlogpdf (strategy:: QuadratureTraining , ltd:: LogTargetDensity , f, autodiff:: Bool ,
166171 tspan, ode_params, θ)
172+ # integrand is shape of NN output
167173 integrand (t:: Number , θ) = innerdiff (ltd, f, autodiff, [t], θ, ode_params)
168174 intprob = IntegralProblem (
169175 integrand, (tspan[1 ], tspan[2 ]), θ; nout = length (ltd. prob. u0))
170176 sol = solve (intprob, QuadGKJL (); strategy. abstol, strategy. reltol)
177+ # sum over losses for all NN outputs
171178 return sum (sol. u)
172179end
173180
@@ -185,7 +192,7 @@ function getlogpdf(strategy::WeightedIntervalTraining, ltd::LogTargetDensity, f,
185192 append! (ts, temp_data)
186193 end
187194
188- t = ltd. dataset isa Vector{Nothing} ? ts : vcat (ts, ltd. dataset[end ])
195+ t = isempty ( ltd. dataset) ? ts : vcat (ts, ltd. dataset[end - 1 ])
189196 return sum (innerdiff (ltd, f, autodiff, t, θ, ode_params))
190197end
191198
@@ -202,23 +209,21 @@ MvNormal likelihood at each `ti` in time `t` for ODE collocation residue with NN
202209
203210 # this is a vector{vector{dx,dy}}(handle case single u(float passed))
204211 if length (out[:, 1 ]) == 1
205- physsol = [f (out[:, i][1 ], ode_params, t[i]) for i in 1 : length (out[ 1 , :] )]
212+ physsol = [f (out[:, i][1 ], ode_params, t[i]) for i in eachindex (t )]
206213 else
207- physsol = [f (out[:, i], ode_params, t[i]) for i in 1 : length (out[ 1 , :] )]
214+ physsol = [f (out[:, i], ode_params, t[i]) for i in eachindex (t )]
208215 end
209216 physsol = reduce (hcat, physsol)
210217
211218 nnsol = ode_dfdx (ltd, t, θ[1 : (length (θ) - ltd. extraparams)], autodiff)
212-
213- vals = nnsol .- physsol
214- T = eltype (vals)
219+ T = eltype (nnsol)
215220
216221 # N dimensional vector if N outputs for NN(each row has logpdf of u[i] where u is vector
217222 # of dependant variables)
218223 return [logpdf (
219- MvNormal (vals [i, :],
220- Diagonal (abs2 .(T (ltd. phystd[i]) .* ones (T, length (vals[i, :] ))))),
221- zeros (T, length (vals[i, :] ))
224+ MvNormal ((nnsol [i, :] .- physsol[i, :]) ,
225+ Diagonal (abs2 .(T (ltd. phystd[i]) .* ones (T, length (t ))))),
226+ zeros (T, length (t ))
222227 ) for i in 1 : length (ltd. prob. u0)]
223228end
224229
@@ -262,9 +267,9 @@ function kernelchoice(Kernel, MCMCkwargs)
262267end
263268
264269"""
265- ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [nothing ],
270+ ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [],
266271 init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0f0,
267- l2std = [0.05], phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0),
272+ l2std = [0.05], phystd = [0.05], phynewstd = (ode_params)-> [0.05], priorsNNw = (0.0, 2.0),
268273 param = [], nchains = 1, autodiff = false, Kernel = HMC,
269274 Adaptorkwargs = (Adaptor = StanHMCAdaptor,
270275 Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
@@ -294,7 +299,7 @@ time = sol.t[1:100]
294299
295300### dataset and BPINN create
296301x̂ = collect(Float64, Array(u) + 0.05 * randn(size(u)))
297- dataset = [x̂, time]
302+ dataset = [x̂, time, 0.05 .* ones(length(time)) ]
298303
299304chain1 = Lux.Chain(Lux.Dense(1, 5, tanh), Lux.Dense(5, 5, tanh), Lux.Dense(5, 1)
300305
@@ -318,26 +323,32 @@ fh_mcmc_chain2, fhsamples2, fhstats2 = ahmc_bayesian_pinn_ode(prob, chain1,
318323
319324## NOTES
320325
321- Dataset is required for accurate Parameter estimation + solving equations
322- Incase you are only solving the Equations for solution, do not provide dataset
326+ Dataset is required for accurate Parameter estimation in Inverse Problems.
327+ Incase you are only solving Non parametric ODE Equations for a solution, do not provide a dataset.
323328
324329## Positional Arguments
325330
326- * `prob`: DEProblem (out of place and the function signature should be f(u,p,t).
331+ * `prob`: ODEProblem (out of place and the function signature should be f(u,p,t).
327332* `chain`: Lux Neural Netork which would be made the Bayesian PINN.
328333
329334## Keyword Arguments
330335
331336* `strategy`: The training strategy used to choose the points for the evaluations. By
332337 default GridTraining is used with given physdt discretization.
338+ * `dataset`: Is either an empty Vector or a nested Vector of the form `[x̂, t, W]` where `x̂` are dependant variable observations, `t` are time points and `W` are quadrature weights for domain.
339+ The dataset is used to compute the L2 loss against the data and also for the new loss function.
340+ For multiple dependant variables, there will be multiple vectors with the last two vectors in dataset still being for `t`, `W`.
341+ Is empty by default assuming a forward problem is being solved.
333342* `init_params`: initial parameter values for BPINN (ideally for multiple chains different
334343 initializations preferred)
335344* `nchains`: number of chains you want to sample
336345* `draw_samples`: number of samples to be drawn in the MCMC algorithms (warmup samples are
337346 ~2/3 of draw samples)
338347* `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset
339348* `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System
340- * `phynewstd`: standard deviation of new loss func term
349+ * `phynewstd`: A function that gives the standard deviation of the new loss function at each iteration.
350+ It takes the ODE parameters as input and returns a vector of standard deviations.
351+ Is (ode_params) -> [0.05] by default.
341352* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of
342353 BPINN are Normal Distributions by default.
343354* `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems.
@@ -357,6 +368,7 @@ Incase you are only solving the Equations for solution, do not provide dataset
357368 * `max_depth`: Maximum doubling tree depth (NUTS)
358369 * `Δ_max`: Maximum divergence during doubling tree (NUTS)
359370 Refer: https://turinglang.org/AdvancedHMC.jl/stable/
371+ * `estim_collocate`: A boolean value to indicate whether to use the new loss function or not. This is only relevant for ODE parameter estimation.
360372* `progress`: controls whether to show the progress meter or not.
361373* `verbose`: controls the verbosity. (Sample call args in AHMC)
362374
@@ -366,9 +378,10 @@ Incase you are only solving the Equations for solution, do not provide dataset
366378 releases.
367379"""
368380function ahmc_bayesian_pinn_ode (
369- prob:: SciMLBase.ODEProblem , chain; strategy = GridTraining, dataset = [nothing ],
381+ prob:: SciMLBase.ODEProblem , chain; strategy = GridTraining, dataset = [],
370382 init_params = nothing , draw_samples = 1000 , physdt = 1 / 20.0 , l2std = [0.05 ],
371- phystd = [0.05 ], phynewstd = [0.05 ], priorsNNw = (0.0 , 2.0 ), param = [], nchains = 1 ,
383+ phystd = [0.05 ], phynewstd = (ode_params) -> [0.05 ],
384+ priorsNNw = (0.0 , 2.0 ), param = [], nchains = 1 ,
372385 autodiff = false , Kernel = HMC,
373386 Adaptorkwargs = (Adaptor = StanHMCAdaptor,
374387 Metric = DiagEuclideanMetric, targetacceptancerate = 0.8 ),
@@ -380,15 +393,15 @@ function ahmc_bayesian_pinn_ode(
380393
381394 strategy = strategy == GridTraining ? strategy (physdt) : strategy
382395
383- if dataset != [ nothing ] &&
384- (length (dataset) < 2 || ! (dataset isa Vector{<: Vector{<:AbstractFloat} }))
385- error (" Invalid dataset. dataset would be timeseries (x̂,t) where type: Vector{Vector{AbstractFloat}" )
396+ if ! isempty (dataset) &&
397+ (length (dataset) < 3 || ! (dataset isa Vector{<: Vector{<:AbstractFloat} }))
398+ error (" Invalid dataset. The dataset would be a timeseries (x̂,t,W) with type: Vector{Vector{AbstractFloat} }" )
386399 end
387400
388- if dataset != [ nothing ] && param == []
389- println (" Dataset is only needed for Parameter Estimation + Forward Problem , not in only Forward Problem case." )
390- elseif dataset == [ nothing ] && param != []
391- error (" Dataset Required for Parameter Estimation." )
401+ if ! isempty ( dataset) && isempty ( param)
402+ println (" Dataset is only needed for Inverse problems performing Parameter Estimation , not in only Forward Problem case." )
403+ elseif isempty ( dataset) && ! isempty (param)
404+ error (" Dataset Required for Inverse problems performing Parameter Estimation." )
392405 end
393406
394407 initial_nnθ, chain, st = generate_ltd (chain, init_params)
@@ -461,7 +474,8 @@ function ahmc_bayesian_pinn_ode(
461474
462475 MCMC_alg = kernelchoice (Kernel, MCMCkwargs)
463476 Kernel = AdvancedHMC. make_kernel (MCMC_alg, integrator)
464- samples, stats = sample (hamiltonian, Kernel, initial_θ, draw_samples, adaptor;
477+ samples,
478+ stats = sample (hamiltonian, Kernel, initial_θ, draw_samples, adaptor;
465479 progress = progress, verbose = verbose)
466480
467481 samplesc[i] = samples
@@ -479,7 +493,8 @@ function ahmc_bayesian_pinn_ode(
479493
480494 MCMC_alg = kernelchoice (Kernel, MCMCkwargs)
481495 Kernel = AdvancedHMC. make_kernel (MCMC_alg, integrator)
482- samples, stats = sample (hamiltonian, Kernel, initial_θ, draw_samples,
496+ samples,
497+ stats = sample (hamiltonian, Kernel, initial_θ, draw_samples,
483498 adaptor; progress = progress, verbose = verbose)
484499
485500 if verbose
0 commit comments