Skip to content

Better BPINN ode Solver #853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
@@ -100,6 +100,8 @@ struct BNNODE{C, K, IT <: NamedTuple,
init_params::I
Adaptorkwargs::A
Integratorkwargs::IT
numensemble::Int64
estim_collocate::Bool
autodiff::Bool
progress::Bool
verbose::Bool
@@ -112,6 +114,8 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
numensemble = floor(Int, draw_samples / 3),
estim_collocate = false,
autodiff = false, progress = false, verbose = false)
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
@@ -120,6 +124,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
phystd, dataset, physdt, MCMCkwargs,
nchains, init_params,
Adaptorkwargs, Integratorkwargs,
numensemble, estim_collocate,
autodiff, progress, verbose)
end

@@ -186,7 +191,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
@unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy,
draw_samples, dataset, init_params,
nchains, physdt, Adaptorkwargs, Integratorkwargs,
MCMCkwargs, autodiff, progress, verbose = alg
MCMCkwargs, numensemble, estim_collocate, autodiff, progress,
verbose = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
param = param === nothing ? [] : param
@@ -211,7 +217,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
Integratorkwargs = Integratorkwargs,
MCMCkwargs = MCMCkwargs,
progress = progress,
verbose = verbose)
verbose = verbose,
estim_collocate = estim_collocate)

fullsolution = BPINNstats(mcmcchain, samples, statistics)
ninv = length(param)
@@ -220,7 +227,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
if chain isa Lux.AbstractExplicitLayer
θinit, st = Lux.setup(Random.default_rng(), chain)
θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit)
for i in (draw_samples - numensemble):draw_samples]
for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)]

luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
θinit = collect(ComponentArrays.ComponentArray(θinit))
5 changes: 0 additions & 5 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
@@ -69,11 +69,6 @@
# + L2loss2(Tar, θ)
end

# function L2loss2(Tar::PDELogTargetDensity, θ)
# return Tar.full_loglikelihood(setparameters(Tar, θ),
# Tar.allstd)
# end

function setparameters(Tar::PDELogTargetDensity, θ)
names = Tar.names
ps_new = θ[1:(end - Tar.extraparams)]
@@ -361,7 +356,7 @@
# append Ode params to all paramvector - initial_θ
if ninv > 0
# shift ode params(initialise ode params by prior means)
# check if means or user speified is better

Check warning on line 359 in src/PDE_BPINN.jl

GitHub Actions / Spell Check with Typos

"speified" should be "specified".
initial_θ = vcat(initial_θ, [Distributions.params(param[i])[1] for i in 1:ninv])
priors = vcat(priors, param)
nparameters += ninv
89 changes: 76 additions & 13 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
@@ -16,11 +16,12 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
physdt::Float64
extraparams::Int
init_params::I
estim_collocate::Bool

function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::AbstractVector)
init_params::AbstractVector, estim_collocate)
new{
typeof(chain),
Nothing,
@@ -39,12 +40,13 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
autodiff,
physdt,
extraparams,
init_params)
init_params,
estim_collocate)
end
function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::NamedTuple)
init_params::NamedTuple, estim_collocate)
new{
typeof(chain),
typeof(st),
@@ -60,7 +62,8 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
autodiff,
physdt,
extraparams,
init_params)
init_params,
estim_collocate)
end
end

@@ -83,7 +86,12 @@ end
vector_to_parameters(ps_new::AbstractVector, ps::AbstractVector) = ps_new

function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ)
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
if Tar.estim_collocate
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) +
L2loss2(Tar, θ)
else
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
end
end

LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim
@@ -92,6 +100,55 @@ function LogDensityProblems.capabilities(::LogTargetDensity)
LogDensityProblems.LogDensityOrder{1}()
end

"""
suggested extra loss function for ODE solver case
"""
function L2loss2(Tar::LogTargetDensity, θ)
f = Tar.prob.f

# parameter estimation chosen or not
if Tar.extraparams > 0
autodiff = Tar.autodiff
# Timepoints to enforce Physics
t = Tar.dataset[end]
u1 = Tar.dataset[2]
û = Tar.dataset[1]

nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff)

ode_params = Tar.extraparams == 1 ?
θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] :
θ[((length(θ) - Tar.extraparams) + 1):length(θ)]

