@@ -101,19 +101,25 @@ Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, El
101101 verbose:: Bool
102102end
103103
104- function BNNODE(chain, kernel = HMC; strategy = nothing , draw_samples = 1000 ,
104+ function BNNODE(
105+ chain, kernel = HMC; strategy = nothing , draw_samples = 1000 ,
105106 priorsNNw = (0.0 , 2.0 ), param = nothing , l2std = [0.05 ], phystd = [0.05 ],
106107 phynewstd = (ode_params) -> [0.05 ], dataset = [], physdt = 1 / 20.0 ,
107108 MCMCkwargs = (n_leapfrog = 30 ,), nchains = 1 , init_params = nothing ,
108- Adaptorkwargs = (Adaptor = StanHMCAdaptor,
109- Metric = DiagEuclideanMetric, targetacceptancerate = 0.8 ),
109+ Adaptorkwargs = (
110+ Adaptor = StanHMCAdaptor,
111+ Metric = DiagEuclideanMetric, targetacceptancerate = 0.8 ,
112+ ),
110113 Integratorkwargs = (Integrator = Leapfrog,),
111114 numensemble = floor(Int, draw_samples / 3 ),
112- estim_collocate = false , autodiff = false , progress = false , verbose = false )
115+ estim_collocate = false , autodiff = false , progress = false , verbose = false
116+ )
113117 chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain))
114- return BNNODE(chain, kernel, strategy, draw_samples, priorsNNw, param, l2std, phystd,
118+ return BNNODE(
119+ chain, kernel, strategy, draw_samples, priorsNNw, param, l2std, phystd,
115120 phynewstd, dataset, physdt, MCMCkwargs, nchains, init_params, Adaptorkwargs,
116- Integratorkwargs, numensemble, estim_collocate, autodiff, progress, verbose)
121+ Integratorkwargs, numensemble, estim_collocate, autodiff, progress, verbose
122+ )
117123end
118124
119125"""
@@ -155,33 +161,38 @@ contains fields related to that).
155161 timepoints
156162end
157163
158- function SciMLBase. __solve(prob:: SciMLBase.ODEProblem , alg:: BNNODE , args... ; dt = nothing ,
164+ function SciMLBase. __solve(
165+ prob:: SciMLBase.ODEProblem , alg:: BNNODE , args... ; dt = nothing ,
159166 timeseries_errors = true , save_everystep = true , adaptive = false ,
160167 abstol = 1.0f-6 , reltol = 1.0f-3 , verbose = false , saveat = 1 / 50.0 ,
161- maxiters = nothing )
168+ maxiters = nothing
169+ )
162170 (; chain, param, strategy, draw_samples, numensemble, verbose) = alg
163171
164172 # ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
165173 param = param === nothing ? [] : param
166174 strategy = strategy === nothing ? GridTraining : strategy
167175
168- @assert alg. draw_samples≥ 0 " Number of samples to be drawn has to be >=0."
176+ @assert alg. draw_samples ≥ 0 " Number of samples to be drawn has to be >=0."
169177
170178 mcmcchain, samples,
171- statistics = ahmc_bayesian_pinn_ode(
179+ statistics = ahmc_bayesian_pinn_ode(
172180 prob, chain; strategy, alg. dataset, alg. draw_samples, alg. init_params,
173181 alg. physdt, alg. l2std, alg. phystd, alg. phynewstd,
174182 alg. priorsNNw, param, alg. nchains, alg. autodiff,
175183 Kernel = alg. kernel, alg. Adaptorkwargs, alg. Integratorkwargs,
176- alg. MCMCkwargs, alg. progress, alg. verbose, alg. estim_collocate)
184+ alg. MCMCkwargs, alg. progress, alg. verbose, alg. estim_collocate
185+ )
177186
178187 fullsolution = BPINNstats(mcmcchain, samples, statistics)
179188 ninv = length(param)
180189 t = collect(eltype(saveat), prob. tspan[1 ]: saveat: prob. tspan[2 ])
181190
182191 θinit, st = LuxCore. setup(Random. default_rng(), chain)
183- θ = [vector_to_parameters(samples[i][1 : (end - ninv)], θinit)
184- for i in (draw_samples - numensemble): draw_samples]
192+ θ = [
193+ vector_to_parameters(samples[i][1 : (end - ninv)], θinit)
194+ for i in (draw_samples - numensemble): draw_samples
195+ ]
185196
186197 luxar = [chain(t' , θ[i], st)[1] for i in 1:numensemble]
187198 # only need for size
@@ -205,27 +216,31 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt
205216 for r in 1:numoutput
206217 ensem_r = hcat(output_matrices[r]...)'
207218 ensemblecurve_r = prob. u0[r] .+
208- [Particles(ensem_r[:, i]) for i in 1 : length(t)] .*
209- (t .- prob. tspan[1 ])
219+ [Particles(ensem_r[:, i]) for i in 1 : length(t)] .*
220+ (t .- prob. tspan[1 ])
210221 push!(ensemblecurves, ensemblecurve_r)
211222 end
212223
213224 else
214225 ensemblecurve = prob. u0 .+
215- [Particles(reduce(vcat, luxar)[:, i]) for i in 1 : length(t)] .*
216- (t .- prob. tspan[1 ])
226+ [Particles(reduce(vcat, luxar)[:, i]) for i in 1 : length(t)] .*
227+ (t .- prob. tspan[1 ])
217228 push!(ensemblecurves, ensemblecurve)
218229 end
219230
220231 nnparams = length(θinit)
221- estimnnparams = [Particles(reduce(hcat, samples[(end - numensemble): end ])[i, :])
222- for i in 1 : nnparams]
232+ estimnnparams = [
233+ Particles(reduce(hcat, samples[(end - numensemble): end ])[i, :])
234+ for i in 1 : nnparams
235+ ]
223236
224237 if ninv == 0
225238 estimated_params = [nothing ]
226239 else
227- estimated_params = [Particles(reduce(hcat, samples[(end - numensemble): end ])[i, :])
228- for i in (nnparams + 1 ): (nnparams + ninv)]
240+ estimated_params = [
241+ Particles(reduce(hcat, samples[(end - numensemble): end ])[i, :])
242+ for i in (nnparams + 1 ): (nnparams + ninv)
243+ ]
229244 end
230245
231246 return BPINNsolution(fullsolution, ensemblecurves, estimnnparams, estimated_params, t)
0 commit comments