-
-
Notifications
You must be signed in to change notification settings - Fork 228
New loss #937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New loss #937
Changes from 30 commits
29e123e
6bf8c49
45c1e39
83a642d
5cfc6e7
d8f602b
30fa615
85b350c
3c59fde
f2aafdb
70e956d
41691eb
f21969c
a968cc8
c66772a
2fcf75e
cb85411
2423224
cc1348f
4a1ca0e
107165e
a4d1fb7
d47c19c
0bc0ec1
65c4b08
8ed4e18
c242419
95140dd
6cb24c5
f24df29
e8dfd9a
0b7123a
d594ac3
e31d4e2
b681693
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
| dataset <: Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}} | ||
| priors <: Vector{<:Distribution} | ||
| phystd::Vector{Float64} | ||
| phynewstd::Vector{Float64} | ||
| phynewstd::Function | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specialize?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. im not sure how we can specialize functions...
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| l2std::Vector{Float64} | ||
| autodiff::Bool | ||
| physdt::Float64 | ||
|
|
@@ -74,32 +74,37 @@ suggested extra loss function for ODE solver case | |
| """ | ||
| @views function L2loss2(ltd::LogTargetDensity, θ) | ||
| ltd.extraparams ≤ 0 && return false # XXX: type-stability? | ||
|
|
||
| u0 = ltd.prob.u0 | ||
| f = ltd.prob.f | ||
| t = ltd.dataset[end] | ||
| u1 = ltd.dataset[2] | ||
| û = ltd.dataset[1] | ||
| t = ltd.dataset[end - 1] | ||
| û = ltd.dataset[1:(end - 2)] | ||
| quadrature_weights = ltd.dataset[end] | ||
|
|
||
| nnsol = ode_dfdx(ltd, t, θ[1:(length(θ) - ltd.extraparams)], ltd.autodiff) | ||
|
|
||
| ode_params = ltd.extraparams == 1 ? θ[((length(θ) - ltd.extraparams) + 1)] : | ||
| θ[((length(θ) - ltd.extraparams) + 1):length(θ)] | ||
| phynewstd = ltd.phynewstd(ode_params) | ||
|
|
||
| physsol = if length(ltd.prob.u0) == 1 | ||
| [f(û[i], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)] | ||
| physsol = if length(u0) == 1 | ||
| [f(û[1][i], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)] | ||
| else | ||
| [f([û[i], u1[i]], ode_params, tᵢ) for (i, tᵢ) in enumerate(t)] | ||
| [f([û[j][i] for j in eachindex(u0)], ode_params, tᵢ) | ||
| for (i, tᵢ) in enumerate(t)] | ||
| end | ||
| # form of NN output matrix output dim x n | ||
| deri_physsol = reduce(hcat, physsol) | ||
| T = promote_type(eltype(deri_physsol), eltype(nnsol)) | ||
|
|
||
| physlogprob = T(0) | ||
| for i in 1:length(ltd.prob.u0) | ||
| # for BPINNS Quadrature is NOT applied on timewise logpdfs, it isnt being driven to zero. | ||
| # Gridtraining/trapezoidal rule quadrature_weights is dt.*ones(T, length(t)) | ||
| # dims of phynewstd is same as u0 due to BNNODE being an out-of-place ODE solver. | ||
| for i in eachindex(u0) | ||
| physlogprob += logpdf( | ||
| MvNormal(deri_physsol[i, :], | ||
| Diagonal(abs2.(T(ltd.phynewstd[i]) .* ones(T, length(nnsol[i, :]))))), | ||
| nnsol[i, :] | ||
| MvNormal((nnsol[i, :] .- deri_physsol[i, :]) .* quadrature_weights, | ||
| Diagonal(abs2.(T(phynewstd[i]) .* ones(T, length(t))))), | ||
| zeros(length(t)) | ||
| ) | ||
| end | ||
| return physlogprob | ||
|
|
@@ -112,7 +117,7 @@ L2 loss loglikelihood(needed for ODE parameter estimation). | |
| (ltd.dataset isa Vector{Nothing} || ltd.extraparams == 0) && return 0 | ||
|
|
||
| # matrix(each row corresponds to vector u's rows) | ||
| nn = ltd(ltd.dataset[end], θ[1:(length(θ) - ltd.extraparams)]) | ||
| nn = ltd(ltd.dataset[end - 1], θ[1:(length(θ) - ltd.extraparams)]) | ||
| T = eltype(nn) | ||
|
|
||
| L2logprob = zero(T) | ||
|
|
@@ -150,24 +155,26 @@ end | |
| function getlogpdf(strategy::GridTraining, ltd::LogTargetDensity, f, autodiff::Bool, | ||
| tspan, ode_params, θ) | ||
| ts = collect(eltype(strategy.dx), tspan[1]:(strategy.dx):tspan[2]) | ||
| t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end]) | ||
| t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end - 1]) | ||
| return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params)) | ||
| end | ||
|
|
||
| function getlogpdf(strategy::StochasticTraining, ltd::LogTargetDensity, | ||
| f, autodiff::Bool, tspan, ode_params, θ) | ||
| T = promote_type(eltype(tspan[1]), eltype(tspan[2])) | ||
| samples = (tspan[2] - tspan[1]) .* rand(T, strategy.points) .+ tspan[1] | ||
| t = ltd.dataset isa Vector{Nothing} ? samples : vcat(samples, ltd.dataset[end]) | ||
| t = ltd.dataset isa Vector{Nothing} ? samples : vcat(samples, ltd.dataset[end - 1]) | ||
| return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params)) | ||
| end | ||
|
|
||
| function getlogpdf(strategy::QuadratureTraining, ltd::LogTargetDensity, f, autodiff::Bool, | ||
| tspan, ode_params, θ) | ||
| # integrand is shape of NN output | ||
| integrand(t::Number, θ) = innerdiff(ltd, f, autodiff, [t], θ, ode_params) | ||
| intprob = IntegralProblem( | ||
| integrand, (tspan[1], tspan[2]), θ; nout = length(ltd.prob.u0)) | ||
| sol = solve(intprob, QuadGKJL(); strategy.abstol, strategy.reltol) | ||
| # sum over losses for all NN outputs | ||
| return sum(sol.u) | ||
| end | ||
|
|
||
|
|
@@ -185,7 +192,7 @@ function getlogpdf(strategy::WeightedIntervalTraining, ltd::LogTargetDensity, f, | |
| append!(ts, temp_data) | ||
| end | ||
|
|
||
| t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end]) | ||
| t = ltd.dataset isa Vector{Nothing} ? ts : vcat(ts, ltd.dataset[end - 1]) | ||
| return sum(innerdiff(ltd, f, autodiff, t, θ, ode_params)) | ||
| end | ||
|
|
||
|
|
@@ -202,23 +209,21 @@ MvNormal likelihood at each `ti` in time `t` for ODE collocation residue with NN | |
|
|
||
| # this is a vector{vector{dx,dy}}(handle case single u(float passed)) | ||
| if length(out[:, 1]) == 1 | ||
| physsol = [f(out[:, i][1], ode_params, t[i]) for i in 1:length(out[1, :])] | ||
| physsol = [f(out[:, i][1], ode_params, t[i]) for i in eachindex(t)] | ||
| else | ||
| physsol = [f(out[:, i], ode_params, t[i]) for i in 1:length(out[1, :])] | ||
| physsol = [f(out[:, i], ode_params, t[i]) for i in eachindex(t)] | ||
| end | ||
| physsol = reduce(hcat, physsol) | ||
|
|
||
| nnsol = ode_dfdx(ltd, t, θ[1:(length(θ) - ltd.extraparams)], autodiff) | ||
|
|
||
| vals = nnsol .- physsol | ||
| T = eltype(vals) | ||
| T = eltype(nnsol) | ||
|
|
||
| # N dimensional vector if N outputs for NN(each row has logpdf of u[i] where u is vector | ||
| # of dependant variables) | ||
| return [logpdf( | ||
| MvNormal(vals[i, :], | ||
| Diagonal(abs2.(T(ltd.phystd[i]) .* ones(T, length(vals[i, :]))))), | ||
| zeros(T, length(vals[i, :])) | ||
| MvNormal((nnsol[i, :] .- physsol[i, :]), | ||
| Diagonal(abs2.(T(ltd.phystd[i]) .* ones(T, length(t))))), | ||
| zeros(T, length(t)) | ||
| ) for i in 1:length(ltd.prob.u0)] | ||
| end | ||
|
|
||
|
|
@@ -264,7 +269,7 @@ end | |
| """ | ||
| ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [nothing], | ||
| init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0f0, | ||
| l2std = [0.05], phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), | ||
| l2std = [0.05], phystd = [0.05], phynewstd = (ode_params)->[0.05], priorsNNw = (0.0, 2.0), | ||
| param = [], nchains = 1, autodiff = false, Kernel = HMC, | ||
| Adaptorkwargs = (Adaptor = StanHMCAdaptor, | ||
| Metric = DiagEuclideanMetric, targetacceptancerate = 0.8), | ||
|
|
@@ -337,7 +342,7 @@ Incase you are only solving the Equations for solution, do not provide dataset | |
| ~2/3 of draw samples) | ||
| * `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset | ||
| * `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System | ||
| * `phynewstd`: standard deviation of new loss func term | ||
| * `phynewstd`: Function in ode_params that gives the standard deviation of the new loss function terms. | ||
| * `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of | ||
| BPINN are Normal Distributions by default. | ||
| * `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems. | ||
|
|
@@ -368,7 +373,8 @@ Incase you are only solving the Equations for solution, do not provide dataset | |
| function ahmc_bayesian_pinn_ode( | ||
| prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [nothing], | ||
| init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0, l2std = [0.05], | ||
| phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1, | ||
| phystd = [0.05], phynewstd = (ode_params) -> [0.05], | ||
| priorsNNw = (0.0, 2.0), param = [], nchains = 1, | ||
| autodiff = false, Kernel = HMC, | ||
| Adaptorkwargs = (Adaptor = StanHMCAdaptor, | ||
| Metric = DiagEuclideanMetric, targetacceptancerate = 0.8), | ||
|
|
@@ -381,8 +387,8 @@ function ahmc_bayesian_pinn_ode( | |
| strategy = strategy == GridTraining ? strategy(physdt) : strategy | ||
|
|
||
| if dataset != [nothing] && | ||
| (length(dataset) < 2 || !(dataset isa Vector{<:Vector{<:AbstractFloat}})) | ||
| error("Invalid dataset. dataset would be timeseries (x̂,t) where type: Vector{Vector{AbstractFloat}") | ||
| (length(dataset) < 3 || !(dataset isa Vector{<:Vector{<:AbstractFloat}})) | ||
| error("Invalid dataset. dataset would be timeseries (x̂,t,W) where type: Vector{Vector{AbstractFloat}") | ||
| end | ||
|
|
||
| if dataset != [nothing] && param == [] | ||
|
|
@@ -461,7 +467,8 @@ function ahmc_bayesian_pinn_ode( | |
|
|
||
| MCMC_alg = kernelchoice(Kernel, MCMCkwargs) | ||
| Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator) | ||
| samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor; | ||
| samples, | ||
| stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor; | ||
| progress = progress, verbose = verbose) | ||
|
|
||
| samplesc[i] = samples | ||
|
|
@@ -479,7 +486,8 @@ function ahmc_bayesian_pinn_ode( | |
|
|
||
| MCMC_alg = kernelchoice(Kernel, MCMCkwargs) | ||
| Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator) | ||
| samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, | ||
| samples, | ||
| stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, | ||
| adaptor; progress = progress, verbose = verbose) | ||
|
|
||
| if verbose | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.