if length(Tar.prob.u0) == 1
physsol = [f(û[i],
ode_params,
t[i])
for i in 1:length(û[:, 1])]
else
physsol = [f([û[i], u1[i]],
ode_params,
t[i])
for i in 1:length(û)]
end
#form of NN output matrix output dim x n
deri_physsol = reduce(hcat, physsol)

physlogprob = 0
for i in 1:length(Tar.prob.u0)
# can add phystd[i] for u[i]
physlogprob += logpdf(MvNormal(deri_physsol[i, :],
LinearAlgebra.Diagonal(map(abs2,
(Tar.l2std[i] * 4.0) .*
ones(length(nnsol[i, :]))))),
nnsol[i, :])
end
return physlogprob
else
return 0
end
end

"""
L2 loss loglikelihood(needed for ODE parameter estimation).
"""
@@ -247,7 +304,7 @@ function innerdiff(Tar::LogTargetDensity, f, autodiff::Bool, t::AbstractVector,

vals = nnsol .- physsol

# N dimensional vector if N outputs for NN(each row has logpdf of i[i] where u is vector of dependant variables)
# N dimensional vector if N outputs for NN(each row has logpdf of u[i] where u is vector of dependant variables)
return [logpdf(
MvNormal(vals[i, :],
LinearAlgebra.Diagonal(abs2.(Tar.phystd[i] .*
@@ -442,7 +499,8 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false)
progress = false, verbose = false,
estim_collocate = false)
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
# NN parameter prior mean and variance(PriorsNN must be a tuple)
@@ -467,7 +525,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
# Lux-Named Tuple
initial_nnθ, recon, st = generate_Tar(chain, init_params)
else
error("Only Lux.AbstractExplicitLayer neural networks are supported")
error("Only Lux.AbstractExplicitLayer Neural networks are supported")
end

if nchains > Threads.nthreads()
@@ -500,7 +558,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
t0 = prob.tspan[1]
# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = LogTargetDensity(nparameters, prob, recon, st, strategy, dataset, priors,
phystd, l2std, autodiff, physdt, ninv, initial_nnθ)
phystd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)

try
ℓπ(t0, initial_θ[1:(nparameters - ninv)])
@@ -515,6 +573,9 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
@info("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, initial_θ))
@info("Current Prior Log-likelihood : ", priorweights(ℓπ, initial_θ))
@info("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, initial_θ))
if estim_collocate
@info("Current gradient loss against dataset Log-likelihood : ", L2loss2(ℓπ, initial_θ))
end

Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor],
Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate]
@@ -565,12 +626,14 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
@info("Sampling Complete.")
@info("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, samples[end]))
@info("Current Prior Log-likelihood : ", priorweights(ℓπ, samples[end]))
@info("Current MSE against dataset Log-likelihood : ",
L2LossData(ℓπ, samples[end]))
@info("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, samples[end]))
if estim_collocate
@info("Current gradient loss against dataset Log-likelihood : ", L2loss2(ℓπ, samples[end]))
end

# return a chain(basic chain),samples and stats
matrix_samples = hcat(samples...)
mcmc_chain = MCMCChains.Chains(matrix_samples')
matrix_samples = reshape(hcat(samples...), (length(samples[1]), length(samples), 1))
mcmc_chain = MCMCChains.Chains(matrix_samples)
return mcmc_chain, samples, stats
end
end
14 changes: 2 additions & 12 deletions test/BPINN_PDEinvsol_tests.jl
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ using ComponentArrays

Random.seed!(100)

@testset "Example 1: 2D Periodic System with parameter estimation" begin
@testset "Example 1: 1D Periodic System with parameter estimation" begin
# Cos(pi*t) periodic curve
@parameters t, p
@variables u(..)
@@ -59,17 +59,7 @@ Random.seed!(100)
saveats = [1 / 50.0],
param = [LogNormal(6.0, 0.5)])

discretization = BayesianPINN([chainl], QuadratureTraining(), param_estim = true,
dataset = [dataset, nothing])

ahmc_bayesian_pinn_pde(pde_system,
discretization;
draw_samples = 1500,
bcstd = [0.05],
phystd = [0.01], l2std = [0.01],
priorsNNw = (0.0, 1.0),
saveats = [1 / 50.0],
param = [LogNormal(6.0, 0.5)])
# alternative to QuadratureTraining [WIP]

discretization = BayesianPINN([chainl], GridTraining([0.02]), param_estim = true,
dataset = [dataset, nothing])
203 changes: 183 additions & 20 deletions test/BPINN_Tests.jl
Original file line number Diff line number Diff line change
@@ -44,8 +44,8 @@
# testing points
t = time
# Mean of last 500 sampled parameter's curves[Ensemble predictions]
θ = [vector_to_parameters(fhsamples[i], θinit) for i in 2000:2500]
luxar = [chainlux(t', θ[i], st)[1] for i in 1:500]
θ = [vector_to_parameters(fhsamples[i], θinit) for i in 2000:length(fhsamples)]
luxar = [chainlux(t', θ[i], st)[1] for i in eachindex(θ)]
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)]
meanscurve = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean

@@ -54,8 +54,8 @@
@test mean(abs.(physsol1 .- meanscurve)) < 0.005

#--------------------- solve() call
@test mean(abs.(x̂1 .- sol1lux.ensemblesol[1])) < 0.05
@test mean(abs.(physsol0_1 .- sol1lux.ensemblesol[1])) < 0.05
@test mean(abs.(x̂1 .- pmean(sol1lux.ensemblesol[1]))) < 0.025
@test mean(abs.(physsol0_1 .- pmean(sol1lux.ensemblesol[1]))) < 0.025
end

@testset "Example 2 - with parameter estimation" begin
@@ -111,19 +111,20 @@
# testing points
t = time
# Mean of last 500 sampled parameter's curves(flux and lux chains)[Ensemble predictions]
θ = [vector_to_parameters(fhsamples[i][1:(end - 1)], θinit) for i in 2000:2500]
luxar = [chainlux1(t', θ[i], st)[1] for i in 1:500]
θ = [vector_to_parameters(fhsamples[i][1:(end - 1)], θinit)
for i in 2000:length(fhsamples)]
luxar = [chainlux1(t', θ[i], st)[1] for i in eachindex(θ)]
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)]
meanscurve = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean

# --------------------- ahmc_bayesian_pinn_ode() call
@test mean(abs.(physsol1 .- meanscurve)) < 0.15

# ESTIMATED ODE PARAMETERS (NN1 AND NN2)
@test abs(p - mean([fhsamples[i][23] for i in 2000:2500])) < abs(0.35 * p)
@test abs(p - mean([fhsamples[i][23] for i in 2000:length(fhsamples)])) < abs(0.35 * p)

#-------------------------- solve() call
@test mean(abs.(physsol1_1 .- sol2lux.ensemblesol[1])) < 8e-2
@test mean(abs.(physsol1_1 .- pmean(sol2lux.ensemblesol[1]))) < 8e-2

# ESTIMATED ODE PARAMETERS (NN1 AND NN2)
@test abs(p - sol2lux.estimated_de_params[1]) < abs(0.15 * p)
@@ -136,19 +137,16 @@
p = -5.0
prob = ODEProblem(linear, u0, tspan, p)
linear_analytic = (u0, p, t) -> exp(t / p) * (u0 + sin(t))

# SOLUTION AND CREATE DATASET
sol = solve(prob, Tsit5(); saveat = 0.1)
u = sol.u
time = sol.t
= u .+ (u .* 0.2) .* randn(size(u))
dataset = [x̂, time]
t = sol.t
physsol1 = [linear_analytic(prob.u0, p, t[i]) for i in eachindex(t)]
physsol1 = [linear_analytic(prob.u0, p, time[i]) for i in eachindex(time)]

ta0 = range(tspan[1], tspan[2], length = 501)
u1 = [linear_analytic(u0, p, ti) for ti in ta0]
time1 = vec(collect(Float64, ta0))
# seperate set of points for testing the solve() call (it uses saveat 1/50 hence here length 501)

Check warning on line 148 in test/BPINN_Tests.jl

GitHub Actions / Spell Check with Typos

"seperate" should be "separate".
time1 = vec(collect(Float64, range(tspan[1], tspan[2], length = 501)))
physsol2 = [linear_analytic(prob.u0, p, time1[i]) for i in eachindex(time1)]

chainlux12 = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6, 1))
@@ -193,13 +191,15 @@
t = sol.t
#------------------------------ ahmc_bayesian_pinn_ode() call
# Mean of last 500 sampled parameter's curves(lux chains)[Ensemble predictions]
θ = [vector_to_parameters(fhsampleslux12[i], θinit) for i in 1000:1500]
luxar = [chainlux12(t', θ[i], st)[1] for i in 1:500]
θ = [vector_to_parameters(fhsampleslux12[i], θinit)
for i in 1000:length(fhsampleslux12)]
luxar = [chainlux12(t', θ[i], st)[1] for i in eachindex(θ)]
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)]
meanscurve2_1 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean

θ = [vector_to_parameters(fhsampleslux22[i][1:(end - 1)], θinit) for i in 1000:1500]
luxar = [chainlux12(t', θ[i], st)[1] for i in 1:500]
θ = [vector_to_parameters(fhsampleslux22[i][1:(end - 1)], θinit)
for i in 1000:length(fhsampleslux22)]
luxar = [chainlux12(t', θ[i], st)[1] for i in eachindex(θ)]
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)]
meanscurve2_2 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean

@@ -209,12 +209,12 @@
@test mean(abs.(physsol1 .- meanscurve2_2)) < 5e-2

# estimated parameters(lux chain)
param1 = mean(i[62] for i in fhsampleslux22[1000:1500])
param1 = mean(i[62] for i in fhsampleslux22[1000:length(fhsampleslux22)])
@test abs(param1 - p) < abs(0.3 * p)

#-------------------------- solve() call
# (lux chain)
@test mean(abs.(physsol2 .- sol3lux_pestim.ensemblesol[1])) < 0.15
@test mean(abs.(physsol2 .- pmean(sol3lux_pestim.ensemblesol[1]))) < 0.15
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pmean typo?

Copy link
Contributor Author

@AstitvaAggarwal AstitvaAggarwal Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, the mean is required as the solution's standard deviation are different at domain points, sometimes these uncertainties can be large enough for the tests to fail. so i just take the means for testing.

# estimated parameters(lux chain)
param1 = sol3lux_pestim.estimated_de_params[1]
@test abs(param1 - p) < abs(0.45 * p)
@@ -247,3 +247,166 @@
alg = BNNODE(chainflux, draw_samples = 2500)
@test alg.chain isa Lux.AbstractExplicitLayer
end

@testset "Example 3 but with the new objective" begin
linear = (u, p, t) -> u / p + exp(t / p) * cos(t)
tspan = (0.0, 10.0)
u0 = 0.0
p = -5.0
prob = ODEProblem(linear, u0, tspan, p)
linear_analytic = (u0, p, t) -> exp(t / p) * (u0 + sin(t))

# SOLUTION AND CREATE DATASET
sol = solve(prob, Tsit5(); saveat = 0.1)
u = sol.u
time = sol.t
= u .+ (0.3 .* randn(size(u)))
dataset = [x̂, time]
physsol1 = [linear_analytic(prob.u0, p, time[i]) for i in eachindex(time)]

# seperate set of points for testing the solve() call (it uses saveat 1/50 hence here length 501)

Check warning on line 267 in test/BPINN_Tests.jl

GitHub Actions / Spell Check with Typos

"seperate" should be "separate".
time1 = vec(collect(Float64, range(tspan[1], tspan[2], length = 501)))
physsol2 = [linear_analytic(prob.u0, p, time1[i]) for i in eachindex(time1)]

chainlux12 = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6, 1))
θinit, st = Lux.setup(Random.default_rng(), chainlux12)

fh_mcmc_chainlux12, fhsampleslux12, fhstatslux12 = ahmc_bayesian_pinn_ode(
prob, chainlux12,
dataset = dataset,
draw_samples = 1000,
l2std = [0.1],
phystd = [0.03],
priorsNNw = (0.0,
1.0),
param = [
Normal(-7, 3)
])

fh_mcmc_chainlux22, fhsampleslux22, fhstatslux22 = ahmc_bayesian_pinn_ode(
prob, chainlux12,
dataset = dataset,
draw_samples = 1000,
l2std = [0.1],
phystd = [0.03],
priorsNNw = (0.0,
1.0),
param = [
Normal(-7, 3)
], estim_collocate = true)

alg = BNNODE(chainlux12,
dataset = dataset,
draw_samples = 1000,
l2std = [0.1],
phystd = [0.03],
priorsNNw = (0.0,
1.0),
param = [
Normal(-7, 3)
], estim_collocate = true)

sol3lux_pestim = solve(prob, alg)

# testing timepoints
t = sol.t
#------------------------------ ahmc_bayesian_pinn_ode() call
# Mean of last 500 sampled parameter's curves(lux chains)[Ensemble predictions]
θ = [vector_to_parameters(fhsampleslux12[i][1:(end - 1)], θinit)
for i in 750:length(fhsampleslux12)]
luxar = [chainlux12(t', θ[i], st)[1] for i in eachindex(θ)]
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)]
meanscurve2_1 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean

θ = [vector_to_parameters(fhsampleslux22[i][1:(end - 1)], θinit)
for i in 750:length(fhsampleslux22)]
luxar = [chainlux12(t', θ[i], st)[1] for i in eachindex(θ)]
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)]
meanscurve2_2 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean

@test mean(abs.(sol.u .- meanscurve2_2)) < 6e-2
@test mean(abs.(physsol1 .- meanscurve2_2)) < 6e-2
@test mean(abs.(sol.u .- meanscurve2_1)) > mean(abs.(sol.u .- meanscurve2_2))
@test mean(abs.(physsol1 .- meanscurve2_1)) > mean(abs.(physsol1 .- meanscurve2_2))

# estimated parameters(lux chain)
param2 = mean(i[62] for i in fhsampleslux22[750:length(fhsampleslux22)])
@test abs(param2 - p) < abs(0.25 * p)

param1 = mean(i[62] for i in fhsampleslux12[750:length(fhsampleslux12)])
@test abs(param1 - p) < abs(0.75 * p)
@test abs(param2 - p) < abs(param1 - p)

#-------------------------- solve() call
# (lux chain)
@test mean(abs.(physsol2 .- pmean(sol3lux_pestim.ensemblesol[1]))) < 0.1
# estimated parameters(lux chain)
param3 = sol3lux_pestim.estimated_de_params[1]
@test abs(param3 - p) < abs(0.2 * p)
end

@testset "Example 4 - improvement" begin
function lotka_volterra(u, p, t)
# Model parameters.
α, β, γ, δ = p
# Current state.
x, y = u

# Evaluate differential equations.
dx =- β * y) * x # prey
dy =* x - γ) * y # predator

return [dx, dy]
end

# initial-value problem.
u0 = [1.0, 1.0]
p = [1.5, 1.0, 3.0, 1.0]
tspan = (0.0, 4.0)
prob = ODEProblem(lotka_volterra, u0, tspan, p)

# Solve using OrdinaryDiffEq.jl solver
dt = 0.2
solution = solve(prob, Tsit5(); saveat = dt)

times = solution.t
u = hcat(solution.u...)
x = u[1, :] + (0.8 .* randn(length(u[1, :])))
y = u[2, :] + (0.8 .* randn(length(u[2, :])))
dataset = [x, y, times]

chain = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh),
Lux.Dense(6, 2))

alg1 = BNNODE(chain;
dataset = dataset,
draw_samples = 1000,
l2std = [0.2, 0.2],
phystd = [0.1, 0.1],
priorsNNw = (0.0, 1.0),
param = [
Normal(2, 0.5),
Normal(2, 0.5),
Normal(2, 0.5),
Normal(2, 0.5)])

alg2 = BNNODE(chain;
dataset = dataset,
draw_samples = 1000,
l2std = [0.2, 0.2],
phystd = [0.1, 0.1],
priorsNNw = (0.0, 1.0),
param = [
Normal(2, 0.5),
Normal(2, 0.5),
Normal(2, 0.5),
Normal(2, 0.5)], estim_collocate = true)

@time sol_pestim1 = solve(prob, alg1; saveat = dt)
@time sol_pestim2 = solve(prob, alg2; saveat = dt)

unsafe_comparisons(true)
bitvec = abs.(p .- sol_pestim1.estimated_de_params) .>
abs.(p .- sol_pestim2.estimated_de_params)
@test bitvec == ones(size(bitvec))
end