From b49cf3e8cf2162706824735f0662559d6f838d55 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 14 Mar 2023 19:13:41 +0000
Subject: [PATCH 001/206] refactor ADVI, change gradient operation interface

---
 Project.toml           |   1 +
 src/AdvancedVI.jl      | 181 ++++++++++++++---------------------------
 src/advi.jl            |  47 -----------
 src/estimators/advi.jl |  29 +++++++
 src/utils.jl           |  15 ++++
 5 files changed, 107 insertions(+), 166 deletions(-)
 create mode 100644 src/estimators/advi.jl

diff --git a/Project.toml b/Project.toml
index 28adc66a5..71a2cbdc3 100644
--- a/Project.toml
+++ b/Project.toml
@@ -9,6 +9,7 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
 DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5"
 ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index e203a13ca..d42683d09 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -33,20 +33,12 @@ function __init__()
         export ZygoteAD
 
         function AdvancedVI.grad!(
-            vo,
-            alg::VariationalInference{<:AdvancedVI.ZygoteAD},
-            q,
-            model,
-            θ::AbstractVector{<:Real},
+            f::Function,
+            ::Type{<:ZygoteAD},
+            λ::AbstractVector{<:Real},
             out::DiffResults.MutableDiffResult,
-            args...
         )
-            f(θ) = if (q isa Distribution)
-                - vo(alg, update(q, θ), model, args...)
-            else
-                - vo(alg, q(θ), model, args...)
-            end
-            y, back = Zygote.pullback(f, θ)
+            y, back = Zygote.pullback(f, λ)
             dy = first(back(1.0))
             DiffResults.value!(out, y)
             DiffResults.gradient!(out, dy)
@@ -58,21 +50,13 @@ function __init__()
         export ReverseDiffAD
 
         function AdvancedVI.grad!(
-            vo,
-            alg::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}},
-            q,
-            model,
-            θ::AbstractVector{<:Real},
+            f::Function,
+            ::Type{<:ReverseDiffAD},
+            λ::AbstractVector{<:Real},
             out::DiffResults.MutableDiffResult,
-            args...
         )
-            f(θ) = if (q isa Distribution)
-                - vo(alg, update(q, θ), model, args...)
-            else
-                - vo(alg, q(θ), model, args...)
-            end
-            tp = AdvancedVI.tape(f, θ)
-            ReverseDiff.gradient!(out, tp, θ)
+            tp = AdvancedVI.tape(f, λ)
+            ReverseDiff.gradient!(out, tp, λ)
             return out
         end
     end
@@ -81,26 +65,18 @@ function __init__()
         export EnzymeAD
 
         function AdvancedVI.grad!(
-            vo,
-            alg::VariationalInference{<:AdvancedVI.EnzymeAD},
-            q,
-            model,
-            θ::AbstractVector{<:Real},
+            f::Function,
+            ::Type{<:EnzymeAD},
+            λ::AbstractVector{<:Real},
             out::DiffResults.MutableDiffResult,
-            args...
         )
-            f(θ) = if (q isa Distribution)
-                - vo(alg, update(q, θ), model, args...)
-            else
-                - vo(alg, q(θ), model, args...)
-            end
             # Use `Enzyme.ReverseWithPrimal` once it is released:
             # https://github.com/EnzymeAD/Enzyme.jl/pull/598
-            y = f(θ)
+            y = f(λ)
             DiffResults.value!(out, y)
             dy = DiffResults.gradient(out)
             fill!(dy, 0)
-            Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy))
+            Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy))
             return out
         end
     end
@@ -109,16 +85,8 @@ end
 export
     vi,
     ADVI,
-    ELBO,
-    elbo,
     TruncatedADAGrad,
-    DecayedADAGrad,
-    VariationalInference
-
-abstract type VariationalInference{AD} end
-
-getchunksize(::Type{<:VariationalInference{AD}}) where AD = getchunksize(AD)
-getADtype(::VariationalInference{AD}) where AD = AD
+    DecayedADAGrad
 
 abstract type VariationalObjective end
 
@@ -126,13 +94,11 @@ const VariationalPosterior = Distribution{Multivariate, Continuous}
 
 
 """
-    grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)
+    grad!(f, λ, out)
 
-Computes the gradients used in `optimize!`. Default implementation is provided for 
+Computes the gradients of the objective f. Default implementation is provided for 
 `VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`.
 This implicitly also gives a default implementation of `optimize!`.
-
-Variance reduction techniques, e.g. control variates, should be implemented in this function.
 """
 function grad! end
 
@@ -157,51 +123,36 @@ function update end
 
 # default implementations
 function grad!(
-    vo,
-    alg::VariationalInference{<:ForwardDiffAD},
-    q,
-    model,
-    θ::AbstractVector{<:Real},
-    out::DiffResults.MutableDiffResult,
-    args...
+    f::Function,
+    adtype::Type{<:ForwardDiffAD},
+    λ::AbstractVector{<:Real},
+    out::DiffResults.MutableDiffResult
 )
-    f(θ_) = if (q isa Distribution)
-        - vo(alg, update(q, θ_), model, args...)
-    else
-        - vo(alg, q(θ_), model, args...)
-    end
-
     # Set chunk size and do ForwardMode.
-    chunk_size = getchunksize(typeof(alg))
+    chunk_size = getchunksize(adtype)
     config = if chunk_size == 0
-        ForwardDiff.GradientConfig(f, θ)
+        ForwardDiff.GradientConfig(f, λ)
     else
-        ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size))
+        ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size))
     end
-    ForwardDiff.gradient!(out, f, θ, config)
+    ForwardDiff.gradient!(out, f, λ, config)
 end
 
 function grad!(
-    vo,
-    alg::VariationalInference{<:TrackerAD},
-    q,
-    model,
-    θ::AbstractVector{<:Real},
-    out::DiffResults.MutableDiffResult,
-    args...
+    f::Function,
+    ::Type{<:TrackerAD},
+    λ::AbstractVector{<:Real},
+    out::DiffResults.MutableDiffResult
 )
-    θ_tracked = Tracker.param(θ)
-    y = if (q isa Distribution)
-        - vo(alg, update(q, θ_tracked), model, args...)
-    else
-        - vo(alg, q(θ_tracked), model, args...)
-    end
+    λ_tracked = Tracker.param(λ)
+    y = f(λ_tracked)
     Tracker.back!(y, 1.0)
 
     DiffResults.value!(out, Tracker.data(y))
-    DiffResults.gradient!(out, Tracker.grad(θ_tracked))
+    DiffResults.gradient!(out, Tracker.grad(λ_tracked))
 end
 
+abstract type AbstractGradientEstimator end
 
 """
     optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())
@@ -210,61 +161,53 @@ Iteratively updates parameters by calling `grad!` and using the given `optimizer
 the steps.
 """
 function optimize!(
-    vo,
-    alg::VariationalInference,
-    q,
-    model,
-    θ::AbstractVector{<:Real};
-    optimizer = TruncatedADAGrad()
+    grad_estimator::AbstractGradientEstimator,
+    rebuild::Function,
+    ℓπ::Function,
+    n_max_iter::Int,
+    λ::AbstractVector{<:Real};
+    optimizer = TruncatedADAGrad(),
+    rng       = Random.GLOBAL_RNG
 )
-    # TODO: should we always assume `samples_per_step` and `max_iters` for all algos?
-    alg_name = alg_str(alg)
-    samples_per_step = alg.samples_per_step
-    max_iters = alg.max_iters
-    
-    num_params = length(θ)
+    obj_name = objective(grad_estimator)
 
     # TODO: really need a better way to warn the user about potentially
     # not using the correct accumulator
-    if (optimizer isa TruncatedADAGrad) && (θ ∉ keys(optimizer.acc))
+    if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc))
         # this message should only occurr once in the optimization process
-        @info "[$alg_name] Should only be seen once: optimizer created for θ" objectid(θ)
+        @info "[$obj_name] Should only be seen once: optimizer created for θ" objectid(λ)
     end
 
-    diff_result = DiffResults.GradientResult(θ)
+    grad_buf = DiffResults.GradientResult(λ)
 
     i = 0
-    prog = if PROGRESS[]
-        ProgressMeter.Progress(max_iters, 1, "[$alg_name] Optimizing...", 0)
-    else
-        0
-    end
+    prog = ProgressMeter.Progress(
+        n_max_iter; desc="[$obj_name] Optimizing...", barlen=0, enabled=PROGRESS[])
 
     # add criterion? A running mean maybe?
-    time_elapsed = @elapsed while (i < max_iters) # & converged
-        grad!(vo, alg, q, model, θ, diff_result, samples_per_step)
-
-        # apply update rule
-        Δ = DiffResults.gradient(diff_result)
-        Δ = apply!(optimizer, θ, Δ)
-        @. θ = θ - Δ
+    time_elapsed = @elapsed begin
+        for i = 1:n_max_iter
+            stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, ℓπ, grad_buf)
+            
+            # apply update rule
+            Δλ = DiffResults.gradient(grad_buf)
+            Δλ = apply!(optimizer, λ, Δλ)
+            @. λ = λ - Δλ
+
+            stat′ = (Δλ=norm(Δλ),)
+            stats = merge(stats, stat′)
         
-        AdvancedVI.DEBUG && @debug "Step $i" Δ DiffResults.value(diff_result)
-        PROGRESS[] && (ProgressMeter.next!(prog))
-
-        i += 1
+            AdvancedVI.DEBUG && @debug "Step $i" stats...
+            pm_next!(prog, stats)
+        end
     end
-
-    return θ
+    return λ
 end
 
 # objectives
-include("objectives.jl")
+include("estimators/advi.jl")
 
 # optimisers
 include("optimisers.jl")
 
-# VI algorithms
-include("advi.jl")
-
 end # module
diff --git a/src/advi.jl b/src/advi.jl
index 7f9e73460..be9823db4 100644
--- a/src/advi.jl
+++ b/src/advi.jl
@@ -50,50 +50,3 @@ function optimize(elbo::ELBO, alg::ADVI, q, model, θ_init; optimizer = Truncate
     return θ
 end
 
-# WITHOUT updating parameters inside ELBO
-function (elbo::ELBO)(
-    rng::Random.AbstractRNG,
-    alg::ADVI,
-    q::VariationalPosterior,
-    logπ::Function,
-    num_samples
-)
-    #   𝔼_q(z)[log p(xᵢ, z)]
-    # = ∫ log p(xᵢ, z) q(z) dz
-    # = ∫ log p(xᵢ, f(ϕ)) q(f(ϕ)) |det J_f(ϕ)| dϕ   (since change of variables)
-    # = ∫ log p(xᵢ, f(ϕ)) q̃(ϕ) dϕ                   (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ))
-    # = 𝔼_q̃(ϕ)[log p(xᵢ, z)]
-
-    #   𝔼_q(z)[log q(z)]
-    # = ∫ q(f(ϕ)) log (q(f(ϕ))) |det J_f(ϕ)| dϕ     (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ))
-    # = 𝔼_q̃(ϕ) [log q(f(ϕ))]
-    # = 𝔼_q̃(ϕ) [log q̃(ϕ) - log |det J_f(ϕ)|]
-    # = 𝔼_q̃(ϕ) [log q̃(ϕ)] - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|]
-    # = - ℍ(q̃(ϕ)) - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|]
-
-    # Finally, the ELBO is given by
-    # ELBO = 𝔼_q(z)[log p(xᵢ, z)] - 𝔼_q(z)[log q(z)]
-    #      = 𝔼_q̃(ϕ)[log p(xᵢ, z)] + 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] + ℍ(q̃(ϕ))
-
-    # If f: supp(p(z | x)) → ℝ then
-    # ELBO = 𝔼[log p(x, z) - log q(z)]
-    #      = 𝔼[log p(x, f⁻¹(z̃)) + logabsdet(J(f⁻¹(z̃)))] + ℍ(q̃(z̃))
-    #      = 𝔼[log p(x, z) - logabsdetjac(J(f(z)))] + ℍ(q̃(z̃))
-
-    # But our `rand_and_logjac(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac`
-    z, logjac = rand_and_logjac(rng, q)
-    res = (logπ(z) + logjac) / num_samples
-
-    if q isa TransformedDistribution
-        res += entropy(q.dist)
-    else
-        res += entropy(q)
-    end
-    
-    for i = 2:num_samples
-        z, logjac = rand_and_logjac(rng, q)
-        res += (logπ(z) + logjac) / num_samples
-    end
-
-    return res
-end
diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl
new file mode 100644
index 000000000..c5a83957b
--- /dev/null
+++ b/src/estimators/advi.jl
@@ -0,0 +1,29 @@
+
+struct ADVI <: AbstractGradientEstimator
+    n_samples::Int
+end
+
+objective(::ADVI) = "ELBO"
+
+function estimate_gradient!(
+    rng::Random.AbstractRNG,
+    estimator::ADVI,
+    λ::Vector{<:Real},
+    rebuild::Function,
+    logπ::Function,
+    out::DiffResults.MutableDiffResult)
+
+    n_samples = estimator.n_samples
+
+    grad!(ADBackend(), λ, out) do λ′
+        q = rebuild(λ′)
+        zs, ∑logjac = rand_and_logjac(rng, q, estimator.n_samples)
+        
+        elbo = mapreduce(+, eachcol(zs)) do zᵢ
+            (logπ(zᵢ) + ∑logjac)
+        end / n_samples
+        -elbo
+    end
+    nelbo = DiffResults.value(out)
+    (elbo=-nelbo,)
+end
diff --git a/src/utils.jl b/src/utils.jl
index bb4c1f18f..87cc0856a 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -13,3 +13,18 @@ function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDis
     y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x)
     return y, logjac
 end
+
+function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution, n_samples::Int)
+    x = rand(rng, dist, n_samples)
+    return x, zero(eltype(x))
+end
+
+function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution, n_samples::Int)
+    x = rand(rng, dist.dist, n_samples)
+    y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x)
+    return y, logjac
+end
+
+function pm_next!(pm, stats::NamedTuple)
+    ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)])
+end

From 88e0b79758c2f207b9d3c7120b469af837049fec Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 14 Mar 2023 19:56:47 +0000
Subject: [PATCH 002/206] remove unused file, remove unused dependency

---
 Project.toml      | 1 -
 src/objectives.jl | 7 -------
 2 files changed, 8 deletions(-)
 delete mode 100644 src/objectives.jl

diff --git a/Project.toml b/Project.toml
index 71a2cbdc3..28adc66a5 100644
--- a/Project.toml
+++ b/Project.toml
@@ -9,7 +9,6 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
 DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
-ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5"
 ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
diff --git a/src/objectives.jl b/src/objectives.jl
deleted file mode 100644
index 5a6b61b0c..000000000
--- a/src/objectives.jl
+++ /dev/null
@@ -1,7 +0,0 @@
-struct ELBO <: VariationalObjective end
-
-function (elbo::ELBO)(alg, q, logπ, num_samples; kwargs...)
-    return elbo(Random.default_rng(), alg, q, logπ, num_samples; kwargs...)
-end
-
-const elbo = ELBO()

From c2fb3f8d08c15b16fa2e84a359b0d9bda3bf45b2 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Wed, 15 Mar 2023 18:53:50 +0000
Subject: [PATCH 003/206] fix ADVI elbo computation more efficiently

---
 src/estimators/advi.jl | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl
index c5a83957b..44f65909a 100644
--- a/src/estimators/advi.jl
+++ b/src/estimators/advi.jl
@@ -17,13 +17,17 @@ function estimate_gradient!(
 
     grad!(ADBackend(), λ, out) do λ′
         q = rebuild(λ′)
-        zs, ∑logjac = rand_and_logjac(rng, q, estimator.n_samples)
-        
-        elbo = mapreduce(+, eachcol(zs)) do zᵢ
-            (logπ(zᵢ) + ∑logjac)
-        end / n_samples
+        zs, ∑logdetjac = rand_and_logjac(rng, q, estimator.n_samples)
+
+        𝔼logπ = mapreduce(+, eachcol(zs)) do zᵢ
+            logπ(zᵢ) / n_samples
+        end
+        𝔼logdetjac = ∑logdetjac/n_samples
+
+        elbo = 𝔼logπ + 𝔼logdetjac 
         -elbo
     end
     nelbo = DiffResults.value(out)
     (elbo=-nelbo,)
 end
+

From 83161fdf7fd18d9f686483da38174148ad305c9f Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Wed, 15 Mar 2023 19:20:51 +0000
Subject: [PATCH 004/206] fix missing entropy regularization term

---
 src/estimators/advi.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl
index 44f65909a..ad45efbbe 100644
--- a/src/estimators/advi.jl
+++ b/src/estimators/advi.jl
@@ -24,7 +24,7 @@ function estimate_gradient!(
         end
         𝔼logdetjac = ∑logdetjac/n_samples
 
-        elbo = 𝔼logπ + 𝔼logdetjac 
+        elbo = 𝔼logπ + 𝔼logdetjac + entropy(q)
         -elbo
     end
     nelbo = DiffResults.value(out)

From efa810687738f4d297ff8b25aaadf28e37ba2080 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 18 Mar 2023 01:04:02 +0000
Subject: [PATCH 005/206] add LogDensityProblem interface

---
 Project.toml           |  1 +
 src/AdvancedVI.jl      |  5 +++--
 src/estimators/advi.jl | 19 ++++++++++++++++---
 3 files changed, 20 insertions(+), 5 deletions(-)

diff --git a/Project.toml b/Project.toml
index 28adc66a5..6ad4b6895 100644
--- a/Project.toml
+++ b/Project.toml
@@ -9,6 +9,7 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
 DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
 ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index d42683d09..e1ac752f1 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -7,6 +7,8 @@ using DocStringExtensions
 
 using ProgressMeter, LinearAlgebra
 
+using LogDensityProblems
+
 using ForwardDiff
 using Tracker
 
@@ -163,7 +165,6 @@ the steps.
 function optimize!(
     grad_estimator::AbstractGradientEstimator,
     rebuild::Function,
-    ℓπ::Function,
     n_max_iter::Int,
     λ::AbstractVector{<:Real};
     optimizer = TruncatedADAGrad(),
@@ -187,7 +188,7 @@ function optimize!(
     # add criterion? A running mean maybe?
     time_elapsed = @elapsed begin
         for i = 1:n_max_iter
-            stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, ℓπ, grad_buf)
+            stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, grad_buf)
             
             # apply update rule
             Δλ = DiffResults.gradient(grad_buf)
diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl
index ad45efbbe..5a8652b6d 100644
--- a/src/estimators/advi.jl
+++ b/src/estimators/advi.jl
@@ -1,8 +1,22 @@
 
-struct ADVI <: AbstractGradientEstimator
+struct ADVI{Tlogπ} <: AbstractGradientEstimator
+    ℓπ::Tlogπ
     n_samples::Int
 end
 
+function ADVI(ℓπ, n_samples; kwargs...)
+    # ADVI requires gradients of log-likelihood
+    cap = LogDensityProblems.capabilities(ℓπ)
+    if cap === nothing
+        throw(
+            ArgumentError(
+                "The log density function does not support the LogDensityProblems.jl interface",
+            ),
+        )
+    end
+    ADVI(Base.Fix1(LogDensityProblems.logdensity, ℓπ), n_samples)
+end
+
 objective(::ADVI) = "ELBO"
 
 function estimate_gradient!(
@@ -10,7 +24,6 @@ function estimate_gradient!(
     estimator::ADVI,
     λ::Vector{<:Real},
     rebuild::Function,
-    logπ::Function,
     out::DiffResults.MutableDiffResult)
 
     n_samples = estimator.n_samples
@@ -20,7 +33,7 @@ function estimate_gradient!(
         zs, ∑logdetjac = rand_and_logjac(rng, q, estimator.n_samples)
 
         𝔼logπ = mapreduce(+, eachcol(zs)) do zᵢ
-            logπ(zᵢ) / n_samples
+            estimator.ℓπ(zᵢ) / n_samples
         end
         𝔼logdetjac = ∑logdetjac/n_samples
 

From 4ae2fbfa832662b5adaa7e3d423cb312cb87b4c9 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 18 Mar 2023 02:22:32 +0000
Subject: [PATCH 006/206] refactor use bijectors directly instead of
 transformed distributions

This is to avoid having to reconstruct transformed distributions all
the time. The direct use of bijectors also avoids going through lots
of abstraction layers that could break.

Instead, transformed distributions could be constructed only once when
returing the VI result.
---
 src/estimators/advi.jl | 43 ++++++++++++++++++++++++++----------------
 src/utils.jl           | 30 -----------------------------
 2 files changed, 27 insertions(+), 46 deletions(-)

diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl
index 5a8652b6d..9784e9248 100644
--- a/src/estimators/advi.jl
+++ b/src/estimators/advi.jl
@@ -1,22 +1,32 @@
 
-struct ADVI{Tlogπ} <: AbstractGradientEstimator
+struct ADVI{Tlogπ, B <: Union{Function, Bijectors.Inverse{<:Bijectors.Bijector}}} <: AbstractGradientEstimator
+    # Automatic differentiation variational inference
+    # 
+    # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017).
+    # Automatic differentiation variational inference.
+    # Journal of machine learning research.
+
     ℓπ::Tlogπ
+    b⁻¹::B
     n_samples::Int
-end
 
-function ADVI(ℓπ, n_samples; kwargs...)
-    # ADVI requires gradients of log-likelihood
-    cap = LogDensityProblems.capabilities(ℓπ)
-    if cap === nothing
-        throw(
-            ArgumentError(
-                "The log density function does not support the LogDensityProblems.jl interface",
-            ),
-        )
+    function ADVI(prob, b⁻¹::B, n_samples; kwargs...) where {B <: Bijectors.Inverse{<:Bijectors.Bijector}}
+        # Could check whether the support of b⁻¹ and ℓπ match
+        cap = LogDensityProblems.capabilities(prob)
+        if cap === nothing
+            throw(
+                ArgumentError(
+                    "The log density function does not support the LogDensityProblems.jl interface",
+                ),
+            )
+        end
+        ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob)
+        new{typeof(ℓπ), typeof(b⁻¹)}(ℓπ, b⁻¹, n_samples)
     end
-    ADVI(Base.Fix1(LogDensityProblems.logdensity, ℓπ), n_samples)
 end
 
+ADVI(prob, n_samples; kwargs...) = ADVI(prob, identity, n_samples; kwargs...)
+
 objective(::ADVI) = "ELBO"
 
 function estimate_gradient!(
@@ -29,18 +39,19 @@ function estimate_gradient!(
     n_samples = estimator.n_samples
 
     grad!(ADBackend(), λ, out) do λ′
-        q = rebuild(λ′)
-        zs, ∑logdetjac = rand_and_logjac(rng, q, estimator.n_samples)
+        q_η = rebuild(λ′)
+        ηs  = rand(rng, q_η, estimator.n_samples)
+
+        zs, ∑logdetjac = Bijectors.with_logabsdet_jacobian(estimator.b⁻¹, ηs)
 
         𝔼logπ = mapreduce(+, eachcol(zs)) do zᵢ
             estimator.ℓπ(zᵢ) / n_samples
         end
         𝔼logdetjac = ∑logdetjac/n_samples
 
-        elbo = 𝔼logπ + 𝔼logdetjac + entropy(q)
+        elbo = 𝔼logπ + 𝔼logdetjac + entropy(q_η)
         -elbo
     end
     nelbo = DiffResults.value(out)
     (elbo=-nelbo,)
 end
-
diff --git a/src/utils.jl b/src/utils.jl
index 87cc0856a..e69de29bb 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -1,30 +0,0 @@
-using Distributions
-
-using Bijectors: Bijectors
-
-
-function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution)
-    x = rand(rng, dist)
-    return x, zero(eltype(x))
-end
-
-function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution)
-    x = rand(rng, dist.dist)
-    y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x)
-    return y, logjac
-end
-
-function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution, n_samples::Int)
-    x = rand(rng, dist, n_samples)
-    return x, zero(eltype(x))
-end
-
-function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution, n_samples::Int)
-    x = rand(rng, dist.dist, n_samples)
-    y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x)
-    return y, logjac
-end
-
-function pm_next!(pm, stats::NamedTuple)
-    ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)])
-end

From 1cadb51a011eeaf0b7d3e05aee7e45494bc2439a Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 8 Jun 2023 00:54:02 +0100
Subject: [PATCH 007/206] fix type restrictions

---
 src/estimators/advi.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl
index 9784e9248..b4b3a9d0a 100644
--- a/src/estimators/advi.jl
+++ b/src/estimators/advi.jl
@@ -1,5 +1,5 @@
 
-struct ADVI{Tlogπ, B <: Union{Function, Bijectors.Inverse{<:Bijectors.Bijector}}} <: AbstractGradientEstimator
+struct ADVI{Tlogπ, B} <: AbstractGradientEstimator
     # Automatic differentiation variational inference
     # 
     # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017).
@@ -33,7 +33,7 @@ function estimate_gradient!(
     rng::Random.AbstractRNG,
     estimator::ADVI,
     λ::Vector{<:Real},
-    rebuild::Function,
+    rebuild,
     out::DiffResults.MutableDiffResult)
 
     n_samples = estimator.n_samples

From 3474e8d2c97032f7a384d3b88cb7cc47bdae12f3 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 8 Jun 2023 00:54:23 +0100
Subject: [PATCH 008/206] remove unused file

---
 src/advi.jl | 52 ----------------------------------------------------
 1 file changed, 52 deletions(-)
 delete mode 100644 src/advi.jl

diff --git a/src/advi.jl b/src/advi.jl
deleted file mode 100644
index be9823db4..000000000
--- a/src/advi.jl
+++ /dev/null
@@ -1,52 +0,0 @@
-using StatsFuns
-using DistributionsAD
-using Bijectors
-using Bijectors: TransformedDistribution
-
-
-"""
-$(TYPEDEF)
-
-Automatic Differentiation Variational Inference (ADVI) with automatic differentiation
-backend `AD`.
-
-# Fields
-
-$(TYPEDFIELDS)
-"""
-struct ADVI{AD} <: VariationalInference{AD}
-    "Number of samples used to estimate the ELBO in each optimization step."
-    samples_per_step::Int
-    "Maximum number of gradient steps."
-    max_iters::Int
-end
-
-function ADVI(samples_per_step::Int=1, max_iters::Int=1000)
-    return ADVI{ADBackend()}(samples_per_step, max_iters)
-end
-
-alg_str(::ADVI) = "ADVI"
-
-function vi(model, alg::ADVI, q, θ_init; optimizer = TruncatedADAGrad())
-    θ = copy(θ_init)
-    optimize!(elbo, alg, q, model, θ; optimizer = optimizer)
-
-    # If `q` is a mean-field approx we use the specialized `update` function
-    if q isa Distribution
-        return update(q, θ)
-    else
-        # Otherwise we assume it's a mapping θ → q
-        return q(θ)
-    end
-end
-
-
-function optimize(elbo::ELBO, alg::ADVI, q, model, θ_init; optimizer = TruncatedADAGrad())
-    θ = copy(θ_init)
-
-    # `model` assumed to be callable z ↦ p(x, z)
-    optimize!(elbo, alg, q, model, θ; optimizer = optimizer)
-
-    return θ
-end
-

From 03a27679f98790f943b784d0f6282035ecdc8abe Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 8 Jun 2023 03:19:03 +0100
Subject: [PATCH 009/206] fix use of with_logabsdet_jacobian

---
 src/estimators/advi.jl | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl
index b4b3a9d0a..701ec1ef1 100644
--- a/src/estimators/advi.jl
+++ b/src/estimators/advi.jl
@@ -10,7 +10,7 @@ struct ADVI{Tlogπ, B} <: AbstractGradientEstimator
     b⁻¹::B
     n_samples::Int
 
-    function ADVI(prob, b⁻¹::B, n_samples; kwargs...) where {B <: Bijectors.Inverse{<:Bijectors.Bijector}}
+    function ADVI(prob, b⁻¹, n_samples; kwargs...)
         # Could check whether the support of b⁻¹ and ℓπ match
         cap = LogDensityProblems.capabilities(prob)
         if cap === nothing
@@ -42,14 +42,12 @@ function estimate_gradient!(
         q_η = rebuild(λ′)
         ηs  = rand(rng, q_η, estimator.n_samples)
 
-        zs, ∑logdetjac = Bijectors.with_logabsdet_jacobian(estimator.b⁻¹, ηs)
-
-        𝔼logπ = mapreduce(+, eachcol(zs)) do zᵢ
-            estimator.ℓπ(zᵢ) / n_samples
+        𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ
+            zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(estimator.b⁻¹, ηᵢ)
+            (estimator.ℓπ(zᵢ) + logdetjacᵢ) / n_samples
         end
-        𝔼logdetjac = ∑logdetjac/n_samples
 
-        elbo = 𝔼logπ + 𝔼logdetjac + entropy(q_η)
+        elbo = 𝔼ℓ + entropy(q_η)
         -elbo
     end
     nelbo = DiffResults.value(out)

From 09c44fb639864167e6548db89b7ad0196d04ddfc Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 8 Jun 2023 03:29:42 +0100
Subject: [PATCH 010/206] restructure project; move the main VI routine to its
 own file

---
 src/AdvancedVI.jl | 60 +++++++-----------------------------------
 src/vi.jl         | 66 +++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 76 insertions(+), 50 deletions(-)
 create mode 100644 src/vi.jl

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index e1ac752f1..d3612cb10 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -12,6 +12,13 @@ using LogDensityProblems
 using ForwardDiff
 using Tracker
 
+using Bijectors: Bijectors
+
+using Distributions
+using DistributionsAD
+
+using StatsFuns
+
 const PROGRESS = Ref(true)
 function turnprogress(switch::Bool)
     @info("[AdvancedVI]: global PROGRESS is set as $switch")
@@ -154,61 +161,14 @@ function grad!(
     DiffResults.gradient!(out, Tracker.grad(λ_tracked))
 end
 
+# estimators
 abstract type AbstractGradientEstimator end
 
-"""
-    optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())
-
-Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute
-the steps.
-"""
-function optimize!(
-    grad_estimator::AbstractGradientEstimator,
-    rebuild::Function,
-    n_max_iter::Int,
-    λ::AbstractVector{<:Real};
-    optimizer = TruncatedADAGrad(),
-    rng       = Random.GLOBAL_RNG
-)
-    obj_name = objective(grad_estimator)
-
-    # TODO: really need a better way to warn the user about potentially
-    # not using the correct accumulator
-    if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc))
-        # this message should only occurr once in the optimization process
-        @info "[$obj_name] Should only be seen once: optimizer created for θ" objectid(λ)
-    end
-
-    grad_buf = DiffResults.GradientResult(λ)
-
-    i = 0
-    prog = ProgressMeter.Progress(
-        n_max_iter; desc="[$obj_name] Optimizing...", barlen=0, enabled=PROGRESS[])
-
-    # add criterion? A running mean maybe?
-    time_elapsed = @elapsed begin
-        for i = 1:n_max_iter
-            stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, grad_buf)
-            
-            # apply update rule
-            Δλ = DiffResults.gradient(grad_buf)
-            Δλ = apply!(optimizer, λ, Δλ)
-            @. λ = λ - Δλ
-
-            stat′ = (Δλ=norm(Δλ),)
-            stats = merge(stats, stat′)
-        
-            AdvancedVI.DEBUG && @debug "Step $i" stats...
-            pm_next!(prog, stats)
-        end
-    end
-    return λ
-end
-
-# objectives
 include("estimators/advi.jl")
 
 # optimisers
 include("optimisers.jl")
 
+include("vi.jl")
+
 end # module
diff --git a/src/vi.jl b/src/vi.jl
new file mode 100644
index 000000000..aceb3f2dd
--- /dev/null
+++ b/src/vi.jl
@@ -0,0 +1,66 @@
+
+function pm_next!(pm, stats::NamedTuple)
+    ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)])
+end
+
+"""
+    optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())
+
+Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute
+the steps.
+"""
+function optimize(
+    grad_estimator::AbstractGradientEstimator,
+    rebuild::Function,
+    n_max_iter::Int,
+    λ::AbstractVector{<:Real};
+    optimizer = TruncatedADAGrad(),
+    rng       = Random.GLOBAL_RNG
+)
+    obj_name = objective(grad_estimator)
+
+    # TODO: really need a better way to warn the user about potentially
+    # not using the correct accumulator
+    if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc))
+        # this message should only occurr once in the optimization process
+        @info "[$obj_name] Should only be seen once: optimizer created for θ" objectid(λ)
+    end
+
+    grad_buf = DiffResults.GradientResult(λ)
+
+    i = 0
+    prog = ProgressMeter.Progress(
+        n_max_iter; desc="[$obj_name] Optimizing...", barlen=0, enabled=PROGRESS[])
+
+    # add criterion? A running mean maybe?
+    time_elapsed = @elapsed begin
+        for i = 1:n_max_iter
+            stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, grad_buf)
+            
+            # apply update rule
+            Δλ = DiffResults.gradient(grad_buf)
+            Δλ = apply!(optimizer, λ, Δλ)
+            @. λ = λ - Δλ
+
+            stat′ = (Δλ=norm(Δλ),)
+            stats = merge(stats, stat′)
+        
+            AdvancedVI.DEBUG && @debug "Step $i" stats...
+            pm_next!(prog, stats)
+        end
+    end
+    return λ
+end
+
+# function vi(grad_estimator, q, θ_init; optimizer = TruncatedADAGrad(), rng = Random.GLOBAL_RNG)
+#     θ = copy(θ_init)
+#     optimize!(grad_estimator, rebuild, n_max_iter, λ, optimizer = optimizer, rng = rng)
+
+#     # If `q` is a mean-field approx we use the specialized `update` function
+#     if q isa Distribution
+#         return update(q, θ)
+#     else
+#         # Otherwise we assume it's a mapping θ → q
+#         return q(θ)
+#     end
+# end

From b7407ceecd7f6c8e3fc7a4c443995347fd4659f5 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 8 Jun 2023 03:31:35 +0100
Subject: [PATCH 011/206] remove redundant import

---
 src/AdvancedVI.jl | 2 --
 1 file changed, 2 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index d3612cb10..32b114bad 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -12,8 +12,6 @@ using LogDensityProblems
 using ForwardDiff
 using Tracker
 
-using Bijectors: Bijectors
-
 using Distributions
 using DistributionsAD
 

From 40401494ef032b1c9623856ed668373b251aaccb Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 9 Jun 2023 00:51:56 +0100
Subject: [PATCH 012/206] restructure project into more modular objective
 estimators

---
 src/AdvancedVI.jl                  |  8 ++---
 src/estimators/advi.jl             | 55 ------------------------------
 src/objectives/elbo/advi_energy.jl | 35 +++++++++++++++++++
 src/objectives/elbo/elbo.jl        | 44 ++++++++++++++++++++++++
 src/objectives/elbo/entropy.jl     | 18 ++++++++++
 src/vi.jl                          | 10 +++---
 6 files changed, 105 insertions(+), 65 deletions(-)
 delete mode 100644 src/estimators/advi.jl
 create mode 100644 src/objectives/elbo/advi_energy.jl
 create mode 100644 src/objectives/elbo/elbo.jl
 create mode 100644 src/objectives/elbo/entropy.jl

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 32b114bad..dfb229300 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -95,8 +95,6 @@ export
     TruncatedADAGrad,
     DecayedADAGrad
 
-abstract type VariationalObjective end
-
 const VariationalPosterior = Distribution{Multivariate, Continuous}
 
 
@@ -160,9 +158,11 @@ function grad!(
 end
 
 # estimators
-abstract type AbstractGradientEstimator end
+abstract type AbstractVariationalObjective end
 
-include("estimators/advi.jl")
+include("objectives/elbo/elbo.jl")
+include("objectives/elbo/advi_energy.jl")
+include("objectives/elbo/entropy.jl")
 
 # optimisers
 include("optimisers.jl")
diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl
deleted file mode 100644
index 701ec1ef1..000000000
--- a/src/estimators/advi.jl
+++ /dev/null
@@ -1,55 +0,0 @@
-
-struct ADVI{Tlogπ, B} <: AbstractGradientEstimator
-    # Automatic differentiation variational inference
-    # 
-    # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017).
-    # Automatic differentiation variational inference.
-    # Journal of machine learning research.
-
-    ℓπ::Tlogπ
-    b⁻¹::B
-    n_samples::Int
-
-    function ADVI(prob, b⁻¹, n_samples; kwargs...)
-        # Could check whether the support of b⁻¹ and ℓπ match
-        cap = LogDensityProblems.capabilities(prob)
-        if cap === nothing
-            throw(
-                ArgumentError(
-                    "The log density function does not support the LogDensityProblems.jl interface",
-                ),
-            )
-        end
-        ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob)
-        new{typeof(ℓπ), typeof(b⁻¹)}(ℓπ, b⁻¹, n_samples)
-    end
-end
-
-ADVI(prob, n_samples; kwargs...) = ADVI(prob, identity, n_samples; kwargs...)
-
-objective(::ADVI) = "ELBO"
-
-function estimate_gradient!(
-    rng::Random.AbstractRNG,
-    estimator::ADVI,
-    λ::Vector{<:Real},
-    rebuild,
-    out::DiffResults.MutableDiffResult)
-
-    n_samples = estimator.n_samples
-
-    grad!(ADBackend(), λ, out) do λ′
-        q_η = rebuild(λ′)
-        ηs  = rand(rng, q_η, estimator.n_samples)
-
-        𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ
-            zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(estimator.b⁻¹, ηᵢ)
-            (estimator.ℓπ(zᵢ) + logdetjacᵢ) / n_samples
-        end
-
-        elbo = 𝔼ℓ + entropy(q_η)
-        -elbo
-    end
-    nelbo = DiffResults.value(out)
-    (elbo=-nelbo,)
-end
diff --git a/src/objectives/elbo/advi_energy.jl b/src/objectives/elbo/advi_energy.jl
new file mode 100644
index 000000000..b27b752e2
--- /dev/null
+++ b/src/objectives/elbo/advi_energy.jl
@@ -0,0 +1,35 @@
+
+struct ADVIEnergy{Tlogπ, B} <: AbstractEnergyEstimator
+    # Automatic differentiation variational inference
+    # 
+    # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017).
+    # Automatic differentiation variational inference.
+    # Journal of machine learning research.
+
+    ℓπ::Tlogπ
+    b⁻¹::B
+
+    function ADVIEnergy(prob, b⁻¹)
+        # Could check whether the support of b⁻¹ and ℓπ match
+        cap = LogDensityProblems.capabilities(prob)
+        if cap === nothing
+            throw(
+                ArgumentError(
+                    "The log density function does not support the LogDensityProblems.jl interface",
+                ),
+            )
+        end
+        ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob)
+        new{typeof(ℓπ), typeof(b⁻¹)}(ℓπ, b⁻¹)
+    end
+end
+
+ADVIEnergy(prob) = ADVIEnergy(prob, identity)
+
+function (energy::ADVIEnergy)(q, ηs::AbstractMatrix)
+    n_samples = size(ηs, 2)
+    mapreduce(+, eachcol(ηs)) do ηᵢ
+        zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(energy.b⁻¹, ηᵢ)
+        (energy.ℓπ(zᵢ) + logdetjacᵢ) / n_samples
+    end
+end
diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl
new file mode 100644
index 000000000..2954ae8e9
--- /dev/null
+++ b/src/objectives/elbo/elbo.jl
@@ -0,0 +1,44 @@
+
+abstract type AbstractEnergyEstimator  end
+abstract type AbstractEntropyEstimator end
+
+struct ELBO{EnergyEst  <: AbstractEnergyEstimator,
+            EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective
+    # Evidence Lower Bound
+    # 
+    # Jordan, Michael I., et al.
+    # "An introduction to variational methods for graphical models."
+    # Machine learning 37 (1999): 183-233.
+
+    energy_estimator::EnergyEst
+    entropy_estimator::EntropyEst
+    n_samples::Int
+end
+
+Base.string(::ELBO) = "ELBO"
+
+function ADVI(ℓπ, b⁻¹, n_samples::Int)
+    ELBO(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples)
+end
+
+function estimate_gradient!(
+    rng::Random.AbstractRNG,
+    objective::ELBO,
+    λ::Vector{<:Real},
+    rebuild,
+    out::DiffResults.MutableDiffResult)
+
+    n_samples = objective.n_samples
+
+    grad!(ADBackend(), λ, out) do λ′
+        q_η = rebuild(λ′)
+        ηs  = rand(rng, q_η, n_samples)
+
+        𝔼ℓ   = objective.energy_estimator(q_η, ηs)
+        ℍ    = objective.entropy_estimator(q_η, ηs)
+        elbo = 𝔼ℓ + ℍ
+        -elbo
+    end
+    nelbo = DiffResults.value(out)
+    (elbo=-nelbo,)
+end
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
new file mode 100644
index 000000000..d7fb70544
--- /dev/null
+++ b/src/objectives/elbo/entropy.jl
@@ -0,0 +1,18 @@
+
+struct ClosedFormEntropy <: AbstractEntropyEstimator
+end
+
+function (::ClosedFormEntropy)(q, ηs::AbstractMatrix)
+    entropy(q)
+end
+
+struct MonteCarloEntropy <: AbstractEntropyEstimator
+end
+
+function (::MonteCarloEntropy)(q, ηs::AbstractMatrix)
+    n_samples = size(ηs, 2)
+    mapreduce(+, eachcol(ηs)) do ηᵢ
+        -logpdf(q, ηᵢ) / n_samples
+    end
+end
+
diff --git a/src/vi.jl b/src/vi.jl
index aceb3f2dd..4bf4595fc 100644
--- a/src/vi.jl
+++ b/src/vi.jl
@@ -10,32 +10,30 @@ Iteratively updates parameters by calling `grad!` and using the given `optimizer
 the steps.
 """
 function optimize(
-    grad_estimator::AbstractGradientEstimator,
+    objective::AbstractVariationalObjective,
     rebuild::Function,
     n_max_iter::Int,
     λ::AbstractVector{<:Real};
     optimizer = TruncatedADAGrad(),
     rng       = Random.GLOBAL_RNG
 )
-    obj_name = objective(grad_estimator)
-
     # TODO: really need a better way to warn the user about potentially
     # not using the correct accumulator
     if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc))
         # this message should only occurr once in the optimization process
-        @info "[$obj_name] Should only be seen once: optimizer created for θ" objectid(λ)
+        @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ)
     end
 
     grad_buf = DiffResults.GradientResult(λ)
 
     i = 0
     prog = ProgressMeter.Progress(
-        n_max_iter; desc="[$obj_name] Optimizing...", barlen=0, enabled=PROGRESS[])
+        n_max_iter; desc="[$(string(objective))] Optimizing...", barlen=0, enabled=PROGRESS[])
 
     # add criterion? A running mean maybe?
     time_elapsed = @elapsed begin
         for i = 1:n_max_iter
-            stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, grad_buf)
+            stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf)
             
             # apply update rule
             Δλ = DiffResults.gradient(grad_buf)

From 2a4514e4ff0ab0459b7ed78dcdee2f61be61c691 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 9 Jun 2023 01:18:02 +0100
Subject: [PATCH 013/206] migrate to AbstractDifferentiation

---
 Project.toml                |   3 +-
 src/AdvancedVI.jl           | 101 ++----------------------------------
 src/objectives/elbo/elbo.jl |  10 ++--
 src/vi.jl                   |   8 ++-
 4 files changed, 13 insertions(+), 109 deletions(-)

diff --git a/Project.toml b/Project.toml
index e73037ecb..6964c135e 100644
--- a/Project.toml
+++ b/Project.toml
@@ -3,6 +3,7 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
 version = "0.2.3"
 
 [deps]
+AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
 Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
@@ -15,7 +16,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
 StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
-Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
 
 [compat]
 Bijectors = "0.11, 0.12"
@@ -27,7 +27,6 @@ ProgressMeter = "1.0.0"
 Requires = "0.5, 1.0"
 StatsBase = "0.32, 0.33, 0.34"
 StatsFuns = "0.8, 0.9, 1"
-Tracker = "0.2.3"
 julia = "1.6"
 
 [extras]
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index dfb229300..809d86c6e 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -9,14 +9,16 @@ using ProgressMeter, LinearAlgebra
 
 using LogDensityProblems
 
-using ForwardDiff
-using Tracker
-
 using Distributions
 using DistributionsAD
 
 using StatsFuns
 
+using ForwardDiff
+import AbstractDifferentiation as AD
+
+value_and_gradient(f, xs...; adbackend) = AD.value_and_gradient(adbackend, f, xs...)
+
 const PROGRESS = Ref(true)
 function turnprogress(switch::Bool)
     @info("[AdvancedVI]: global PROGRESS is set as $switch")
@@ -35,58 +37,6 @@ function __init__()
         Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ)
         Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ)
     end
-    @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
-        include("compat/zygote.jl")
-        export ZygoteAD
-
-        function AdvancedVI.grad!(
-            f::Function,
-            ::Type{<:ZygoteAD},
-            λ::AbstractVector{<:Real},
-            out::DiffResults.MutableDiffResult,
-        )
-            y, back = Zygote.pullback(f, λ)
-            dy = first(back(1.0))
-            DiffResults.value!(out, y)
-            DiffResults.gradient!(out, dy)
-            return out
-        end
-    end
-    @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
-        include("compat/reversediff.jl")
-        export ReverseDiffAD
-
-        function AdvancedVI.grad!(
-            f::Function,
-            ::Type{<:ReverseDiffAD},
-            λ::AbstractVector{<:Real},
-            out::DiffResults.MutableDiffResult,
-        )
-            tp = AdvancedVI.tape(f, λ)
-            ReverseDiff.gradient!(out, tp, λ)
-            return out
-        end
-    end
-    @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
-        include("compat/enzyme.jl")
-        export EnzymeAD
-
-        function AdvancedVI.grad!(
-            f::Function,
-            ::Type{<:EnzymeAD},
-            λ::AbstractVector{<:Real},
-            out::DiffResults.MutableDiffResult,
-        )
-            # Use `Enzyme.ReverseWithPrimal` once it is released:
-            # https://github.com/EnzymeAD/Enzyme.jl/pull/598
-            y = f(λ)
-            DiffResults.value!(out, y)
-            dy = DiffResults.gradient(out)
-            fill!(dy, 0)
-            Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy))
-            return out
-        end
-    end
 end
 
 export
@@ -97,16 +47,6 @@ export
 
 const VariationalPosterior = Distribution{Multivariate, Continuous}
 
-
-"""
-    grad!(f, λ, out)
-
-Computes the gradients of the objective f. Default implementation is provided for 
-`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`.
-This implicitly also gives a default implementation of `optimize!`.
-"""
-function grad! end
-
 """
     vi(model, alg::VariationalInference)
     vi(model, alg::VariationalInference, q::VariationalPosterior)
@@ -126,37 +66,6 @@ function vi end
 
 function update end
 
-# default implementations
-function grad!(
-    f::Function,
-    adtype::Type{<:ForwardDiffAD},
-    λ::AbstractVector{<:Real},
-    out::DiffResults.MutableDiffResult
-)
-    # Set chunk size and do ForwardMode.
-    chunk_size = getchunksize(adtype)
-    config = if chunk_size == 0
-        ForwardDiff.GradientConfig(f, λ)
-    else
-        ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size))
-    end
-    ForwardDiff.gradient!(out, f, λ, config)
-end
-
-function grad!(
-    f::Function,
-    ::Type{<:TrackerAD},
-    λ::AbstractVector{<:Real},
-    out::DiffResults.MutableDiffResult
-)
-    λ_tracked = Tracker.param(λ)
-    y = f(λ_tracked)
-    Tracker.back!(y, 1.0)
-
-    DiffResults.value!(out, Tracker.data(y))
-    DiffResults.gradient!(out, Tracker.grad(λ_tracked))
-end
-
 # estimators
 abstract type AbstractVariationalObjective end
 
diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl
index 2954ae8e9..213cc725f 100644
--- a/src/objectives/elbo/elbo.jl
+++ b/src/objectives/elbo/elbo.jl
@@ -22,15 +22,14 @@ function ADVI(ℓπ, b⁻¹, n_samples::Int)
 end
 
 function estimate_gradient!(
+    adbackend::AD.AbstractBackend,
     rng::Random.AbstractRNG,
     objective::ELBO,
     λ::Vector{<:Real},
-    rebuild,
-    out::DiffResults.MutableDiffResult)
+    rebuild)
 
     n_samples = objective.n_samples
-
-    grad!(ADBackend(), λ, out) do λ′
+    nelbo, grad = value_and_gradient(λ; adbackend) do λ′
         q_η = rebuild(λ′)
         ηs  = rand(rng, q_η, n_samples)
 
@@ -39,6 +38,5 @@ function estimate_gradient!(
         elbo = 𝔼ℓ + ℍ
         -elbo
     end
-    nelbo = DiffResults.value(out)
-    (elbo=-nelbo,)
+    first(grad), (elbo=-nelbo,)
 end
diff --git a/src/vi.jl b/src/vi.jl
index 4bf4595fc..7b7858b84 100644
--- a/src/vi.jl
+++ b/src/vi.jl
@@ -15,7 +15,8 @@ function optimize(
     n_max_iter::Int,
     λ::AbstractVector{<:Real};
     optimizer = TruncatedADAGrad(),
-    rng       = Random.GLOBAL_RNG
+    rng       = Random.default_rng(),
+    adbackend = AD.ForwardDiffBackend()
 )
     # TODO: really need a better way to warn the user about potentially
     # not using the correct accumulator
@@ -24,8 +25,6 @@ function optimize(
         @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ)
     end
 
-    grad_buf = DiffResults.GradientResult(λ)
-
     i = 0
     prog = ProgressMeter.Progress(
         n_max_iter; desc="[$(string(objective))] Optimizing...", barlen=0, enabled=PROGRESS[])
@@ -33,10 +32,9 @@ function optimize(
     # add criterion? A running mean maybe?
     time_elapsed = @elapsed begin
         for i = 1:n_max_iter
-            stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf)
+            Δλ, stats = estimate_gradient!(adbackend, rng, objective, λ, rebuild)
             
             # apply update rule
-            Δλ = DiffResults.gradient(grad_buf)
             Δλ = apply!(optimizer, λ, Δλ)
             @. λ = λ - Δλ
 

From 93a16d8bc6aac9725081ea4c414ffd9343e6e79e Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 10 Jun 2023 00:42:36 +0100
Subject: [PATCH 014/206] add location scale pre-packaged variational family,
 add functors

---
 Project.toml                        |  2 ++
 src/AdvancedVI.jl                   | 19 +++++++++++++----
 src/distributions/location_scale.jl | 33 +++++++++++++++++++++++++++++
 3 files changed, 50 insertions(+), 4 deletions(-)
 create mode 100644 src/distributions/location_scale.jl

diff --git a/Project.toml b/Project.toml
index 6964c135e..88342f192 100644
--- a/Project.toml
+++ b/Project.toml
@@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
 DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
+Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
 ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
@@ -16,6 +17,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
 StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
+Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
 
 [compat]
 Bijectors = "0.11, 0.12"
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 809d86c6e..8c33f74a9 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -2,6 +2,8 @@ module AdvancedVI
 
 using Random: Random
 
+using Functors
+
 using Distributions, DistributionsAD, Bijectors
 using DocStringExtensions
 
@@ -13,8 +15,9 @@ using Distributions
 using DistributionsAD
 
 using StatsFuns
+import StatsBase: entropy
 
-using ForwardDiff
+using ForwardDiff, Tracker
 import AbstractDifferentiation as AD
 
 value_and_gradient(f, xs...; adbackend) = AD.value_and_gradient(adbackend, f, xs...)
@@ -40,13 +43,18 @@ function __init__()
 end
 
 export
-    vi,
+    optimize,
+    ELBO,
     ADVI,
+    ADVIEnergy,
+    ClosedFormEntropy,
+    MonteCarloEntropy,
+    LocationScale,
+    FullRankGaussian,
+    MeanFieldGaussian,
     TruncatedADAGrad,
     DecayedADAGrad
 
-const VariationalPosterior = Distribution{Multivariate, Continuous}
-
 """
     vi(model, alg::VariationalInference)
     vi(model, alg::VariationalInference, q::VariationalPosterior)
@@ -73,6 +81,9 @@ include("objectives/elbo/elbo.jl")
 include("objectives/elbo/advi_energy.jl")
 include("objectives/elbo/entropy.jl")
 
+# Variational Families
+include("distributions/location_scale.jl")
+
 # optimisers
 include("optimisers.jl")
 
diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
new file mode 100644
index 000000000..3aba53c55
--- /dev/null
+++ b/src/distributions/location_scale.jl
@@ -0,0 +1,33 @@
+
+LocationScale(μ::LinearAlgebra.AbstractVector,
+              L::Union{<: LinearAlgebra.AbstractTriangular,
+                       <: LinearAlgebra.Diagonal},
+              q₀::Distributions.ContinuousMultivariateDistribution) =
+                  transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L))
+
+function location_scale_entropy(
+    q₀::Distributions.ContinuousMultivariateDistribution,
+    locscale_bijector::Bijectors.ComposedFunction)
+end
+
+function entropy(q_trans::MultivariateTransformed{
+    <: Distributions.ContinuousMultivariateDistribution,
+    <: Bijectors.ComposedFunction{
+        <: Bijectors.Shift,
+        <: Bijectors.Scale}})
+    q_base = q_trans.dist
+    scale  = q_trans.transform.inner.a
+    entropy(q_base) + first(logabsdet(scale))
+end
+
+function FullRankGaussian(μ::AbstractVector,
+                          L::LinearAlgebra.AbstractTriangular)
+    q₀ = MvNormal(zeros(eltype(μ), length(μ)), one(eltype(μ)))
+    LocationScale(μ, L, q₀)
+end
+
+function MeanFieldGaussian(μ::AbstractVector,
+                           L::LinearAlgebra.Diagonal)
+    q₀ = MvNormal(zeros(eltype(μ), length(μ)), one(eltype(μ)))
+    LocationScale(μ, L, q₀)
+end

From 2b6e9ebed556dd67bb9325a5b04228637e1e03df Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 10 Jun 2023 21:04:19 +0100
Subject: [PATCH 015/206] Revert "migrate to AbstractDifferentiation"

This reverts commit 2a4514e4ff0ab0459b7ed78dcdee2f61be61c691.
---
 Project.toml                |   2 +-
 src/AdvancedVI.jl           | 101 ++++++++++++++++++++++++++++++++++--
 src/objectives/elbo/elbo.jl |  10 ++--
 src/vi.jl                   |   8 +--
 4 files changed, 108 insertions(+), 13 deletions(-)

diff --git a/Project.toml b/Project.toml
index 88342f192..9a3303f58 100644
--- a/Project.toml
+++ b/Project.toml
@@ -3,7 +3,6 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
 version = "0.2.3"
 
 [deps]
-AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
 Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
@@ -29,6 +28,7 @@ ProgressMeter = "1.0.0"
 Requires = "0.5, 1.0"
 StatsBase = "0.32, 0.33, 0.34"
 StatsFuns = "0.8, 0.9, 1"
+Tracker = "0.2.3"
 julia = "1.6"
 
 [extras]
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 8c33f74a9..116bb63c4 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -11,17 +11,15 @@ using ProgressMeter, LinearAlgebra
 
 using LogDensityProblems
 
+using ForwardDiff
+using Tracker
+
 using Distributions
 using DistributionsAD
 
 using StatsFuns
 import StatsBase: entropy
 
-using ForwardDiff, Tracker
-import AbstractDifferentiation as AD
-
-value_and_gradient(f, xs...; adbackend) = AD.value_and_gradient(adbackend, f, xs...)
-
 const PROGRESS = Ref(true)
 function turnprogress(switch::Bool)
     @info("[AdvancedVI]: global PROGRESS is set as $switch")
@@ -40,6 +38,58 @@ function __init__()
         Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ)
         Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ)
     end
+    @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
+        include("compat/zygote.jl")
+        export ZygoteAD
+
+        function AdvancedVI.grad!(
+            f::Function,
+            ::Type{<:ZygoteAD},
+            λ::AbstractVector{<:Real},
+            out::DiffResults.MutableDiffResult,
+        )
+            y, back = Zygote.pullback(f, λ)
+            dy = first(back(1.0))
+            DiffResults.value!(out, y)
+            DiffResults.gradient!(out, dy)
+            return out
+        end
+    end
+    @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
+        include("compat/reversediff.jl")
+        export ReverseDiffAD
+
+        function AdvancedVI.grad!(
+            f::Function,
+            ::Type{<:ReverseDiffAD},
+            λ::AbstractVector{<:Real},
+            out::DiffResults.MutableDiffResult,
+        )
+            tp = AdvancedVI.tape(f, λ)
+            ReverseDiff.gradient!(out, tp, λ)
+            return out
+        end
+    end
+    @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
+        include("compat/enzyme.jl")
+        export EnzymeAD
+
+        function AdvancedVI.grad!(
+            f::Function,
+            ::Type{<:EnzymeAD},
+            λ::AbstractVector{<:Real},
+            out::DiffResults.MutableDiffResult,
+        )
+            # Use `Enzyme.ReverseWithPrimal` once it is released:
+            # https://github.com/EnzymeAD/Enzyme.jl/pull/598
+            y = f(λ)
+            DiffResults.value!(out, y)
+            dy = DiffResults.gradient(out)
+            fill!(dy, 0)
+            Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy))
+            return out
+        end
+    end
 end
 
 export
@@ -55,6 +105,16 @@ export
     TruncatedADAGrad,
     DecayedADAGrad
 
+
+"""
+    grad!(f, λ, out)
+
+Computes the gradients of the objective f. Default implementation is provided for 
+`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`.
+This implicitly also gives a default implementation of `optimize!`.
+"""
+function grad! end
+
 """
     vi(model, alg::VariationalInference)
     vi(model, alg::VariationalInference, q::VariationalPosterior)
@@ -74,6 +134,37 @@ function vi end
 
 function update end
 
+# default implementations
+function grad!(
+    f::Function,
+    adtype::Type{<:ForwardDiffAD},
+    λ::AbstractVector{<:Real},
+    out::DiffResults.MutableDiffResult
+)
+    # Set chunk size and do ForwardMode.
+    chunk_size = getchunksize(adtype)
+    config = if chunk_size == 0
+        ForwardDiff.GradientConfig(f, λ)
+    else
+        ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size))
+    end
+    ForwardDiff.gradient!(out, f, λ, config)
+end
+
+function grad!(
+    f::Function,
+    ::Type{<:TrackerAD},
+    λ::AbstractVector{<:Real},
+    out::DiffResults.MutableDiffResult
+)
+    λ_tracked = Tracker.param(λ)
+    y = f(λ_tracked)
+    Tracker.back!(y, 1.0)
+
+    DiffResults.value!(out, Tracker.data(y))
+    DiffResults.gradient!(out, Tracker.grad(λ_tracked))
+end
+
 # estimators
 abstract type AbstractVariationalObjective end
 
diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl
index 213cc725f..2954ae8e9 100644
--- a/src/objectives/elbo/elbo.jl
+++ b/src/objectives/elbo/elbo.jl
@@ -22,14 +22,15 @@ function ADVI(ℓπ, b⁻¹, n_samples::Int)
 end
 
 function estimate_gradient!(
-    adbackend::AD.AbstractBackend,
     rng::Random.AbstractRNG,
     objective::ELBO,
     λ::Vector{<:Real},
-    rebuild)
+    rebuild,
+    out::DiffResults.MutableDiffResult)
 
     n_samples = objective.n_samples
-    nelbo, grad = value_and_gradient(λ; adbackend) do λ′
+
+    grad!(ADBackend(), λ, out) do λ′
         q_η = rebuild(λ′)
         ηs  = rand(rng, q_η, n_samples)
 
@@ -38,5 +39,6 @@ function estimate_gradient!(
         elbo = 𝔼ℓ + ℍ
         -elbo
     end
-    first(grad), (elbo=-nelbo,)
+    nelbo = DiffResults.value(out)
+    (elbo=-nelbo,)
 end
diff --git a/src/vi.jl b/src/vi.jl
index 7b7858b84..4bf4595fc 100644
--- a/src/vi.jl
+++ b/src/vi.jl
@@ -15,8 +15,7 @@ function optimize(
     n_max_iter::Int,
     λ::AbstractVector{<:Real};
     optimizer = TruncatedADAGrad(),
-    rng       = Random.default_rng(),
-    adbackend = AD.ForwardDiffBackend()
+    rng       = Random.GLOBAL_RNG
 )
     # TODO: really need a better way to warn the user about potentially
     # not using the correct accumulator
@@ -25,6 +24,8 @@ function optimize(
         @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ)
     end
 
+    grad_buf = DiffResults.GradientResult(λ)
+
     i = 0
     prog = ProgressMeter.Progress(
         n_max_iter; desc="[$(string(objective))] Optimizing...", barlen=0, enabled=PROGRESS[])
@@ -32,9 +33,10 @@ function optimize(
     # add criterion? A running mean maybe?
     time_elapsed = @elapsed begin
         for i = 1:n_max_iter
-            Δλ, stats = estimate_gradient!(adbackend, rng, objective, λ, rebuild)
+            stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf)
             
             # apply update rule
+            Δλ = DiffResults.gradient(grad_buf)
             Δλ = apply!(optimizer, λ, Δλ)
             @. λ = λ - Δλ
 

From 1bfec36961c437cf000234bd29504fd49848d676 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 10 Jun 2023 21:41:25 +0100
Subject: [PATCH 016/206] fix use optimized MvNormal specialization, add logpdf
 for Loc.Scale.

---
 Project.toml                        |  2 +
 src/AdvancedVI.jl                   | 23 +++++++-----
 src/distributions/location_scale.jl | 57 +++++++++++++++++++----------
 3 files changed, 53 insertions(+), 29 deletions(-)

diff --git a/Project.toml b/Project.toml
index 9a3303f58..38a5026a8 100644
--- a/Project.toml
+++ b/Project.toml
@@ -7,10 +7,12 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
 DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
+FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
+PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
 ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 116bb63c4..d5a06fcef 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -4,20 +4,23 @@ using Random: Random
 
 using Functors
 
-using Distributions, DistributionsAD, Bijectors
 using DocStringExtensions
 
-using ProgressMeter, LinearAlgebra
+using ProgressMeter
+using LinearAlgebra
+using LinearAlgebra: AbstractTriangular
 
 using LogDensityProblems
 
 using ForwardDiff
 using Tracker
 
-using Distributions
-using DistributionsAD
+using FillArrays
+using PDMats
+using Distributions, DistributionsAD
+using Distributions: ContinuousMultivariateDistribution
+using Bijectors
 
-using StatsFuns
 import StatsBase: entropy
 
 const PROGRESS = Ref(true)
@@ -29,7 +32,6 @@ end
 const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0")))
 
 include("ad.jl")
-include("utils.jl")
 
 using Requires
 function __init__()
@@ -116,9 +118,9 @@ This implicitly also gives a default implementation of `optimize!`.
 function grad! end
 
 """
-    vi(model, alg::VariationalInference)
-    vi(model, alg::VariationalInference, q::VariationalPosterior)
-    vi(model, alg::VariationalInference, getq::Function, θ::AbstractArray)
+    optimize(model, alg::VariationalInference)
+    optimize(model, alg::VariationalInference, q::VariationalPosterior)
+    optimize(model, alg::VariationalInference, getq::Function, θ::AbstractArray)
 
 Constructs the variational posterior from the `model` and performs the optimization
 following the configuration of the given `VariationalInference` instance.
@@ -130,7 +132,7 @@ following the configuration of the given `VariationalInference` instance.
 - `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior`
 - `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior
 """
-function vi end
+function optimize end
 
 function update end
 
@@ -178,6 +180,7 @@ include("distributions/location_scale.jl")
 # optimisers
 include("optimisers.jl")
 
+include("utils.jl")
 include("vi.jl")
 
 end # module
diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index 3aba53c55..365ae15e8 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -1,33 +1,52 @@
 
-LocationScale(μ::LinearAlgebra.AbstractVector,
-              L::Union{<: LinearAlgebra.AbstractTriangular,
-                       <: LinearAlgebra.Diagonal},
-              q₀::Distributions.ContinuousMultivariateDistribution) =
-                  transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L))
+function LocationScale(μ::AbstractVector,
+                       L::Union{<: AbstractTriangular,
+                                <: Diagonal},
+                       q₀::ContinuousMultivariateDistribution)
+    @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2))
+    transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L))
+end
 
 function location_scale_entropy(
-    q₀::Distributions.ContinuousMultivariateDistribution,
+    q₀::ContinuousMultivariateDistribution,
     locscale_bijector::Bijectors.ComposedFunction)
 end
 
-function entropy(q_trans::MultivariateTransformed{
-    <: Distributions.ContinuousMultivariateDistribution,
-    <: Bijectors.ComposedFunction{
-        <: Bijectors.Shift,
-        <: Bijectors.Scale}})
+function entropy(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution,
+                                                  <: Bijectors.ComposedFunction{
+                                                      <: Bijectors.Shift,
+                                                      <: Bijectors.Scale}})
     q_base = q_trans.dist
     scale  = q_trans.transform.inner.a
     entropy(q_base) + first(logabsdet(scale))
 end
 
-function FullRankGaussian(μ::AbstractVector,
-                          L::LinearAlgebra.AbstractTriangular)
-    q₀ = MvNormal(zeros(eltype(μ), length(μ)), one(eltype(μ)))
-    LocationScale(μ, L, q₀)
+function logpdf(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution,
+                                                 <: Bijectors.ComposedFunction{
+                                                     <: Bijectors.Shift,
+                                                     <: Bijectors.Scale}},
+                z::AbstractVector)
+    q_base  = q_trans.dist
+    reparam = q_trans.transform
+    scale   = q_trans.transform.inner.a
+    η       = inverse(reparam)(z)
+    logpdf(q_base, η) - first(logabsdet(scale))
+end
+
+function FullRankGaussian(μ::AbstractVector{T},
+                          L::AbstractTriangular{T,S}) where {T <: Real, S}
+    @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2))
+    n_dims = length(μ)
+    q_base = MvNormal(FillArrays.Zeros{T}(n_dims),
+                      PDMats.ScalMat{T}(n_dims, one(T)))
+    LocationScale(μ, L, q_base)
 end
 
-function MeanFieldGaussian(μ::AbstractVector,
-                           L::LinearAlgebra.Diagonal)
-    q₀ = MvNormal(zeros(eltype(μ), length(μ)), one(eltype(μ)))
-    LocationScale(μ, L, q₀)
+function MeanFieldGaussian(μ::AbstractVector{T},
+                           L::Diagonal{T,V}) where {T <: Real, V}
+    @assert (length(μ) == size(L,1))
+    n_dims = length(μ)
+    q_base = MvNormal(FillArrays.Zeros{T}(n_dims),
+                      PDMats.ScalMat{T}(n_dims, one(T)))
+    LocationScale(μ, L, q_base)
 end

From 1003606283efd6b8cf340e74dced65d8ea72b296 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 10 Jun 2023 21:52:53 +0100
Subject: [PATCH 017/206] remove dead code

---
 src/distributions/location_scale.jl | 5 -----
 1 file changed, 5 deletions(-)

diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index 365ae15e8..1f7bad85c 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -7,11 +7,6 @@ function LocationScale(μ::AbstractVector,
     transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L))
 end
 
-function location_scale_entropy(
-    q₀::ContinuousMultivariateDistribution,
-    locscale_bijector::Bijectors.ComposedFunction)
-end
-
 function entropy(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution,
                                                   <: Bijectors.ComposedFunction{
                                                       <: Bijectors.Shift,

From 60a9987ed259b906da9cdd6e38ed33102497f389 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 10 Jun 2023 21:56:30 +0100
Subject: [PATCH 018/206] fix location-scale logpdf

- Full Monte Carlo ELBO estimation now works. I checked.
---
 src/AdvancedVI.jl                   |  3 ++-
 src/distributions/location_scale.jl | 20 +++++++++++---------
 2 files changed, 13 insertions(+), 10 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index d5a06fcef..9b9d3ab2a 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -21,7 +21,8 @@ using Distributions, DistributionsAD
 using Distributions: ContinuousMultivariateDistribution
 using Bijectors
 
-import StatsBase: entropy
+using StatsBase
+using StatsBase: entropy
 
 const PROGRESS = Ref(true)
 function turnprogress(switch::Bool)
diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index 1f7bad85c..dd9b5f2a3 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -7,20 +7,22 @@ function LocationScale(μ::AbstractVector,
     transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L))
 end
 
-function entropy(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution,
-                                                  <: Bijectors.ComposedFunction{
-                                                      <: Bijectors.Shift,
-                                                      <: Bijectors.Scale}})
+function StatsBase.entropy(
+    q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution,
+                                     <: Bijectors.ComposedFunction{
+                                         <: Bijectors.Shift,
+                                         <: Bijectors.Scale}})
     q_base = q_trans.dist
     scale  = q_trans.transform.inner.a
     entropy(q_base) + first(logabsdet(scale))
 end
 
-function logpdf(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution,
-                                                 <: Bijectors.ComposedFunction{
-                                                     <: Bijectors.Shift,
-                                                     <: Bijectors.Scale}},
-                z::AbstractVector)
+function Distributions.logpdf(
+    q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution,
+                                     <: Bijectors.ComposedFunction{
+                                         <: Bijectors.Shift,
+                                         <: Bijectors.Scale}},
+    z::AbstractVector)
     q_base  = q_trans.dist
     reparam = q_trans.transform
     scale   = q_trans.transform.inner.a

From cd84f02898d7cf82c530f98a91579f0b01935f33 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 10 Jun 2023 22:21:22 +0100
Subject: [PATCH 019/206] add sticking-the-landing (STL) estimator

---
 src/objectives/elbo/elbo.jl    | 36 ++++++++++++++++++++++++----------
 src/objectives/elbo/entropy.jl | 35 ++++++++++++++++++++++++++++-----
 2 files changed, 56 insertions(+), 15 deletions(-)

diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl
index 2954ae8e9..343581d82 100644
--- a/src/objectives/elbo/elbo.jl
+++ b/src/objectives/elbo/elbo.jl
@@ -21,23 +21,39 @@ function ADVI(ℓπ, b⁻¹, n_samples::Int)
     ELBO(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples)
 end
 
+function (elbo::ELBO)(q_η::ContinuousMultivariateDistribution;
+                      rng = Random.default_rng(),
+                      n_samples::Int = elbo.n_samples,
+                      q_η_entropy::ContinuousMultivariateDistribution = q_η)
+    ηs = rand(rng, q_η, n_samples)
+    𝔼ℓ = elbo.energy_estimator(q_η, ηs)
+    ℍ  = elbo.entropy_estimator(q_η_entropy, ηs)
+    𝔼ℓ + ℍ
+end
+
 function estimate_gradient!(
     rng::Random.AbstractRNG,
-    objective::ELBO,
+    elbo::ELBO{EnergyEst, EntropyEst},
     λ::Vector{<:Real},
     rebuild,
-    out::DiffResults.MutableDiffResult)
-
-    n_samples = objective.n_samples
+    out::DiffResults.MutableDiffResult) where {EnergyEst  <: AbstractEnergyEstimator,
+                                               EntropyEst <: AbstractEntropyEstimator}
+
+    # Gradient-stopping for computing the sticking-the-landing control variate
+    q_η_stop = if EntropyEst isa MonteCarloEntropy{true}
+        rebuild(λ)
+    else
+        nothing
+    end
 
     grad!(ADBackend(), λ, out) do λ′
         q_η = rebuild(λ′)
-        ηs  = rand(rng, q_η, n_samples)
-
-        𝔼ℓ   = objective.energy_estimator(q_η, ηs)
-        ℍ    = objective.entropy_estimator(q_η, ηs)
-        elbo = 𝔼ℓ + ℍ
-        -elbo
+        q_η_entropy = if EntropyEst isa MonteCarloEntropy{true}
+            q_η_stop
+        else
+            q_η
+        end
+        -elbo(q_η; rng, n_samples=elbo.n_samples, q_η_entropy)
     end
     nelbo = DiffResults.value(out)
     (elbo=-nelbo,)
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index d7fb70544..8efb7c711 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -1,13 +1,38 @@
 
-struct ClosedFormEntropy <: AbstractEntropyEstimator
-end
+struct ClosedFormEntropy <: AbstractEntropyEstimator end
 
-function (::ClosedFormEntropy)(q, ηs::AbstractMatrix)
+function (::ClosedFormEntropy)(q, ::AbstractMatrix)
     entropy(q)
 end
 
-struct MonteCarloEntropy <: AbstractEntropyEstimator
-end
+struct MonteCarloEntropy{IsStickingTheLanding} <: AbstractEntropyEstimator end
+
+MonteCarloEntropy() = MonteCarloEntropy{false}()
+
+"""
+  Sticking the Landing Control Variate
+
+  # Explanation
+
+  This eatimator forms a control variate of the form of
+ 
+    c(z)  = 𝔼-logq(z) + logq(z) = ℍ[q] - logq(z)
+ 
+   Adding this to the closed-form entropy ELBO estimator yields:
+ 
+     ELBO - c(z) = 𝔼logπ(z) + ℍ[q] - c(z) = 𝔼logπ(z) - logq(z),
+
+   which has the same expectation, but lower variance when π ≈ q,
+   and higher variance when π ≉ q.
+
+   # Reference
+
+   Roeder, Geoffrey, Yuhuai Wu, and David K. Duvenaud.
+   "Sticking the landing: Simple, lower-variance gradient estimators for
+   variational inference."
+   Advances in Neural Information Processing Systems 30 (2017).
+"""
+StickingTheLandingEntropy() = MonteCarloEntropy{true}()
 
 function (::MonteCarloEntropy)(q, ηs::AbstractMatrix)
     n_samples = size(ηs, 2)

From 768641b1979f4e63125780e53f48e21794bbcdd2 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 10 Jun 2023 22:41:50 +0100
Subject: [PATCH 020/206] migrate to Optimisers.jl

---
 Project.toml      |  1 +
 src/AdvancedVI.jl | 11 +++--------
 src/vi.jl         | 27 ++++++++++++++++-----------
 3 files changed, 20 insertions(+), 19 deletions(-)

diff --git a/Project.toml b/Project.toml
index 38a5026a8..ba807698f 100644
--- a/Project.toml
+++ b/Project.toml
@@ -12,6 +12,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
+Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
 PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
 ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 9b9d3ab2a..5a02501b4 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -4,6 +4,8 @@ using Random: Random
 
 using Functors
 
+using Optimisers
+
 using DocStringExtensions
 
 using ProgressMeter
@@ -12,8 +14,7 @@ using LinearAlgebra: AbstractTriangular
 
 using LogDensityProblems
 
-using ForwardDiff
-using Tracker
+using ForwardDiff, Tracker
 
 using FillArrays
 using PDMats
@@ -24,12 +25,6 @@ using Bijectors
 using StatsBase
 using StatsBase: entropy
 
-const PROGRESS = Ref(true)
-function turnprogress(switch::Bool)
-    @info("[AdvancedVI]: global PROGRESS is set as $switch")
-    PROGRESS[] = switch
-end
-
 const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0")))
 
 include("ad.jl")
diff --git a/src/vi.jl b/src/vi.jl
index 4bf4595fc..6c8b26d1a 100644
--- a/src/vi.jl
+++ b/src/vi.jl
@@ -10,12 +10,13 @@ Iteratively updates parameters by calling `grad!` and using the given `optimizer
 the steps.
 """
 function optimize(
-    objective::AbstractVariationalObjective,
-    rebuild::Function,
+    objective ::AbstractVariationalObjective,
+    rebuild,
     n_max_iter::Int,
-    λ::AbstractVector{<:Real};
-    optimizer = TruncatedADAGrad(),
-    rng       = Random.GLOBAL_RNG
+    λ         ::AbstractVector{<:Real};
+    optimizer ::Optimisers.AbstractRule = TruncatedADAGrad(),
+    rng       ::Random.AbstractRNG      = Random.GLOBAL_RNG,
+    progress  ::Bool                    = true
 )
     # TODO: really need a better way to warn the user about potentially
     # not using the correct accumulator
@@ -24,21 +25,25 @@ function optimize(
         @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ)
     end
 
+    optstate = Optimisers.init(optimizer, λ)
     grad_buf = DiffResults.GradientResult(λ)
 
     i = 0
     prog = ProgressMeter.Progress(
-        n_max_iter; desc="[$(string(objective))] Optimizing...", barlen=0, enabled=PROGRESS[])
+        n_max_iter;
+        desc      = "[$(string(objective))] Optimizing...",
+        barlen    = 0,
+        enabled   = progress,
+        showspeed = true)
 
     # add criterion? A running mean maybe?
     time_elapsed = @elapsed begin
         for i = 1:n_max_iter
             stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf)
-            
-            # apply update rule
-            Δλ = DiffResults.gradient(grad_buf)
-            Δλ = apply!(optimizer, λ, Δλ)
-            @. λ = λ - Δλ
+            g     = DiffResults.gradient(grad_buf)
+
+            optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g)
+            Optimisers.subtract!(λ, Δλ)
 
             stat′ = (Δλ=norm(Δλ),)
             stats = merge(stats, stat′)

From ca02fa315486a0977327f3e2824cd87b40b1908a Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 10 Jun 2023 22:42:38 +0100
Subject: [PATCH 021/206] remove execution time measurement (replace later with
 somethin else)

---
 src/vi.jl | 21 +++++++++------------
 1 file changed, 9 insertions(+), 12 deletions(-)

diff --git a/src/vi.jl b/src/vi.jl
index 6c8b26d1a..e5062defd 100644
--- a/src/vi.jl
+++ b/src/vi.jl
@@ -36,23 +36,20 @@ function optimize(
         enabled   = progress,
         showspeed = true)
 
-    # add criterion? A running mean maybe?
-    time_elapsed = @elapsed begin
-        for i = 1:n_max_iter
-            stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf)
-            g     = DiffResults.gradient(grad_buf)
+    for i = 1:n_max_iter
+        stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf)
+        g     = DiffResults.gradient(grad_buf)
 
-            optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g)
-            Optimisers.subtract!(λ, Δλ)
+        optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g)
+        Optimisers.subtract!(λ, Δλ)
 
-            stat′ = (Δλ=norm(Δλ),)
-            stats = merge(stats, stat′)
+        stat′ = (Δλ=norm(Δλ),)
+        stats = merge(stats, stat′)
         
-            AdvancedVI.DEBUG && @debug "Step $i" stats...
+        AdvancedVI.DEBUG && @debug "Step $i" stats...
             pm_next!(prog, stats)
-        end
     end
-    return λ
+    λ
 end
 
 # function vi(grad_estimator, q, θ_init; optimizer = TruncatedADAGrad(), rng = Random.GLOBAL_RNG)

From a48377f016c82461000ba10c35803a5181f4b4a9 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Mon, 12 Jun 2023 21:47:22 +0100
Subject: [PATCH 022/206] fix use multiple dispatch for deciding whether to
 stop entropy grad.

---
 src/objectives/elbo/elbo.jl    | 21 +++++++--------------
 src/objectives/elbo/entropy.jl |  4 ++++
 src/vi.jl                      |  4 ++--
 3 files changed, 13 insertions(+), 16 deletions(-)

diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl
index 343581d82..cebd7d823 100644
--- a/src/objectives/elbo/elbo.jl
+++ b/src/objectives/elbo/elbo.jl
@@ -15,6 +15,8 @@ struct ELBO{EnergyEst  <: AbstractEnergyEstimator,
     n_samples::Int
 end
 
+skip_entropy_gradient(elbo::ELBO) = skip_entropy_gradient(elbo.entropy_estimator)
+
 Base.string(::ELBO) = "ELBO"
 
 function ADVI(ℓπ, b⁻¹, n_samples::Int)
@@ -33,28 +35,19 @@ end
 
 function estimate_gradient!(
     rng::Random.AbstractRNG,
-    elbo::ELBO{EnergyEst, EntropyEst},
+    elbo::ELBO,
     λ::Vector{<:Real},
     rebuild,
-    out::DiffResults.MutableDiffResult) where {EnergyEst  <: AbstractEnergyEstimator,
-                                               EntropyEst <: AbstractEntropyEstimator}
+    out::DiffResults.MutableDiffResult)
 
     # Gradient-stopping for computing the sticking-the-landing control variate
-    q_η_stop = if EntropyEst isa MonteCarloEntropy{true}
-        rebuild(λ)
-    else
-        nothing
-    end
+    q_η_stop = skip_entropy_gradient(elbo) ? rebuild(λ) : nothing
 
     grad!(ADBackend(), λ, out) do λ′
         q_η = rebuild(λ′)
-        q_η_entropy = if EntropyEst isa MonteCarloEntropy{true}
-            q_η_stop
-        else
-            q_η
-        end
+        q_η_entropy = skip_entropy_gradient(elbo) ? q_η_stop : q_η
         -elbo(q_η; rng, n_samples=elbo.n_samples, q_η_entropy)
     end
     nelbo = DiffResults.value(out)
-    (elbo=-nelbo,)
+    out, (elbo=-nelbo,)
 end
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index 8efb7c711..50f498d6e 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -5,6 +5,8 @@ function (::ClosedFormEntropy)(q, ::AbstractMatrix)
     entropy(q)
 end
 
+skip_entropy_gradient(::ClosedFormEntropy) = false
+
 struct MonteCarloEntropy{IsStickingTheLanding} <: AbstractEntropyEstimator end
 
 MonteCarloEntropy() = MonteCarloEntropy{false}()
@@ -34,6 +36,8 @@ MonteCarloEntropy() = MonteCarloEntropy{false}()
 """
 StickingTheLandingEntropy() = MonteCarloEntropy{true}()
 
+skip_entropy_gradient(::MonteCarloEntropy{IsStickingTheLanding}) where {IsStickingTheLanding} = IsStickingTheLanding
+
 function (::MonteCarloEntropy)(q, ηs::AbstractMatrix)
     n_samples = size(ηs, 2)
     mapreduce(+, eachcol(ηs)) do ηᵢ
diff --git a/src/vi.jl b/src/vi.jl
index e5062defd..8b8fe14fa 100644
--- a/src/vi.jl
+++ b/src/vi.jl
@@ -37,8 +37,8 @@ function optimize(
         showspeed = true)
 
     for i = 1:n_max_iter
-        stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf)
-        g     = DiffResults.gradient(grad_buf)
+        grad_buf, stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf)
+        g = DiffResults.gradient(grad_buf)
 
         optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g)
         Optimisers.subtract!(λ, Δλ)

From 0b40ccf6ef10e6ebef9d6372e407731bb4dc2ca0 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Mon, 12 Jun 2023 22:16:30 +0100
Subject: [PATCH 023/206] add termination decision, callback arguments

---
 src/vi.jl | 18 ++++++++++++++++--
 1 file changed, 16 insertions(+), 2 deletions(-)

diff --git a/src/vi.jl b/src/vi.jl
index 8b8fe14fa..1a4d57ecb 100644
--- a/src/vi.jl
+++ b/src/vi.jl
@@ -16,7 +16,9 @@ function optimize(
     λ         ::AbstractVector{<:Real};
     optimizer ::Optimisers.AbstractRule = TruncatedADAGrad(),
     rng       ::Random.AbstractRNG      = Random.GLOBAL_RNG,
-    progress  ::Bool                    = true
+    progress  ::Bool                    = true,
+    callback!                           = nothing,
+    terminate                           = (args...) -> false,
 )
     # TODO: really need a better way to warn the user about potentially
     # not using the correct accumulator
@@ -28,6 +30,7 @@ function optimize(
     optstate = Optimisers.init(optimizer, λ)
     grad_buf = DiffResults.GradientResult(λ)
 
+    q = rebuild(λ)
     i = 0
     prog = ProgressMeter.Progress(
         n_max_iter;
@@ -43,11 +46,22 @@ function optimize(
         optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g)
         Optimisers.subtract!(λ, Δλ)
 
-        stat′ = (Δλ=norm(Δλ),)
+        stat′ = (Δλ=norm(Δλ), gradient_norm=norm(g))
         stats = merge(stats, stat′)
+        q     = rebuild(λ)
+
+        if !isnothing(callback!)
+            stat′  = callback!(q, stats)
+            stats = !isnothing(stat′) ? merge(stat′, stats) : stats
+        end
         
         AdvancedVI.DEBUG && @debug "Step $i" stats...
             pm_next!(prog, stats)
+
+        # Termination decision is work in progress
+        if terminate(rng, q, objective, stats)
+            break
+        end
     end
     λ
 end

From 21db3fb842d226148ee23b758c0756e332132066 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Mon, 12 Jun 2023 22:35:48 +0100
Subject: [PATCH 024/206] add Base.show to modules

---
 src/objectives/elbo/advi_energy.jl | 2 ++
 src/objectives/elbo/elbo.jl        | 6 +++++-
 src/objectives/elbo/entropy.jl     | 4 ++++
 src/vi.jl                          | 1 -
 4 files changed, 11 insertions(+), 2 deletions(-)

diff --git a/src/objectives/elbo/advi_energy.jl b/src/objectives/elbo/advi_energy.jl
index b27b752e2..078a157ed 100644
--- a/src/objectives/elbo/advi_energy.jl
+++ b/src/objectives/elbo/advi_energy.jl
@@ -26,6 +26,8 @@ end
 
 ADVIEnergy(prob) = ADVIEnergy(prob, identity)
 
+Base.show(io::IO, energy::ADVIEnergy) = print(io, "ADVIEnergy()")
+
 function (energy::ADVIEnergy)(q, ηs::AbstractMatrix)
     n_samples = size(ηs, 2)
     mapreduce(+, eachcol(ηs)) do ηᵢ
diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl
index cebd7d823..b26516d96 100644
--- a/src/objectives/elbo/elbo.jl
+++ b/src/objectives/elbo/elbo.jl
@@ -17,7 +17,11 @@ end
 
 skip_entropy_gradient(elbo::ELBO) = skip_entropy_gradient(elbo.entropy_estimator)
 
-Base.string(::ELBO) = "ELBO"
+Base.show(io::IO, elbo::ELBO) = print(
+    io,
+    "ELBO(energy_estimator=$(elbo.energy_estimator), " *
+    "entropy_estimator=$(elbo.entropy_estimator)), " *
+    "n_samples=$(elbo.n_samples))")
 
 function ADVI(ℓπ, b⁻¹, n_samples::Int)
     ELBO(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples)
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index 50f498d6e..ddeb64a9c 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -11,6 +11,8 @@ struct MonteCarloEntropy{IsStickingTheLanding} <: AbstractEntropyEstimator end
 
 MonteCarloEntropy() = MonteCarloEntropy{false}()
 
+Base.show(io::IO, entropy::MonteCarloEntropy{false}) = print(io, "MonteCarloEntropy()")
+
 """
   Sticking the Landing Control Variate
 
@@ -38,6 +40,8 @@ StickingTheLandingEntropy() = MonteCarloEntropy{true}()
 
 skip_entropy_gradient(::MonteCarloEntropy{IsStickingTheLanding}) where {IsStickingTheLanding} = IsStickingTheLanding
 
+Base.show(io::IO, entropy::MonteCarloEntropy{true}) = print(io, "StickingTheLandingEntropy()")
+
 function (::MonteCarloEntropy)(q, ηs::AbstractMatrix)
     n_samples = size(ηs, 2)
     mapreduce(+, eachcol(ηs)) do ηᵢ
diff --git a/src/vi.jl b/src/vi.jl
index 1a4d57ecb..605464b69 100644
--- a/src/vi.jl
+++ b/src/vi.jl
@@ -34,7 +34,6 @@ function optimize(
     i = 0
     prog = ProgressMeter.Progress(
         n_max_iter;
-        desc      = "[$(string(objective))] Optimizing...",
         barlen    = 0,
         enabled   = progress,
         showspeed = true)

From 25c51b4796b2e550d1ee9747e5ccbf81a48aff38 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Mon, 12 Jun 2023 23:03:25 +0100
Subject: [PATCH 025/206] add interface calling `restructure`, rename rebuild
 -> restructure

---
 src/objectives/elbo/elbo.jl |  6 ++--
 src/vi.jl                   | 61 ++++++++++++++++++-------------------
 2 files changed, 32 insertions(+), 35 deletions(-)

diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl
index b26516d96..b3bad3c09 100644
--- a/src/objectives/elbo/elbo.jl
+++ b/src/objectives/elbo/elbo.jl
@@ -41,14 +41,14 @@ function estimate_gradient!(
     rng::Random.AbstractRNG,
     elbo::ELBO,
     λ::Vector{<:Real},
-    rebuild,
+    restructure,
     out::DiffResults.MutableDiffResult)
 
     # Gradient-stopping for computing the sticking-the-landing control variate
-    q_η_stop = skip_entropy_gradient(elbo) ? rebuild(λ) : nothing
+    q_η_stop = skip_entropy_gradient(elbo) ? restructure(λ) : nothing
 
     grad!(ADBackend(), λ, out) do λ′
-        q_η = rebuild(λ′)
+        q_η = restructure(λ′)
         q_η_entropy = skip_entropy_gradient(elbo) ? q_η_stop : q_η
         -elbo(q_η; rng, n_samples=elbo.n_samples, q_η_entropy)
     end
diff --git a/src/vi.jl b/src/vi.jl
index 605464b69..f1f4bc255 100644
--- a/src/vi.jl
+++ b/src/vi.jl
@@ -11,9 +11,9 @@ the steps.
 """
 function optimize(
     objective ::AbstractVariationalObjective,
-    rebuild,
-    n_max_iter::Int,
-    λ         ::AbstractVector{<:Real};
+    restructure,
+    λ         ::AbstractVector{<:Real},
+    n_max_iter::Int;
     optimizer ::Optimisers.AbstractRule = TruncatedADAGrad(),
     rng       ::Random.AbstractRNG      = Random.GLOBAL_RNG,
     progress  ::Bool                    = true,
@@ -30,50 +30,47 @@ function optimize(
     optstate = Optimisers.init(optimizer, λ)
     grad_buf = DiffResults.GradientResult(λ)
 
-    q = rebuild(λ)
-    i = 0
-    prog = ProgressMeter.Progress(
-        n_max_iter;
-        barlen    = 0,
-        enabled   = progress,
-        showspeed = true)
+    prog = ProgressMeter.Progress(n_max_iter;
+                                  barlen    = 0,
+                                  enabled   = progress,
+                                  showspeed = true)
+    stats = Vector{NamedTuple}(undef, n_max_iter)
 
-    for i = 1:n_max_iter
-        grad_buf, stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf)
+    for t = 1:n_max_iter
+        grad_buf, stat = estimate_gradient!(rng, objective, λ, restructure, grad_buf)
         g = DiffResults.gradient(grad_buf)
 
         optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g)
         Optimisers.subtract!(λ, Δλ)
 
         stat′ = (Δλ=norm(Δλ), gradient_norm=norm(g))
-        stats = merge(stats, stat′)
-        q     = rebuild(λ)
+        stat  = merge(stat, stat′)
+        q     = restructure(λ)
 
         if !isnothing(callback!)
-            stat′  = callback!(q, stats)
-            stats = !isnothing(stat′) ? merge(stat′, stats) : stats
+            stat′ = callback!(q, stat)
+            stat = !isnothing(stat′) ? merge(stat′, stat) : stat
         end
         
-        AdvancedVI.DEBUG && @debug "Step $i" stats...
-            pm_next!(prog, stats)
+        AdvancedVI.DEBUG && @debug "Step $i" stat...
+
+        pm_next!(prog, stat)
+        stats[t] = stat
 
         # Termination decision is work in progress
-        if terminate(rng, q, objective, stats)
+        if terminate(rng, q, objective, stat)
+            stats = stats[1:t]
             break
         end
     end
-    λ
+    λ, stats
 end
 
-# function vi(grad_estimator, q, θ_init; optimizer = TruncatedADAGrad(), rng = Random.GLOBAL_RNG)
-#     θ = copy(θ_init)
-#     optimize!(grad_estimator, rebuild, n_max_iter, λ, optimizer = optimizer, rng = rng)
-
-#     # If `q` is a mean-field approx we use the specialized `update` function
-#     if q isa Distribution
-#         return update(q, θ)
-#     else
-#         # Otherwise we assume it's a mapping θ → q
-#         return q(θ)
-#     end
-# end
+function optimize(objective::AbstractVariationalObjective,
+                  q,
+                  n_max_iter::Int;
+                  kwargs...)
+    λ, restructure = Optimisers.destructure(q)
+    λ, stats = optimize(objective, restructure, λ, n_max_iter; kwargs...)
+    restructure(λ), stats
+end

From fc200462e0a6929ca580d6cabaad27afd179b30f Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 13 Jun 2023 00:33:47 +0100
Subject: [PATCH 026/206] add estimator state interface, add control variate
 interface to ADVI

---
 src/AdvancedVI.jl           | 12 ++++++-
 src/objectives/elbo/advi.jl | 64 +++++++++++++++++++++++++++++++++++++
 src/objectives/elbo/elbo.jl | 57 ---------------------------------
 src/vi.jl                   | 22 +++++++------
 4 files changed, 88 insertions(+), 67 deletions(-)
 create mode 100644 src/objectives/elbo/advi.jl
 delete mode 100644 src/objectives/elbo/elbo.jl

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 5a02501b4..f2eb2317c 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -166,7 +166,17 @@ end
 # estimators
 abstract type AbstractVariationalObjective end
 
-include("objectives/elbo/elbo.jl")
+function estimate_gradient end
+
+abstract type AbstractEnergyEstimator  end
+abstract type AbstractEntropyEstimator end
+abstract type AbstractControlVariate end
+
+init(::Nothing) = nothing
+
+update(::Nothing, ::Nothing) = (nothing, nothing)
+
+include("objectives/elbo/advi.jl")
 include("objectives/elbo/advi_energy.jl")
 include("objectives/elbo/entropy.jl")
 
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
new file mode 100644
index 000000000..66e5f320a
--- /dev/null
+++ b/src/objectives/elbo/advi.jl
@@ -0,0 +1,64 @@
+
+struct ADVI{EnergyEst   <: AbstractEnergyEstimator,
+            EntropyEst  <: AbstractEntropyEstimator,
+            ControlVar  <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective
+    energy_estimator::EnergyEst
+    entropy_estimator::EntropyEst
+    control_variate::ControlVar
+    n_samples::Int
+end
+
+skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy_estimator)
+
+init(advi::ADVI) = init(advi.control_variate)
+
+Base.show(io::IO, advi::ADVI) = print(
+    io,
+    "ADVI(energy_estimator=$(advi.energy_estimator), " *
+    "entropy_estimator=$(advi.entropy_estimator)), " *
+    "n_samples=$(advi.n_samples))")
+
+function ADVI(energy_estimator::AbstractEnergyEstimator,
+              entropy_estimator::AbstractEntropyEstimator,
+              n_samples::Int)
+    ADVI(energy_estimator, entropy_estimator, nothing, n_samples)
+end
+
+function ADVI(ℓπ, b⁻¹, n_samples::Int)
+    ADVI(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples)
+end
+
+function (advi::ADVI)(q_η::ContinuousMultivariateDistribution;
+                      rng       ::Random.AbstractRNG = Random.default_rng(),
+                      n_samples ::Int                = advi.n_samples,
+                      ηs        ::AbstractMatrix     = rand(rng, q_η, n_samples),
+                      q_η_entropy::ContinuousMultivariateDistribution = q_η)
+    𝔼ℓ = advi.energy_estimator(q_η, ηs)
+    ℍ  = advi.entropy_estimator(q_η_entropy, ηs)
+    𝔼ℓ + ℍ
+end
+
+function estimate_gradient(
+    rng::Random.AbstractRNG,
+    advi::ADVI,
+    est_state,
+    λ::Vector{<:Real},
+    restructure,
+    out::DiffResults.MutableDiffResult)
+
+    # Gradient-stopping for computing the sticking-the-landing control variate
+    q_η_stop = skip_entropy_gradient(advi.entropy_estimator) ? restructure(λ) : nothing
+
+    grad!(ADBackend(), λ, out) do λ′
+        q_η = restructure(λ′)
+        q_η_entropy = skip_entropy_gradient(advi.entropy_estimator) ? q_η_stop : q_η
+        -advi(q_η; rng, q_η_entropy)
+    end
+    nelbo = DiffResults.value(out)
+    stat  = (elbo=-nelbo,)
+
+    est_state, stat′ = update(advi.control_variate, est_state)
+    stat = !isnothing(stat′) ? merge(stat′, stat) : stat 
+
+    out, est_state, stat
+end
diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl
deleted file mode 100644
index b3bad3c09..000000000
--- a/src/objectives/elbo/elbo.jl
+++ /dev/null
@@ -1,57 +0,0 @@
-
-abstract type AbstractEnergyEstimator  end
-abstract type AbstractEntropyEstimator end
-
-struct ELBO{EnergyEst  <: AbstractEnergyEstimator,
-            EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective
-    # Evidence Lower Bound
-    # 
-    # Jordan, Michael I., et al.
-    # "An introduction to variational methods for graphical models."
-    # Machine learning 37 (1999): 183-233.
-
-    energy_estimator::EnergyEst
-    entropy_estimator::EntropyEst
-    n_samples::Int
-end
-
-skip_entropy_gradient(elbo::ELBO) = skip_entropy_gradient(elbo.entropy_estimator)
-
-Base.show(io::IO, elbo::ELBO) = print(
-    io,
-    "ELBO(energy_estimator=$(elbo.energy_estimator), " *
-    "entropy_estimator=$(elbo.entropy_estimator)), " *
-    "n_samples=$(elbo.n_samples))")
-
-function ADVI(ℓπ, b⁻¹, n_samples::Int)
-    ELBO(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples)
-end
-
-function (elbo::ELBO)(q_η::ContinuousMultivariateDistribution;
-                      rng = Random.default_rng(),
-                      n_samples::Int = elbo.n_samples,
-                      q_η_entropy::ContinuousMultivariateDistribution = q_η)
-    ηs = rand(rng, q_η, n_samples)
-    𝔼ℓ = elbo.energy_estimator(q_η, ηs)
-    ℍ  = elbo.entropy_estimator(q_η_entropy, ηs)
-    𝔼ℓ + ℍ
-end
-
-function estimate_gradient!(
-    rng::Random.AbstractRNG,
-    elbo::ELBO,
-    λ::Vector{<:Real},
-    restructure,
-    out::DiffResults.MutableDiffResult)
-
-    # Gradient-stopping for computing the sticking-the-landing control variate
-    q_η_stop = skip_entropy_gradient(elbo) ? restructure(λ) : nothing
-
-    grad!(ADBackend(), λ, out) do λ′
-        q_η = restructure(λ′)
-        q_η_entropy = skip_entropy_gradient(elbo) ? q_η_stop : q_η
-        -elbo(q_η; rng, n_samples=elbo.n_samples, q_η_entropy)
-    end
-    nelbo = DiffResults.value(out)
-    out, (elbo=-nelbo,)
-end
diff --git a/src/vi.jl b/src/vi.jl
index f1f4bc255..ebb246bee 100644
--- a/src/vi.jl
+++ b/src/vi.jl
@@ -27,8 +27,9 @@ function optimize(
         @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ)
     end
 
-    optstate = Optimisers.init(optimizer, λ)
-    grad_buf = DiffResults.GradientResult(λ)
+    opt_state = Optimisers.init(optimizer, λ)
+    est_state = init(objective)
+    grad_buf  = DiffResults.GradientResult(λ)
 
     prog = ProgressMeter.Progress(n_max_iter;
                                   barlen    = 0,
@@ -37,22 +38,25 @@ function optimize(
     stats = Vector{NamedTuple}(undef, n_max_iter)
 
     for t = 1:n_max_iter
-        grad_buf, stat = estimate_gradient!(rng, objective, λ, restructure, grad_buf)
-        g = DiffResults.gradient(grad_buf)
+        stat = (iteration=t,)
 
-        optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g)
+        grad_buf, est_state, stat′ = estimate_gradient(rng, objective, est_state, λ, restructure, grad_buf)
+        g    = DiffResults.gradient(grad_buf)
+        stat = merge(stat, stat′)
+
+        opt_state, Δλ = Optimisers.apply!(optimizer, opt_state, λ, g)
         Optimisers.subtract!(λ, Δλ)
+        stat′ = (iteration=t, Δλ=norm(Δλ), gradient_norm=norm(g))
+        stat = merge(stat, stat′)
 
-        stat′ = (Δλ=norm(Δλ), gradient_norm=norm(g))
-        stat  = merge(stat, stat′)
-        q     = restructure(λ)
+        q    = restructure(λ)
 
         if !isnothing(callback!)
             stat′ = callback!(q, stat)
             stat = !isnothing(stat′) ? merge(stat′, stat) : stat
         end
         
-        AdvancedVI.DEBUG && @debug "Step $i" stat...
+        AdvancedVI.DEBUG && @debug "Step $t" stat...
 
         pm_next!(prog, stat)
         stats[t] = stat

From 6faa807f067ff77856c307ef4baa11865616deae Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 13 Jun 2023 00:39:05 +0100
Subject: [PATCH 027/206] fix `show(advi)` to show control variate

---
 src/objectives/elbo/advi.jl | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 66e5f320a..de2c683b0 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -15,7 +15,8 @@ init(advi::ADVI) = init(advi.control_variate)
 Base.show(io::IO, advi::ADVI) = print(
     io,
     "ADVI(energy_estimator=$(advi.energy_estimator), " *
-    "entropy_estimator=$(advi.entropy_estimator)), " *
+    "entropy_estimator=$(advi.entropy_estimator), " *
+    (!isnothing(advi.control_variate) ? "control_variate=$(advi.control_variate), " : "") *
     "n_samples=$(advi.n_samples))")
 
 function ADVI(energy_estimator::AbstractEnergyEstimator,

From 7095d276f5947b855289099a0ce56f2106c8b16c Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 13 Jun 2023 00:39:45 +0100
Subject: [PATCH 028/206] fix simplify `show(advi.control_variate)`

---
 src/objectives/elbo/advi.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index de2c683b0..dc2962eeb 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -16,7 +16,7 @@ Base.show(io::IO, advi::ADVI) = print(
     io,
     "ADVI(energy_estimator=$(advi.energy_estimator), " *
     "entropy_estimator=$(advi.entropy_estimator), " *
-    (!isnothing(advi.control_variate) ? "control_variate=$(advi.control_variate), " : "") *
+    "control_variate=$(advi.control_variate), " *
     "n_samples=$(advi.n_samples))")
 
 function ADVI(energy_estimator::AbstractEnergyEstimator,

From 9169ae262f8ac289d8e7355f8642584e18da3614 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 13 Jun 2023 00:51:48 +0100
Subject: [PATCH 029/206] fix type piracy by wrapping location-scale bijected
 distribution

---
 src/distributions/location_scale.jl | 67 ++++++++++++++++-------------
 1 file changed, 38 insertions(+), 29 deletions(-)

diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index dd9b5f2a3..f3c95d0c1 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -1,41 +1,51 @@
 
-function LocationScale(μ::AbstractVector,
-                       L::Union{<: AbstractTriangular,
-                                <: Diagonal},
-                       q₀::ContinuousMultivariateDistribution)
-    @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2))
-    transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L))
+import Base: rand, _rand!
+
+struct LocationScale{ReparamMvDist <: Bijectors.TransformedDistribution} <: ContinuousMultivariateDistribution
+    q_trans::ReparamMvDist
+
+    function LocationScale(μ::AbstractVector,
+                           L::Union{<: AbstractTriangular,
+                                    <: Diagonal},
+                           q₀::ContinuousMultivariateDistribution)
+        @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2))
+        q_trans = transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L))       
+        new{typeof(q_trans)}(q_trans)
+    end
+
+    function LocationScale(q_trans::Bijectors.TransformedDistribution)
+        new{typeof(q_trans)}(q_trans)
+    end
 end
 
-function StatsBase.entropy(
-    q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution,
-                                     <: Bijectors.ComposedFunction{
-                                         <: Bijectors.Shift,
-                                         <: Bijectors.Scale}})
-    q_base = q_trans.dist
-    scale  = q_trans.transform.inner.a
+Functors.@functor LocationScale
+
+Base.length(q::LocationScale) = length(q.q_trans)
+Base.size(q::LocationScale) = size(q.q_trans)
+
+function StatsBase.entropy(q::LocationScale)
+    q_base = q.q_trans.dist
+    scale  = q.q_trans.transform.inner.a
     entropy(q_base) + first(logabsdet(scale))
 end
 
-function Distributions.logpdf(
-    q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution,
-                                     <: Bijectors.ComposedFunction{
-                                         <: Bijectors.Shift,
-                                         <: Bijectors.Scale}},
-    z::AbstractVector)
-    q_base  = q_trans.dist
-    reparam = q_trans.transform
-    scale   = q_trans.transform.inner.a
-    η       = inverse(reparam)(z)
-    logpdf(q_base, η) - first(logabsdet(scale))
-end
+
+Distributions.logpdf(q::LocationScale, z::AbstractVector) = logpdf(q.q_trans, z)
+
+_logpdf(q::LocationScale, y::AbstractVector) = _logpdf(q.q_trans, y)
+
+rand(q::LocationScale) = rand(q.q_trans)
+
+rand(rng::Random.AbstractRNG, q::LocationScale, num_samples::Int) = rand(rng, q.q_trans, num_samples)
+
+_rand!(rng::Random.AbstractRNG, q::LocationScale, x::AbstractVector{<:Real}) = _rand!(rng, q.q_trans, x)
+
 
 function FullRankGaussian(μ::AbstractVector{T},
                           L::AbstractTriangular{T,S}) where {T <: Real, S}
     @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2))
     n_dims = length(μ)
-    q_base = MvNormal(FillArrays.Zeros{T}(n_dims),
-                      PDMats.ScalMat{T}(n_dims, one(T)))
+    q_base = MvNormal(FillArrays.Zeros{T}(n_dims), PDMats.ScalMat{T}(n_dims, one(T)))
     LocationScale(μ, L, q_base)
 end
 
@@ -43,7 +53,6 @@ function MeanFieldGaussian(μ::AbstractVector{T},
                            L::Diagonal{T,V}) where {T <: Real, V}
     @assert (length(μ) == size(L,1))
     n_dims = length(μ)
-    q_base = MvNormal(FillArrays.Zeros{T}(n_dims),
-                      PDMats.ScalMat{T}(n_dims, one(T)))
+    q_base = MvNormal(FillArrays.Zeros{T}(n_dims), PDMats.ScalMat{T}(n_dims, one(T)))
     LocationScale(μ, L, q_base)
 end

From 3db73011a430fb3aa5830264be687d860410f483 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Mon, 26 Jun 2023 23:01:27 +0100
Subject: [PATCH 030/206] remove old AdvancedVI custom optimizers

---
 Project.toml      |  1 +
 src/AdvancedVI.jl | 15 +++-----
 src/optimisers.jl | 94 -----------------------------------------------
 src/vi.jl         | 11 +-----
 4 files changed, 8 insertions(+), 113 deletions(-)
 delete mode 100644 src/optimisers.jl

diff --git a/Project.toml b/Project.toml
index ba807698f..d2708915f 100644
--- a/Project.toml
+++ b/Project.toml
@@ -20,6 +20,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
 StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
 Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
+UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
 
 [compat]
 Bijectors = "0.11, 0.12"
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index f2eb2317c..76c6d859d 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -1,9 +1,12 @@
+
 module AdvancedVI
 
-using Random: Random
+using UnPack
 
-using Functors
+import Random: AbstractRNG, default_rng
+import Distributions: logpdf, _logpdf, rand, _rand!, _rand!
 
+using Functors
 using Optimisers
 
 using DocStringExtensions
@@ -31,11 +34,6 @@ include("ad.jl")
 
 using Requires
 function __init__()
-    @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
-        apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ)
-        Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ)
-        Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ)
-    end
     @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
         include("compat/zygote.jl")
         export ZygoteAD
@@ -183,9 +181,6 @@ include("objectives/elbo/entropy.jl")
 # Variational Families
 include("distributions/location_scale.jl")
 
-# optimisers
-include("optimisers.jl")
-
 include("utils.jl")
 include("vi.jl")
 
diff --git a/src/optimisers.jl b/src/optimisers.jl
deleted file mode 100644
index 8077f98cb..000000000
--- a/src/optimisers.jl
+++ /dev/null
@@ -1,94 +0,0 @@
-const ϵ = 1e-8
-
-"""
-    TruncatedADAGrad(η=0.1, τ=1.0, n=100)
-
-Implements a truncated version of AdaGrad in the sense that only the `n` previous gradient norms are used to compute the scaling rather than *all* previous. It has parameter specific learning rates based on how frequently it is updated.
-
-## Parameters
-  - η: learning rate
-  - τ: constant scale factor
-  - n: number of previous gradient norms to use in the scaling.
-```
-## References
-[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
-Parameters don't need tuning.
-
-[TruncatedADAGrad](https://arxiv.org/abs/1506.03431v2) (Appendix E).
-"""
-mutable struct TruncatedADAGrad
-    eta::Float64
-    tau::Float64
-    n::Int
-    
-    iters::IdDict
-    acc::IdDict
-end
-
-function TruncatedADAGrad(η = 0.1, τ = 1.0, n = 100)
-    TruncatedADAGrad(η, τ, n, IdDict(), IdDict())
-end
-
-function apply!(o::TruncatedADAGrad, x, Δ)
-    T = eltype(Tracker.data(Δ))
-    
-    η = o.eta
-    τ = o.tau
-
-    g² = get!(
-        o.acc,
-        x,
-        [zeros(T, size(x)) for j = 1:o.n]
-    )::Array{typeof(Tracker.data(Δ)), 1}
-    i = get!(o.iters, x, 1)::Int
-
-    # Example: suppose i = 12 and o.n = 10
-    idx = mod(i - 1, o.n) + 1 # => idx = 2
-
-    # set the current
-    @inbounds @. g²[idx] = Δ^2 # => g²[2] = Δ^2 where Δ is the (o.n + 2)-th Δ
-
-    # TODO: make more efficient and stable
-    s = sum(g²)
-    
-    # increment
-    o.iters[x] += 1
-    
-    # TODO: increment (but "truncate")
-    # o.iters[x] = i > o.n ? o.n + mod(i, o.n) : i + 1
-
-    @. Δ *= η / (τ + sqrt(s) + ϵ)
-end
-
-"""
-    DecayedADAGrad(η=0.1, pre=1.0, post=0.9)
-
-Implements a decayed version of AdaGrad. It has parameter specific learning rates based on how frequently it is updated.
-
-## Parameters
-  - η: learning rate
-  - pre: weight of new gradient norm
-  - post: weight of histroy of gradient norms
-```
-## References
-[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
-Parameters don't need tuning.
-"""
-mutable struct DecayedADAGrad
-    eta::Float64
-    pre::Float64
-    post::Float64
-
-    acc::IdDict
-end
-
-DecayedADAGrad(η = 0.1, pre = 1.0, post = 0.9) = DecayedADAGrad(η, pre, post, IdDict())
-
-function apply!(o::DecayedADAGrad, x, Δ)
-    T = eltype(Tracker.data(Δ))
-    
-    η = o.eta
-    acc = get!(o.acc, x, fill(T(ϵ), size(x)))::typeof(Tracker.data(x))
-    @. acc = o.post * acc + o.pre * Δ^2
-    @. Δ *= η / (√acc + ϵ)
-end
diff --git a/src/vi.jl b/src/vi.jl
index ebb246bee..842f187eb 100644
--- a/src/vi.jl
+++ b/src/vi.jl
@@ -14,19 +14,12 @@ function optimize(
     restructure,
     λ         ::AbstractVector{<:Real},
     n_max_iter::Int;
-    optimizer ::Optimisers.AbstractRule = TruncatedADAGrad(),
-    rng       ::Random.AbstractRNG      = Random.GLOBAL_RNG,
+    optimizer ::Optimisers.AbstractRule = Optimisers.Adam(),
+    rng       ::AbstractRNG             = default_rng(),
     progress  ::Bool                    = true,
     callback!                           = nothing,
     terminate                           = (args...) -> false,
 )
-    # TODO: really need a better way to warn the user about potentially
-    # not using the correct accumulator
-    if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc))
-        # this message should only occurr once in the optimization process
-        @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ)
-    end
-
     opt_state = Optimisers.init(optimizer, λ)
     est_state = init(objective)
     grad_buf  = DiffResults.GradientResult(λ)

From e6a082aadbd3fa92e60fedf5373f2efbb1875ecc Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Mon, 26 Jun 2023 23:47:04 +0100
Subject: [PATCH 031/206] fix Location Scale to not depend on Bijectors

---
 src/distributions/location_scale.jl | 101 +++++++++++++++++-----------
 1 file changed, 61 insertions(+), 40 deletions(-)

diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index f3c95d0c1..c46b5111f 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -1,58 +1,79 @@
 
-import Base: rand, _rand!
-
-struct LocationScale{ReparamMvDist <: Bijectors.TransformedDistribution} <: ContinuousMultivariateDistribution
-    q_trans::ReparamMvDist
-
-    function LocationScale(μ::AbstractVector,
-                           L::Union{<: AbstractTriangular,
-                                    <: Diagonal},
-                           q₀::ContinuousMultivariateDistribution)
+struct VILocationScale{L, S, D, R} <: ContinuousMultivariateDistribution
+    location::L
+    scale   ::S
+    dist    ::D
+    epsilon ::R
+
+    function VILocationScale(μ::AbstractVector{<:Real},
+                             L::Union{<:AbstractTriangular{<:Real},
+                                      <:Diagonal{<:Real}},
+                             q_base::ContinuousUnivariateDistribution,
+                             epsilon::Real)
+        # Restricting all the arguments to have the same types creates problems 
+        # with dual-variable-based AD frameworks.
         @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2))
-        q_trans = transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L))       
-        new{typeof(q_trans)}(q_trans)
-    end
-
-    function LocationScale(q_trans::Bijectors.TransformedDistribution)
-        new{typeof(q_trans)}(q_trans)
+        new{typeof(μ), typeof(L), typeof(q_base), typeof(epsilon)}(μ, L, q_base, epsilon)
     end
 end
 
-Functors.@functor LocationScale
+Functors.@functor VILocationScale (location, scale)
 
-Base.length(q::LocationScale) = length(q.q_trans)
-Base.size(q::LocationScale) = size(q.q_trans)
+Base.length(q::VILocationScale) = length(q.location)
+Base.size(q::VILocationScale) = size(q.location)
 
-function StatsBase.entropy(q::LocationScale)
-    q_base = q.q_trans.dist
-    scale  = q.q_trans.transform.inner.a
-    entropy(q_base) + first(logabsdet(scale))
+function StatsBase.entropy(q::VILocationScale)
+    @unpack location, scale, dist = q
+    n_dims = length(location)
+    n_dims*entropy(dist) + first(logabsdet(scale))
 end
 
+function logpdf(q::VILocationScale, z::AbstractVector{<:Real})
+    @unpack location, scale, dist = q
+    mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) + first(logabsdet(scale))
+end
 
-Distributions.logpdf(q::LocationScale, z::AbstractVector) = logpdf(q.q_trans, z)
-
-_logpdf(q::LocationScale, y::AbstractVector) = _logpdf(q.q_trans, y)
+function _logpdf(q::VILocationScale, z::AbstractVector{<:Real})
+    @unpack location, scale, dist = q
+    mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) + first(logabsdet(scale))
+end
 
-rand(q::LocationScale) = rand(q.q_trans)
+function rand(q::VILocationScale)
+    @unpack location, scale, dist = q
+    n_dims = length(location)
+    scale*rand(dist, n_dims) + location
+end
 
-rand(rng::Random.AbstractRNG, q::LocationScale, num_samples::Int) = rand(rng, q.q_trans, num_samples)
+function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) 
+    @unpack location, scale, dist = q
+    n_dims = length(location)
+    scale*rand(dist, n_dims, num_samples) .+ location
+end
 
-_rand!(rng::Random.AbstractRNG, q::LocationScale, x::AbstractVector{<:Real}) = _rand!(rng, q.q_trans, x)
+function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real})
+    @unpack location, scale, dist = q
+    rand!(rng, dist, x)
+    x .= scale*x
+    return x += location
+end
 
+function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real})
+    @unpack location, scale, dist = q
+    rand!(rng, dist, x)
+    x *= scale
+    return x += location
+end
 
-function FullRankGaussian(μ::AbstractVector{T},
-                          L::AbstractTriangular{T,S}) where {T <: Real, S}
-    @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2))
-    n_dims = length(μ)
-    q_base = MvNormal(FillArrays.Zeros{T}(n_dims), PDMats.ScalMat{T}(n_dims, one(T)))
-    LocationScale(μ, L, q_base)
+function VIFullRankGaussian(μ::AbstractVector{T},
+                            L::AbstractTriangular{T},
+                            epsilon::Real = eps(T)) where {T <: Real}
+    q_base = Normal{T}(zero(T), one(T))
+    VILocationScale(μ, L, q_base, epsilon)
 end
 
-function MeanFieldGaussian(μ::AbstractVector{T},
-                           L::Diagonal{T,V}) where {T <: Real, V}
-    @assert (length(μ) == size(L,1))
-    n_dims = length(μ)
-    q_base = MvNormal(FillArrays.Zeros{T}(n_dims), PDMats.ScalMat{T}(n_dims, one(T)))
-    LocationScale(μ, L, q_base)
+function VIMeanFieldGaussian(μ::AbstractVector{T},
+                             L::Diagonal{T},
+                             epsilon::Real = eps(T)) where {T <: Real}
+    q_base = Normal{T}(zero(T), one(T))
+    VILocationScale(μ, L, q_base, epsilon)
 end

From a034ebdec0e42d63211fe8e1c23d4b4e714a30bb Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 13 Jul 2023 00:50:33 +0100
Subject: [PATCH 032/206] fix RNG namespace

---
 src/objectives/elbo/advi.jl | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index dc2962eeb..311a94f30 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -30,9 +30,9 @@ function ADVI(ℓπ, b⁻¹, n_samples::Int)
 end
 
 function (advi::ADVI)(q_η::ContinuousMultivariateDistribution;
-                      rng       ::Random.AbstractRNG = Random.default_rng(),
-                      n_samples ::Int                = advi.n_samples,
-                      ηs        ::AbstractMatrix     = rand(rng, q_η, n_samples),
+                      rng       ::AbstractRNG    = default_rng(),
+                      n_samples ::Int            = advi.n_samples,
+                      ηs        ::AbstractMatrix = rand(rng, q_η, n_samples),
                       q_η_entropy::ContinuousMultivariateDistribution = q_η)
     𝔼ℓ = advi.energy_estimator(q_η, ηs)
     ℍ  = advi.entropy_estimator(q_η_entropy, ηs)
@@ -40,7 +40,7 @@ function (advi::ADVI)(q_η::ContinuousMultivariateDistribution;
 end
 
 function estimate_gradient(
-    rng::Random.AbstractRNG,
+    rng::AbstractRNG,
     advi::ADVI,
     est_state,
     λ::Vector{<:Real},

From e19abd3d06291090f45b4b8b118e7be3003343c5 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 13 Jul 2023 03:08:46 +0100
Subject: [PATCH 033/206] fix location scale logpdf bug

---
 src/distributions/location_scale.jl | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index c46b5111f..c1803ffef 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -23,19 +23,19 @@ Base.length(q::VILocationScale) = length(q.location)
 Base.size(q::VILocationScale) = size(q.location)
 
 function StatsBase.entropy(q::VILocationScale)
-    @unpack location, scale, dist = q
+    @unpack  location, scale, dist = q
     n_dims = length(location)
     n_dims*entropy(dist) + first(logabsdet(scale))
 end
 
 function logpdf(q::VILocationScale, z::AbstractVector{<:Real})
     @unpack location, scale, dist = q
-    mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) + first(logabsdet(scale))
+    mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) - first(logabsdet(scale))
 end
 
 function _logpdf(q::VILocationScale, z::AbstractVector{<:Real})
     @unpack location, scale, dist = q
-    mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) + first(logabsdet(scale))
+    mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) - first(logabsdet(scale))
 end
 
 function rand(q::VILocationScale)

From 680c1864ecfe2a2867e9f48fe4bbf1ca37065aa3 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 13 Jul 2023 03:12:19 +0100
Subject: [PATCH 034/206] add Accessors dependency

---
 Project.toml | 1 +
 1 file changed, 1 insertion(+)

diff --git a/Project.toml b/Project.toml
index d2708915f..add1e3912 100644
--- a/Project.toml
+++ b/Project.toml
@@ -3,6 +3,7 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
 version = "0.2.3"
 
 [deps]
+Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
 Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"

From 4c6cabf688af0552a307c22b821901cc792676be Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 13 Jul 2023 03:12:44 +0100
Subject: [PATCH 035/206] add location scale, autodiff tests

---
 test/ad.jl            | 22 +++++++++++++++++++++
 test/distributions.jl | 45 +++++++++++++++++++++++++++++++++++++++++++
 test/runtests.jl      | 32 ++++++++----------------------
 3 files changed, 75 insertions(+), 24 deletions(-)
 create mode 100644 test/ad.jl
 create mode 100644 test/distributions.jl

diff --git a/test/ad.jl b/test/ad.jl
new file mode 100644
index 000000000..c084165ce
--- /dev/null
+++ b/test/ad.jl
@@ -0,0 +1,22 @@
+
+using ReTest
+using ForwardDiff, ReverseDiff, Tracker, Enzyme, Zygote
+using AdvancedVI: grad!
+
+@testset "ad" begin
+    @testset "$(string(adsymbol))" for adsymbol ∈ [
+        :forwarddiff, :reversediff, :tracker, :enzyme, :zygote]
+        D = 10
+        A = randn(D, D)
+        λ = randn(D)
+        AdvancedVI.setadbackend(adsymbol)
+        grad_buf = DiffResults.GradientResult(λ)
+        AdvancedVI.grad!(AdvancedVI.ADBackend(), λ, grad_buf) do λ′
+            λ′'*A*λ′ / 2
+        end
+        ∇ = DiffResults.gradient(grad_buf)
+        f = DiffResults.value(grad_buf)
+        @test ∇ ≈ (A + A')*λ/2
+        @test f ≈ λ'*A*λ / 2
+    end
+end
diff --git a/test/distributions.jl b/test/distributions.jl
new file mode 100644
index 000000000..ab9617aa8
--- /dev/null
+++ b/test/distributions.jl
@@ -0,0 +1,45 @@
+
+using ReTest
+using Distributions
+using Distributions: _logpdf
+using LinearAlgebra
+using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian
+
+@testset "distributions" begin
+    @testset "$(string(covtype)) Gaussian $(realtype)" for
+        covtype  = [:diagonal, :fullrank],
+        realtype = [Float32,     Float64]
+
+        realtype     = Float64
+        ϵ            = 1e-2
+        n_dims       = 10
+        n_montecarlo = 1000_000
+
+        μ  = randn(realtype, n_dims)
+        L₀ = randn(realtype, n_dims, n_dims)
+        Σ  = if covtype == :fullrank
+            Σ = (L₀*L₀' + ϵ*I) |> Hermitian
+        else
+            Diagonal(exp.(randn(realtype, n_dims)))
+        end
+
+        L = cholesky(Σ).L
+        q = if covtype == :fullrank
+            VIFullRankGaussian(μ, L |> LowerTriangular)
+        else
+            VIMeanFieldGaussian(μ, L |> Diagonal)
+        end
+        q_true = MvNormal(μ, Σ)
+
+        z = randn(n_dims)
+        @test logpdf(q, z)  ≈ logpdf(q_true, z)
+        @test _logpdf(q, z) ≈ _logpdf(q_true, z)
+        @test entropy(q)    ≈ entropy(q_true)
+
+        z_samples  = rand(q, n_montecarlo)
+        threesigma = L
+        @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ       rtol=realtype(1e-2)
+        @test dropdims(var(z_samples, dims=2),  dims=2) ≈ diag(Σ) rtol=realtype(1e-2)
+        @test cov(z_samples, dims=2)                    ≈ Σ       rtol=realtype(1e-2)
+    end
+end
diff --git a/test/runtests.jl b/test/runtests.jl
index a305c25e5..440741974 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,28 +1,12 @@
-using Test
-using Distributions, DistributionsAD
-using AdvancedVI
 
-include("optimisers.jl")
+using ReTest: @testset, @test
+#using Random
+#using Statistics
+#using Distributions, DistributionsAD
 
-target = MvNormal(ones(2))
-logπ(z) = logpdf(target, z)
-advi = ADVI(10, 1000)
+println("Environment variables for testing")
+println(ENV)
 
-# Using a function z ↦ q(⋅∣z)
-getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4]))
-q = vi(logπ, advi, getq, randn(4))
-
-xs = rand(target, 10)
-@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05
-
-# OR: implement `update` and pass a `Distribution`
-function AdvancedVI.update(d::TuringDiagMvNormal, θ::AbstractArray{<:Real})
-    return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[length(q) + 1:end]))
-end
-
-q0 = TuringDiagMvNormal(zeros(2), ones(2))
-q = vi(logπ, advi, q0, randn(4))
-
-xs = rand(target, 10)
-@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05
+include("ad.jl")
+include("distributions.jl")
 

From 06db2f02233e8e4e6010be6473ea7f356742a4a3 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 13 Jul 2023 03:15:03 +0100
Subject: [PATCH 036/206] add Accessors import statement

---
 src/AdvancedVI.jl | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 76c6d859d..5800cd93d 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -1,7 +1,7 @@
 
 module AdvancedVI
 
-using UnPack
+using UnPack, Accessors
 
 import Random: AbstractRNG, default_rng
 import Distributions: logpdf, _logpdf, rand, _rand!, _rand!
@@ -179,6 +179,7 @@ include("objectives/elbo/advi_energy.jl")
 include("objectives/elbo/entropy.jl")
 
 # Variational Families
+
 include("distributions/location_scale.jl")
 
 include("utils.jl")

From 12de2bda787624b862772fc0b4fa55729ebb6ff9 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 13 Jul 2023 20:12:48 +0100
Subject: [PATCH 037/206] remove optimiser tests

---
 test/optimisers.jl | 17 -----------------
 1 file changed, 17 deletions(-)
 delete mode 100644 test/optimisers.jl

diff --git a/test/optimisers.jl b/test/optimisers.jl
deleted file mode 100644
index fae652ed0..000000000
--- a/test/optimisers.jl
+++ /dev/null
@@ -1,17 +0,0 @@
-using Random, Test, LinearAlgebra, ForwardDiff
-using AdvancedVI: TruncatedADAGrad, DecayedADAGrad, apply!
-
-θ = randn(10, 10)
-@testset for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)]
-    θ_fit = randn(10, 10)
-    loss(x, θ_) = mean(sum(abs2, θ*x - θ_*x; dims = 1))
-    for t = 1:10^4
-        x = rand(10)
-        Δ = ForwardDiff.gradient(θ_ -> loss(x, θ_), θ_fit)
-        Δ = apply!(opt, θ_fit, Δ)
-        @. θ_fit = θ_fit - Δ
-    end
-    @test loss(rand(10, 100), θ_fit) < 0.01
-    @test length(opt.acc) == 1
-end
-

From bbb2cc649fce6caddb751d0e5743d2fc2a814ad2 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 13 Jul 2023 20:12:59 +0100
Subject: [PATCH 038/206] refactor slightly generalize the distribution tests
 for the future

---
 test/distributions.jl | 13 ++++++++-----
 1 file changed, 8 insertions(+), 5 deletions(-)

diff --git a/test/distributions.jl b/test/distributions.jl
index ab9617aa8..07b3efdfe 100644
--- a/test/distributions.jl
+++ b/test/distributions.jl
@@ -6,8 +6,9 @@ using LinearAlgebra
 using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian
 
 @testset "distributions" begin
-    @testset "$(string(covtype)) Gaussian $(realtype)" for
-        covtype  = [:diagonal, :fullrank],
+    @testset "$(string(covtype)) $(basedist) $(realtype)" for
+        basedist = [:gaussian],
+        covtype  = [:meanfield, :fullrank],
         realtype = [Float32,     Float64]
 
         realtype     = Float64
@@ -24,12 +25,14 @@ using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian
         end
 
         L = cholesky(Σ).L
-        q = if covtype == :fullrank
+        q = if covtype == :fullrank  && basedist == :gaussian
             VIFullRankGaussian(μ, L |> LowerTriangular)
-        else
+        elseif covtype == :meanfield && basedist == :gaussian
             VIMeanFieldGaussian(μ, L |> Diagonal)
         end
-        q_true = MvNormal(μ, Σ)
+        q_true = if basedist == :gaussian
+            MvNormal(μ, Σ)
+        end
 
         z = randn(n_dims)
         @test logpdf(q, z)  ≈ logpdf(q_true, z)

From 197484655468ec5bab362380fb58d896a082b150 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 13 Jul 2023 23:10:51 +0100
Subject: [PATCH 039/206] migrate to SimpleUnPack, migrate to ADTypes

---
 Project.toml              |   3 +-
 src/AdvancedVI.jl         | 150 ++++++++++----------------------------
 src/ad.jl                 |  46 ------------
 src/compat/enzyme.jl      |  19 ++++-
 src/compat/reversediff.jl |  21 +++---
 src/compat/zygote.jl      |  16 +++-
 test/ad.jl                |  14 ++--
 7 files changed, 90 insertions(+), 179 deletions(-)
 delete mode 100644 src/ad.jl

diff --git a/Project.toml b/Project.toml
index 93e3a52ac..2fcc845e8 100644
--- a/Project.toml
+++ b/Project.toml
@@ -3,6 +3,7 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
 version = "0.2.4"
 
 [deps]
+ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
 Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
 Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -18,10 +19,10 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
 ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
+SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
 StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
 Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
-UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
 
 [compat]
 Bijectors = "0.11, 0.12, 0.13"
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 5800cd93d..573f7179b 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -1,7 +1,8 @@
 
 module AdvancedVI
 
-using UnPack, Accessors
+using SimpleUnPack: @unpack
+using Accessors
 
 import Random: AbstractRNG, default_rng
 import Distributions: logpdf, _logpdf, rand, _rand!, _rand!
@@ -17,6 +18,8 @@ using LinearAlgebra: AbstractTriangular
 
 using LogDensityProblems
 
+using ADTypes
+using ADTypes: AbstractADType
 using ForwardDiff, Tracker
 
 using FillArrays
@@ -30,78 +33,19 @@ using StatsBase: entropy
 
 const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0")))
 
-include("ad.jl")
-
 using Requires
 function __init__()
     @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
         include("compat/zygote.jl")
-        export ZygoteAD
-
-        function AdvancedVI.grad!(
-            f::Function,
-            ::Type{<:ZygoteAD},
-            λ::AbstractVector{<:Real},
-            out::DiffResults.MutableDiffResult,
-        )
-            y, back = Zygote.pullback(f, λ)
-            dy = first(back(1.0))
-            DiffResults.value!(out, y)
-            DiffResults.gradient!(out, dy)
-            return out
-        end
     end
     @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
         include("compat/reversediff.jl")
-        export ReverseDiffAD
-
-        function AdvancedVI.grad!(
-            f::Function,
-            ::Type{<:ReverseDiffAD},
-            λ::AbstractVector{<:Real},
-            out::DiffResults.MutableDiffResult,
-        )
-            tp = AdvancedVI.tape(f, λ)
-            ReverseDiff.gradient!(out, tp, λ)
-            return out
-        end
     end
     @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
         include("compat/enzyme.jl")
-        export EnzymeAD
-
-        function AdvancedVI.grad!(
-            f::Function,
-            ::Type{<:EnzymeAD},
-            λ::AbstractVector{<:Real},
-            out::DiffResults.MutableDiffResult,
-        )
-            # Use `Enzyme.ReverseWithPrimal` once it is released:
-            # https://github.com/EnzymeAD/Enzyme.jl/pull/598
-            y = f(λ)
-            DiffResults.value!(out, y)
-            dy = DiffResults.gradient(out)
-            fill!(dy, 0)
-            Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy))
-            return out
-        end
     end
 end
 
-export
-    optimize,
-    ELBO,
-    ADVI,
-    ADVIEnergy,
-    ClosedFormEntropy,
-    MonteCarloEntropy,
-    LocationScale,
-    FullRankGaussian,
-    MeanFieldGaussian,
-    TruncatedADAGrad,
-    DecayedADAGrad
-
-
 """
     grad!(f, λ, out)
 
@@ -111,55 +55,7 @@ This implicitly also gives a default implementation of `optimize!`.
 """
 function grad! end
 
-"""
-    optimize(model, alg::VariationalInference)
-    optimize(model, alg::VariationalInference, q::VariationalPosterior)
-    optimize(model, alg::VariationalInference, getq::Function, θ::AbstractArray)
-
-Constructs the variational posterior from the `model` and performs the optimization
-following the configuration of the given `VariationalInference` instance.
-
-# Arguments
-- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations
-- `alg`: the VI algorithm used
-- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists.
-- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior`
-- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior
-"""
-function optimize end
-
-function update end
-
-# default implementations
-function grad!(
-    f::Function,
-    adtype::Type{<:ForwardDiffAD},
-    λ::AbstractVector{<:Real},
-    out::DiffResults.MutableDiffResult
-)
-    # Set chunk size and do ForwardMode.
-    chunk_size = getchunksize(adtype)
-    config = if chunk_size == 0
-        ForwardDiff.GradientConfig(f, λ)
-    else
-        ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size))
-    end
-    ForwardDiff.gradient!(out, f, λ, config)
-end
-
-function grad!(
-    f::Function,
-    ::Type{<:TrackerAD},
-    λ::AbstractVector{<:Real},
-    out::DiffResults.MutableDiffResult
-)
-    λ_tracked = Tracker.param(λ)
-    y = f(λ_tracked)
-    Tracker.back!(y, 1.0)
-
-    DiffResults.value!(out, Tracker.data(y))
-    DiffResults.gradient!(out, Tracker.grad(λ_tracked))
-end
+include("grad.jl")
 
 # estimators
 abstract type AbstractVariationalObjective end
@@ -170,6 +66,9 @@ abstract type AbstractEnergyEstimator  end
 abstract type AbstractEntropyEstimator end
 abstract type AbstractControlVariate end
 
+function init   end
+function update end
+
 init(::Nothing) = nothing
 
 update(::Nothing, ::Nothing) = (nothing, nothing)
@@ -178,11 +77,42 @@ include("objectives/elbo/advi.jl")
 include("objectives/elbo/advi_energy.jl")
 include("objectives/elbo/entropy.jl")
 
+export
+    ELBO,
+    ADVI,
+    ADVIEnergy,
+    ClosedFormEntropy,
+    MonteCarloEntropy
+
 # Variational Families
 
 include("distributions/location_scale.jl")
 
+export
+    VIFullRankGaussian,
+    VIMeanFieldGaussian
+
+"""
+    optimize(model, alg::VariationalInference)
+    optimize(model, alg::VariationalInference, q::VariationalPosterior)
+    optimize(model, alg::VariationalInference, getq::Function, θ::AbstractArray)
+
+Constructs the variational posterior from the `model` and performs the optimization
+following the configuration of the given `VariationalInference` instance.
+
+# Arguments
+- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations
+- `alg`: the VI algorithm used
+- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists.
+- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior`
+- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior
+"""
+function optimize end
+
+include("optimize.jl")
+
+export optimize
+
 include("utils.jl")
-include("vi.jl")
 
 end # module
diff --git a/src/ad.jl b/src/ad.jl
deleted file mode 100644
index 62e785e1b..000000000
--- a/src/ad.jl
+++ /dev/null
@@ -1,46 +0,0 @@
-##############################
-# Global variables/constants #
-##############################
-const ADBACKEND = Ref(:forwarddiff)
-setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym))
-function setadbackend(::Val{:forward_diff})
-    Base.depwarn("`AdvancedVI.setadbackend(:forward_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend)
-    setadbackend(Val(:forwarddiff))
-end
-function setadbackend(::Val{:forwarddiff})
-    ADBACKEND[] = :forwarddiff
-end
-
-function setadbackend(::Val{:reverse_diff})
-    Base.depwarn("`AdvancedVI.setadbackend(:reverse_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:tracker)` to use `Tracker` or `AdvancedVI.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.",  :setadbackend)
-    setadbackend(Val(:tracker))
-end
-function setadbackend(::Val{:tracker})
-    ADBACKEND[] = :tracker
-end
-
-const ADSAFE = Ref(false)
-function setadsafe(switch::Bool)
-    @info("[AdvancedVI]: global ADSAFE is set as $switch")
-    ADSAFE[] = switch
-end
-
-const CHUNKSIZE = Ref(0) # 0 means letting ForwardDiff set it automatically
-
-function setchunksize(chunk_size::Int)
-    @info("[AdvancedVI]: AD chunk size is set as $chunk_size")
-    CHUNKSIZE[] = chunk_size
-end
-
-abstract type ADBackend end
-struct ForwardDiffAD{chunk} <: ADBackend end
-getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk
-
-struct TrackerAD <: ADBackend end
-
-ADBackend() = ADBackend(ADBACKEND[])
-ADBackend(T::Symbol) = ADBackend(Val(T))
-
-ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]}
-ADBackend(::Val{:tracker}) = TrackerAD
-ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")
diff --git a/src/compat/enzyme.jl b/src/compat/enzyme.jl
index c6bb9ac39..cab50862e 100644
--- a/src/compat/enzyme.jl
+++ b/src/compat/enzyme.jl
@@ -1,5 +1,16 @@
-struct EnzymeAD <: ADBackend end
-ADBackend(::Val{:enzyme}) = EnzymeAD
-function setadbackend(::Val{:enzyme})
-    ADBACKEND[] = :enzyme
+
+function AdvancedVI.grad!(
+    f::Function,
+    ::AutoEnzyme,
+    λ::AbstractVector{<:Real},
+    out::DiffResults.MutableDiffResult,
+    )
+    # Use `Enzyme.ReverseWithPrimal` once it is released:
+    # https://github.com/EnzymeAD/Enzyme.jl/pull/598
+    y = f(λ)
+    DiffResults.value!(out, y)
+    dy = DiffResults.gradient(out)
+    fill!(dy, 0)
+    Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy))
+    return out
 end
diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl
index 721d03618..4d8f87d8c 100644
--- a/src/compat/reversediff.jl
+++ b/src/compat/reversediff.jl
@@ -1,16 +1,19 @@
 using .ReverseDiff: compile, GradientTape
 using .ReverseDiff.DiffResults: GradientResult
 
-struct ReverseDiffAD{cache} <: ADBackend end
-const RDCache = Ref(false)
-setcache(b::Bool) = RDCache[] = b
-getcache() = RDCache[]
-ADBackend(::Val{:reversediff}) = ReverseDiffAD{getcache()}
-function setadbackend(::Val{:reversediff})
-    ADBACKEND[] = :reversediff
-end
-
 tape(f, x) = GradientTape(f, x)
 function taperesult(f, x)
     return tape(f, x), GradientResult(x)
 end
+
+# Precompiled tapes are not properly supported yet.
+function AdvancedVI.grad!(
+    f::Function,
+    ::AutoReverseDiff,
+    λ::AbstractVector{<:Real},
+    out::DiffResults.MutableDiffResult,
+    )
+    tp = tape(f, λ)
+    ReverseDiff.gradient!(out, tp, λ)
+    return out
+end
diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl
index 40022e215..f1a29b87f 100644
--- a/src/compat/zygote.jl
+++ b/src/compat/zygote.jl
@@ -1,5 +1,13 @@
-struct ZygoteAD <: ADBackend end
-ADBackend(::Val{:zygote}) = ZygoteAD
-function setadbackend(::Val{:zygote})
-    ADBACKEND[] = :zygote
+
+function AdvancedVI.grad!(
+    f::Function,
+    ::AutoZygote,
+    λ::AbstractVector{<:Real},
+    out::DiffResults.MutableDiffResult,
+    )
+    y, back = Zygote.pullback(f, λ)
+    dy = first(back(1.0))
+    DiffResults.value!(out, y)
+    DiffResults.gradient!(out, dy)
+    return out
 end
diff --git a/test/ad.jl b/test/ad.jl
index c084165ce..6b587598e 100644
--- a/test/ad.jl
+++ b/test/ad.jl
@@ -1,17 +1,21 @@
 
 using ReTest
 using ForwardDiff, ReverseDiff, Tracker, Enzyme, Zygote
-using AdvancedVI: grad!
+using ADTypes
 
 @testset "ad" begin
-    @testset "$(string(adsymbol))" for adsymbol ∈ [
-        :forwarddiff, :reversediff, :tracker, :enzyme, :zygote]
+    @testset "$(adname)" for (adname, adsymbol) ∈ Dict(
+          :ForwardDiffAuto => AutoForwardDiff(),
+          :ForwardDiff     => AutoForwardDiff(10),
+          :ReverseDiff     => AutoReverseDiff(),
+          :Zygote          => AutoZygote(),
+          :Tracker         => AutoTracker(),
+        )
         D = 10
         A = randn(D, D)
         λ = randn(D)
-        AdvancedVI.setadbackend(adsymbol)
         grad_buf = DiffResults.GradientResult(λ)
-        AdvancedVI.grad!(AdvancedVI.ADBackend(), λ, grad_buf) do λ′
+        AdvancedVI.grad!(adsymbol, λ, grad_buf) do λ′
             λ′'*A*λ′ / 2
         end
         ∇ = DiffResults.gradient(grad_buf)

From 19c62c888fafbed9271e66cf1c7ced7b11a90457 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 13 Jul 2023 23:11:12 +0100
Subject: [PATCH 040/206] rename vi.jl to optimize.jl

---
 src/{vi.jl => optimize.jl} | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)
 rename src/{vi.jl => optimize.jl} (89%)

diff --git a/src/vi.jl b/src/optimize.jl
similarity index 89%
rename from src/vi.jl
rename to src/optimize.jl
index 842f187eb..071849007 100644
--- a/src/vi.jl
+++ b/src/optimize.jl
@@ -19,6 +19,7 @@ function optimize(
     progress  ::Bool                    = true,
     callback!                           = nothing,
     terminate                           = (args...) -> false,
+    adback::AbstractADType              = AutoForwardDiff(), 
 )
     opt_state = Optimisers.init(optimizer, λ)
     est_state = init(objective)
@@ -33,7 +34,8 @@ function optimize(
     for t = 1:n_max_iter
         stat = (iteration=t,)
 
-        grad_buf, est_state, stat′ = estimate_gradient(rng, objective, est_state, λ, restructure, grad_buf)
+        grad_buf, est_state, stat′ = estimate_gradient(
+            rng, adback, objective, est_state, λ, restructure, grad_buf)
         g    = DiffResults.gradient(grad_buf)
         stat = merge(stat, stat′)
 
@@ -51,6 +53,9 @@ function optimize(
         
         AdvancedVI.DEBUG && @debug "Step $t" stat...
 
+        q    = project_domain(q)
+        λ, _ = Optimisers.destructure(q)
+
         pm_next!(prog, stat)
         stats[t] = stat
 

From 63da51de8870575971b8e70e28dfc6c2265c5e30 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 13 Jul 2023 23:11:25 +0100
Subject: [PATCH 041/206] fix estimate_gradient to use adtypes

---
 src/objectives/elbo/advi.jl | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 311a94f30..ed834273d 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -41,6 +41,7 @@ end
 
 function estimate_gradient(
     rng::AbstractRNG,
+    adback::AbstractADType,
     advi::ADVI,
     est_state,
     λ::Vector{<:Real},
@@ -50,7 +51,7 @@ function estimate_gradient(
     # Gradient-stopping for computing the sticking-the-landing control variate
     q_η_stop = skip_entropy_gradient(advi.entropy_estimator) ? restructure(λ) : nothing
 
-    grad!(ADBackend(), λ, out) do λ′
+    grad!(adback, λ, out) do λ′
         q_η = restructure(λ′)
         q_η_entropy = skip_entropy_gradient(advi.entropy_estimator) ? q_η_stop : q_η
         -advi(q_η; rng, q_η_entropy)

From 65ab47395fa4fe88b6b65323325c68b5c0ee078a Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 13 Jul 2023 23:17:20 +0100
Subject: [PATCH 042/206] add exact inference tests

---
 test/distributions.jl         |  5 +--
 test/exact.jl                 | 64 +++++++++++++++++++++++++++++++++++
 test/exact/normallognormal.jl | 52 ++++++++++++++++++++++++++++
 test/runtests.jl              | 13 +++----
 4 files changed, 124 insertions(+), 10 deletions(-)
 create mode 100644 test/exact.jl
 create mode 100644 test/exact/normallognormal.jl

diff --git a/test/distributions.jl b/test/distributions.jl
index 07b3efdfe..074cad7cc 100644
--- a/test/distributions.jl
+++ b/test/distributions.jl
@@ -1,9 +1,6 @@
 
 using ReTest
-using Distributions
 using Distributions: _logpdf
-using LinearAlgebra
-using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian
 
 @testset "distributions" begin
     @testset "$(string(covtype)) $(basedist) $(realtype)" for
@@ -17,7 +14,7 @@ using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian
         n_montecarlo = 1000_000
 
         μ  = randn(realtype, n_dims)
-        L₀ = randn(realtype, n_dims, n_dims)
+        L₀ = randn(realtype, n_dims, n_dims) |> LowerTriangular
         Σ  = if covtype == :fullrank
             Σ = (L₀*L₀' + ϵ*I) |> Hermitian
         else
diff --git a/test/exact.jl b/test/exact.jl
new file mode 100644
index 000000000..27b92c04f
--- /dev/null
+++ b/test/exact.jl
@@ -0,0 +1,64 @@
+
+const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
+
+using ReTest
+using Turing, LogDensityProblems
+using Optimisers
+using Distributions
+using LinearAlgebra
+using SimpleUnPack: @unpack
+
+struct TestModel{M,L,S}
+    model::M
+    μ_true::L
+    L_true::S
+    n_dims::Int
+    is_meanfield::Bool
+end
+
+include("inference/normallognormal.jl")
+
+@testset "exact" begin
+    @testset "$(modelname) $(realtype)"  for
+        realtype ∈ [Float32, Float64],
+        (modelname, modelconstr) ∈ Dict(
+            :NormalLogNormalMeanField => normallognormal_meanfield,
+            :NormalLogNormalFullRank  => normallognormal_fullrank,
+        )
+
+        T = 10000
+        modelstats = modelconstr(realtype)
+        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+        b    = Bijectors.bijector(model)
+        b⁻¹  = inverse(b)
+        prob = DynamicPPL.LogDensityFunction(model)
+
+        μ₀ = zeros(realtype, n_dims)
+        L₀ = if is_meanfield
+            ones(realtype, n_dims) |> Diagonal
+        else
+            diagm(ones(realtype, n_dims)) |> LowerTriangular
+        end
+        q₀ = if is_meanfield
+            AdvancedVI.VIMeanFieldGaussian(μ₀, L₀, realtype(1e-8))
+        else
+            AdvancedVI.VIFullRankGaussian(μ₀, L₀, realtype(1e-8))
+        end
+
+        Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
+
+        objective = AdvancedVI.ADVI(prob, b⁻¹, 10)
+        q, stats  = AdvancedVI.optimize(
+            objective, q₀, T;
+            optimizer = Optimisers.AdaGrad(1e-1),
+            progress  = PROGRESS,
+        )
+
+        μ  = q.location
+        L  = q.scale
+        Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
+
+        @test Δλ ≤ Δλ₀/√T
+    end
+end
diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl
new file mode 100644
index 000000000..4e9e14046
--- /dev/null
+++ b/test/exact/normallognormal.jl
@@ -0,0 +1,52 @@
+
+function normallognormal_fullrank(realtype; rng = default_rng())
+    n_dims = 5
+
+    μ_x  = randn(rng, realtype)
+    σ_x  = π
+    μ_y  = randn(rng, realtype, n_dims)
+    L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular
+    ϵ    = realtype(1.0)
+    Σ_y  = (L₀_y*L₀_y' + ϵ*I) |> Hermitian
+
+    Turing.@model function normallognormal()
+        x ~ LogNormal(μ_x, σ_x)
+        y ~ MvNormal(μ_y, Σ_y)
+    end
+    model = normallognormal()
+
+    Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1)
+    Σ[1,1]         = σ_x^2
+    Σ[2:end,2:end] = Σ_y
+    Σ = Σ |> Hermitian
+
+    μ = vcat(μ_x, μ_y)
+    L = cholesky(Σ).L |> LowerTriangular
+
+    TestModel(model, μ, L, n_dims+1, false)
+end
+
+function normallognormal_meanfield(realtype)
+    n_dims = 5
+
+    μ_x  = randn(realtype)
+    σ_x  = π
+    μ_y  = randn(realtype, n_dims)
+    ϵ    = realtype(1.0)
+    Σ_y  = Diagonal(exp.(randn(realtype, n_dims)))
+
+    Turing.@model function normallognormal()
+        x ~ LogNormal(μ_x, σ_x)
+        y ~ MvNormal(μ_y, Σ_y)
+    end
+    model = normallognormal()
+
+    σ²        = Vector{realtype}(undef, n_dims+1)
+    σ²[1]     = σ_x^2
+    σ²[2:end] = diag(Σ_y)
+
+    μ = vcat(μ_x, μ_y)
+    L = sqrt.(σ²) |> Diagonal
+
+    TestModel(model, μ, L, n_dims+1, true)
+end
diff --git a/test/runtests.jl b/test/runtests.jl
index 440741974..26f9a06fe 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,12 +1,13 @@
 
 using ReTest: @testset, @test
-#using Random
-#using Statistics
-#using Distributions, DistributionsAD
-
-println("Environment variables for testing")
-println(ENV)
+using Random
+using Random: default_rng
+using Statistics
+using Distributions, DistributionsAD
+using LinearAlgebra
+using AdvancedVI
 
 include("ad.jl")
 include("distributions.jl")
+include("exact.jl")
 

From 3e5a4520835f0d182b8f7c4aaef0529ff37498e6 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 14 Jul 2023 00:28:18 +0100
Subject: [PATCH 043/206] remove Turing dependency in tests

---
 test/exact.jl                 |  9 ++++---
 test/exact/normallognormal.jl | 47 +++++++++++++++++++++++------------
 test/runtests.jl              |  9 ++++++-
 3 files changed, 44 insertions(+), 21 deletions(-)

diff --git a/test/exact.jl b/test/exact.jl
index 27b92c04f..d5283e8e9 100644
--- a/test/exact.jl
+++ b/test/exact.jl
@@ -2,9 +2,11 @@
 const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
 
 using ReTest
-using Turing, LogDensityProblems
+using Bijectors
+using LogDensityProblems
 using Optimisers
 using Distributions
+using PDMats
 using LinearAlgebra
 using SimpleUnPack: @unpack
 
@@ -16,7 +18,7 @@ struct TestModel{M,L,S}
     is_meanfield::Bool
 end
 
-include("inference/normallognormal.jl")
+include("exact/normallognormal.jl")
 
 @testset "exact" begin
     @testset "$(modelname) $(realtype)"  for
@@ -32,7 +34,6 @@ include("inference/normallognormal.jl")
 
         b    = Bijectors.bijector(model)
         b⁻¹  = inverse(b)
-        prob = DynamicPPL.LogDensityFunction(model)
 
         μ₀ = zeros(realtype, n_dims)
         L₀ = if is_meanfield
@@ -48,7 +49,7 @@ include("inference/normallognormal.jl")
 
         Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
 
-        objective = AdvancedVI.ADVI(prob, b⁻¹, 10)
+        objective = AdvancedVI.ADVI(model, b⁻¹, 10)
         q, stats  = AdvancedVI.optimize(
             objective, q₀, T;
             optimizer = Optimisers.AdaGrad(1e-1),
diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl
index 4e9e14046..e39ec2cb9 100644
--- a/test/exact/normallognormal.jl
+++ b/test/exact/normallognormal.jl
@@ -1,4 +1,31 @@
 
+struct NormalLogNormal{MX,SX,MY,SY}
+    μ_x::MX
+    σ_x::SX
+    μ_y::MY
+    Σ_y::SY
+end
+
+function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
+end
+
+function LogDensityProblems.dimension(model::NormalLogNormal)
+    length(model.μ_y) + 1
+end
+
+function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
+    LogDensityProblems.LogDensityOrder{0}()
+end
+
+function Bijectors.bijector(model::NormalLogNormal)
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    Bijectors.Stacked(
+        Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
+        [1:1, 2:1+length(μ_y)])
+end
+
 function normallognormal_fullrank(realtype; rng = default_rng())
     n_dims = 5
 
@@ -9,11 +36,7 @@ function normallognormal_fullrank(realtype; rng = default_rng())
     ϵ    = realtype(1.0)
     Σ_y  = (L₀_y*L₀_y' + ϵ*I) |> Hermitian
 
-    Turing.@model function normallognormal()
-        x ~ LogNormal(μ_x, σ_x)
-        y ~ MvNormal(μ_y, Σ_y)
-    end
-    model = normallognormal()
+    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y))
 
     Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1)
     Σ[1,1]         = σ_x^2
@@ -33,20 +56,12 @@ function normallognormal_meanfield(realtype)
     σ_x  = π
     μ_y  = randn(realtype, n_dims)
     ϵ    = realtype(1.0)
-    Σ_y  = Diagonal(exp.(randn(realtype, n_dims)))
-
-    Turing.@model function normallognormal()
-        x ~ LogNormal(μ_x, σ_x)
-        y ~ MvNormal(μ_y, Σ_y)
-    end
-    model = normallognormal()
+    σ_y  = exp.(randn(realtype, n_dims))
 
-    σ²        = Vector{realtype}(undef, n_dims+1)
-    σ²[1]     = σ_x^2
-    σ²[2:end] = diag(Σ_y)
+    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2))
 
     μ = vcat(μ_x, μ_y)
-    L = sqrt.(σ²) |> Diagonal
+    L = vcat(σ_x, σ_y) |> Diagonal
 
     TestModel(model, μ, L, n_dims+1, true)
 end
diff --git a/test/runtests.jl b/test/runtests.jl
index 26f9a06fe..0b86222b1 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,13 +1,20 @@
 
+using Comonicon
 using ReTest: @testset, @test
 using Random
 using Random: default_rng
 using Statistics
-using Distributions, DistributionsAD
+using Distributions
 using LinearAlgebra
 using AdvancedVI
 
+const GROUP = get(ENV, "AHMC_TEST_GROUP", "AdvancedHMC")
+
 include("ad.jl")
 include("distributions.jl")
 include("exact.jl")
 
+@main function runtests(patterns...; dry::Bool = false)
+    retest(patterns...; dry = dry, verbose = Inf)
+end
+

From 3117cec8952b80b58e205726f2abe9f77ffddf80 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 14 Jul 2023 02:44:22 +0100
Subject: [PATCH 044/206] remove unused projection

---
 src/optimize.jl | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 071849007..2acfbc0b4 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -43,7 +43,6 @@ function optimize(
         Optimisers.subtract!(λ, Δλ)
         stat′ = (iteration=t, Δλ=norm(Δλ), gradient_norm=norm(g))
         stat = merge(stat, stat′)
-
         q    = restructure(λ)
 
         if !isnothing(callback!)
@@ -53,9 +52,6 @@ function optimize(
         
         AdvancedVI.DEBUG && @debug "Step $t" stat...
 
-        q    = project_domain(q)
-        λ, _ = Optimisers.destructure(q)
-
         pm_next!(prog, stat)
         stats[t] = stat
 

From b1ca9cf5cfad2345c92481c7519b12e1520776ef Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 14 Jul 2023 03:03:57 +0100
Subject: [PATCH 045/206] remove redundant `ADVIEnergy` object (now baked into
 `ADVI`)

---
 src/AdvancedVI.jl                  |  2 +-
 src/objectives/elbo/advi.jl        | 38 ++++++++++++++++++++----------
 src/objectives/elbo/advi_energy.jl | 37 -----------------------------
 3 files changed, 26 insertions(+), 51 deletions(-)
 delete mode 100644 src/objectives/elbo/advi_energy.jl

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 573f7179b..502112c76 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -74,7 +74,6 @@ init(::Nothing) = nothing
 update(::Nothing, ::Nothing) = (nothing, nothing)
 
 include("objectives/elbo/advi.jl")
-include("objectives/elbo/advi_energy.jl")
 include("objectives/elbo/entropy.jl")
 
 export
@@ -82,6 +81,7 @@ export
     ADVI,
     ADVIEnergy,
     ClosedFormEntropy,
+    StickingTheLandingEntropy,
     MonteCarloEntropy
 
 # Variational Families
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index ed834273d..9cd2433ee 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -1,32 +1,41 @@
 
-struct ADVI{EnergyEst   <: AbstractEnergyEstimator,
+struct ADVI{Tlogπ, B,
             EntropyEst  <: AbstractEntropyEstimator,
             ControlVar  <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective
-    energy_estimator::EnergyEst
+    ℓπ::Tlogπ
+    b⁻¹::B
     entropy_estimator::EntropyEst
     control_variate::ControlVar
     n_samples::Int
+
+    function ADVI(prob, b⁻¹, entropy_estimator, control_variate, n_samples)
+        cap = LogDensityProblems.capabilities(prob)
+        if cap === nothing
+            throw(
+                ArgumentError(
+                    "The log density function does not support the LogDensityProblems.jl interface",
+                ),
+            )
+        end
+        ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob)
+        new{typeof(ℓπ), typeof(b⁻¹), typeof(entropy_estimator), typeof(control_variate)}(
+            ℓπ, b⁻¹, entropy_estimator, control_variate, n_samples
+        )
+    end
 end
 
 skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy_estimator)
 
 init(advi::ADVI) = init(advi.control_variate)
 
-Base.show(io::IO, advi::ADVI) = print(
-    io,
-    "ADVI(energy_estimator=$(advi.energy_estimator), " *
-    "entropy_estimator=$(advi.entropy_estimator), " *
-    "control_variate=$(advi.control_variate), " *
-    "n_samples=$(advi.n_samples))")
-
-function ADVI(energy_estimator::AbstractEnergyEstimator,
+function ADVI(ℓπ, b⁻¹,
               entropy_estimator::AbstractEntropyEstimator,
               n_samples::Int)
-    ADVI(energy_estimator, entropy_estimator, nothing, n_samples)
+    ADVI(ℓπ, b⁻¹, entropy_estimator, nothing, n_samples)
 end
 
 function ADVI(ℓπ, b⁻¹, n_samples::Int)
-    ADVI(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples)
+    ADVI(ℓπ, b⁻¹, ClosedFormEntropy(), nothing, n_samples)
 end
 
 function (advi::ADVI)(q_η::ContinuousMultivariateDistribution;
@@ -34,7 +43,10 @@ function (advi::ADVI)(q_η::ContinuousMultivariateDistribution;
                       n_samples ::Int            = advi.n_samples,
                       ηs        ::AbstractMatrix = rand(rng, q_η, n_samples),
                       q_η_entropy::ContinuousMultivariateDistribution = q_η)
-    𝔼ℓ = advi.energy_estimator(q_η, ηs)
+    𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ
+        zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b⁻¹, ηᵢ)
+        (advi.ℓπ(zᵢ) + logdetjacᵢ) / n_samples
+    end
     ℍ  = advi.entropy_estimator(q_η_entropy, ηs)
     𝔼ℓ + ℍ
 end
diff --git a/src/objectives/elbo/advi_energy.jl b/src/objectives/elbo/advi_energy.jl
deleted file mode 100644
index 078a157ed..000000000
--- a/src/objectives/elbo/advi_energy.jl
+++ /dev/null
@@ -1,37 +0,0 @@
-
-struct ADVIEnergy{Tlogπ, B} <: AbstractEnergyEstimator
-    # Automatic differentiation variational inference
-    # 
-    # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017).
-    # Automatic differentiation variational inference.
-    # Journal of machine learning research.
-
-    ℓπ::Tlogπ
-    b⁻¹::B
-
-    function ADVIEnergy(prob, b⁻¹)
-        # Could check whether the support of b⁻¹ and ℓπ match
-        cap = LogDensityProblems.capabilities(prob)
-        if cap === nothing
-            throw(
-                ArgumentError(
-                    "The log density function does not support the LogDensityProblems.jl interface",
-                ),
-            )
-        end
-        ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob)
-        new{typeof(ℓπ), typeof(b⁻¹)}(ℓπ, b⁻¹)
-    end
-end
-
-ADVIEnergy(prob) = ADVIEnergy(prob, identity)
-
-Base.show(io::IO, energy::ADVIEnergy) = print(io, "ADVIEnergy()")
-
-function (energy::ADVIEnergy)(q, ηs::AbstractMatrix)
-    n_samples = size(ηs, 2)
-    mapreduce(+, eachcol(ηs)) do ηᵢ
-        zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(energy.b⁻¹, ηᵢ)
-        (energy.ℓπ(zᵢ) + logdetjacᵢ) / n_samples
-    end
-end

From fcbb729378e3e4e16e6288a9336511f2b616b557 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 14 Jul 2023 03:04:21 +0100
Subject: [PATCH 046/206] add more tests, fix rng seed for tests

---
 test/exact.jl                 | 69 +++++++++++++++++++++++++++--------
 test/exact/normallognormal.jl | 15 ++++----
 test/runtests.jl              |  2 +-
 3 files changed, 61 insertions(+), 25 deletions(-)

diff --git a/test/exact.jl b/test/exact.jl
index d5283e8e9..637a95ed4 100644
--- a/test/exact.jl
+++ b/test/exact.jl
@@ -21,15 +21,22 @@ end
 include("exact/normallognormal.jl")
 
 @testset "exact" begin
-    @testset "$(modelname) $(realtype)"  for
+    @testset "$(modelname) $(objname) $(realtype)"  for
         realtype ∈ [Float32, Float64],
         (modelname, modelconstr) ∈ Dict(
             :NormalLogNormalMeanField => normallognormal_meanfield,
             :NormalLogNormalFullRank  => normallognormal_fullrank,
+        ),
+        (objname, objective) ∈ Dict(
+            :ADVIClosedFormEntropy  => (model, b⁻¹, M) -> ADVI(model, b⁻¹,                              M),
+            :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M),
+            :ADVIFullMonteCarlo     => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(),         M),
         )
-
+        seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
+        rng  = Philox4x(UInt64, seed, 8)
+        
         T = 10000
-        modelstats = modelconstr(realtype)
+        modelstats = modelconstr(realtype; rng)
         @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
 
         b    = Bijectors.bijector(model)
@@ -42,24 +49,54 @@ include("exact/normallognormal.jl")
             diagm(ones(realtype, n_dims)) |> LowerTriangular
         end
         q₀ = if is_meanfield
-            AdvancedVI.VIMeanFieldGaussian(μ₀, L₀, realtype(1e-8))
+            VIMeanFieldGaussian(μ₀, L₀, realtype(1e-8))
         else
-            AdvancedVI.VIFullRankGaussian(μ₀, L₀, realtype(1e-8))
+            VIFullRankGaussian(μ₀, L₀, realtype(1e-8))
         end
 
-        Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
+        obj = objective(model, b⁻¹, 10)
 
-        objective = AdvancedVI.ADVI(model, b⁻¹, 10)
-        q, stats  = AdvancedVI.optimize(
-            objective, q₀, T;
-            optimizer = Optimisers.AdaGrad(1e-1),
-            progress  = PROGRESS,
-        )
+        @testset "convergence" begin
+            Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
+            q, stats  = optimize(
+                obj, q₀, T;
+                optimizer = Optimisers.AdaGrad(1e-0),
+                progress  = PROGRESS,
+                rng       = rng,
+            )
 
-        μ  = q.location
-        L  = q.scale
-        Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
+            μ  = q.location
+            L  = q.scale
+            Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
+
+            @test Δλ ≤ Δλ₀/√T
+            @test eltype(μ) == eltype(μ_true)
+            @test eltype(L) == eltype(L_true)
+        end
 
-        @test Δλ ≤ Δλ₀/√T
+        @testset "determinism" begin
+            rng      = Philox4x(UInt64, seed, 8)
+            q, stats = optimize(
+                obj, q₀, T;
+                optimizer = Optimisers.AdaGrad(1e-2),
+                progress  = PROGRESS,
+                rng       = rng,
+            )
+            μ  = q.location
+            L  = q.scale
+
+            rng_repl = Philox4x(UInt64, seed, 8)
+            q, stats = optimize(
+                obj, q₀, T;
+                optimizer = Optimisers.AdaGrad(1e-2),
+                progress  = PROGRESS,
+                rng       = rng_repl,
+            )
+            μ_repl = q.location
+            L_repl = q.scale
+            @test μ == μ_repl
+            @test L == L_repl
+        end
     end
 end
+
diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl
index e39ec2cb9..7c5c000de 100644
--- a/test/exact/normallognormal.jl
+++ b/test/exact/normallognormal.jl
@@ -30,10 +30,10 @@ function normallognormal_fullrank(realtype; rng = default_rng())
     n_dims = 5
 
     μ_x  = randn(rng, realtype)
-    σ_x  = π
+    σ_x  = ℯ
     μ_y  = randn(rng, realtype, n_dims)
     L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular
-    ϵ    = realtype(1.0)
+    ϵ    = realtype(n_dims)
     Σ_y  = (L₀_y*L₀_y' + ϵ*I) |> Hermitian
 
     model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y))
@@ -49,14 +49,13 @@ function normallognormal_fullrank(realtype; rng = default_rng())
     TestModel(model, μ, L, n_dims+1, false)
 end
 
-function normallognormal_meanfield(realtype)
+function normallognormal_meanfield(realtype; rng = default_rng())
     n_dims = 5
 
-    μ_x  = randn(realtype)
-    σ_x  = π
-    μ_y  = randn(realtype, n_dims)
-    ϵ    = realtype(1.0)
-    σ_y  = exp.(randn(realtype, n_dims))
+    μ_x  = randn(rng, realtype)
+    σ_x  = ℯ
+    μ_y  = randn(rng, realtype, n_dims)
+    σ_y  = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
 
     model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2))
 
diff --git a/test/runtests.jl b/test/runtests.jl
index 0b86222b1..b571f8b81 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -2,7 +2,7 @@
 using Comonicon
 using ReTest: @testset, @test
 using Random
-using Random: default_rng
+using Random123
 using Statistics
 using Distributions
 using LinearAlgebra

From 0f6f6a429ba74e491943ad96fa52ff9f897cc862 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 14 Jul 2023 03:04:35 +0100
Subject: [PATCH 047/206] add more tests, fix seed for tests

---
 test/distributions.jl | 37 +++++++++++++++++++++++++------------
 1 file changed, 25 insertions(+), 12 deletions(-)

diff --git a/test/distributions.jl b/test/distributions.jl
index 074cad7cc..073fff644 100644
--- a/test/distributions.jl
+++ b/test/distributions.jl
@@ -8,17 +8,19 @@ using Distributions: _logpdf
         covtype  = [:meanfield, :fullrank],
         realtype = [Float32,     Float64]
 
+        seed         = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
+        rng          = Philox4x(UInt64, seed, 8)
         realtype     = Float64
         ϵ            = 1e-2
         n_dims       = 10
         n_montecarlo = 1000_000
 
-        μ  = randn(realtype, n_dims)
-        L₀ = randn(realtype, n_dims, n_dims) |> LowerTriangular
+        μ  = randn(rng, realtype, n_dims)
+        L₀ = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular
         Σ  = if covtype == :fullrank
             Σ = (L₀*L₀' + ϵ*I) |> Hermitian
         else
-            Diagonal(exp.(randn(realtype, n_dims)))
+            Diagonal(log.(exp.(randn(rng, realtype, n_dims)) .+ 1))
         end
 
         L = cholesky(Σ).L
@@ -31,15 +33,26 @@ using Distributions: _logpdf
             MvNormal(μ, Σ)
         end
 
-        z = randn(n_dims)
-        @test logpdf(q, z)  ≈ logpdf(q_true, z)
-        @test _logpdf(q, z) ≈ _logpdf(q_true, z)
-        @test entropy(q)    ≈ entropy(q_true)
+        @testset "logpdf" begin
+            z = randn(rng, realtype, n_dims)
+            @test logpdf(q, z)  ≈ logpdf(q_true, z)
+            @test _logpdf(q, z) ≈ _logpdf(q_true, z)
+            @test eltype(logpdf(q, z))  == realtype
+            @test eltype(_logpdf(q, z)) == realtype
+        end
+
+        @testset "entropy" begin
+            @test eltype(entropy(q)) == realtype
+            @test entropy(q)         ≈ entropy(q_true)
+        end
 
-        z_samples  = rand(q, n_montecarlo)
-        threesigma = L
-        @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ       rtol=realtype(1e-2)
-        @test dropdims(var(z_samples, dims=2),  dims=2) ≈ diag(Σ) rtol=realtype(1e-2)
-        @test cov(z_samples, dims=2)                    ≈ Σ       rtol=realtype(1e-2)
+        @testset "sampling" begin
+            z_samples  = rand(rng, q, n_montecarlo)
+            threesigma = L
+            @test eltype(z_samples) == realtype
+            @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ       rtol=realtype(1e-2)
+            @test dropdims(var(z_samples, dims=2),  dims=2) ≈ diag(Σ) rtol=realtype(1e-2)
+            @test cov(z_samples, dims=2)                    ≈ Σ       rtol=realtype(1e-2)
+        end
     end
 end

From f5f5863b55af07ea1009528e5b8e1fdb1bfc96df Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 14 Jul 2023 03:16:49 +0100
Subject: [PATCH 048/206] fix non-determinism bug

---
 src/distributions/location_scale.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index c1803ffef..e9e8c743f 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -47,7 +47,7 @@ end
 function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) 
     @unpack location, scale, dist = q
     n_dims = length(location)
-    scale*rand(dist, n_dims, num_samples) .+ location
+    scale*rand(rng, dist, n_dims, num_samples) .+ location
 end
 
 function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real})

From ade0d1007c1507fb0359d744fa640349314e325d Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 14 Jul 2023 04:56:29 +0100
Subject: [PATCH 049/206] fix test hyperparameters so that tests pass, minor
 cleanups

---
 src/distributions/location_scale.jl | 12 ++++++++++++
 src/objectives/elbo/advi.jl         |  6 ++++++
 src/optimize.jl                     |  6 ++++--
 test/exact.jl                       | 10 +++++-----
 test/exact/normallognormal.jl       |  2 +-
 5 files changed, 28 insertions(+), 8 deletions(-)

diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index e9e8c743f..dc9c1b279 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -1,4 +1,16 @@
 
+"""
+
+The [location scale] variational family broadly represents various variational
+families using `location` and `scale` variational parameters.
+
+Multivariate Student-t variational family with ``\\nu``-degrees of freedom can
+be constructed as:
+```julia
+q₀ = VILocationScale(μ, L, StudentT(ν), eps(Float32))
+```
+
+"""
 struct VILocationScale{L, S, D, R} <: ContinuousMultivariateDistribution
     location::L
     scale   ::S
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 9cd2433ee..b9b1185f7 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -24,6 +24,12 @@ struct ADVI{Tlogπ, B,
     end
 end
 
+Base.show(io::IO, advi::ADVI) =
+    print(io,
+          "ADVI(entropy_estimator=$(advi.entropy_estimator), " *
+          "control_variate=$(advi.control_variate), " *
+          "n_samples=$(advi.n_samples))")
+
 skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy_estimator)
 
 init(advi::ADVI) = init(advi.control_variate)
diff --git a/src/optimize.jl b/src/optimize.jl
index 2acfbc0b4..dcd1c4399 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -41,9 +41,11 @@ function optimize(
 
         opt_state, Δλ = Optimisers.apply!(optimizer, opt_state, λ, g)
         Optimisers.subtract!(λ, Δλ)
+
         stat′ = (iteration=t, Δλ=norm(Δλ), gradient_norm=norm(g))
         stat = merge(stat, stat′)
-        q    = restructure(λ)
+
+        q = restructure(λ)
 
         if !isnothing(callback!)
             stat′ = callback!(q, stat)
@@ -56,7 +58,7 @@ function optimize(
         stats[t] = stat
 
         # Termination decision is work in progress
-        if terminate(rng, q, objective, stat)
+        if terminate(rng, λ, q, objective, stat)
             stats = stats[1:t]
             break
         end
diff --git a/test/exact.jl b/test/exact.jl
index 637a95ed4..d1be4626e 100644
--- a/test/exact.jl
+++ b/test/exact.jl
@@ -49,9 +49,9 @@ include("exact/normallognormal.jl")
             diagm(ones(realtype, n_dims)) |> LowerTriangular
         end
         q₀ = if is_meanfield
-            VIMeanFieldGaussian(μ₀, L₀, realtype(1e-8))
+            VIMeanFieldGaussian(μ₀, L₀)
         else
-            VIFullRankGaussian(μ₀, L₀, realtype(1e-8))
+            VIFullRankGaussian(μ₀, L₀)
         end
 
         obj = objective(model, b⁻¹, 10)
@@ -60,7 +60,7 @@ include("exact/normallognormal.jl")
             Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
             q, stats  = optimize(
                 obj, q₀, T;
-                optimizer = Optimisers.AdaGrad(1e-0),
+                optimizer = Optimisers.AdaGrad(1e-1),
                 progress  = PROGRESS,
                 rng       = rng,
             )
@@ -78,7 +78,7 @@ include("exact/normallognormal.jl")
             rng      = Philox4x(UInt64, seed, 8)
             q, stats = optimize(
                 obj, q₀, T;
-                optimizer = Optimisers.AdaGrad(1e-2),
+                optimizer = Optimisers.AdaGrad(1e-1),
                 progress  = PROGRESS,
                 rng       = rng,
             )
@@ -88,7 +88,7 @@ include("exact/normallognormal.jl")
             rng_repl = Philox4x(UInt64, seed, 8)
             q, stats = optimize(
                 obj, q₀, T;
-                optimizer = Optimisers.AdaGrad(1e-2),
+                optimizer = Optimisers.AdaGrad(1e-1),
                 progress  = PROGRESS,
                 rng       = rng_repl,
             )
diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl
index 7c5c000de..18e8b4a34 100644
--- a/test/exact/normallognormal.jl
+++ b/test/exact/normallognormal.jl
@@ -33,7 +33,7 @@ function normallognormal_fullrank(realtype; rng = default_rng())
     σ_x  = ℯ
     μ_y  = randn(rng, realtype, n_dims)
     L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular
-    ϵ    = realtype(n_dims)
+    ϵ    = realtype(n_dims*2)
     Σ_y  = (L₀_y*L₀_y' + ϵ*I) |> Hermitian
 
     model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y))

From 0caf7a9ef768ce97c7498c981d5ef60ee673488f Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 14 Jul 2023 19:37:45 +0100
Subject: [PATCH 050/206] fix minor reorganization

---
 src/AdvancedVI.jl             |   9 +--
 test/exact.jl                 | 102 ----------------------------------
 test/exact/normallognormal.jl |  66 ----------------------
 test/runtests.jl              |   4 +-
 4 files changed, 4 insertions(+), 177 deletions(-)
 delete mode 100644 test/exact.jl
 delete mode 100644 test/exact/normallognormal.jl

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 502112c76..86c9fc44a 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -1,7 +1,7 @@
 
 module AdvancedVI
 
-using SimpleUnPack: @unpack
+using SimpleUnPack: @unpack, @pack!
 using Accessors
 
 import Random: AbstractRNG, default_rng
@@ -60,17 +60,14 @@ include("grad.jl")
 # estimators
 abstract type AbstractVariationalObjective end
 
+function init              end
 function estimate_gradient end
 
-abstract type AbstractEnergyEstimator  end
+# ADVI-specific interfaces
 abstract type AbstractEntropyEstimator end
 abstract type AbstractControlVariate end
 
-function init   end
 function update end
-
-init(::Nothing) = nothing
-
 update(::Nothing, ::Nothing) = (nothing, nothing)
 
 include("objectives/elbo/advi.jl")
diff --git a/test/exact.jl b/test/exact.jl
deleted file mode 100644
index d1be4626e..000000000
--- a/test/exact.jl
+++ /dev/null
@@ -1,102 +0,0 @@
-
-const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
-
-using ReTest
-using Bijectors
-using LogDensityProblems
-using Optimisers
-using Distributions
-using PDMats
-using LinearAlgebra
-using SimpleUnPack: @unpack
-
-struct TestModel{M,L,S}
-    model::M
-    μ_true::L
-    L_true::S
-    n_dims::Int
-    is_meanfield::Bool
-end
-
-include("exact/normallognormal.jl")
-
-@testset "exact" begin
-    @testset "$(modelname) $(objname) $(realtype)"  for
-        realtype ∈ [Float32, Float64],
-        (modelname, modelconstr) ∈ Dict(
-            :NormalLogNormalMeanField => normallognormal_meanfield,
-            :NormalLogNormalFullRank  => normallognormal_fullrank,
-        ),
-        (objname, objective) ∈ Dict(
-            :ADVIClosedFormEntropy  => (model, b⁻¹, M) -> ADVI(model, b⁻¹,                              M),
-            :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M),
-            :ADVIFullMonteCarlo     => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(),         M),
-        )
-        seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
-        rng  = Philox4x(UInt64, seed, 8)
-        
-        T = 10000
-        modelstats = modelconstr(realtype; rng)
-        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
-
-        b    = Bijectors.bijector(model)
-        b⁻¹  = inverse(b)
-
-        μ₀ = zeros(realtype, n_dims)
-        L₀ = if is_meanfield
-            ones(realtype, n_dims) |> Diagonal
-        else
-            diagm(ones(realtype, n_dims)) |> LowerTriangular
-        end
-        q₀ = if is_meanfield
-            VIMeanFieldGaussian(μ₀, L₀)
-        else
-            VIFullRankGaussian(μ₀, L₀)
-        end
-
-        obj = objective(model, b⁻¹, 10)
-
-        @testset "convergence" begin
-            Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
-            q, stats  = optimize(
-                obj, q₀, T;
-                optimizer = Optimisers.AdaGrad(1e-1),
-                progress  = PROGRESS,
-                rng       = rng,
-            )
-
-            μ  = q.location
-            L  = q.scale
-            Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
-
-            @test Δλ ≤ Δλ₀/√T
-            @test eltype(μ) == eltype(μ_true)
-            @test eltype(L) == eltype(L_true)
-        end
-
-        @testset "determinism" begin
-            rng      = Philox4x(UInt64, seed, 8)
-            q, stats = optimize(
-                obj, q₀, T;
-                optimizer = Optimisers.AdaGrad(1e-1),
-                progress  = PROGRESS,
-                rng       = rng,
-            )
-            μ  = q.location
-            L  = q.scale
-
-            rng_repl = Philox4x(UInt64, seed, 8)
-            q, stats = optimize(
-                obj, q₀, T;
-                optimizer = Optimisers.AdaGrad(1e-1),
-                progress  = PROGRESS,
-                rng       = rng_repl,
-            )
-            μ_repl = q.location
-            L_repl = q.scale
-            @test μ == μ_repl
-            @test L == L_repl
-        end
-    end
-end
-
diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl
deleted file mode 100644
index 18e8b4a34..000000000
--- a/test/exact/normallognormal.jl
+++ /dev/null
@@ -1,66 +0,0 @@
-
-struct NormalLogNormal{MX,SX,MY,SY}
-    μ_x::MX
-    σ_x::SX
-    μ_y::MY
-    Σ_y::SY
-end
-
-function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
-    logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
-end
-
-function LogDensityProblems.dimension(model::NormalLogNormal)
-    length(model.μ_y) + 1
-end
-
-function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
-    LogDensityProblems.LogDensityOrder{0}()
-end
-
-function Bijectors.bijector(model::NormalLogNormal)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
-    Bijectors.Stacked(
-        Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
-        [1:1, 2:1+length(μ_y)])
-end
-
-function normallognormal_fullrank(realtype; rng = default_rng())
-    n_dims = 5
-
-    μ_x  = randn(rng, realtype)
-    σ_x  = ℯ
-    μ_y  = randn(rng, realtype, n_dims)
-    L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular
-    ϵ    = realtype(n_dims*2)
-    Σ_y  = (L₀_y*L₀_y' + ϵ*I) |> Hermitian
-
-    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y))
-
-    Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1)
-    Σ[1,1]         = σ_x^2
-    Σ[2:end,2:end] = Σ_y
-    Σ = Σ |> Hermitian
-
-    μ = vcat(μ_x, μ_y)
-    L = cholesky(Σ).L |> LowerTriangular
-
-    TestModel(model, μ, L, n_dims+1, false)
-end
-
-function normallognormal_meanfield(realtype; rng = default_rng())
-    n_dims = 5
-
-    μ_x  = randn(rng, realtype)
-    σ_x  = ℯ
-    μ_y  = randn(rng, realtype, n_dims)
-    σ_y  = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
-
-    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2))
-
-    μ = vcat(μ_x, μ_y)
-    L = vcat(σ_x, σ_y) |> Diagonal
-
-    TestModel(model, μ, L, n_dims+1, true)
-end
diff --git a/test/runtests.jl b/test/runtests.jl
index b571f8b81..ddc1d09cf 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -8,11 +8,9 @@ using Distributions
 using LinearAlgebra
 using AdvancedVI
 
-const GROUP = get(ENV, "AHMC_TEST_GROUP", "AdvancedHMC")
-
 include("ad.jl")
 include("distributions.jl")
-include("exact.jl")
+include("advi_locscale.jl")
 
 @main function runtests(patterns...; dry::Bool = false)
     retest(patterns...; dry = dry, verbose = Inf)

From 5658cbf10e3f6e64d7b03380d4c026951cb3f0c2 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 14 Jul 2023 19:40:59 +0100
Subject: [PATCH 051/206] add missing files

---
 test/Project.toml              |  20 +++++++
 test/advi_locscale.jl          | 102 +++++++++++++++++++++++++++++++++
 test/models/normallognormal.jl |  66 +++++++++++++++++++++
 3 files changed, 188 insertions(+)
 create mode 100644 test/Project.toml
 create mode 100644 test/advi_locscale.jl
 create mode 100644 test/models/normallognormal.jl

diff --git a/test/Project.toml b/test/Project.toml
new file mode 100644
index 000000000..2f38c88fa
--- /dev/null
+++ b/test/Project.toml
@@ -0,0 +1,20 @@
+[deps]
+ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
+Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
+Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
+Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
+Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
+LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
+Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
+PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
+Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
+Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+Random123 = "74087812-796a-5b5d-8853-05524746bad3"
+ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"
+ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
+SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
+Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
+Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
+Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
new file mode 100644
index 000000000..2beb05470
--- /dev/null
+++ b/test/advi_locscale.jl
@@ -0,0 +1,102 @@
+
+const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
+
+using ReTest
+using Bijectors
+using LogDensityProblems
+using Optimisers
+using Distributions
+using PDMats
+using LinearAlgebra
+using SimpleUnPack: @unpack
+
+struct TestModel{M,L,S}
+    model::M
+    μ_true::L
+    L_true::S
+    n_dims::Int
+    is_meanfield::Bool
+end
+
+include("models/normallognormal.jl")
+
+@testset "exact" begin
+    @testset "$(modelname) $(objname) $(realtype)"  for
+        realtype ∈ [Float32, Float64],
+        (modelname, modelconstr) ∈ Dict(
+            :NormalLogNormalMeanField => normallognormal_meanfield,
+            :NormalLogNormalFullRank  => normallognormal_fullrank,
+        ),
+        (objname, objective) ∈ Dict(
+            :ADVIClosedFormEntropy  => (model, b⁻¹, M) -> ADVI(model, b⁻¹,                              M),
+            :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M),
+            :ADVIFullMonteCarlo     => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(),         M),
+        )
+        seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
+        rng  = Philox4x(UInt64, seed, 8)
+        
+        T = 10000
+        modelstats = modelconstr(realtype; rng)
+        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+        b    = Bijectors.bijector(model)
+        b⁻¹  = inverse(b)
+
+        μ₀ = zeros(realtype, n_dims)
+        L₀ = if is_meanfield
+            ones(realtype, n_dims) |> Diagonal
+        else
+            diagm(ones(realtype, n_dims)) |> LowerTriangular
+        end
+        q₀ = if is_meanfield
+            VIMeanFieldGaussian(μ₀, L₀)
+        else
+            VIFullRankGaussian(μ₀, L₀)
+        end
+
+        obj = objective(model, b⁻¹, 10)
+
+        @testset "convergence" begin
+            Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
+            q, stats  = optimize(
+                obj, q₀, T;
+                optimizer = Optimisers.AdaGrad(1e-1),
+                progress  = PROGRESS,
+                rng       = rng,
+            )
+
+            μ  = q.location
+            L  = q.scale
+            Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
+
+            @test Δλ ≤ Δλ₀/√T
+            @test eltype(μ) == eltype(μ_true)
+            @test eltype(L) == eltype(L_true)
+        end
+
+        @testset "determinism" begin
+            rng      = Philox4x(UInt64, seed, 8)
+            q, stats = optimize(
+                obj, q₀, T;
+                optimizer = Optimisers.AdaGrad(1e-1),
+                progress  = PROGRESS,
+                rng       = rng,
+            )
+            μ  = q.location
+            L  = q.scale
+
+            rng_repl = Philox4x(UInt64, seed, 8)
+            q, stats = optimize(
+                obj, q₀, T;
+                optimizer = Optimisers.AdaGrad(1e-1),
+                progress  = PROGRESS,
+                rng       = rng_repl,
+            )
+            μ_repl = q.location
+            L_repl = q.scale
+            @test μ == μ_repl
+            @test L == L_repl
+        end
+    end
+end
+
diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl
new file mode 100644
index 000000000..18e8b4a34
--- /dev/null
+++ b/test/models/normallognormal.jl
@@ -0,0 +1,66 @@
+
+struct NormalLogNormal{MX,SX,MY,SY}
+    μ_x::MX
+    σ_x::SX
+    μ_y::MY
+    Σ_y::SY
+end
+
+function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
+end
+
+function LogDensityProblems.dimension(model::NormalLogNormal)
+    length(model.μ_y) + 1
+end
+
+function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
+    LogDensityProblems.LogDensityOrder{0}()
+end
+
+function Bijectors.bijector(model::NormalLogNormal)
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    Bijectors.Stacked(
+        Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
+        [1:1, 2:1+length(μ_y)])
+end
+
+function normallognormal_fullrank(realtype; rng = default_rng())
+    n_dims = 5
+
+    μ_x  = randn(rng, realtype)
+    σ_x  = ℯ
+    μ_y  = randn(rng, realtype, n_dims)
+    L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular
+    ϵ    = realtype(n_dims*2)
+    Σ_y  = (L₀_y*L₀_y' + ϵ*I) |> Hermitian
+
+    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y))
+
+    Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1)
+    Σ[1,1]         = σ_x^2
+    Σ[2:end,2:end] = Σ_y
+    Σ = Σ |> Hermitian
+
+    μ = vcat(μ_x, μ_y)
+    L = cholesky(Σ).L |> LowerTriangular
+
+    TestModel(model, μ, L, n_dims+1, false)
+end
+
+function normallognormal_meanfield(realtype; rng = default_rng())
+    n_dims = 5
+
+    μ_x  = randn(rng, realtype)
+    σ_x  = ℯ
+    μ_y  = randn(rng, realtype, n_dims)
+    σ_y  = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
+
+    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2))
+
+    μ = vcat(μ_x, μ_y)
+    L = vcat(σ_x, σ_y) |> Diagonal
+
+    TestModel(model, μ, L, n_dims+1, true)
+end

From c712a9762afdbc60468953bfeab1ad076a6cc2f9 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 14 Jul 2023 19:51:14 +0100
Subject: [PATCH 052/206] fix add missing file, rename adbackend argument

---
 src/grad.jl     | 30 ++++++++++++++++++++++++++++++
 src/optimize.jl |  2 +-
 2 files changed, 31 insertions(+), 1 deletion(-)
 create mode 100644 src/grad.jl

diff --git a/src/grad.jl b/src/grad.jl
new file mode 100644
index 000000000..e68e16234
--- /dev/null
+++ b/src/grad.jl
@@ -0,0 +1,30 @@
+
+# default implementations
+function grad!(
+    f::Function,
+    adtype::AutoForwardDiff{chunksize},
+    λ::AbstractVector{<:Real},
+    out::DiffResults.MutableDiffResult
+) where {chunksize}
+    # Set chunk size and do ForwardMode.
+    config = if isnothing(chunksize)
+        ForwardDiff.GradientConfig(f, λ)
+    else
+        ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunksize))
+    end
+    ForwardDiff.gradient!(out, f, λ, config)
+end
+
+function grad!(
+    f::Function,
+    ::AutoTracker,
+    λ::AbstractVector{<:Real},
+    out::DiffResults.MutableDiffResult
+)
+    λ_tracked = Tracker.param(λ)
+    y = f(λ_tracked)
+    Tracker.back!(y, 1.0)
+
+    DiffResults.value!(out, Tracker.data(y))
+    DiffResults.gradient!(out, Tracker.grad(λ_tracked))
+end
diff --git a/src/optimize.jl b/src/optimize.jl
index dcd1c4399..169959259 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -19,7 +19,7 @@ function optimize(
     progress  ::Bool                    = true,
     callback!                           = nothing,
     terminate                           = (args...) -> false,
-    adback::AbstractADType              = AutoForwardDiff(), 
+    adbackend::AbstractADType           = AutoForwardDiff(), 
 )
     opt_state = Optimisers.init(optimizer, λ)
     est_state = init(objective)

From bee839d91399ce9cc2d776f907dd9197e14aa241 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 14 Jul 2023 20:03:16 +0100
Subject: [PATCH 053/206] fix errors

---
 src/AdvancedVI.jl           | 2 ++
 src/objectives/elbo/advi.jl | 4 ++--
 src/optimize.jl             | 2 +-
 3 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 86c9fc44a..4010b1fe8 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -63,6 +63,8 @@ abstract type AbstractVariationalObjective end
 function init              end
 function estimate_gradient end
 
+init(::Nothing) = nothing
+
 # ADVI-specific interfaces
 abstract type AbstractEntropyEstimator end
 abstract type AbstractControlVariate end
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index b9b1185f7..1fb6b0c63 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -59,7 +59,7 @@ end
 
 function estimate_gradient(
     rng::AbstractRNG,
-    adback::AbstractADType,
+    adbackend::AbstractADType,
     advi::ADVI,
     est_state,
     λ::Vector{<:Real},
@@ -69,7 +69,7 @@ function estimate_gradient(
     # Gradient-stopping for computing the sticking-the-landing control variate
     q_η_stop = skip_entropy_gradient(advi.entropy_estimator) ? restructure(λ) : nothing
 
-    grad!(adback, λ, out) do λ′
+    grad!(adbackend, λ, out) do λ′
         q_η = restructure(λ′)
         q_η_entropy = skip_entropy_gradient(advi.entropy_estimator) ? q_η_stop : q_η
         -advi(q_η; rng, q_η_entropy)
diff --git a/src/optimize.jl b/src/optimize.jl
index 169959259..8b36df04d 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -35,7 +35,7 @@ function optimize(
         stat = (iteration=t,)
 
         grad_buf, est_state, stat′ = estimate_gradient(
-            rng, adback, objective, est_state, λ, restructure, grad_buf)
+            rng, adbackend, objective, est_state, λ, restructure, grad_buf)
         g    = DiffResults.gradient(grad_buf)
         stat = merge(stat, stat′)
 

From 913911ec74f835d566e2f19b0df16358a3fd055b Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 14 Jul 2023 20:03:23 +0100
Subject: [PATCH 054/206] rename test suite

---
 test/advi_locscale.jl | 149 +++++++++++++++++++++++-------------------
 1 file changed, 80 insertions(+), 69 deletions(-)

diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 2beb05470..342b9db1a 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -20,83 +20,94 @@ end
 
 include("models/normallognormal.jl")
 
-@testset "exact" begin
-    @testset "$(modelname) $(objname) $(realtype)"  for
-        realtype ∈ [Float32, Float64],
-        (modelname, modelconstr) ∈ Dict(
-            :NormalLogNormalMeanField => normallognormal_meanfield,
-            :NormalLogNormalFullRank  => normallognormal_fullrank,
-        ),
-        (objname, objective) ∈ Dict(
-            :ADVIClosedFormEntropy  => (model, b⁻¹, M) -> ADVI(model, b⁻¹,                              M),
-            :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M),
-            :ADVIFullMonteCarlo     => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(),         M),
-        )
-        seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
-        rng  = Philox4x(UInt64, seed, 8)
-        
-        T = 10000
-        modelstats = modelconstr(realtype; rng)
-        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+@testset "advi" begin
+    @testset "locscale" begin
+        @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
+            realtype ∈ [Float32, Float64],
+            (modelname, modelconstr) ∈ Dict(
+                :NormalLogNormalMeanField => normallognormal_meanfield,
+                :NormalLogNormalFullRank  => normallognormal_fullrank,
+            ),
+            (objname, objective) ∈ Dict(
+                :ADVIClosedFormEntropy  => (model, b⁻¹, M) -> ADVI(model, b⁻¹,                              M),
+                :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M),
+                :ADVIFullMonteCarlo     => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(),         M),
+            ),
+            (adbackname, adbackend) ∈ Dict(
+                :ForwarDiff  => AutoForwardDiff(),
+                :ReverseDiff => AutoReverseDiff(),
+                :Zygote      => AutoZygote(),
+                :Enzyme      => AutoEnzyme(),
+            )
 
-        b    = Bijectors.bijector(model)
-        b⁻¹  = inverse(b)
+            seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
+            rng  = Philox4x(UInt64, seed, 8)
 
-        μ₀ = zeros(realtype, n_dims)
-        L₀ = if is_meanfield
-            ones(realtype, n_dims) |> Diagonal
-        else
-            diagm(ones(realtype, n_dims)) |> LowerTriangular
-        end
-        q₀ = if is_meanfield
-            VIMeanFieldGaussian(μ₀, L₀)
-        else
-            VIFullRankGaussian(μ₀, L₀)
-        end
+            T = 10000
+            modelstats = modelconstr(realtype; rng)
+            @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
 
-        obj = objective(model, b⁻¹, 10)
+            b    = Bijectors.bijector(model)
+            b⁻¹  = inverse(b)
 
-        @testset "convergence" begin
-            Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
-            q, stats  = optimize(
-                obj, q₀, T;
-                optimizer = Optimisers.AdaGrad(1e-1),
-                progress  = PROGRESS,
-                rng       = rng,
-            )
+            μ₀ = zeros(realtype, n_dims)
+            L₀ = if is_meanfield
+                ones(realtype, n_dims) |> Diagonal
+            else
+                diagm(ones(realtype, n_dims)) |> LowerTriangular
+            end
+            q₀ = if is_meanfield
+                VIMeanFieldGaussian(μ₀, L₀)
+            else
+                VIFullRankGaussian(μ₀, L₀)
+            end
 
-            μ  = q.location
-            L  = q.scale
-            Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
+            obj = objective(model, b⁻¹, 10)
 
-            @test Δλ ≤ Δλ₀/√T
-            @test eltype(μ) == eltype(μ_true)
-            @test eltype(L) == eltype(L_true)
-        end
+            @testset "convergence" begin
+                Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
+                q, stats  = optimize(
+                    obj, q₀, T;
+                    optimizer = Optimisers.AdaGrad(1e-1),
+                    progress  = PROGRESS,
+                    rng       = rng,
+                    adbackend = adbackend,
+                )
 
-        @testset "determinism" begin
-            rng      = Philox4x(UInt64, seed, 8)
-            q, stats = optimize(
-                obj, q₀, T;
-                optimizer = Optimisers.AdaGrad(1e-1),
-                progress  = PROGRESS,
-                rng       = rng,
-            )
-            μ  = q.location
-            L  = q.scale
+                μ  = q.location
+                L  = q.scale
+                Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
 
-            rng_repl = Philox4x(UInt64, seed, 8)
-            q, stats = optimize(
-                obj, q₀, T;
-                optimizer = Optimisers.AdaGrad(1e-1),
-                progress  = PROGRESS,
-                rng       = rng_repl,
-            )
-            μ_repl = q.location
-            L_repl = q.scale
-            @test μ == μ_repl
-            @test L == L_repl
+                @test Δλ ≤ Δλ₀/√T
+                @test eltype(μ) == eltype(μ_true)
+                @test eltype(L) == eltype(L_true)
+            end
+
+            @testset "determinism" begin
+                rng      = Philox4x(UInt64, seed, 8)
+                q, stats = optimize(
+                    obj, q₀, T;
+                    optimizer = Optimisers.AdaGrad(1e-1),
+                    progress  = PROGRESS,
+                    rng       = rng,
+                    adbackend = adbackend,
+                )
+                μ  = q.location
+                L  = q.scale
+
+                rng_repl = Philox4x(UInt64, seed, 8)
+                q, stats = optimize(
+                    obj, q₀, T;
+                    optimizer = Optimisers.AdaGrad(1e-1),
+                    progress  = PROGRESS,
+                    rng       = rng_repl,
+                    adbackend = adbackend,
+                )
+                μ_repl = q.location
+                L_repl = q.scale
+                @test μ == μ_repl
+                @test L == L_repl
+            end
         end
     end
 end
-

From d50cabb0f0b7b7fac8bfd79c43ef38196b2df8c9 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 15 Jul 2023 01:59:43 +0100
Subject: [PATCH 055/206] refactor renamed arguments for ADVI to be shorter

---
 Project.toml                   |  3 +-
 src/AdvancedVI.jl              |  7 ++--
 src/objectives/elbo/advi.jl    | 59 +++++++++++++++++-----------------
 src/objectives/elbo/entropy.jl | 42 ++++++++++++++----------
 test/ad.jl                     | 10 +++---
 test/advi_locscale.jl          | 18 +++++------
 6 files changed, 73 insertions(+), 66 deletions(-)

diff --git a/Project.toml b/Project.toml
index 2fcc845e8..cf698f7a6 100644
--- a/Project.toml
+++ b/Project.toml
@@ -7,7 +7,6 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
 Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
 Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
-DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
 DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
 FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -25,9 +24,9 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
 Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
 
 [compat]
+ADTypes = "0.1"
 Bijectors = "0.11, 0.12, 0.13"
 Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
-DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6"
 DocStringExtensions = "0.8, 0.9"
 ForwardDiff = "0.10.3"
 ProgressMeter = "1.0.0"
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 4010b1fe8..e3dd85a89 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -5,7 +5,10 @@ using SimpleUnPack: @unpack, @pack!
 using Accessors
 
 import Random: AbstractRNG, default_rng
-import Distributions: logpdf, _logpdf, rand, _rand!, _rand!
+using Distributions
+import Distributions:
+    logpdf, _logpdf, rand, _rand!, _rand!,
+    ContinuousMultivariateDistribution
 
 using Functors
 using Optimisers
@@ -24,8 +27,6 @@ using ForwardDiff, Tracker
 
 using FillArrays
 using PDMats
-using Distributions, DistributionsAD
-using Distributions: ContinuousMultivariateDistribution
 using Bijectors
 
 using StatsBase
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 1fb6b0c63..e965ea734 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -1,14 +1,28 @@
 
+"""
+    ADVI
+
+Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective.
+
+# Requirements
+- ``q_{\\lambda}`` implements `rand`.
+- ``\\pi`` must be differentiable
+
+Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
+"""
 struct ADVI{Tlogπ, B,
-            EntropyEst  <: AbstractEntropyEstimator,
-            ControlVar  <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective
+            EntropyEst <: AbstractEntropyEstimator,
+            ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective
     ℓπ::Tlogπ
-    b⁻¹::B
-    entropy_estimator::EntropyEst
-    control_variate::ControlVar
+    b::B
+    entropy::EntropyEst
+    cv::ControlVar
     n_samples::Int
 
-    function ADVI(prob, b⁻¹, entropy_estimator, control_variate, n_samples)
+    function ADVI(prob, n_samples::Int;
+                  entropy::AbstractEntropyEstimator = ClosedFormEntropy(),
+                  cv::Union{<:AbstractControlVariate, Nothing} = nothing,
+                  b = Bijectors.identity)
         cap = LogDensityProblems.capabilities(prob)
         if cap === nothing
             throw(
@@ -18,31 +32,16 @@ struct ADVI{Tlogπ, B,
             )
         end
         ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob)
-        new{typeof(ℓπ), typeof(b⁻¹), typeof(entropy_estimator), typeof(control_variate)}(
-            ℓπ, b⁻¹, entropy_estimator, control_variate, n_samples
-        )
+        new{typeof(ℓπ), typeof(b), typeof(entropy), typeof(cv)}(ℓπ, b, entropy, cv, n_samples)
     end
 end
 
 Base.show(io::IO, advi::ADVI) =
-    print(io,
-          "ADVI(entropy_estimator=$(advi.entropy_estimator), " *
-          "control_variate=$(advi.control_variate), " *
-          "n_samples=$(advi.n_samples))")
-
-skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy_estimator)
+    print(io, "ADVI(entropy=$(advi.entropy), cv=$(advi.cv), n_samples=$(advi.n_samples))")
 
-init(advi::ADVI) = init(advi.control_variate)
+skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy)
 
-function ADVI(ℓπ, b⁻¹,
-              entropy_estimator::AbstractEntropyEstimator,
-              n_samples::Int)
-    ADVI(ℓπ, b⁻¹, entropy_estimator, nothing, n_samples)
-end
-
-function ADVI(ℓπ, b⁻¹, n_samples::Int)
-    ADVI(ℓπ, b⁻¹, ClosedFormEntropy(), nothing, n_samples)
-end
+init(advi::ADVI) = init(advi.cv)
 
 function (advi::ADVI)(q_η::ContinuousMultivariateDistribution;
                       rng       ::AbstractRNG    = default_rng(),
@@ -50,10 +49,10 @@ function (advi::ADVI)(q_η::ContinuousMultivariateDistribution;
                       ηs        ::AbstractMatrix = rand(rng, q_η, n_samples),
                       q_η_entropy::ContinuousMultivariateDistribution = q_η)
     𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ
-        zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b⁻¹, ηᵢ)
+        zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b, ηᵢ)
         (advi.ℓπ(zᵢ) + logdetjacᵢ) / n_samples
     end
-    ℍ  = advi.entropy_estimator(q_η_entropy, ηs)
+    ℍ  = advi.entropy(q_η_entropy, ηs)
     𝔼ℓ + ℍ
 end
 
@@ -67,17 +66,17 @@ function estimate_gradient(
     out::DiffResults.MutableDiffResult)
 
     # Gradient-stopping for computing the sticking-the-landing control variate
-    q_η_stop = skip_entropy_gradient(advi.entropy_estimator) ? restructure(λ) : nothing
+    q_η_stop = skip_entropy_gradient(advi.entropy) ? restructure(λ) : nothing
 
     grad!(adbackend, λ, out) do λ′
         q_η = restructure(λ′)
-        q_η_entropy = skip_entropy_gradient(advi.entropy_estimator) ? q_η_stop : q_η
+        q_η_entropy = skip_entropy_gradient(advi.entropy) ? q_η_stop : q_η
         -advi(q_η; rng, q_η_entropy)
     end
     nelbo = DiffResults.value(out)
     stat  = (elbo=-nelbo,)
 
-    est_state, stat′ = update(advi.control_variate, est_state)
+    est_state, stat′ = update(advi.cv, est_state)
     stat = !isnothing(stat′) ? merge(stat′, stat) : stat 
 
     out, est_state, stat
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index ddeb64a9c..994bdd4f1 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -14,27 +14,35 @@ MonteCarloEntropy() = MonteCarloEntropy{false}()
 Base.show(io::IO, entropy::MonteCarloEntropy{false}) = print(io, "MonteCarloEntropy()")
 
 """
-  Sticking the Landing Control Variate
+    StickingTheLandingEntropy()
 
-  # Explanation
+# Explanation
 
-  This eatimator forms a control variate of the form of
+The STL estimator forms a control variate of the form of
  
-    c(z)  = 𝔼-logq(z) + logq(z) = ℍ[q] - logq(z)
+```math
+\\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right) =
+  \\mathbb{E}\\left[ -\\log q\\left(z\\right) \\right]
+  + \\log q\\left(z\\right) = \\mathbb{H}\\left(q_{\\lambda}\\right) + \\log q_{\\lambda}\\left(z\\right),
+```
+where, for the score term, the gradient is stopped from propagating.
  
-   Adding this to the closed-form entropy ELBO estimator yields:
- 
-     ELBO - c(z) = 𝔼logπ(z) + ℍ[q] - c(z) = 𝔼logπ(z) - logq(z),
-
-   which has the same expectation, but lower variance when π ≈ q,
-   and higher variance when π ≉ q.
-
-   # Reference
-
-   Roeder, Geoffrey, Yuhuai Wu, and David K. Duvenaud.
-   "Sticking the landing: Simple, lower-variance gradient estimators for
-   variational inference."
-   Advances in Neural Information Processing Systems 30 (2017).
+Adding this to the closed-form entropy ELBO estimator yields the STL estimator:
+```math
+\\begin{aligned}
+  \\widehat{\\mathrm{ELBO}}_{\\mathrm{STL}}\\left(\\lambda\\right)
+    &\\triangleq \\mathbb{E}\\left[ \\log \\pi \\left(z\\right) \\right] - \\log q_{\\lambda} \\left(z\\right) \\\\
+    &= \\mathbb{E}\\left[ \\log \\pi\\left(z\\right) \\right] 
+      + \\mathbb{H}\\left(q_{\\lambda}\\right) - \\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right) \\\\
+    &= \\widehat{\\mathrm{ELBO}}\\left(\\lambda\\right)
+      - \\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right),
+\\end{aligned}
+```
+which has the same expectation, but lower variance when ``\\pi \\approx q_{\\lambda}``,
+and higher variance when ``\\pi \\not\\approx q_{\\lambda}``.
+
+# Reference
+1. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30.
 """
 StickingTheLandingEntropy() = MonteCarloEntropy{true}()
 
diff --git a/test/ad.jl b/test/ad.jl
index 6b587598e..1efa536b2 100644
--- a/test/ad.jl
+++ b/test/ad.jl
@@ -5,11 +5,11 @@ using ADTypes
 
 @testset "ad" begin
     @testset "$(adname)" for (adname, adsymbol) ∈ Dict(
-          :ForwardDiffAuto => AutoForwardDiff(),
-          :ForwardDiff     => AutoForwardDiff(10),
-          :ReverseDiff     => AutoReverseDiff(),
-          :Zygote          => AutoZygote(),
-          :Tracker         => AutoTracker(),
+          :ForwardDiff => AutoForwardDiff(),
+          :ReverseDiff => AutoReverseDiff(),
+          :Zygote      => AutoZygote(),
+          :Tracker     => AutoTracker(),
+          :Enzyme      => AutoEnzyme(),
         )
         D = 10
         A = randn(D, D)
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 342b9db1a..dadbaf25d 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -29,15 +29,15 @@ include("models/normallognormal.jl")
                 :NormalLogNormalFullRank  => normallognormal_fullrank,
             ),
             (objname, objective) ∈ Dict(
-                :ADVIClosedFormEntropy  => (model, b⁻¹, M) -> ADVI(model, b⁻¹,                              M),
-                :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M),
-                :ADVIFullMonteCarlo     => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(),         M),
+                :ADVIClosedFormEntropy  => (model, b, M) -> ADVI(model, M; b),
+                :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, H = StickingTheLandingEntropy()),
+                :ADVIFullMonteCarlo     => (model, b, M) -> ADVI(model, M; b, H = MonteCarloEntropy()),
             ),
             (adbackname, adbackend) ∈ Dict(
                 :ForwarDiff  => AutoForwardDiff(),
-                :ReverseDiff => AutoReverseDiff(),
-                :Zygote      => AutoZygote(),
-                :Enzyme      => AutoEnzyme(),
+                # :ReverseDiff => AutoReverseDiff(),
+                # :Zygote      => AutoZygote(),
+                # :Enzyme      => AutoEnzyme(),
             )
 
             seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
@@ -68,7 +68,7 @@ include("models/normallognormal.jl")
                 Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
                 q, stats  = optimize(
                     obj, q₀, T;
-                    optimizer = Optimisers.AdaGrad(1e-1),
+                    optimizer = Optimisers.Adam(1e-3),
                     progress  = PROGRESS,
                     rng       = rng,
                     adbackend = adbackend,
@@ -87,7 +87,7 @@ include("models/normallognormal.jl")
                 rng      = Philox4x(UInt64, seed, 8)
                 q, stats = optimize(
                     obj, q₀, T;
-                    optimizer = Optimisers.AdaGrad(1e-1),
+                    optimizer = Optimisers.Adam(1e-3),
                     progress  = PROGRESS,
                     rng       = rng,
                     adbackend = adbackend,
@@ -98,7 +98,7 @@ include("models/normallognormal.jl")
                 rng_repl = Philox4x(UInt64, seed, 8)
                 q, stats = optimize(
                     obj, q₀, T;
-                    optimizer = Optimisers.AdaGrad(1e-1),
+                    optimizer = Optimisers.Adam(1e-3),
                     progress  = PROGRESS,
                     rng       = rng_repl,
                     adbackend = adbackend,

From b134f7099062b2c7a6d7b3ec9e30867703c609da Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 15 Jul 2023 02:07:08 +0100
Subject: [PATCH 056/206] fix compile error in advi test

---
 test/advi_locscale.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index dadbaf25d..40e5dace7 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -30,8 +30,8 @@ include("models/normallognormal.jl")
             ),
             (objname, objective) ∈ Dict(
                 :ADVIClosedFormEntropy  => (model, b, M) -> ADVI(model, M; b),
-                :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, H = StickingTheLandingEntropy()),
-                :ADVIFullMonteCarlo     => (model, b, M) -> ADVI(model, M; b, H = MonteCarloEntropy()),
+                :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, entropy = StickingTheLandingEntropy()),
+                :ADVIFullMonteCarlo     => (model, b, M) -> ADVI(model, M; b, entropy = MonteCarloEntropy()),
             ),
             (adbackname, adbackend) ∈ Dict(
                 :ForwarDiff  => AutoForwardDiff(),

From a6ba379b9a97e509076ce0c7e2c2ebd4b6caa737 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 15 Jul 2023 22:33:52 +0100
Subject: [PATCH 057/206] add initial doc

---
 docs/make.jl         | 17 +++++++++++
 docs/src/advi.md     | 67 ++++++++++++++++++++++++++++++++++++++++++++
 docs/src/families.md | 58 ++++++++++++++++++++++++++++++++++++++
 docs/src/index.md    | 14 +++++++++
 4 files changed, 156 insertions(+)
 create mode 100644 docs/make.jl
 create mode 100644 docs/src/advi.md
 create mode 100644 docs/src/families.md
 create mode 100644 docs/src/index.md

diff --git a/docs/make.jl b/docs/make.jl
new file mode 100644
index 000000000..d2a01d1bf
--- /dev/null
+++ b/docs/make.jl
@@ -0,0 +1,17 @@
+#using AdvancedVI
+using Documenter
+
+DocMeta.setdocmeta!(
+    AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true
+)
+
+makedocs(;
+    sitename = "AdvancedVI.jl",
+    modules  = [AdvancedVI],
+    format   = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"),
+         pages    = ["index.md",
+                     "families.md",
+                     "advi.md"],
+)
+
+deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", devbranch="main")
diff --git a/docs/src/advi.md b/docs/src/advi.md
new file mode 100644
index 000000000..4f4a2ecad
--- /dev/null
+++ b/docs/src/advi.md
@@ -0,0 +1,67 @@
+
+# [Automatic Differentiation Variational Inference](@id advi)
+The automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective is a method for estimating the evidence lower bound between a target posterior distribution ``\pi`` and a variational approximation ``q_{\phi,\lambda}``.
+By maximizing ADVI objective, it is equivalent to solving the problem
+
+```math
+  \mathrm{minimize}_{\lambda \in \Lambda}\quad \mathrm{KL}\left(q_{\phi,\lambda}, \pi\right).
+```
+
+The key aspects of the ADVI objective are the followings:
+1. The use of the reparameterization gradient estimator
+2. Automatically match the support of the target posterior through "bijectors."
+
+Thanks to Item 2, the user is free to choose any unconstrained variational family, for which
+bijectors will automatically match the potentially constrained support of the target.
+
+In particular, ADVI implicitly forms a variational approximation ``q_{\phi,\lambda}``
+from a reparameterizable distribution ``q_{\lambda}`` and a bijector ``\phi`` such that
+```math
+z &\sim  q_{\phi,\lambda} \qquad\Leftrightarrow\qquad
+z &\stackrel{d}{=} \phi^{-1}\left(\eta\right);\quad \eta \sim q_{\lambda} 
+```
+ADVI provides a principled way to compute the evidence lower bound for ``q_{\phi,\lambda}``.
+
+That is,
+
+```math
+\begin{aligned}
+\mathrm{ADVI}\left(\lambda\right)
+&\triangleq
+\mathbb{E}_{\eta \sim q_{\lambda}}\left[
+  \log \pi\left( \phi^{-1}\left( \eta \right) \right)
+\right]
++ \mathbb{H}\left(q_{\lambda}\right)
++ \log \lvert J_{\phi^{-1}}\left(\eta\right) \rvert \\
+&=
+\mathbb{E}_{\eta \sim q_{\lambda}}\left[
+  \log \pi\left( \phi^{-1}\left( \eta \right) \right)
+\right]
++
+\mathbb{E}_{\eta \sim q_{\lambda}}\left[
+  - \log q_{\lambda}\left( \eta \right) \lvert J_{\phi}\left(\eta\right) \rvert
+\right] \\
+&=
+\mathbb{E}_{z \sim q_{\phi,\lambda}}\left[ \log \pi\left(z\right) \right]
++
+\mathbb{H}\left(q_{\phi,\lambda}\right)
+\end{aligned}
+```
+
+The idea of using the reparameterization gradient estimator for variational inference was first 
+coined by Titsias and Lázaro-Gredilla (2014).
+Bijectors were generalized by Dillon *et al.* (2017) and later implemented in Julia by
+Fjelde *et al.* (2017).
+
+
+```@docs
+ADVI
+```
+
+# References
+1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research.
+2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR.
+3. Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604.
+4. Fjelde, T. E., Xu, K., Tarek, M., Yalburgi, S., & Ge, H. (2020, February). Bijectors. jl: Flexible transformations for probability distributions. In Symposium on Advances in Approximate Bayesian Inference (pp. 1-17). PMLR.
+
+
diff --git a/docs/src/families.md b/docs/src/families.md
new file mode 100644
index 000000000..f203cf18a
--- /dev/null
+++ b/docs/src/families.md
@@ -0,0 +1,58 @@
+
+# [Variational Families](@id families)
+
+## Location-Scale Variational Family
+
+### Description
+The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as
+```math
+z = C u + m,
+```
+where ``C`` is the *scale* and ``m`` is the location variational parameter.
+This family encompases many 
+
+
+### Constructors
+
+```@docs
+VILocationScale
+```
+
+```@docs
+VIFullRankGaussian
+VIMeanFieldGaussian
+```
+
+### Examples
+
+A full-rank variational family can be formed by choosing
+```@repl locscale
+using AdvancedVI, LinearAlgebra
+μ = zeros(2);
+L = diagm(ones(2)) |> LowerTriangular;
+```
+
+A mean-field variational family can be formed by choosing 
+```@repl locscale
+μ = zeros(2);
+L = ones(2) |> Diagonal;
+```
+
+Gaussian variational family:
+```@repl locscale
+q = VIFullRankGaussian(μ, L)
+q = VIMeanFieldGaussian(μ, L)
+```
+
+Sudent-T Variational Family:
+
+```@repl locscale
+ν = 3
+q = VILocationScale(μ, L, StudentT(ν))
+```
+
+Multivariate Laplace family:
+```@repl locscale
+q = VILocationScale(μ, L, Laplace())
+```
+
diff --git a/docs/src/index.md b/docs/src/index.md
new file mode 100644
index 000000000..be3269217
--- /dev/null
+++ b/docs/src/index.md
@@ -0,0 +1,14 @@
+```@meta
+CurrentModule = AdvancedVI
+```
+
+# AdvancedVI
+
+Documentation for [AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl).
+
+```@index
+```
+
+```@autodocs
+Modules = [AdvancedVI]
+```

From 619b1c05eaf669491f82406becb9a31dba1871cc Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 15 Jul 2023 22:34:32 +0100
Subject: [PATCH 058/206] remove unused epsilon argument in location scale

---
 src/distributions/location_scale.jl | 40 +++++++++++++++++------------
 1 file changed, 23 insertions(+), 17 deletions(-)

diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index dc9c1b279..5eb371ad4 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -1,31 +1,31 @@
 
 """
+    VILocationScale{L,R,D}(location::L, scale::S, dist::D) <: ContinuousMultivariateDistribution
 
 The [location scale] variational family broadly represents various variational
 families using `location` and `scale` variational parameters.
 
-Multivariate Student-t variational family with ``\\nu``-degrees of freedom can
-be constructed as:
+It generally represents any distribution for which the sampling path can be
+represented as the following:
 ```julia
-q₀ = VILocationScale(μ, L, StudentT(ν), eps(Float32))
+  d = length(location)
+  u = rand(dist, d)
+  z = scale*u + location
 ```
-
 """
-struct VILocationScale{L, S, D, R} <: ContinuousMultivariateDistribution
+struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution
     location::L
     scale   ::S
     dist    ::D
-    epsilon ::R
 
     function VILocationScale(μ::AbstractVector{<:Real},
                              L::Union{<:AbstractTriangular{<:Real},
                                       <:Diagonal{<:Real}},
-                             q_base::ContinuousUnivariateDistribution,
-                             epsilon::Real)
+                             q_base::ContinuousUnivariateDistribution)
         # Restricting all the arguments to have the same types creates problems 
         # with dual-variable-based AD frameworks.
         @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2))
-        new{typeof(μ), typeof(L), typeof(q_base), typeof(epsilon)}(μ, L, q_base, epsilon)
+        new{typeof(μ), typeof(L), typeof(q_base)}(μ, L, q_base)
     end
 end
 
@@ -76,16 +76,22 @@ function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real})
     return x += location
 end
 
-function VIFullRankGaussian(μ::AbstractVector{T},
-                            L::AbstractTriangular{T},
-                            epsilon::Real = eps(T)) where {T <: Real}
+"""
+    VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T})
+
+This constructs a multivariate Gaussian distribution with a full rank covariance matrix.
+"""
+function VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) where {T <: Real}
     q_base = Normal{T}(zero(T), one(T))
-    VILocationScale(μ, L, q_base, epsilon)
+    VILocationScale(μ, L, q_base)
 end
 
-function VIMeanFieldGaussian(μ::AbstractVector{T},
-                             L::Diagonal{T},
-                             epsilon::Real = eps(T)) where {T <: Real}
+"""
+    VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T})
+
+This constructs a multivariate Gaussian distribution with a diagonal covariance matrix.
+"""
+function VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) where {T <: Real}
     q_base = Normal{T}(zero(T), one(T))
-    VILocationScale(μ, L, q_base, epsilon)
+    VILocationScale(μ, L, q_base)
 end

From f1c02f02909ff15ac2ddc6276af8589c97cfedf8 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 15 Jul 2023 22:39:16 +0100
Subject: [PATCH 059/206] add project file for documenter

---
 docs/Project.toml | 7 +++++++
 1 file changed, 7 insertions(+)
 create mode 100644 docs/Project.toml

diff --git a/docs/Project.toml b/docs/Project.toml
new file mode 100644
index 000000000..fc885857a
--- /dev/null
+++ b/docs/Project.toml
@@ -0,0 +1,7 @@
+[deps]
+AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
+Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
+Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
+
+[compat]
+Documenter = "0.26"
\ No newline at end of file

From b0f259a4c32ad293cf0edd236b42b132d7e959b5 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 16 Jul 2023 02:55:03 +0100
Subject: [PATCH 060/206] refactor STL gradient calculation to use multiple
 dispatch

---
 src/AdvancedVI.jl                   |  6 +-
 src/distributions/location_scale.jl | 16 ++---
 src/objectives/elbo/advi.jl         | 97 +++++++++++++++++++++++------
 src/objectives/elbo/entropy.jl      | 11 +---
 test/advi_locscale.jl               |  6 +-
 test/models/normal.jl               | 51 +++++++++++++++
 test/models/normallognormal.jl      |  4 +-
 test/models/utils.jl                |  8 +++
 8 files changed, 160 insertions(+), 39 deletions(-)
 create mode 100644 test/models/normal.jl
 create mode 100644 test/models/utils.jl

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index e3dd85a89..9f93885c1 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -73,8 +73,9 @@ abstract type AbstractControlVariate end
 function update end
 update(::Nothing, ::Nothing) = (nothing, nothing)
 
-include("objectives/elbo/advi.jl")
+# entropy.jl must preceed advi.jl
 include("objectives/elbo/entropy.jl")
+include("objectives/elbo/advi.jl")
 
 export
     ELBO,
@@ -82,13 +83,14 @@ export
     ADVIEnergy,
     ClosedFormEntropy,
     StickingTheLandingEntropy,
-    MonteCarloEntropy
+    FullMonteCarloEntropy
 
 # Variational Families
 
 include("distributions/location_scale.jl")
 
 export
+    VILocationScale,
     VIFullRankGaussian,
     VIMeanFieldGaussian
 
diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index 5eb371ad4..e901e8deb 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -1,8 +1,8 @@
 
 """
-    VILocationScale{L,R,D}(location::L, scale::S, dist::D) <: ContinuousMultivariateDistribution
+    VILocationScale(location, scale, dist) <: ContinuousMultivariateDistribution
 
-The [location scale] variational family broadly represents various variational
+The location scale variational family broadly represents various variational
 families using `location` and `scale` variational parameters.
 
 It generally represents any distribution for which the sampling path can be
@@ -18,14 +18,14 @@ struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution
     scale   ::S
     dist    ::D
 
-    function VILocationScale(μ::AbstractVector{<:Real},
-                             L::Union{<:AbstractTriangular{<:Real},
+    function VILocationScale(location::AbstractVector{<:Real},
+                             scale::Union{<:AbstractTriangular{<:Real},
                                       <:Diagonal{<:Real}},
-                             q_base::ContinuousUnivariateDistribution)
+                             dist::ContinuousUnivariateDistribution)
         # Restricting all the arguments to have the same types creates problems 
         # with dual-variable-based AD frameworks.
-        @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2))
-        new{typeof(μ), typeof(L), typeof(q_base)}(μ, L, q_base)
+        @assert (length(location) == size(scale,1)) && (length(location) == size(scale,2))
+        new{typeof(location), typeof(scale), typeof(dist)}(location, scale, dist)
     end
 end
 
@@ -87,7 +87,7 @@ function VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) whe
 end
 
 """
-    VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T})
+    VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T})
 
 This constructs a multivariate Gaussian distribution with a diagonal covariance matrix.
 """
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index e965ea734..e4e933277 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -1,9 +1,23 @@
 
 """
-    ADVI
+    ADVI(
+        prob,
+        n_samples::Int;
+        entropy::AbstractEntropyEstimator = ClosedFormEntropy(),
+        cv::Union{<:AbstractControlVariate, Nothing} = nothing,
+        b = Bijectors.identity
+    )
 
 Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective.
 
+# Arguments
+- `prob`: An object that implements the order `K == 0` `LogDensityProblems` interface.
+    - `logdensity` must be differentiable by the selected AD backend.
+- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO.
+- `entropy`: The estimator for the entropy term.
+- `cv`: A control variate
+- `b`: A bijector mapping the support of the base distribution to that of `prob`.
+
 # Requirements
 - ``q_{\\lambda}`` implements `rand`.
 - ``\\pi`` must be differentiable
@@ -39,40 +53,87 @@ end
 Base.show(io::IO, advi::ADVI) =
     print(io, "ADVI(entropy=$(advi.entropy), cv=$(advi.cv), n_samples=$(advi.n_samples))")
 
-skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy)
-
 init(advi::ADVI) = init(advi.cv)
 
-function (advi::ADVI)(q_η::ContinuousMultivariateDistribution;
-                      rng       ::AbstractRNG    = default_rng(),
-                      n_samples ::Int            = advi.n_samples,
-                      ηs        ::AbstractMatrix = rand(rng, q_η, n_samples),
-                      q_η_entropy::ContinuousMultivariateDistribution = q_η)
+function (advi::ADVI)(
+    rng::AbstractRNG,
+    q_η::ContinuousMultivariateDistribution,
+    ηs ::AbstractMatrix
+)
+    n_samples = size(ηs, 2)
     𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ
         zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b, ηᵢ)
         (advi.ℓπ(zᵢ) + logdetjacᵢ) / n_samples
     end
-    ℍ  = advi.entropy(q_η_entropy, ηs)
+    ℍ  = advi.entropy(q_η, ηs)
     𝔼ℓ + ℍ
 end
 
-function estimate_gradient(
+"""
+    (advi::ADVI)(
+        q_η::ContinuousMultivariateDistribution;
+        rng::AbstractRNG = Random.default_rng(),
+        n_samples::Int = advi.n_samples
+    )
+
+Evaluate the ELBO using the ADVI formulation.
+
+# Arguments
+- `q_η`: Variational approximation before applying a bijector (unconstrained support).
+- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO.
+
+"""
+function (advi::ADVI)(
+    q_η::ContinuousMultivariateDistribution;
+    rng::AbstractRNG = default_rng(),
+    n_samples::Int = advi.n_samples
+)
+    ηs = rand(rng, q_η, n_samples)
+    advi(rng, q_η, ηs)
+end
+
+function estimate_advi_gradient_maybe_stl!(
     rng::AbstractRNG,
     adbackend::AbstractADType,
-    advi::ADVI,
-    est_state,
+    advi::ADVI{P, B, StickingTheLandingEntropy, CV},
     λ::Vector{<:Real},
     restructure,
-    out::DiffResults.MutableDiffResult)
-
-    # Gradient-stopping for computing the sticking-the-landing control variate
-    q_η_stop = skip_entropy_gradient(advi.entropy) ? restructure(λ) : nothing
+    out::DiffResults.MutableDiffResult
+) where {P, B, CV}
+    q_η_stop = restructure(λ)
+    grad!(adbackend, λ, out) do λ′
+        q_η = restructure(λ′)
+        ηs  = rand(rng, q_η, advi.n_samples)
+        -advi(rng, q_η_stop, ηs)
+    end
+end
 
+function estimate_advi_gradient_maybe_stl!(
+    rng::AbstractRNG,
+    adbackend::AbstractADType,
+    advi::ADVI{P, B, <:Union{ClosedFormEntropy, FullMonteCarloEntropy}, CV},
+    λ::Vector{<:Real},
+    restructure,
+    out::DiffResults.MutableDiffResult
+) where {P, B, CV}
     grad!(adbackend, λ, out) do λ′
         q_η = restructure(λ′)
-        q_η_entropy = skip_entropy_gradient(advi.entropy) ? q_η_stop : q_η
-        -advi(q_η; rng, q_η_entropy)
+        ηs  = rand(rng, q_η, advi.n_samples)
+        -advi(rng, q_η, ηs)
     end
+end
+
+function estimate_gradient(
+    rng::AbstractRNG,
+    adbackend::AbstractADType,
+    advi::ADVI,
+    est_state,
+    λ::Vector{<:Real},
+    restructure,
+    out::DiffResults.MutableDiffResult
+)
+    estimate_advi_gradient_maybe_stl!(
+        rng, adbackend, advi, λ, restructure, out)
     nelbo = DiffResults.value(out)
     stat  = (elbo=-nelbo,)
 
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index 994bdd4f1..7f37b619b 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -7,11 +7,9 @@ end
 
 skip_entropy_gradient(::ClosedFormEntropy) = false
 
-struct MonteCarloEntropy{IsStickingTheLanding} <: AbstractEntropyEstimator end
+abstract type MonteCarloEntropy <: AbstractEntropyEstimator end
 
-MonteCarloEntropy() = MonteCarloEntropy{false}()
-
-Base.show(io::IO, entropy::MonteCarloEntropy{false}) = print(io, "MonteCarloEntropy()")
+struct FullMonteCarloEntropy <: MonteCarloEntropy end
 
 """
     StickingTheLandingEntropy()
@@ -44,11 +42,8 @@ and higher variance when ``\\pi \\not\\approx q_{\\lambda}``.
 # Reference
 1. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30.
 """
-StickingTheLandingEntropy() = MonteCarloEntropy{true}()
-
-skip_entropy_gradient(::MonteCarloEntropy{IsStickingTheLanding}) where {IsStickingTheLanding} = IsStickingTheLanding
 
-Base.show(io::IO, entropy::MonteCarloEntropy{true}) = print(io, "StickingTheLandingEntropy()")
+struct StickingTheLandingEntropy <: MonteCarloEntropy end
 
 function (::MonteCarloEntropy)(q, ηs::AbstractMatrix)
     n_samples = size(ηs, 2)
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 40e5dace7..2f19ca611 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -19,6 +19,8 @@ struct TestModel{M,L,S}
 end
 
 include("models/normallognormal.jl")
+include("models/normal.jl")
+include("models/utils.jl")
 
 @testset "advi" begin
     @testset "locscale" begin
@@ -27,11 +29,13 @@ include("models/normallognormal.jl")
             (modelname, modelconstr) ∈ Dict(
                 :NormalLogNormalMeanField => normallognormal_meanfield,
                 :NormalLogNormalFullRank  => normallognormal_fullrank,
+                :NormalMeanField          => normal_meanfield,
+                :NormalFullRank           => normal_fullrank,
             ),
             (objname, objective) ∈ Dict(
                 :ADVIClosedFormEntropy  => (model, b, M) -> ADVI(model, M; b),
                 :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, entropy = StickingTheLandingEntropy()),
-                :ADVIFullMonteCarlo     => (model, b, M) -> ADVI(model, M; b, entropy = MonteCarloEntropy()),
+                :ADVIFullMonteCarlo     => (model, b, M) -> ADVI(model, M; b, entropy = FullMonteCarloEntropy()),
             ),
             (adbackname, adbackend) ∈ Dict(
                 :ForwarDiff  => AutoForwardDiff(),
diff --git a/test/models/normal.jl b/test/models/normal.jl
new file mode 100644
index 000000000..a677af93c
--- /dev/null
+++ b/test/models/normal.jl
@@ -0,0 +1,51 @@
+
+struct TestMvNormal{M,S}
+    μ::M
+    Σ::S
+end
+
+function LogDensityProblems.logdensity(model::TestMvNormal, θ)
+    @unpack μ, Σ = model
+    logpdf(MvNormal(μ, Σ), θ)
+end
+
+function LogDensityProblems.dimension(model::TestMvNormal)
+    length(model.μ)
+end
+
+function LogDensityProblems.capabilities(::Type{<:TestMvNormal})
+    LogDensityProblems.LogDensityOrder{0}()
+end
+
+function Bijectors.bijector(model::TestMvNormal)
+    identity
+end
+
+function normal_fullrank(realtype; rng = default_rng())
+    n_dims = 5
+
+    μ  = randn(rng, realtype, n_dims)
+    L₀ = sample_cholesky(rng, n_dims)
+    ϵ  = eps(realtype)*10
+    Σ  = (L₀*L₀' + ϵ*I) |> Hermitian
+
+    Σ_chol = cholesky(Σ)
+    model  = TestMvNormal(μ, PDMats.PDMat(Σ, Σ_chol))
+
+    L = Σ_chol.L |> LowerTriangular
+
+    TestModel(model, μ, L, n_dims, false)
+end
+
+function normal_meanfield(realtype; rng = default_rng())
+    n_dims = 5
+
+    μ = randn(rng, realtype, n_dims)
+    σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
+
+    model = TestMvNormal(μ, PDMats.PDiagMat(σ))
+
+    L = σ |> Diagonal
+
+    TestModel(model, μ, L, n_dims, true)
+end
diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl
index 18e8b4a34..ca8c9a4d3 100644
--- a/test/models/normallognormal.jl
+++ b/test/models/normallognormal.jl
@@ -32,8 +32,8 @@ function normallognormal_fullrank(realtype; rng = default_rng())
     μ_x  = randn(rng, realtype)
     σ_x  = ℯ
     μ_y  = randn(rng, realtype, n_dims)
-    L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular
-    ϵ    = realtype(n_dims*2)
+    L₀_y = sample_cholesky(rng, n_dims)
+    ϵ    = eps(realtype)*10
     Σ_y  = (L₀_y*L₀_y' + ϵ*I) |> Hermitian
 
     model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y))
diff --git a/test/models/utils.jl b/test/models/utils.jl
new file mode 100644
index 000000000..c1a9a407e
--- /dev/null
+++ b/test/models/utils.jl
@@ -0,0 +1,8 @@
+
+function sample_cholesky(rng::AbstractRNG, n_dims::Int)
+    A   = randn(rng, n_dims, n_dims) 
+    L   = tril(A)
+    idx = diagind(L)
+    @. L[idx] = log(exp(L[idx]) + 1)
+    L |> LowerTriangular
+end

From b72c2585a1d3e461d9903884d16d9b019c11e828 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 16 Jul 2023 03:49:08 +0100
Subject: [PATCH 061/206] fix type bugs, relax test threshold for the exact
 inference tests

---
 test/advi_locscale.jl          | 8 ++++----
 test/models/normal.jl          | 5 ++---
 test/models/normallognormal.jl | 5 ++---
 test/models/utils.jl           | 4 ++--
 4 files changed, 10 insertions(+), 12 deletions(-)

diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 2f19ca611..1552be5e4 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -72,7 +72,7 @@ include("models/utils.jl")
                 Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
                 q, stats  = optimize(
                     obj, q₀, T;
-                    optimizer = Optimisers.Adam(1e-3),
+                    optimizer = Optimisers.Adam(1e-2),
                     progress  = PROGRESS,
                     rng       = rng,
                     adbackend = adbackend,
@@ -82,7 +82,7 @@ include("models/utils.jl")
                 L  = q.scale
                 Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
 
-                @test Δλ ≤ Δλ₀/√T
+                @test Δλ ≤ Δλ₀/T^(1/4)
                 @test eltype(μ) == eltype(μ_true)
                 @test eltype(L) == eltype(L_true)
             end
@@ -91,7 +91,7 @@ include("models/utils.jl")
                 rng      = Philox4x(UInt64, seed, 8)
                 q, stats = optimize(
                     obj, q₀, T;
-                    optimizer = Optimisers.Adam(1e-3),
+                    optimizer = Optimisers.Adam(realtype(1e-2)),
                     progress  = PROGRESS,
                     rng       = rng,
                     adbackend = adbackend,
@@ -102,7 +102,7 @@ include("models/utils.jl")
                 rng_repl = Philox4x(UInt64, seed, 8)
                 q, stats = optimize(
                     obj, q₀, T;
-                    optimizer = Optimisers.Adam(1e-3),
+                    optimizer = Optimisers.Adam(realtype(1e-2)),
                     progress  = PROGRESS,
                     rng       = rng_repl,
                     adbackend = adbackend,
diff --git a/test/models/normal.jl b/test/models/normal.jl
index a677af93c..f60ad5f38 100644
--- a/test/models/normal.jl
+++ b/test/models/normal.jl
@@ -25,9 +25,8 @@ function normal_fullrank(realtype; rng = default_rng())
     n_dims = 5
 
     μ  = randn(rng, realtype, n_dims)
-    L₀ = sample_cholesky(rng, n_dims)
-    ϵ  = eps(realtype)*10
-    Σ  = (L₀*L₀' + ϵ*I) |> Hermitian
+    L₀ = sample_cholesky(rng, realtype, n_dims)
+    Σ  = L₀*L₀' |> Hermitian
 
     Σ_chol = cholesky(Σ)
     model  = TestMvNormal(μ, PDMats.PDMat(Σ, Σ_chol))
diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl
index ca8c9a4d3..cab73ccee 100644
--- a/test/models/normallognormal.jl
+++ b/test/models/normallognormal.jl
@@ -32,9 +32,8 @@ function normallognormal_fullrank(realtype; rng = default_rng())
     μ_x  = randn(rng, realtype)
     σ_x  = ℯ
     μ_y  = randn(rng, realtype, n_dims)
-    L₀_y = sample_cholesky(rng, n_dims)
-    ϵ    = eps(realtype)*10
-    Σ_y  = (L₀_y*L₀_y' + ϵ*I) |> Hermitian
+    L₀_y = sample_cholesky(rng, realtype, n_dims)
+    Σ_y  = L₀_y*L₀_y' |> Hermitian
 
     model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y))
 
diff --git a/test/models/utils.jl b/test/models/utils.jl
index c1a9a407e..3d483c46d 100644
--- a/test/models/utils.jl
+++ b/test/models/utils.jl
@@ -1,6 +1,6 @@
 
-function sample_cholesky(rng::AbstractRNG, n_dims::Int)
-    A   = randn(rng, n_dims, n_dims) 
+function sample_cholesky(rng::AbstractRNG, type::Type, n_dims::Int)
+    A   = randn(rng, type, n_dims, n_dims) 
     L   = tril(A)
     idx = diagind(L)
     @. L[idx] = log(exp(L[idx]) + 1)

From a8df9eb8b635e9805e3f307b7b5b64ccb4f1f970 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 01:33:15 +0100
Subject: [PATCH 062/206] refactor derivative utils to match
 NormalizingFlows.jl with extras

---
 Project.toml                    | 20 +++++++--
 ext/AdvancedVIEnzymeExt.jl      | 26 +++++++++++
 ext/AdvancedVIForwardDiffExt.jl | 29 ++++++++++++
 ext/AdvancedVIReverseDiffExt.jl | 23 ++++++++++
 ext/AdvancedVIZygoteExt.jl      | 24 ++++++++++
 src/AdvancedVI.jl               | 79 ++++++++++++++++++---------------
 src/compat/enzyme.jl            | 16 -------
 src/compat/reversediff.jl       | 19 --------
 src/compat/zygote.jl            | 13 ------
 src/grad.jl                     | 30 -------------
 src/objectives/elbo/advi.jl     |  6 ++-
 test/ad.jl                      |  7 ++-
 12 files changed, 167 insertions(+), 125 deletions(-)
 create mode 100644 ext/AdvancedVIEnzymeExt.jl
 create mode 100644 ext/AdvancedVIForwardDiffExt.jl
 create mode 100644 ext/AdvancedVIReverseDiffExt.jl
 create mode 100644 ext/AdvancedVIZygoteExt.jl
 delete mode 100644 src/compat/enzyme.jl
 delete mode 100644 src/compat/reversediff.jl
 delete mode 100644 src/compat/zygote.jl
 delete mode 100644 src/grad.jl

diff --git a/Project.toml b/Project.toml
index cf698f7a6..ab00d6741 100644
--- a/Project.toml
+++ b/Project.toml
@@ -6,10 +6,10 @@ version = "0.2.4"
 ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
 Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
 Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
+DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
 FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
-ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
@@ -21,24 +21,36 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
 SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
 StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
-Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
+
+[weakdeps]
+Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
+ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
+Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
 
 [compat]
 ADTypes = "0.1"
 Bijectors = "0.11, 0.12, 0.13"
+DiffResults = "1.0.3"
 Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
 DocStringExtensions = "0.8, 0.9"
-ForwardDiff = "0.10.3"
+ForwardDiff = "0.10.25"
+LogDensityProblems = "2.1.1"
+Optimisers = "0.2.16"
 ProgressMeter = "1.0.0"
 Requires = "0.5, 1.0"
+ReverseDiff = "1.14"
 StatsBase = "0.32, 0.33, 0.34"
 StatsFuns = "0.8, 0.9, 1"
-Tracker = "0.2.3"
 julia = "1.6"
 
 [extras]
+Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
+ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
 
 [targets]
 test = ["Pkg", "Test"]
diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl
new file mode 100644
index 000000000..8333299f0
--- /dev/null
+++ b/ext/AdvancedVIEnzymeExt.jl
@@ -0,0 +1,26 @@
+
+module AdvancedVIEnzymeExt
+
+if isdefined(Base, :get_extension)
+    using Enzyme
+    using AdvancedVI
+    using AdvancedVI: ADTypes, DiffResults
+else
+    using ..Enzyme
+    using ..AdvancedVI
+    using ..AdvancedVI: ADTypes, DiffResults
+end
+
+# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916)
+function AdvancedVI.value_and_gradient!(
+    ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
+) where {T<:Real}
+    y = f(θ)
+    DiffResults.value!(out, y)
+    ∇θ = DiffResults.gradient(out)
+    fill!(∇θ, zero(T))
+    Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ))
+    return out
+end
+
+end
diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl
new file mode 100644
index 000000000..e6b03af21
--- /dev/null
+++ b/ext/AdvancedVIForwardDiffExt.jl
@@ -0,0 +1,29 @@
+
+module AdvancedVIForwardDiffExt
+
+if isdefined(Base, :get_extension)
+    using ForwardDiff
+    using AdvancedVI
+    using AdvancedVI: ADTypes, DiffResults
+else
+    using ..ForwardDiff
+    using ..AdvancedVI
+    using ..AdvancedVI: ADTypes, DiffResults
+end
+
+# extract chunk size from AutoForwardDiff
+getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize
+function AdvancedVI.value_and_gradient!(
+    ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
+) where {T<:Real}
+    chunk_size = getchunksize(ad)
+    config = if isnothing(chunk_size)
+        ForwardDiff.GradientConfig(f, θ)
+    else
+        ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size))
+    end
+    ForwardDiff.gradient!(out, f, θ, config)
+    return out
+end
+
+end
diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl
new file mode 100644
index 000000000..fd7fbaabc
--- /dev/null
+++ b/ext/AdvancedVIReverseDiffExt.jl
@@ -0,0 +1,23 @@
+
+module AdvancedVIReverseDiffExt
+
+if isdefined(Base, :get_extension)
+    using AdvancedVI
+    using AdvancedVI: ADTypes, DiffResults
+    using ReverseDiff
+else
+    using ..AdvancedVI
+    using ..AdvancedVI: ADTypes, DiffResults
+    using ..ReverseDiff
+end
+
+# ReverseDiff without compiled tape
+function AdvancedVI.value_and_gradient!(
+    ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
+) where {T<:Real}
+    tp = ReverseDiff.GradientTape(f, θ)
+    ReverseDiff.gradient!(out, tp, θ)
+    return out
+end
+
+end
diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl
new file mode 100644
index 000000000..b447d0718
--- /dev/null
+++ b/ext/AdvancedVIZygoteExt.jl
@@ -0,0 +1,24 @@
+
+module AdvancedVIZygoteExt
+
+if isdefined(Base, :get_extension)
+    using AdvancedVI
+    using AdvancedVI: ADTypes, DiffResults
+    using Zygote
+else
+    using ..AdvancedVI
+    using ..AdvancedVI: ADTypes, DiffResults
+    using ..Zygote
+end
+
+function AdvancedVI.value_and_gradient!(
+    ad::ADTypes.AutoZygote, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
+) where {T<:Real}
+    y, back = Zygote.pullback(f, θ)
+    ∇θ = back(one(T))
+    DiffResults.value!(out, y)
+    DiffResults.gradient!(out, first(∇θ))
+    return out
+end
+
+end
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 9f93885c1..697f3c83c 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -21,9 +21,9 @@ using LinearAlgebra: AbstractTriangular
 
 using LogDensityProblems
 
-using ADTypes
+using ADTypes, DiffResults
 using ADTypes: AbstractADType
-using ForwardDiff, Tracker
+
 
 using FillArrays
 using PDMats
@@ -34,29 +34,23 @@ using StatsBase: entropy
 
 const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0")))
 
-using Requires
-function __init__()
-    @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
-        include("compat/zygote.jl")
-    end
-    @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
-        include("compat/reversediff.jl")
-    end
-    @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
-        include("compat/enzyme.jl")
-    end
-end
-
+# derivatives
 """
-    grad!(f, λ, out)
-
-Computes the gradients of the objective f. Default implementation is provided for 
-`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`.
-This implicitly also gives a default implementation of `optimize!`.
+    value_and_gradient!(
+        ad::ADTypes.AbstractADType,
+        f,
+        θ::AbstractVector{T},
+        out::DiffResults.MutableDiffResult
+    ) where {T<:Real}
+
+Compute the value and gradient of a function `f` at `θ` using the automatic
+differentiation backend `ad`.  The result is stored in `out`. 
+The function `f` must return a scalar value. The gradient is stored in `out` as a
+vector of the same length as `θ`.
 """
-function grad! end
+function value_and_gradient! end
 
-include("grad.jl")
+export value_and_gradient!
 
 # estimators
 abstract type AbstractVariationalObjective end
@@ -94,21 +88,8 @@ export
     VIFullRankGaussian,
     VIMeanFieldGaussian
 
-"""
-    optimize(model, alg::VariationalInference)
-    optimize(model, alg::VariationalInference, q::VariationalPosterior)
-    optimize(model, alg::VariationalInference, getq::Function, θ::AbstractArray)
-
-Constructs the variational posterior from the `model` and performs the optimization
-following the configuration of the given `VariationalInference` instance.
-
-# Arguments
-- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations
-- `alg`: the VI algorithm used
-- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists.
-- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior`
-- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior
-"""
+# Optimization Routine
+
 function optimize end
 
 include("optimize.jl")
@@ -117,4 +98,28 @@ export optimize
 
 include("utils.jl")
 
+
+# optional dependencies 
+if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base
+    using Requires
+end
+
+using Requires
+function __init__()
+    @static if !isdefined(Base, :get_extension)
+        @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
+            include("../ext/AdvancedVIZygoteExt.jl")
+        end
+        @require ForwardDiff = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
+            include("../ext/AdvancedVIForwardDiffExt.jl")
+        end
+        @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
+            include("../ext/AdvancedVIReverseDiffExt.jl")
+        end
+        @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
+            include("../ext/AdvancedVIEnzymeExt.jl")
+        end
+    end
+end
 end # module
+
diff --git a/src/compat/enzyme.jl b/src/compat/enzyme.jl
deleted file mode 100644
index cab50862e..000000000
--- a/src/compat/enzyme.jl
+++ /dev/null
@@ -1,16 +0,0 @@
-
-function AdvancedVI.grad!(
-    f::Function,
-    ::AutoEnzyme,
-    λ::AbstractVector{<:Real},
-    out::DiffResults.MutableDiffResult,
-    )
-    # Use `Enzyme.ReverseWithPrimal` once it is released:
-    # https://github.com/EnzymeAD/Enzyme.jl/pull/598
-    y = f(λ)
-    DiffResults.value!(out, y)
-    dy = DiffResults.gradient(out)
-    fill!(dy, 0)
-    Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy))
-    return out
-end
diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl
deleted file mode 100644
index 4d8f87d8c..000000000
--- a/src/compat/reversediff.jl
+++ /dev/null
@@ -1,19 +0,0 @@
-using .ReverseDiff: compile, GradientTape
-using .ReverseDiff.DiffResults: GradientResult
-
-tape(f, x) = GradientTape(f, x)
-function taperesult(f, x)
-    return tape(f, x), GradientResult(x)
-end
-
-# Precompiled tapes are not properly supported yet.
-function AdvancedVI.grad!(
-    f::Function,
-    ::AutoReverseDiff,
-    λ::AbstractVector{<:Real},
-    out::DiffResults.MutableDiffResult,
-    )
-    tp = tape(f, λ)
-    ReverseDiff.gradient!(out, tp, λ)
-    return out
-end
diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl
deleted file mode 100644
index f1a29b87f..000000000
--- a/src/compat/zygote.jl
+++ /dev/null
@@ -1,13 +0,0 @@
-
-function AdvancedVI.grad!(
-    f::Function,
-    ::AutoZygote,
-    λ::AbstractVector{<:Real},
-    out::DiffResults.MutableDiffResult,
-    )
-    y, back = Zygote.pullback(f, λ)
-    dy = first(back(1.0))
-    DiffResults.value!(out, y)
-    DiffResults.gradient!(out, dy)
-    return out
-end
diff --git a/src/grad.jl b/src/grad.jl
deleted file mode 100644
index e68e16234..000000000
--- a/src/grad.jl
+++ /dev/null
@@ -1,30 +0,0 @@
-
-# default implementations
-function grad!(
-    f::Function,
-    adtype::AutoForwardDiff{chunksize},
-    λ::AbstractVector{<:Real},
-    out::DiffResults.MutableDiffResult
-) where {chunksize}
-    # Set chunk size and do ForwardMode.
-    config = if isnothing(chunksize)
-        ForwardDiff.GradientConfig(f, λ)
-    else
-        ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunksize))
-    end
-    ForwardDiff.gradient!(out, f, λ, config)
-end
-
-function grad!(
-    f::Function,
-    ::AutoTracker,
-    λ::AbstractVector{<:Real},
-    out::DiffResults.MutableDiffResult
-)
-    λ_tracked = Tracker.param(λ)
-    y = f(λ_tracked)
-    Tracker.back!(y, 1.0)
-
-    DiffResults.value!(out, Tracker.data(y))
-    DiffResults.gradient!(out, Tracker.grad(λ_tracked))
-end
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index e4e933277..d308db0a1 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -101,11 +101,12 @@ function estimate_advi_gradient_maybe_stl!(
     out::DiffResults.MutableDiffResult
 ) where {P, B, CV}
     q_η_stop = restructure(λ)
-    grad!(adbackend, λ, out) do λ′
+    f(λ′) = begin
         q_η = restructure(λ′)
         ηs  = rand(rng, q_η, advi.n_samples)
         -advi(rng, q_η_stop, ηs)
     end
+    grad!(adbackend, f, λ, out)
 end
 
 function estimate_advi_gradient_maybe_stl!(
@@ -116,11 +117,12 @@ function estimate_advi_gradient_maybe_stl!(
     restructure,
     out::DiffResults.MutableDiffResult
 ) where {P, B, CV}
-    grad!(adbackend, λ, out) do λ′
+    f(λ′) = begin
         q_η = restructure(λ′)
         ηs  = rand(rng, q_η, advi.n_samples)
         -advi(rng, q_η, ηs)
     end
+    value_and_gradient!(adbackend, f, λ, out)
 end
 
 function estimate_gradient(
diff --git a/test/ad.jl b/test/ad.jl
index 1efa536b2..9df26d9f5 100644
--- a/test/ad.jl
+++ b/test/ad.jl
@@ -9,15 +9,14 @@ using ADTypes
           :ReverseDiff => AutoReverseDiff(),
           :Zygote      => AutoZygote(),
           :Tracker     => AutoTracker(),
-          :Enzyme      => AutoEnzyme(),
+          # :Enzyme      => AutoEnzyme(), # Currently not tested against.
         )
         D = 10
         A = randn(D, D)
         λ = randn(D)
         grad_buf = DiffResults.GradientResult(λ)
-        AdvancedVI.grad!(adsymbol, λ, grad_buf) do λ′
-            λ′'*A*λ′ / 2
-        end
+        f(λ′) = λ′'*A*λ′ / 2
+        AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf)
         ∇ = DiffResults.gradient(grad_buf)
         f = DiffResults.value(grad_buf)
         @test ∇ ≈ (A + A')*λ/2

From e8db6a7ac62d1916969aaaeea336677ba19eafa0 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 01:34:14 +0100
Subject: [PATCH 063/206] add documentation, refactor optimize

---
 docs/Project.toml              |   2 +-
 docs/make.jl                   |  11 ++--
 docs/src/advi.md               |  36 ++++++++++-
 docs/src/families.md           |  32 ++++++----
 src/objectives/elbo/entropy.jl |  32 ----------
 src/optimize.jl                | 106 +++++++++++++++++++++------------
 6 files changed, 130 insertions(+), 89 deletions(-)

diff --git a/docs/Project.toml b/docs/Project.toml
index fc885857a..c625d07f2 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -4,4 +4,4 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
 
 [compat]
-Documenter = "0.26"
\ No newline at end of file
+Documenter = "0.26, 0.27"
\ No newline at end of file
diff --git a/docs/make.jl b/docs/make.jl
index d2a01d1bf..b9a8eb5f1 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -1,4 +1,5 @@
-#using AdvancedVI
+
+using AdvancedVI
 using Documenter
 
 DocMeta.setdocmeta!(
@@ -9,9 +10,9 @@ makedocs(;
     sitename = "AdvancedVI.jl",
     modules  = [AdvancedVI],
     format   = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"),
-         pages    = ["index.md",
-                     "families.md",
-                     "advi.md"],
+         pages    = ["Home"     => "index.md",
+                     "Families" => "families.md",
+                     "ADVI"     => "advi.md"],
 )
 
-deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", devbranch="main")
+deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true)
diff --git a/docs/src/advi.md b/docs/src/advi.md
index 4f4a2ecad..0597e03c3 100644
--- a/docs/src/advi.md
+++ b/docs/src/advi.md
@@ -1,5 +1,8 @@
 
 # [Automatic Differentiation Variational Inference](@id advi)
+
+# Introduction
+
 The automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective is a method for estimating the evidence lower bound between a target posterior distribution ``\pi`` and a variational approximation ``q_{\phi,\lambda}``.
 By maximizing ADVI objective, it is equivalent to solving the problem
 
@@ -17,8 +20,8 @@ bijectors will automatically match the potentially constrained support of the ta
 In particular, ADVI implicitly forms a variational approximation ``q_{\phi,\lambda}``
 from a reparameterizable distribution ``q_{\lambda}`` and a bijector ``\phi`` such that
 ```math
-z &\sim  q_{\phi,\lambda} \qquad\Leftrightarrow\qquad
-z &\stackrel{d}{=} \phi^{-1}\left(\eta\right);\quad \eta \sim q_{\lambda} 
+z \sim  q_{\phi,\lambda} \qquad\Leftrightarrow\qquad
+z \stackrel{d}{=} \phi^{-1}\left(\eta\right);\quad \eta \sim q_{\lambda} 
 ```
 ADVI provides a principled way to compute the evidence lower bound for ``q_{\phi,\lambda}``.
 
@@ -53,15 +56,44 @@ coined by Titsias and Lázaro-Gredilla (2014).
 Bijectors were generalized by Dillon *et al.* (2017) and later implemented in Julia by
 Fjelde *et al.* (2017).
 
+# The `ADVI` Objective
 
 ```@docs
 ADVI
 ```
 
+# The "Sticking the Landing" Control Variate
+The STL control variate was proposed by Roeder *et al.* (2017).
+By slightly modifying the differentiation path, it implicitly forms a control variate of the form of
+```math
+\mathrm{CV}_{\mathrm{STL}}\left(z\right) \triangleq \mathbb{H}\left(q_{\lambda}\right) + \log q_{\lambda}\left(z\right),
+```
+which has a mean of zero.
+ 
+Adding this to the closed-form entropy ELBO estimator yields the STL estimator:
+```math
+\begin{aligned}
+  \widehat{\mathrm{ELBO}}_{\mathrm{STL}}\left(\lambda\right)
+    &\triangleq \mathbb{E}\left[ \log \pi \left(z\right) \right] - \log q_{\lambda} \left(z\right) \\
+    &= \mathbb{E}\left[ \log \pi\left(z\right) \right] 
+      + \mathbb{H}\left(q_{\lambda}\right) - \mathrm{CV}_{\mathrm{STL}}\left(z\right) \\
+    &= \widehat{\mathrm{ELBO}}\left(\lambda\right)
+      - \mathrm{CV}_{\mathrm{STL}}\left(z\right),
+\end{aligned}
+```
+which has the same expectation, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``.
+The conditions for which the STL estimator results in lower variance is still an active subject for research.
+
+The STL control variate can be used by changing the entropy estimator as follows:
+```julia
+ADVI(prob, n_samples; entropy = StickingTheLanding(), b = bijector)
+```
+
 # References
 1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research.
 2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR.
 3. Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604.
 4. Fjelde, T. E., Xu, K., Tarek, M., Yalburgi, S., & Ge, H. (2020, February). Bijectors. jl: Flexible transformations for probability distributions. In Symposium on Advances in Approximate Bayesian Inference (pp. 1-17). PMLR.
+5. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30.
 
 
diff --git a/docs/src/families.md b/docs/src/families.md
index f203cf18a..d326ce7ae 100644
--- a/docs/src/families.md
+++ b/docs/src/families.md
@@ -1,5 +1,5 @@
 
-# [Variational Families](@id families)
+# Variational Families
 
 ## Location-Scale Variational Family
 
@@ -25,34 +25,42 @@ VIMeanFieldGaussian
 
 ### Examples
 
-A full-rank variational family can be formed by choosing
 ```@repl locscale
-using AdvancedVI, LinearAlgebra
+using AdvancedVI, LinearAlgebra, Distributions;
 μ = zeros(2);
-L = diagm(ones(2)) |> LowerTriangular;
-```
-
-A mean-field variational family can be formed by choosing 
-```@repl locscale
-μ = zeros(2);
-L = ones(2) |> Diagonal;
 ```
 
 Gaussian variational family:
 ```@repl locscale
+L = diagm(ones(2)) |> LowerTriangular;
 q = VIFullRankGaussian(μ, L)
+
+L = ones(2) |> Diagonal;
 q = VIMeanFieldGaussian(μ, L)
 ```
 
 Sudent-T Variational Family:
 
 ```@repl locscale
-ν = 3
-q = VILocationScale(μ, L, StudentT(ν))
+ν = 3;
+
+# Full-Rank 
+L = diagm(ones(2)) |> LowerTriangular;
+q = VILocationScale(μ, L, TDist(ν))
+
+# Mean-Field
+L = ones(2) |> Diagonal;
+q = VILocationScale(μ, L, TDist(ν))
 ```
 
 Multivariate Laplace family:
 ```@repl locscale
+# Full-Rank 
+L = diagm(ones(2)) |> LowerTriangular;
+q = VILocationScale(μ, L, Laplace())
+
+# Mean-Field
+L = ones(2) |> Diagonal;
 q = VILocationScale(μ, L, Laplace())
 ```
 
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index 7f37b619b..e9f180f5c 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -11,38 +11,6 @@ abstract type MonteCarloEntropy <: AbstractEntropyEstimator end
 
 struct FullMonteCarloEntropy <: MonteCarloEntropy end
 
-"""
-    StickingTheLandingEntropy()
-
-# Explanation
-
-The STL estimator forms a control variate of the form of
- 
-```math
-\\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right) =
-  \\mathbb{E}\\left[ -\\log q\\left(z\\right) \\right]
-  + \\log q\\left(z\\right) = \\mathbb{H}\\left(q_{\\lambda}\\right) + \\log q_{\\lambda}\\left(z\\right),
-```
-where, for the score term, the gradient is stopped from propagating.
- 
-Adding this to the closed-form entropy ELBO estimator yields the STL estimator:
-```math
-\\begin{aligned}
-  \\widehat{\\mathrm{ELBO}}_{\\mathrm{STL}}\\left(\\lambda\\right)
-    &\\triangleq \\mathbb{E}\\left[ \\log \\pi \\left(z\\right) \\right] - \\log q_{\\lambda} \\left(z\\right) \\\\
-    &= \\mathbb{E}\\left[ \\log \\pi\\left(z\\right) \\right] 
-      + \\mathbb{H}\\left(q_{\\lambda}\\right) - \\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right) \\\\
-    &= \\widehat{\\mathrm{ELBO}}\\left(\\lambda\\right)
-      - \\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right),
-\\end{aligned}
-```
-which has the same expectation, but lower variance when ``\\pi \\approx q_{\\lambda}``,
-and higher variance when ``\\pi \\not\\approx q_{\\lambda}``.
-
-# Reference
-1. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30.
-"""
-
 struct StickingTheLandingEntropy <: MonteCarloEntropy end
 
 function (::MonteCarloEntropy)(q, ηs::AbstractMatrix)
diff --git a/src/optimize.jl b/src/optimize.jl
index 8b36df04d..ef16dcceb 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -4,73 +4,105 @@ function pm_next!(pm, stats::NamedTuple)
 end
 
 """
-    optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())
+    optimize(
+        objective    ::AbstractVariationalObjective,
+        restructure,
+        λ₀           ::AbstractVector{<:Real},
+        n_max_iter   ::Int;
+        kwargs...
+    )              
 
-Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute
-the steps.
+Optimize the variational objective `objective` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `λ₀` to the function `restructure`.
+
+    optimize(
+        objective ::AbstractVariationalObjective,
+        q,
+        n_max_iter::Int;
+        kwargs...
+    )              
+
+Optimize the variational objective `objective` by estimating (stochastic) gradients, where the initial variational approximation `q₀` supports the `Optimisers.destructure` interface.
+
+# Arguments
+- `objective`: Variational Objective.
+- `λ₀`: Initial value of the variational parameters.
+- `restructure`: Function that reconstructs the variational approximation from the flattened parameters.
+- `q`: Initial variational approximation. The variational parameters must be extractable through `Optimisers.destructure`.
+- `n_max_iter`: Maximum number of iterations.
+
+# Keyword Arguments
+- `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.)
+- `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.)
+- `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.)
+- `callback!`: Callback function called after every iteration. The signature is `cb(; t, est_state, stats, restructure, λ)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If `objective` is stateful, `est_state` contains its state. (Default: `nothing`.)
+- `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.)
+
+# Returns
+- `λ`: Variational parameters optimizing the variational objective.
+- `stats`: Statistics gathered during inference.
+- `opt_state`: Final state of the optimiser.
 """
 function optimize(
-    objective ::AbstractVariationalObjective,
+    objective    ::AbstractVariationalObjective,
     restructure,
-    λ         ::AbstractVector{<:Real},
-    n_max_iter::Int;
-    optimizer ::Optimisers.AbstractRule = Optimisers.Adam(),
-    rng       ::AbstractRNG             = default_rng(),
-    progress  ::Bool                    = true,
-    callback!                           = nothing,
-    terminate                           = (args...) -> false,
-    adbackend::AbstractADType           = AutoForwardDiff(), 
+    λ₀           ::AbstractVector{<:Real},
+    n_max_iter   ::Int;
+    optimizer    ::Optimisers.AbstractRule = Optimisers.Adam(),
+    rng          ::AbstractRNG             = default_rng(),
+    show_progress::Bool                    = true,
+    callback!                              = nothing,
+    #convergence                           = (args...) -> (false, con_state),
+    adbackend::AbstractADType              = AutoForwardDiff(), 
+    prog                                   = ProgressMeter.Progress(
+        n_max_iter;
+        desc      = "Optimizing",
+        barlen    = 31,
+        showspeed = true,
+        enabled   = show_progress
+    )              
 )
-    opt_state = Optimisers.init(optimizer, λ)
+    λ         = copy(λ₀)
+    opt_state = Optimisers.setup(optimizer, λ)
     est_state = init(objective)
+    #con_state = init(convergence)
     grad_buf  = DiffResults.GradientResult(λ)
-
-    prog = ProgressMeter.Progress(n_max_iter;
-                                  barlen    = 0,
-                                  enabled   = progress,
-                                  showspeed = true)
-    stats = Vector{NamedTuple}(undef, n_max_iter)
+    stats     = NamedTuple[]
 
     for t = 1:n_max_iter
         stat = (iteration=t,)
 
         grad_buf, est_state, stat′ = estimate_gradient(
             rng, adbackend, objective, est_state, λ, restructure, grad_buf)
-        g    = DiffResults.gradient(grad_buf)
         stat = merge(stat, stat′)
 
-        opt_state, Δλ = Optimisers.apply!(optimizer, opt_state, λ, g)
-        Optimisers.subtract!(λ, Δλ)
-
-        stat′ = (iteration=t, Δλ=norm(Δλ), gradient_norm=norm(g))
+        g            = DiffResults.gradient(grad_buf)
+        opt_state, λ = Optimisers.update!(opt_state, λ, g)
+        stat′ = (iteration=t, gradient_norm=norm(g))
         stat = merge(stat, stat′)
 
-        q = restructure(λ)
-
         if !isnothing(callback!)
-            stat′ = callback!(q, stat)
+            stat′ = callback!(; est_state, stat, restructure, λ)
             stat = !isnothing(stat′) ? merge(stat′, stat) : stat
         end
         
         AdvancedVI.DEBUG && @debug "Step $t" stat...
 
         pm_next!(prog, stat)
-        stats[t] = stat
+        push!(stats, stat)
 
-        # Termination decision is work in progress
-        if terminate(rng, λ, q, objective, stat)
-            stats = stats[1:t]
-            break
-        end
+        #convergence(rng, t, restructure, λ, q, objective, stat)
+        #if terminate()
+        #    break
+        #end
     end
-    λ, stats
+    λ, map(identity, stats), opt_state
 end
 
-function optimize(objective::AbstractVariationalObjective,
-                  q,
+function optimize(objective ::AbstractVariationalObjective,
+                  q₀,
                   n_max_iter::Int;
                   kwargs...)
-    λ, restructure = Optimisers.destructure(q)
+    λ, restructure = Optimisers.destructure(q₀)
     λ, stats = optimize(objective, restructure, λ, n_max_iter; kwargs...)
     restructure(λ), stats
 end

From 65a2b37d354798dd40161dc897f7124b5b68b857 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 01:49:57 +0100
Subject: [PATCH 064/206] fix bug missing extension

---
 Project.toml | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/Project.toml b/Project.toml
index ab00d6741..ffc41a4b6 100644
--- a/Project.toml
+++ b/Project.toml
@@ -28,6 +28,12 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
 Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
 
+[extensions]
+AdvancedVIEnzymeExt = "Enzyme"
+AdvancedVIForwardDiffExt = "ForwardDiff"
+AdvancedVIReverseDiffExt = "ReverseDiff"
+AdvancedVIZygoteExt = "Zygote"
+
 [compat]
 ADTypes = "0.1"
 Bijectors = "0.11, 0.12, 0.13"

From 1a02051f6fb8e2c59b39e7faa58c91db7ca589b3 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 01:50:24 +0100
Subject: [PATCH 065/206] remove tracker from tests

---
 test/ad.jl | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/test/ad.jl b/test/ad.jl
index 9df26d9f5..2c4f802a1 100644
--- a/test/ad.jl
+++ b/test/ad.jl
@@ -1,6 +1,6 @@
 
 using ReTest
-using ForwardDiff, ReverseDiff, Tracker, Enzyme, Zygote
+using ForwardDiff, ReverseDiff, Enzyme, Zygote
 using ADTypes
 
 @testset "ad" begin
@@ -8,7 +8,6 @@ using ADTypes
           :ForwardDiff => AutoForwardDiff(),
           :ReverseDiff => AutoReverseDiff(),
           :Zygote      => AutoZygote(),
-          :Tracker     => AutoTracker(),
           # :Enzyme      => AutoEnzyme(), # Currently not tested against.
         )
         D = 10

From d8b5ea5a153e5a484972c8c46e98a58e0b958b95 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 01:50:50 +0100
Subject: [PATCH 066/206] remove export for internal derivative utils

---
 src/AdvancedVI.jl | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 697f3c83c..a1cf360a3 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -50,8 +50,6 @@ vector of the same length as `θ`.
 """
 function value_and_gradient! end
 
-export value_and_gradient!
-
 # estimators
 abstract type AbstractVariationalObjective end
 
@@ -104,11 +102,10 @@ if !isdefined(Base, :get_extension) # check whether :get_extension is defined in
     using Requires
 end
 
-using Requires
 function __init__()
     @static if !isdefined(Base, :get_extension)
-        @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
-            include("../ext/AdvancedVIZygoteExt.jl")
+        @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
+            include("../ext/AdvancedVIEnzymeExt.jl")
         end
         @require ForwardDiff = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
             include("../ext/AdvancedVIForwardDiffExt.jl")
@@ -116,10 +113,11 @@ function __init__()
         @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
             include("../ext/AdvancedVIReverseDiffExt.jl")
         end
-        @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
-            include("../ext/AdvancedVIEnzymeExt.jl")
+        @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
+            include("../ext/AdvancedVIZygoteExt.jl")
         end
     end
 end
-end # module
+
+end
 

From 818bc2c33fb7513681c06bb6a99cf341c97957dc Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 02:28:47 +0100
Subject: [PATCH 067/206] fix test errors, old interface

---
 src/optimize.jl       |  4 ++--
 test/advi_locscale.jl | 36 ++++++++++++++++++------------------
 test/runtests.jl      |  4 +++-
 3 files changed, 23 insertions(+), 21 deletions(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index ef16dcceb..7c876b39e 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -103,6 +103,6 @@ function optimize(objective ::AbstractVariationalObjective,
                   n_max_iter::Int;
                   kwargs...)
     λ, restructure = Optimisers.destructure(q₀)
-    λ, stats = optimize(objective, restructure, λ, n_max_iter; kwargs...)
-    restructure(λ), stats
+    λ, stats, opt_state = optimize(objective, restructure, λ, n_max_iter; kwargs...)
+    restructure(λ), stats, opt_state
 end
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 1552be5e4..d4ef7aec5 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -69,13 +69,13 @@ include("models/utils.jl")
             obj = objective(model, b⁻¹, 10)
 
             @testset "convergence" begin
-                Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
-                q, stats  = optimize(
+                Δλ₀         = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
+                q, stats, _ = optimize(
                     obj, q₀, T;
-                    optimizer = Optimisers.Adam(1e-2),
-                    progress  = PROGRESS,
-                    rng       = rng,
-                    adbackend = adbackend,
+                    optimizer     = Optimisers.Adam(1e-2),
+                    show_progress = PROGRESS,
+                    rng           = rng,
+                    adbackend     = adbackend,
                 )
 
                 μ  = q.location
@@ -88,24 +88,24 @@ include("models/utils.jl")
             end
 
             @testset "determinism" begin
-                rng      = Philox4x(UInt64, seed, 8)
-                q, stats = optimize(
+                rng         = Philox4x(UInt64, seed, 8)
+                q, stats, _ = optimize(
                     obj, q₀, T;
-                    optimizer = Optimisers.Adam(realtype(1e-2)),
-                    progress  = PROGRESS,
-                    rng       = rng,
-                    adbackend = adbackend,
+                    optimizer     = Optimisers.Adam(realtype(1e-2)),
+                    show_progress = PROGRESS,
+                    rng           = rng,
+                    adbackend     = adbackend,
                 )
                 μ  = q.location
                 L  = q.scale
 
-                rng_repl = Philox4x(UInt64, seed, 8)
-                q, stats = optimize(
+                rng_repl    = Philox4x(UInt64, seed, 8)
+                q, stats, _ = optimize(
                     obj, q₀, T;
-                    optimizer = Optimisers.Adam(realtype(1e-2)),
-                    progress  = PROGRESS,
-                    rng       = rng_repl,
-                    adbackend = adbackend,
+                    optimizer     = Optimisers.Adam(realtype(1e-2)),
+                    show_progress = PROGRESS,
+                    rng           = rng_repl,
+                    adbackend     = adbackend,
                 )
                 μ_repl = q.location
                 L_repl = q.scale
diff --git a/test/runtests.jl b/test/runtests.jl
index ddc1d09cf..68225fd9e 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,6 +1,8 @@
 
-using Comonicon
+using ReTest
 using ReTest: @testset, @test
+
+using Comonicon
 using Random
 using Random123
 using Statistics

From 215abf34639e76b59d3d8b7ad1b64d24ec7500e0 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 02:29:06 +0100
Subject: [PATCH 068/206] fix wrong derivative interface, add documentation

---
 src/objectives/elbo/advi.jl | 23 +++++++++--------------
 1 file changed, 9 insertions(+), 14 deletions(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index d308db0a1..8bc14bc9b 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -1,26 +1,21 @@
 
 """
-    ADVI(
-        prob,
-        n_samples::Int;
-        entropy::AbstractEntropyEstimator = ClosedFormEntropy(),
-        cv::Union{<:AbstractControlVariate, Nothing} = nothing,
-        b = Bijectors.identity
-    )
+    ADVI(prob, n_samples; kwargs...)
 
 Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective.
 
 # Arguments
 - `prob`: An object that implements the order `K == 0` `LogDensityProblems` interface.
-    - `logdensity` must be differentiable by the selected AD backend.
-- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO.
-- `entropy`: The estimator for the entropy term.
-- `cv`: A control variate
-- `b`: A bijector mapping the support of the base distribution to that of `prob`.
+- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. (Type `<: Int`.)
+
+# Keyword Arguments
+- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy())
+- `cv`: A control variate.
+- `b`: A bijector mapping the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.)
 
 # Requirements
 - ``q_{\\lambda}`` implements `rand`.
-- ``\\pi`` must be differentiable
+- `logdensity(prob)` must be differentiable by the selected AD backend.
 
 Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
 """
@@ -106,7 +101,7 @@ function estimate_advi_gradient_maybe_stl!(
         ηs  = rand(rng, q_η, advi.n_samples)
         -advi(rng, q_η_stop, ηs)
     end
-    grad!(adbackend, f, λ, out)
+    value_and_gradient!(adbackend, f, λ, out)
 end
 
 function estimate_advi_gradient_maybe_stl!(

From 88ad7680a928932be97e1f075d5cd1c0d497a651 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 02:29:25 +0100
Subject: [PATCH 069/206] update documentation

---
 docs/src/advi.md               | 17 ++++++++-----
 docs/src/families.md           | 44 +++++++++++++++++++++-------------
 src/objectives/elbo/entropy.jl |  9 +++++++
 3 files changed, 48 insertions(+), 22 deletions(-)

diff --git a/docs/src/advi.md b/docs/src/advi.md
index 0597e03c3..37b3541bb 100644
--- a/docs/src/advi.md
+++ b/docs/src/advi.md
@@ -1,7 +1,7 @@
 
 # [Automatic Differentiation Variational Inference](@id advi)
 
-# Introduction
+## Introduction
 
 The automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective is a method for estimating the evidence lower bound between a target posterior distribution ``\pi`` and a variational approximation ``q_{\phi,\lambda}``.
 By maximizing ADVI objective, it is equivalent to solving the problem
@@ -56,13 +56,13 @@ coined by Titsias and Lázaro-Gredilla (2014).
 Bijectors were generalized by Dillon *et al.* (2017) and later implemented in Julia by
 Fjelde *et al.* (2017).
 
-# The `ADVI` Objective
+## The `ADVI` Objective
 
 ```@docs
 ADVI
 ```
 
-# The "Sticking the Landing" Control Variate
+## The `StickingTheLanding` Control Variate
 The STL control variate was proposed by Roeder *et al.* (2017).
 By slightly modifying the differentiation path, it implicitly forms a control variate of the form of
 ```math
@@ -84,12 +84,17 @@ Adding this to the closed-form entropy ELBO estimator yields the STL estimator:
 which has the same expectation, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``.
 The conditions for which the STL estimator results in lower variance is still an active subject for research.
 
-The STL control variate can be used by changing the entropy estimator as follows:
+The STL control variate can be used by changing the entropy estimator using the following object:
+```@docs
+StickingTheLandingEntropy
+```
+
+For example:
 ```julia
-ADVI(prob, n_samples; entropy = StickingTheLanding(), b = bijector)
+ADVI(prob, n_samples; entropy = StickingTheLandingEntropy(), b = bijector)
 ```
 
-# References
+## References
 1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research.
 2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR.
 3. Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604.
diff --git a/docs/src/families.md b/docs/src/families.md
index d326ce7ae..e6eaa91b2 100644
--- a/docs/src/families.md
+++ b/docs/src/families.md
@@ -1,18 +1,26 @@
 
-# Variational Families
+# Location-Scale Variational Family
 
-## Location-Scale Variational Family
-
-### Description
+## Description
 The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as
 ```math
-z = C u + m,
+z \sim  q_{\lambda} \qquad\Leftrightarrow\qquad
+z \stackrel{d}{=} z = C u + m;\quad u \sim \varphi
 ```
-where ``C`` is the *scale* and ``m`` is the location variational parameter.
-This family encompases many 
-
+where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*.
+``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. 
+The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``.
+The probability density is given by
+```math
+  q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m))
+```
+and the entropy is given as
+```math
+  \mathcal{H}(q_{\lambda}) = \mathcal{H}(\varphi) + \log |C|,
+```
+where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution.
 
-### Constructors
+## Constructors
 
 ```@docs
 VILocationScale
@@ -23,15 +31,13 @@ VIFullRankGaussian
 VIMeanFieldGaussian
 ```
 
-### Examples
+## Gaussian Variational Families
 
-```@repl locscale
+Gaussian variational family:
+```julia
 using AdvancedVI, LinearAlgebra, Distributions;
 μ = zeros(2);
-```
 
-Gaussian variational family:
-```@repl locscale
 L = diagm(ones(2)) |> LowerTriangular;
 q = VIFullRankGaussian(μ, L)
 
@@ -39,9 +45,12 @@ L = ones(2) |> Diagonal;
 q = VIMeanFieldGaussian(μ, L)
 ```
 
+## Non-Gaussian Variational Families
 Sudent-T Variational Family:
 
-```@repl locscale
+```julia
+using AdvancedVI, LinearAlgebra, Distributions;
+μ = zeros(2);
 ν = 3;
 
 # Full-Rank 
@@ -54,7 +63,10 @@ q = VILocationScale(μ, L, TDist(ν))
 ```
 
 Multivariate Laplace family:
-```@repl locscale
+```julia
+using AdvancedVI, LinearAlgebra, Distributions;
+μ = zeros(2);
+
 # Full-Rank 
 L = diagm(ones(2)) |> LowerTriangular;
 q = VILocationScale(μ, L, Laplace())
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index e9f180f5c..0edc47f4e 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -11,6 +11,15 @@ abstract type MonteCarloEntropy <: AbstractEntropyEstimator end
 
 struct FullMonteCarloEntropy <: MonteCarloEntropy end
 
+"""
+    StickingTheLandingEntropy()
+
+The "sticking the landing" entropy estimator.
+
+# Requirements
+- `q` implements `logpdf`.
+- `logpdf(q, η)` must be differentiable by the selected AD framework.
+"""
 struct StickingTheLandingEntropy <: MonteCarloEntropy end
 
 function (::MonteCarloEntropy)(q, ηs::AbstractMatrix)

From e66935bb2881a61cf137ff74899e7117c53a9f46 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 02:34:11 +0100
Subject: [PATCH 070/206] add doc build CI

---
 .github/workflows/CI.yml | 27 +++++++++++++++++++++++++++
 1 file changed, 27 insertions(+)

diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 9731f20c2..158da963c 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -61,3 +61,30 @@ jobs:
         with:
           github-token: ${{ secrets.GITHUB_TOKEN }}
           path-to-lcov: lcov.info
+  docs:
+    name: Documentation
+    runs-on: ubuntu-latest
+    permissions:
+      contents: write
+      statuses: write
+    steps:
+      - uses: actions/checkout@v3
+      - uses: julia-actions/setup-julia@v1
+        with:
+          version: '1'
+      - name: Configure doc environment
+        run: |
+          julia --project=docs/ -e '
+            using Pkg
+            Pkg.develop(PackageSpec(path=pwd()))
+            Pkg.instantiate()'
+      - uses: julia-actions/julia-buildpkg@v1
+      - uses: julia-actions/julia-docdeploy@v1
+        env:
+          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+      - run: |
+          julia --project=docs -e '
+            using Documenter: DocMeta, doctest
+            using AdvancedVI
+            DocMeta.setdocmeta!(AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true)
+            doctest(AdvancedVI)'

From 9f1c647a6fb2b945754e808dcb608e3f19c4cae8 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 02:47:56 +0100
Subject: [PATCH 071/206] remove convergence criterion for now

---
 docs/src/families.md | 2 +-
 src/optimize.jl      | 7 -------
 2 files changed, 1 insertion(+), 8 deletions(-)

diff --git a/docs/src/families.md b/docs/src/families.md
index e6eaa91b2..8ae48be30 100644
--- a/docs/src/families.md
+++ b/docs/src/families.md
@@ -5,7 +5,7 @@
 The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as
 ```math
 z \sim  q_{\lambda} \qquad\Leftrightarrow\qquad
-z \stackrel{d}{=} z = C u + m;\quad u \sim \varphi
+z \stackrel{d}{=} C u + m;\quad u \sim \varphi
 ```
 where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*.
 ``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. 
diff --git a/src/optimize.jl b/src/optimize.jl
index 7c876b39e..0f2d29e9b 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -51,7 +51,6 @@ function optimize(
     rng          ::AbstractRNG             = default_rng(),
     show_progress::Bool                    = true,
     callback!                              = nothing,
-    #convergence                           = (args...) -> (false, con_state),
     adbackend::AbstractADType              = AutoForwardDiff(), 
     prog                                   = ProgressMeter.Progress(
         n_max_iter;
@@ -64,7 +63,6 @@ function optimize(
     λ         = copy(λ₀)
     opt_state = Optimisers.setup(optimizer, λ)
     est_state = init(objective)
-    #con_state = init(convergence)
     grad_buf  = DiffResults.GradientResult(λ)
     stats     = NamedTuple[]
 
@@ -89,11 +87,6 @@ function optimize(
 
         pm_next!(prog, stat)
         push!(stats, stat)
-
-        #convergence(rng, t, restructure, λ, q, objective, stat)
-        #if terminate()
-        #    break
-        #end
     end
     λ, map(identity, stats), opt_state
 end

From c8b3ee3ed7ec43051631462b7674a7c1d66722d7 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 02:54:12 +0100
Subject: [PATCH 072/206] remove outdated export

---
 src/AdvancedVI.jl | 1 -
 1 file changed, 1 deletion(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index a1cf360a3..1677be622 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -72,7 +72,6 @@ include("objectives/elbo/advi.jl")
 export
     ELBO,
     ADVI,
-    ADVIEnergy,
     ClosedFormEntropy,
     StickingTheLandingEntropy,
     FullMonteCarloEntropy

From afda1a19527f4197b25a50fcae8e52cdeace660b Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 20:53:42 +0100
Subject: [PATCH 073/206] update documentation

---
 docs/make.jl                          |  9 +++--
 docs/src/index.md                     | 16 ++++-----
 docs/src/{families.md => locscale.md} |  4 +--
 docs/src/started.md                   | 51 +++++++++++++++++++++++++++
 4 files changed, 67 insertions(+), 13 deletions(-)
 rename docs/src/{families.md => locscale.md} (96%)
 create mode 100644 docs/src/started.md

diff --git a/docs/make.jl b/docs/make.jl
index b9a8eb5f1..ca21b5fde 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -10,9 +10,12 @@ makedocs(;
     sitename = "AdvancedVI.jl",
     modules  = [AdvancedVI],
     format   = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"),
-         pages    = ["Home"     => "index.md",
-                     "Families" => "families.md",
-                     "ADVI"     => "advi.md"],
+         pages    = ["AdvancedVI"        => "index.md",
+                     "Getting Started"   => "started.md",
+                     "ELBO Maximization" => [
+                         "Automatic Differentiation VI" => "advi.md",   
+                         "Location Scale Family"        => "locscale.md",
+                     ]],
 )
 
 deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true)
diff --git a/docs/src/index.md b/docs/src/index.md
index be3269217..dea6d405d 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -4,11 +4,11 @@ CurrentModule = AdvancedVI
 
 # AdvancedVI
 
-Documentation for [AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl).
-
-```@index
-```
-
-```@autodocs
-Modules = [AdvancedVI]
-```
+## Introduction
+[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms.
+VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness.
+`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem.
+
+## Provided Algorithms
+`AdvancedVI` currently provides the following algorithm for evidence lower bound maximization:
+- [Automatic Differentiation Variational Inference](@ref advi)
diff --git a/docs/src/families.md b/docs/src/locscale.md
similarity index 96%
rename from docs/src/families.md
rename to docs/src/locscale.md
index 8ae48be30..a4bc2dc1f 100644
--- a/docs/src/families.md
+++ b/docs/src/locscale.md
@@ -1,7 +1,7 @@
 
-# Location-Scale Variational Family
+# [Location-Scale Variational Family](@id locscale)
 
-## Description
+## Introduction
 The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as
 ```math
 z \sim  q_{\lambda} \qquad\Leftrightarrow\qquad
diff --git a/docs/src/started.md b/docs/src/started.md
new file mode 100644
index 000000000..faff61660
--- /dev/null
+++ b/docs/src/started.md
@@ -0,0 +1,51 @@
+
+# [Getting Started with `AdvancedVI`](@id getting_started)
+
+## General Usage
+Each VI algorithm should provide the following:
+1. A variational family
+2. A variational objective
+
+Feeding these two into `optimize` runs the inference procedure.
+
+```@docs
+optimize
+```
+
+## `ADVI` Example Using `Turing`
+
+```julia
+using Turing
+using Bijectors
+using Optimisers
+using ForwardDiff
+using ADTypes
+
+import AdvancedVI as AVI
+
+μ_y, σ_y = 1.0, 1.0
+μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0]
+
+Turing.@model function normallognormal()
+    y ~ LogNormal(μ_y, σ_y)
+    z ~ MvNormal(μ_z, Σ_z)
+end
+model = normallognormal()
+b     = Bijectors.bijector(model)
+b⁻¹   = inverse(b)
+prob  = DynamicPPL.LogDensityFunction(model)
+d     = LogDensityProblems.dimension(prob)
+
+μ = randn(d)
+L = Diagonal(ones(d))
+q = AVI.MeanFieldGaussian(μ, L)
+
+n_max_iter = 10^4
+q, stats = AVI.optimize(
+    AVI.ADVI(prob, b⁻¹, 10),
+    q,
+    n_max_iter;
+    adbackend = AutoForwardDiff(),
+    optimizer = Optimisers.Adam(1e-3)
+)
+```

From 0d37acea1dd96c95d7cef427be7d84fee8d95c09 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 21:12:02 +0100
Subject: [PATCH 074/206] update documentation

---
 docs/src/started.md | 55 ++++++++++++++++++++++++++++++++-------------
 1 file changed, 39 insertions(+), 16 deletions(-)

diff --git a/docs/src/started.md b/docs/src/started.md
index faff61660..26c75a797 100644
--- a/docs/src/started.md
+++ b/docs/src/started.md
@@ -2,11 +2,13 @@
 # [Getting Started with `AdvancedVI`](@id getting_started)
 
 ## General Usage
-Each VI algorithm should provide the following:
-1. A variational family
-2. A variational objective
+Each VI algorithm provides the followings:
+1. Variational families supported by each VI algorithm.
+2. A variational objective corresponding to the VI algorithm.
+Note that each variational family is subject to its own constraints.
+Thus, please refer to the documentation of the variational inference algorithm of interest. 
 
-Feeding these two into `optimize` runs the inference procedure.
+To use `AdvancedVI`, a user needs to select a `variational family`, `variational objective`,  and feed them into `optimize`.
 
 ```@docs
 optimize
@@ -14,14 +16,10 @@ optimize
 
 ## `ADVI` Example Using `Turing`
 
+In this tutorial, we'll use `Turing` to define a basic `normal-log-normal` model.
+ADVI with log bijectors is able to infer this model exactly.
 ```julia
 using Turing
-using Bijectors
-using Optimisers
-using ForwardDiff
-using ADTypes
-
-import AdvancedVI as AVI
 
 μ_y, σ_y = 1.0, 1.0
 μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0]
@@ -31,18 +29,43 @@ Turing.@model function normallognormal()
     z ~ MvNormal(μ_z, Σ_z)
 end
 model = normallognormal()
+```
+
+Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
+Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation.
+```julia
+using Bijectors
+
 b     = Bijectors.bijector(model)
 b⁻¹   = inverse(b)
-prob  = DynamicPPL.LogDensityFunction(model)
-d     = LogDensityProblems.dimension(prob)
+```
 
+Let's now load `AdvancedVI`.
+Since ADVI relies on automatic differentiation (AD), hence the "AD" in "ADVI", we need to load an AD library, *before* loading `AdvancedVI`.
+Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface.
+Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`.
+```julia
+using Optimisers
+using ForwardDiff
+import AdvancedVI as AVI
+```
+We now need to select 1. a variational objective, and 2. a variational family.
+Here, we will use the [ADVI objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector.
+```julia
+prob      = DynamicPPL.LogDensityFunction(model)
+objective = AVI.ADVI(prob, b⁻¹, 10),
+```
+For the variational family, we will use the classic mean-field Gaussian family.
+```julia
+d = LogDensityProblems.dimension(prob)
 μ = randn(d)
 L = Diagonal(ones(d))
-q = AVI.MeanFieldGaussian(μ, L)
-
+q = AVI.VIMeanFieldGaussian(μ, L)
+```
+It now remains to run inverence!
+```
 n_max_iter = 10^4
-q, stats = AVI.optimize(
-    AVI.ADVI(prob, b⁻¹, 10),
+q, stats   = AVI.optimize(
     q,
     n_max_iter;
     adbackend = AutoForwardDiff(),

From b8b113da2b3a64395e9daaf2bbb64e9b0b602a4e Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 13 Aug 2023 21:14:17 +0100
Subject: [PATCH 075/206] update documentation

---
 docs/src/started.md | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/docs/src/started.md b/docs/src/started.md
index 26c75a797..355e93502 100644
--- a/docs/src/started.md
+++ b/docs/src/started.md
@@ -50,10 +50,11 @@ using ForwardDiff
 import AdvancedVI as AVI
 ```
 We now need to select 1. a variational objective, and 2. a variational family.
-Here, we will use the [ADVI objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector.
+Here, we will use the [`ADVI` objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector.
 ```julia
-prob      = DynamicPPL.LogDensityFunction(model)
-objective = AVI.ADVI(prob, b⁻¹, 10),
+prob        = DynamicPPL.LogDensityFunction(model)]
+n_montecaro = 10
+objective   = AVI.ADVI(prob, b⁻¹, n_montecaro),
 ```
 For the variational family, we will use the classic mean-field Gaussian family.
 ```julia

From b78e713eaf46d3caab540e1d818be9930bea54dc Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Wed, 16 Aug 2023 23:35:23 +0100
Subject: [PATCH 076/206] fix type error in test

---
 test/distributions.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/distributions.jl b/test/distributions.jl
index 073fff644..9b18d0207 100644
--- a/test/distributions.jl
+++ b/test/distributions.jl
@@ -11,7 +11,7 @@ using Distributions: _logpdf
         seed         = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
         rng          = Philox4x(UInt64, seed, 8)
         realtype     = Float64
-        ϵ            = 1e-2
+        ϵ            = 1f-2
         n_dims       = 10
         n_montecarlo = 1000_000
 

From a0564b56bbe86b5885c333aa7fe2ca0e48fa0b24 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Wed, 16 Aug 2023 23:35:29 +0100
Subject: [PATCH 077/206] remove default ADType argument

---
 Project.toml    | 2 +-
 src/optimize.jl | 3 ++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/Project.toml b/Project.toml
index ffc41a4b6..35650ae5d 100644
--- a/Project.toml
+++ b/Project.toml
@@ -37,7 +37,7 @@ AdvancedVIZygoteExt = "Zygote"
 [compat]
 ADTypes = "0.1"
 Bijectors = "0.11, 0.12, 0.13"
-DiffResults = "1.0.3"
+DiffResults = "1"
 Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
 DocStringExtensions = "0.8, 0.9"
 ForwardDiff = "0.10.25"
diff --git a/src/optimize.jl b/src/optimize.jl
index 0f2d29e9b..93e6f7546 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -31,6 +31,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie
 - `n_max_iter`: Maximum number of iterations.
 
 # Keyword Arguments
+- `adbackend`: Automatic differentiation backend. (Type: `<: ADtypes.AbstractADType`.)
 - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.)
 - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.)
 - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.)
@@ -47,11 +48,11 @@ function optimize(
     restructure,
     λ₀           ::AbstractVector{<:Real},
     n_max_iter   ::Int;
+    adbackend::AbstractADType, 
     optimizer    ::Optimisers.AbstractRule = Optimisers.Adam(),
     rng          ::AbstractRNG             = default_rng(),
     show_progress::Bool                    = true,
     callback!                              = nothing,
-    adbackend::AbstractADType              = AutoForwardDiff(), 
     prog                                   = ProgressMeter.Progress(
         n_max_iter;
         desc      = "Optimizing",

From 3795d1e05f510887df1c2900ab9f7638797ecc87 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 17 Aug 2023 01:01:52 +0100
Subject: [PATCH 078/206] update README

---
 README.md | 304 +++++++++++++++---------------------------------------
 1 file changed, 81 insertions(+), 223 deletions(-)

diff --git a/README.md b/README.md
index 18ba63e50..e8718e7c5 100644
--- a/README.md
+++ b/README.md
@@ -1,250 +1,108 @@
-# AdvancedVI.jl
-A library for variational Bayesian inference in Julia.
-
-At the time of writing (05/02/2020), implementations of the variational inference (VI) interface and some algorithms are implemented in [Turing.jl](https://github.com/TuringLang/Turing.jl). The idea is to soon separate the VI functionality in Turing.jl out and into this package.
-
-The purpose of this package will then be to provide a common interface together with implementations of standard algorithms and utilities with the goal of ease of use and the ability for other packages, e.g. Turing.jl, to write a light wrapper around AdvancedVI.jl for integration. 
 
-As an example, in Turing.jl we support automatic differentiation variational inference (ADVI) but really the only piece of code tied into the Turing.jl is the conversion of a `Turing.Model` to a `logjoint(z)` function which computes `z ↦ log p(x, z)`, with `x` denoting the observations embedded in the `Turing.Model`. As long as this `logjoint(z)` method is compatible with some AD framework, e.g. `ForwardDiff.jl` or `Zygote.jl`, this is all we need from Turing.jl to be able to perform ADVI!
-
-## [WIP] Interface
-- `vi`: the main interface to the functionality in this package
-  - `vi(model, alg)`: only used when `alg` has a default variational posterior which it will provide.
-  - `vi(model, alg, q::VariationalPosterior, θ)`: `q` represents the family of variational distributions and `θ` is the initial parameters "indexing" the starting distribution. This assumes that there exists an implementation `Variational.update(q, θ)` which returns the variational posterior corresponding to parameters `θ`.
-  - `vi(model, alg, getq::Function, θ)`: here `getq(θ)` is a function returning a `VariationalPosterior` corresponding to `θ`.
-- `optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())`
-- `grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)`
-  - Different combinations of variational objectives (`vo`), VI methods (`alg`), and variational posteriors (`q`) might use different gradient estimators. `grad!` allows us to specify these different behaviors.
+# AdvancedVI.jl
+[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms.
+VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness.
+`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem.
+The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. `Turing`, only need to write a light wrapper for integration. 
+For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bijectors`](https://github.com/TuringLang/Bijectors.jl) by simply converting a `Turing.Model` into a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) and extracting a corresponding `Bijectors.bijector`.
 
 ## Examples
-### Variational Inference
-A very simple generative model is the following
-
-    μ ~ 𝒩(0, 1)
-    xᵢ ∼ 𝒩(μ, 1) , ∀i = 1, …, n
-
-where μ and xᵢ are some ℝᵈ vectors and 𝒩 denotes a d-dimensional multivariate Normal distribution.
-
-Given a set of `n` observations `[x₁, …, xₙ]` we're interested in finding the distribution `p(μ∣x₁, …, xₙ)` over the mean `μ`. We can obtain (an approximation to) this distribution that using AdvancedVI.jl!
-
-First we generate some observations and set up the problem:
-```julia
-julia> using Distributions
-
-julia> d = 2; n = 100;
-
-julia> observations = randn((d, n)); # 100 observations from 2D 𝒩(0, 1)
-
-julia> # Define generative model
-       #    μ ~ 𝒩(0, 1)
-       #    xᵢ ∼ 𝒩(μ, 1) , ∀i = 1, …, n
-       prior(μ) = logpdf(MvNormal(ones(d)), μ)
-prior (generic function with 1 method)
-
-julia> likelihood(x, μ) = sum(logpdf(MvNormal(μ, ones(d)), x))
-likelihood (generic function with 1 method)
-
-julia> logπ(μ) = likelihood(observations, μ) + prior(μ)
-logπ (generic function with 1 method)
-
-julia> logπ(randn(2))  # <= just checking that it works
--311.74132761437653
-```
-Now there are mainly two different ways of specifying the approximate posterior (and its family). The first is by providing a mapping from distribution parameters to the distribution `θ ↦ q(⋅∣θ)`:
-```julia
-julia> using DistributionsAD, AdvancedVI
-
-julia> # Using a function z ↦ q(⋅∣z)
-       getq(θ) = TuringDiagMvNormal(θ[1:d], exp.(θ[d + 1:4]))
-getq (generic function with 1 method)
-```
-Then we make the choice of algorithm, a subtype of `VariationalInference`, 
-```julia
-julia> # Perform VI
-       advi = ADVI(10, 10_000)
-ADVI{AdvancedVI.ForwardDiffAD{40}}(10, 10000)
-```
-And finally we can perform VI! The usual inferface is to call `vi` which behind the scenes takes care of the optimization and returns the resulting variational posterior:
-```julia
-julia> q = vi(logπ, advi, getq, randn(4))
-[ADVI] Optimizing...100% Time: 0:00:01
-TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}}(m=[0.16282745378074515, 0.15789310089462574], σ=[0.09519377533754399, 0.09273176907111745])
-```
-Let's have a look at the resulting ELBO:
-```julia
-julia> AdvancedVI.elbo(advi, q, logπ, 1000)
--287.7866366886285
-```
-Unfortunately, the *final* value of the ELBO is not always a very good diagnostic, though the ELBO is an important metric to keep an eye on during training since an *increase* in the ELBO means we're going in the right direction. Luckily, this is such a simple problem that we can indeed obtain a closed form solution! Because we're lazy (at least I am), we'll let [ConjugatePriors.jl](https://github.com/JuliaStats/ConjugatePriors.jl) do this for us:
-```julia
-julia> # True posterior
-       using ConjugatePriors
 
-julia> pri = MvNormal(zeros(2), ones(2));
+`AdvancedVI` basically expects a `LogDensityProblem`.
+For example, for the normal-log-normal model:
+$$
+\begin{aligned}
+x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\
+y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right)
+\end{aligned}
+$$ 
 
-julia> true_posterior = posterior((pri, pri.Σ), MvNormal, observations)
-DiagNormal(
-dim: 2
-μ: [0.1746546592601148, 0.16457110079543008]
-Σ: [0.009900990099009901 0.0; 0.0 0.009900990099009901]
-)
+A `LogDensityProblem` can be implemented as 
 ```
-Comparing to our variational approximation, this looks pretty good! Worth noting that in this particular case the variational posterior seems to overestimate the variance.
+using LogDensityProblems
 
-To conclude, let's make a somewhat pretty picture:
-```julia
-julia> using Plots
-
-julia> p_samples = rand(true_posterior, 10_000); q_samples = rand(q, 10_000);
-
-julia> p1 = histogram(p_samples[1, :], label="p"); histogram!(q_samples[1, :], alpha=0.7, label="q")
-
-julia> title!(raw"$\mu_1$")
+struct NormalLogNormal{MX,SX,MY,SY}
+    μ_x::MX
+    σ_x::SX
+    μ_y::MY
+    Σ_y::SY
+end
 
-julia> p2 = histogram(p_samples[2, :], label="p"); histogram!(q_samples[2, :], alpha=0.7, label="q")
+function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
+end
 
-julia> title!(raw"$\mu_2$")
+function LogDensityProblems.dimension(model::NormalLogNormal)
+    length(model.μ_y) + 1
+end
 
-julia> plot(p1, p2)
+function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
+    LogDensityProblems.LogDensityOrder{0}()
+end
 ```
-![Histogram](hist.png?raw=true)
-
-### Simple example: using Advanced.jl to directly minimize the KL-divergence between two distributions `p(z)` and `q(z)`
-In VI we aim to approximate the true posterior `p(z ∣ x)` by some approximate variational posterior `q(z)` by maximizing the ELBO:
-
-    ELBO(q) = 𝔼_q[log p(x, z) - log q(z)]
-
-Observe that we can express the ELBO as the negative KL-divergence between `p(x, ⋅)` and `q(⋅)`:
-
-    ELBO(q) = - 𝔼_q[log (q(z) / p(x, z))]
-            = - KL(q(⋅) || p(x, ⋅))
-
-So if we apply VI to something that isn't an actual posterior, i.e. there's no data involved and we write `p(z ∣ x) = p(z)`, we're really just minimizing the KL-divergence between the distributions.
-
-Therefore, we can try out `AdvancedVI.jl` real quick by applying using the interface to minimize the KL-divergence between two distributions:
 
+Since the support of `x` is constrained to be $$\mathbb{R}_+$$, and inference is best done in the unconstrained space $$\mathbb{R}_+$$, we need to use a *bijector* to match support.
+This corresponds to the automatic differentiation VI (ADVI; Kucukelbir *et al.*, 2015).
 ```julia
-julia> using Distributions, DistributionsAD, AdvancedVI
-
-julia> # Target distribution
-       p = MvNormal(ones(2))
-ZeroMeanDiagNormal(
-dim: 2
-μ: [0.0, 0.0]
-Σ: [1.0 0.0; 0.0 1.0]
-)
+using Bijectors
 
-julia> logπ(z) = logpdf(p, z)
-logπ (generic function with 1 method)
-
-julia> # Make a choice of VI algorithm
-       advi = ADVI(10, 1000)
-ADVI{AdvancedVI.ForwardDiffAD{40}}(10, 1000)
-```
-Now there are two different ways of specifying the approximate posterior (and its family); the first is by providing a mapping from parameters to distribution `θ ↦ q(⋅∣θ)`:
-```julia
-julia> # Using a function z ↦ q(⋅∣z)
-       getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4]))
-getq (generic function with 1 method)
-
-julia> # Perform VI
-       q = vi(logπ, advi, getq, randn(4))
-┌ Info: [ADVI] Should only be seen once: optimizer created for θ
-└   objectid(θ) = 0x5ddb564423896704
-[ADVI] Optimizing...100% Time: 0:00:01
-TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}}(m=[-0.012691337868985757, -0.0004442434543332919], σ=[1.0334797673569802, 0.9957355128767893])
-```
-Or we can check the ELBO (which in this case since, as mentioned, doesn't involve data, is the negative KL-divergence):
-```julia
-julia> AdvancedVI.elbo(advi, q, logπ, 1000)  # empirical estimate
-0.08031049170093245
+function Bijectors.bijector(model::NormalLogNormal)
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    Bijectors.Stacked(
+        Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
+        [1:1, 2:1+length(μ_y)])
+end
 ```
-It's worth noting that the actual value of the ELBO doesn't really tell us too much about the quality of fit. In this particular case, because we're *directly* minimizing the KL-divergence, we can only say something useful if we reach 0, in which case we have obtained the true distribution.
 
-Let's just quickly check the mean-squared error between the `log p(z)` and `log q(z)` for a random set of samples from the target `p`:
-```julia
-julia> zs = rand(p, 100);
+A simpler approach is to use `Turing`, where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated.
 
-julia> mean(abs2, logpdf(q, zs) - logpdf(p, zs))
-0.0014889109427524852
+Let us instantiate a random normal-log-normal model.
+```julia
+using PDMats
+
+n_dims = 10
+μ_x    = randn()
+σ_x    = exp.(randn())
+μ_y    = randn(n_dims)
+σ_y    = exp.(randn(n_dims))
+model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2))
 ```
-That doesn't look too bad!
-
-## Implementing your own training loop
-Sometimes it might be convenient to roll your own training loop rather than using `vi(...)`. Here's some psuedo-code for how one would do that when used together with Turing.jl:
 
+ADVI can be used as follows:
 ```julia
-using Turing, AdvancedVI, DiffResults
-using Turing: Variational
-
-using ProgressMeter
-
-# Assuming you have an instance of a Turing model (`model`)
-
-# 1. Create log-joint needed for ELBO evaluation
-logπ = Variational.make_logjoint(model)
-
-# 2. Define objective
-variational_objective = Variational.ELBO()
-
-# 3. Optimizer
-optimizer = Variational.DecayedADAGrad()
-
-# 4. VI-algorithm
-alg = ADVI(10, 1000)
-
-# 5. Variational distribution
-function getq(θ)
-    # ...
-end
-
-# 6. [OPTIONAL] Implement convergence criterion
-function hasconverged(args...)
-    # ...
-end
-
-# 7. [OPTIONAL] Implement a callback for tracking stats
-function callback(args...)
-    # ...
-end
-
-# 8. Train
-converged = false
-step = 1
-
-prog = ProgressMeter.Progress(num_steps, 1)
-
-diff_results = DiffResults.GradientResult(θ_init)
-
-while (step ≤ num_steps) && !converged
-    # 1. Compute gradient and objective value; results are stored in `diff_results`
-    AdvancedVI.grad!(variational_objective, alg, getq, model, diff_results)
-
-    # 2. Extract gradient from `diff_result`
-    ∇ = DiffResults.gradient(diff_result)
-
-    # 3. Apply optimizer, e.g. multiplying by step-size
-    Δ = apply!(optimizer, θ, ∇)
-
-    # 4. Update parameters
-    @. θ = θ - Δ
-
-    # 5. Do whatever analysis you want
-    callback(args...)
-
-    # 6. Update
-    converged = hasconverged(...) # or something user-defined
-    step += 1
+using LinearAlgebra
+using Optimisers
+using ADTypes, ForwardDiff
+import AdvancedVI as AVI
+
+b     = Bijectors.bijector(model)
+b⁻¹   = inverse(b)
+
+# ADVI objective 
+objective = AVI.ADVI(model, 10; b=b⁻¹)
+
+# Mean-field Gaussian variational family
+d = LogDensityProblems.dimension(model)
+μ = randn(d)
+L = Diagonal(ones(d))
+q = AVI.VIMeanFieldGaussian(μ, L)
+
+# Run inference
+n_max_iter = 10^4
+q, stats, _ = AVI.optimize(
+    objective,
+    q,
+    n_max_iter;
+    adbackend = ADTypes.AutoForwardDiff(),
+    optimizer = Optimisers.Adam(1e-3)
+)
 
-    ProgressMeter.next!(prog)
-end
+# Evaluate final ELBO with 10^3 Monte Carlo samples
+objective(q; n_samples=10^3)
 ```
 
 
 ## References
 
-- Jordan, Michael I., Zoubin Ghahramani, Tommi S. Jaakkola, and Lawrence K. Saul. "An introduction to variational methods for graphical models." Machine learning 37, no. 2 (1999): 183-233.
-- Blei, David M., Alp Kucukelbir, and Jon D. McAuliffe. "Variational inference: A review for statisticians." Journal of the American statistical Association 112, no. 518 (2017): 859-877.
 - Kucukelbir, Alp, Rajesh Ranganath, Andrew Gelman, and David Blei. "Automatic variational inference in Stan." In Advances in Neural Information Processing Systems, pp. 568-576. 2015.
-- Salimans, Tim, and David A. Knowles. "Fixed-form variational posterior approximation through stochastic linear regression." Bayesian Analysis 8, no. 4 (2013): 837-882.
-- Beal, Matthew James. Variational algorithms for approximate Bayesian inference. 2003.

From 28a35bcd0ce6bd4489915ae1cf37801db211b2ec Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 17 Aug 2023 01:02:04 +0100
Subject: [PATCH 079/206] update make getting started example actually run
 Julia

---
 docs/Project.toml   |  13 ++++-
 docs/src/started.md | 115 +++++++++++++++++++++++++++++++++-----------
 2 files changed, 98 insertions(+), 30 deletions(-)

diff --git a/docs/Project.toml b/docs/Project.toml
index c625d07f2..182edd3e6 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -1,7 +1,18 @@
 [deps]
+ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
 AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
+Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
+ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
+LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
+Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
+PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
+Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
+SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 
 [compat]
-Documenter = "0.26, 0.27"
\ No newline at end of file
+ADTypes = "0.1.6"
+Bijectors = "0.13.6"
+Documenter = "0.26, 0.27"
+LogDensityProblems = "2.1.1"
diff --git a/docs/src/started.md b/docs/src/started.md
index 355e93502..fec60f1a5 100644
--- a/docs/src/started.md
+++ b/docs/src/started.md
@@ -14,62 +14,119 @@ To use `AdvancedVI`, a user needs to select a `variational family`, `variational
 optimize
 ```
 
-## `ADVI` Example Using `Turing`
+## `ADVI` Example 
+In this tutorial, we will work with a basic `normal-log-normal` model.
+```math
+\begin{aligned}
+x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\
+y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right)
+\end{aligned}
+```
+ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly.
 
-In this tutorial, we'll use `Turing` to define a basic `normal-log-normal` model.
-ADVI with log bijectors is able to infer this model exactly.
-```julia
-using Turing
+Using the `LogDensityProblems` interface, we the model can be defined as follows:
+```@example advi
+using LogDensityProblems
+using SimpleUnPack
 
-μ_y, σ_y = 1.0, 1.0
-μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0]
+struct NormalLogNormal{MX,SX,MY,SY}
+    μ_x::MX
+    σ_x::SX
+    μ_y::MY
+    Σ_y::SY
+end
 
-Turing.@model function normallognormal()
-    y ~ LogNormal(μ_y, σ_y)
-    z ~ MvNormal(μ_z, Σ_z)
+function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
 end
-model = normallognormal()
+
+function LogDensityProblems.dimension(model::NormalLogNormal)
+    length(model.μ_y) + 1
+end
+
+function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
+    LogDensityProblems.LogDensityOrder{0}()
+end
+```
+Let's now instantiate the model
+```@example advi
+using PDMats
+
+n_dims = 10
+μ_x    = randn()
+σ_x    = exp.(randn())
+μ_y    = randn(n_dims)
+σ_y    = exp.(randn(n_dims))
+model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2));
 ```
 
 Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
 Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation.
-```julia
+```@example advi
 using Bijectors
 
-b     = Bijectors.bijector(model)
-b⁻¹   = inverse(b)
+function Bijectors.bijector(model::NormalLogNormal)
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    Bijectors.Stacked(
+        Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
+        [1:1, 2:1+length(μ_y)])
+end
+
+b   = Bijectors.bijector(model);
+b⁻¹ = inverse(b)
 ```
 
 Let's now load `AdvancedVI`.
 Since ADVI relies on automatic differentiation (AD), hence the "AD" in "ADVI", we need to load an AD library, *before* loading `AdvancedVI`.
 Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface.
 Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`.
-```julia
+```@example advi
 using Optimisers
-using ForwardDiff
+using ADTypes, ForwardDiff
 import AdvancedVI as AVI
 ```
 We now need to select 1. a variational objective, and 2. a variational family.
 Here, we will use the [`ADVI` objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector.
-```julia
-prob        = DynamicPPL.LogDensityFunction(model)]
-n_montecaro = 10
-objective   = AVI.ADVI(prob, b⁻¹, n_montecaro),
+```@example advi
+n_montecaro = 10;
+objective   = AVI.ADVI(model, n_montecaro; b = b⁻¹)
 ```
 For the variational family, we will use the classic mean-field Gaussian family.
-```julia
-d = LogDensityProblems.dimension(prob)
-μ = randn(d)
-L = Diagonal(ones(d))
+```@example advi
+using LinearAlgebra
+
+d = LogDensityProblems.dimension(model);
+μ = randn(d);
+L = Diagonal(ones(d));
 q = AVI.VIMeanFieldGaussian(μ, L)
 ```
-It now remains to run inverence!
-```
-n_max_iter = 10^4
-q, stats   = AVI.optimize(
+Passing `objective` and the initial variational approximation `q` to `optimize` performs inference.
+```@example advi
+n_max_iter  = 10^4
+q, stats, _ = AVI.optimize(
+    objective,
     q,
     n_max_iter;
     adbackend = AutoForwardDiff(),
     optimizer = Optimisers.Adam(1e-3)
-)
+); 
+```
+
+The selected inference procedure stores per-iteration statistics into `stats`.
+For instance, the ELBO can be ploted as follows:
+```@example advi
+using Plots
+
+t = [stat.iteration for stat ∈ stats]
+y = [stat.elbo for stat ∈ stats]
+plot(t[1:100:end], y[1:100:end])
+savefig("advi_example_elbo.svg"); nothing
+```
+![](advi_example_elbo.svg)
+Further information can be gathered by defining your own `callback!`.
+
+The final ELBO can be estimated by calling the objective directly with a different number of Monte Carlo samples as follows:
+```@example advi
+ELBO = objective(q; n_samples=10^4)
 ```

From 620b38e7d345c60d59c08174144f1349618ff60c Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 17 Aug 2023 01:02:16 +0100
Subject: [PATCH 080/206] fix remove Float32 tests for inference tests

---
 ext/AdvancedVIForwardDiffExt.jl | 2 +-
 test/advi_locscale.jl           | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl
index e6b03af21..5949bdf81 100644
--- a/ext/AdvancedVIForwardDiffExt.jl
+++ b/ext/AdvancedVIForwardDiffExt.jl
@@ -11,8 +11,8 @@ else
     using ..AdvancedVI: ADTypes, DiffResults
 end
 
-# extract chunk size from AutoForwardDiff
 getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize
+
 function AdvancedVI.value_and_gradient!(
     ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
 ) where {T<:Real}
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index d4ef7aec5..e4c81402c 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -25,7 +25,7 @@ include("models/utils.jl")
 @testset "advi" begin
     @testset "locscale" begin
         @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
-            realtype ∈ [Float32, Float64],
+            realtype ∈ [Float64], # Currently only tested against Float64
             (modelname, modelconstr) ∈ Dict(
                 :NormalLogNormalMeanField => normallognormal_meanfield,
                 :NormalLogNormalFullRank  => normallognormal_fullrank,

From fa533981d6c3208e008d04f35a18ec08728ca608 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 17 Aug 2023 01:54:13 +0100
Subject: [PATCH 081/206] update version

---
 Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index 35650ae5d..2092b0cb0 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
 name = "AdvancedVI"
 uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
-version = "0.2.4"
+version = "0.3.0"
 
 [deps]
 ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

From e909f4106e919e2d834a4f73eac3ca929bd5b9dd Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 17 Aug 2023 20:04:34 +0100
Subject: [PATCH 082/206] add documentation publishing url

---
 docs/make.jl | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/docs/make.jl b/docs/make.jl
index ca21b5fde..5d3716089 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -7,15 +7,16 @@ DocMeta.setdocmeta!(
 )
 
 makedocs(;
-    sitename = "AdvancedVI.jl",
     modules  = [AdvancedVI],
+    sitename = "AdvancedVI.jl",
+    repo     = "https://github.com/TuringLang/AdvancedVI.jl/blob/{commit}{path}#{line}",
     format   = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"),
-         pages    = ["AdvancedVI"        => "index.md",
-                     "Getting Started"   => "started.md",
-                     "ELBO Maximization" => [
-                         "Automatic Differentiation VI" => "advi.md",   
-                         "Location Scale Family"        => "locscale.md",
-                     ]],
+    pages    = ["AdvancedVI"        => "index.md",
+                "Getting Started"   => "started.md",
+                "ELBO Maximization" => [
+                    "Automatic Differentiation VI" => "advi.md",   
+                    "Location Scale Family"        => "locscale.md",
+                ]],
 )
 
 deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true)

From 43f5b751abb963533cbb6835ca6c8315a53a41d2 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 17 Aug 2023 20:17:04 +0100
Subject: [PATCH 083/206] fix wrong uuid for ForwardDiff

---
 src/AdvancedVI.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 1677be622..c45d49971 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -106,7 +106,7 @@ function __init__()
         @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
             include("../ext/AdvancedVIEnzymeExt.jl")
         end
-        @require ForwardDiff = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
+        @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
             include("../ext/AdvancedVIForwardDiffExt.jl")
         end
         @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin

From 468d5ca3aa94f7c83287633beba23aa5d174ca88 Mon Sep 17 00:00:00 2001
From: Hong Ge <3279477+yebai@users.noreply.github.com>
Date: Thu, 17 Aug 2023 21:44:15 +0100
Subject: [PATCH 084/206] Update CI.yml

---
 .github/workflows/CI.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 158da963c..26f6876f5 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -20,7 +20,7 @@ jobs:
           - windows-latest
         arch:
           - x64
-          - x86
+          # - x86 # Uncomment after https://github.com/JuliaTesting/ReTest.jl/pull/52 is merged
         exclude:
           - os: macOS-latest
             arch: x86

From c07a5118a237fd5eb3a478a88fdcefe06673b366 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 17 Aug 2023 21:49:26 +0100
Subject: [PATCH 085/206] refactor use `sum` and `mean` instead of abusing
 `mapreduce`

---
 src/distributions/location_scale.jl | 4 ++--
 src/objectives/elbo/advi.jl         | 5 ++---
 src/objectives/elbo/entropy.jl      | 5 ++---
 3 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index e901e8deb..3113c679e 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -42,12 +42,12 @@ end
 
 function logpdf(q::VILocationScale, z::AbstractVector{<:Real})
     @unpack location, scale, dist = q
-    mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) - first(logabsdet(scale))
+    sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale))
 end
 
 function _logpdf(q::VILocationScale, z::AbstractVector{<:Real})
     @unpack location, scale, dist = q
-    mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) - first(logabsdet(scale))
+    sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale))
 end
 
 function rand(q::VILocationScale)
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 8bc14bc9b..67af43757 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -55,10 +55,9 @@ function (advi::ADVI)(
     q_η::ContinuousMultivariateDistribution,
     ηs ::AbstractMatrix
 )
-    n_samples = size(ηs, 2)
-    𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ
+    𝔼ℓ = mean(eachcol(ηs)) do ηᵢ
         zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b, ηᵢ)
-        (advi.ℓπ(zᵢ) + logdetjacᵢ) / n_samples
+        (advi.ℓπ(zᵢ) + logdetjacᵢ)
     end
     ℍ  = advi.entropy(q_η, ηs)
     𝔼ℓ + ℍ
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index 0edc47f4e..694eacefd 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -23,9 +23,8 @@ The "sticking the landing" entropy estimator.
 struct StickingTheLandingEntropy <: MonteCarloEntropy end
 
 function (::MonteCarloEntropy)(q, ηs::AbstractMatrix)
-    n_samples = size(ηs, 2)
-    mapreduce(+, eachcol(ηs)) do ηᵢ
-        -logpdf(q, ηᵢ) / n_samples
+    mean(eachcol(ηs)) do ηᵢ
+        -logpdf(q, ηᵢ)
     end
 end
 

From 13a8a445af64690b61137f6791f4f11eb6130a2b Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 17 Aug 2023 22:14:42 +0100
Subject: [PATCH 086/206] remove tests for `FullMonteCarlo`

---
 test/advi_locscale.jl | 1 -
 1 file changed, 1 deletion(-)

diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index e4c81402c..962d31699 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -35,7 +35,6 @@ include("models/utils.jl")
             (objname, objective) ∈ Dict(
                 :ADVIClosedFormEntropy  => (model, b, M) -> ADVI(model, M; b),
                 :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, entropy = StickingTheLandingEntropy()),
-                :ADVIFullMonteCarlo     => (model, b, M) -> ADVI(model, M; b, entropy = FullMonteCarloEntropy()),
             ),
             (adbackname, adbackend) ∈ Dict(
                 :ForwarDiff  => AutoForwardDiff(),

From aadf8d397aad300b6e5d502b8a90bd0f2724d778 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 18 Aug 2023 01:31:58 +0100
Subject: [PATCH 087/206] add tests for the `optimize` interface

---
 test/advi_locscale.jl |  4 +--
 test/optimize.jl      | 84 +++++++++++++++++++++++++++++++++++++++++++
 test/runtests.jl      |  2 ++
 3 files changed, 88 insertions(+), 2 deletions(-)
 create mode 100644 test/optimize.jl

diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 962d31699..bf51199fa 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -38,8 +38,8 @@ include("models/utils.jl")
             ),
             (adbackname, adbackend) ∈ Dict(
                 :ForwarDiff  => AutoForwardDiff(),
-                # :ReverseDiff => AutoReverseDiff(),
-                # :Zygote      => AutoZygote(),
+                :ReverseDiff => AutoReverseDiff(),
+                :Zygote      => AutoZygote(),
                 # :Enzyme      => AutoEnzyme(),
             )
 
diff --git a/test/optimize.jl b/test/optimize.jl
new file mode 100644
index 000000000..3ece467f0
--- /dev/null
+++ b/test/optimize.jl
@@ -0,0 +1,84 @@
+
+using ReTest
+using Bijectors
+using LogDensityProblems
+using Optimisers
+using Distributions
+using PDMats
+using LinearAlgebra
+using SimpleUnPack: @unpack
+
+struct TestModel{M,L,S}
+    model::M
+    μ_true::L
+    L_true::S
+    n_dims::Int
+    is_meanfield::Bool
+end
+
+include("models/normallognormal.jl")
+include("models/utils.jl")
+
+@testset "optimize" begin
+    seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
+    rng  = Philox4x(UInt64, seed, 8)
+
+    T = 1000
+    modelstats = normallognormal_meanfield(Float64; rng)
+
+    @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+    # Global Test Configurations
+    b⁻¹ = Bijectors.bijector(model) |> inverse
+    μ₀  = zeros(Float64, n_dims)
+    L₀  = ones(Float64, n_dims) |> Diagonal
+    q₀  = VIMeanFieldGaussian(μ₀, L₀)
+    obj = ADVI(model, 10; b=b⁻¹)
+
+    adbackend = AutoForwardDiff()
+    optimizer = Optimisers.Adam(1e-2)
+
+    rng                 = Philox4x(UInt64, seed, 8)
+    q_ref, stats_ref, _ = optimize(
+        obj, q₀, T;
+        optimizer,
+        show_progress = false,
+        rng,
+        adbackend,
+    )
+    λ_ref, _ = Optimisers.destructure(q_ref)
+
+    @testset "restructure" begin
+        λ₀, re  = Optimisers.destructure(q₀)
+
+        rng         = Philox4x(UInt64, seed, 8)
+        λ, stats, _ = optimize(
+            obj, re, λ₀, T;
+            optimizer,
+            show_progress = false,
+            rng,
+            adbackend,
+        )
+        @test λ     == λ_ref
+        @test stats == stats_ref
+    end
+
+    @testset "callback" begin
+        rng = Philox4x(UInt64, seed, 8)
+        test_values = rand(rng, T)
+
+        callback!(; stat, est_state, restructure, λ) = begin
+            (test_value = test_values[stat.iteration],)
+        end
+
+        rng         = Philox4x(UInt64, seed, 8)
+        _, stats, _ = optimize(
+            obj, q₀, T;
+            show_progress = false,
+            rng,
+            adbackend,
+            callback!
+        )
+        @test [stat.test_value for stat ∈ stats] == test_values
+    end
+end
diff --git a/test/runtests.jl b/test/runtests.jl
index 68225fd9e..6bd3bc491 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -8,11 +8,13 @@ using Random123
 using Statistics
 using Distributions
 using LinearAlgebra
+
 using AdvancedVI
 
 include("ad.jl")
 include("distributions.jl")
 include("advi_locscale.jl")
+include("optimize.jl")
 
 @main function runtests(patterns...; dry::Bool = false)
     retest(patterns...; dry = dry, verbose = Inf)

From 8c4e13db72524ad31bf6306219436d3b78320237 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 18 Aug 2023 01:33:05 +0100
Subject: [PATCH 088/206] fix turn off Zygote tests for now

---
 test/advi_locscale.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index bf51199fa..e8b4be03d 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -39,7 +39,7 @@ include("models/utils.jl")
             (adbackname, adbackend) ∈ Dict(
                 :ForwarDiff  => AutoForwardDiff(),
                 :ReverseDiff => AutoReverseDiff(),
-                :Zygote      => AutoZygote(),
+                # :Zygote      => AutoZygote(), 
                 # :Enzyme      => AutoEnzyme(),
             )
 

From 0b708e6297d781722a582058d42f7e0917cf49bd Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 18 Aug 2023 03:09:11 +0100
Subject: [PATCH 089/206] remove unused function

---
 src/objectives/elbo/entropy.jl | 2 --
 1 file changed, 2 deletions(-)

diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index 694eacefd..022ed4f62 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -5,8 +5,6 @@ function (::ClosedFormEntropy)(q, ::AbstractMatrix)
     entropy(q)
 end
 
-skip_entropy_gradient(::ClosedFormEntropy) = false
-
 abstract type MonteCarloEntropy <: AbstractEntropyEstimator end
 
 struct FullMonteCarloEntropy <: MonteCarloEntropy end

From be61acd46d457206cbd07386377958d5afb178e3 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 18 Aug 2023 03:51:34 +0100
Subject: [PATCH 090/206] refactor change bijector field name, simplify STL
 estimator

---
 Project.toml                   |  2 ++
 src/AdvancedVI.jl              |  4 +--
 src/objectives/elbo/advi.jl    | 46 +++++++---------------------------
 src/objectives/elbo/entropy.jl | 15 ++++++-----
 test/advi_locscale.jl          |  4 +--
 test/optimize.jl               |  2 +-
 6 files changed, 25 insertions(+), 48 deletions(-)

diff --git a/Project.toml b/Project.toml
index 2092b0cb0..e099308a7 100644
--- a/Project.toml
+++ b/Project.toml
@@ -6,6 +6,7 @@ version = "0.3.0"
 ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
 Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
 Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
+ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
 DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -37,6 +38,7 @@ AdvancedVIZygoteExt = "Zygote"
 [compat]
 ADTypes = "0.1"
 Bijectors = "0.11, 0.12, 0.13"
+ChainRules = "1.53.0"
 DiffResults = "1"
 Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
 DocStringExtensions = "0.8, 0.9"
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index c45d49971..cca220f1a 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -23,7 +23,7 @@ using LogDensityProblems
 
 using ADTypes, DiffResults
 using ADTypes: AbstractADType
-
+using ChainRules: @ignore_derivatives 
 
 using FillArrays
 using PDMats
@@ -74,7 +74,7 @@ export
     ADVI,
     ClosedFormEntropy,
     StickingTheLandingEntropy,
-    FullMonteCarloEntropy
+    MonteCarloEntropy
 
 # Variational Families
 
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 67af43757..788449d18 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -11,7 +11,7 @@ Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017)
 # Keyword Arguments
 - `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy())
 - `cv`: A control variate.
-- `b`: A bijector mapping the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.)
+- `invbij`: A bijective mapping the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.)
 
 # Requirements
 - ``q_{\\lambda}`` implements `rand`.
@@ -23,7 +23,7 @@ struct ADVI{Tlogπ, B,
             EntropyEst <: AbstractEntropyEstimator,
             ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective
     ℓπ::Tlogπ
-    b::B
+    invbij::B
     entropy::EntropyEst
     cv::ControlVar
     n_samples::Int
@@ -31,7 +31,7 @@ struct ADVI{Tlogπ, B,
     function ADVI(prob, n_samples::Int;
                   entropy::AbstractEntropyEstimator = ClosedFormEntropy(),
                   cv::Union{<:AbstractControlVariate, Nothing} = nothing,
-                  b = Bijectors.identity)
+                  invbij = Bijectors.identity)
         cap = LogDensityProblems.capabilities(prob)
         if cap === nothing
             throw(
@@ -41,7 +41,7 @@ struct ADVI{Tlogπ, B,
             )
         end
         ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob)
-        new{typeof(ℓπ), typeof(b), typeof(entropy), typeof(cv)}(ℓπ, b, entropy, cv, n_samples)
+        new{typeof(ℓπ), typeof(invbij), typeof(entropy), typeof(cv)}(ℓπ, invbij, entropy, cv, n_samples)
     end
 end
 
@@ -56,7 +56,7 @@ function (advi::ADVI)(
     ηs ::AbstractMatrix
 )
     𝔼ℓ = mean(eachcol(ηs)) do ηᵢ
-        zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b, ηᵢ)
+        zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.invbij, ηᵢ)
         (advi.ℓπ(zᵢ) + logdetjacᵢ)
     end
     ℍ  = advi.entropy(q_η, ηs)
@@ -86,50 +86,22 @@ function (advi::ADVI)(
     advi(rng, q_η, ηs)
 end
 
-function estimate_advi_gradient_maybe_stl!(
-    rng::AbstractRNG,
-    adbackend::AbstractADType,
-    advi::ADVI{P, B, StickingTheLandingEntropy, CV},
-    λ::Vector{<:Real},
-    restructure,
-    out::DiffResults.MutableDiffResult
-) where {P, B, CV}
-    q_η_stop = restructure(λ)
-    f(λ′) = begin
-        q_η = restructure(λ′)
-        ηs  = rand(rng, q_η, advi.n_samples)
-        -advi(rng, q_η_stop, ηs)
-    end
-    value_and_gradient!(adbackend, f, λ, out)
-end
-
-function estimate_advi_gradient_maybe_stl!(
+function estimate_gradient(
     rng::AbstractRNG,
     adbackend::AbstractADType,
-    advi::ADVI{P, B, <:Union{ClosedFormEntropy, FullMonteCarloEntropy}, CV},
+    advi::ADVI,
+    est_state,
     λ::Vector{<:Real},
     restructure,
     out::DiffResults.MutableDiffResult
-) where {P, B, CV}
+)
     f(λ′) = begin
         q_η = restructure(λ′)
         ηs  = rand(rng, q_η, advi.n_samples)
         -advi(rng, q_η, ηs)
     end
     value_and_gradient!(adbackend, f, λ, out)
-end
 
-function estimate_gradient(
-    rng::AbstractRNG,
-    adbackend::AbstractADType,
-    advi::ADVI,
-    est_state,
-    λ::Vector{<:Real},
-    restructure,
-    out::DiffResults.MutableDiffResult
-)
-    estimate_advi_gradient_maybe_stl!(
-        rng, adbackend, advi, λ, restructure, out)
     nelbo = DiffResults.value(out)
     stat  = (elbo=-nelbo,)
 
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index 022ed4f62..97ccda299 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -5,9 +5,13 @@ function (::ClosedFormEntropy)(q, ::AbstractMatrix)
     entropy(q)
 end
 
-abstract type MonteCarloEntropy <: AbstractEntropyEstimator end
+struct MonteCarloEntropy <: AbstractEntropyEstimator end
 
-struct FullMonteCarloEntropy <: MonteCarloEntropy end
+function (::MonteCarloEntropy)(q, ηs::AbstractMatrix)
+    mean(eachcol(ηs)) do ηᵢ
+        -logpdf(q, ηᵢ)
+    end
+end
 
 """
     StickingTheLandingEntropy()
@@ -18,11 +22,10 @@ The "sticking the landing" entropy estimator.
 - `q` implements `logpdf`.
 - `logpdf(q, η)` must be differentiable by the selected AD framework.
 """
-struct StickingTheLandingEntropy <: MonteCarloEntropy end
+struct StickingTheLandingEntropy <: AbstractEntropyEstimator end
 
-function (::MonteCarloEntropy)(q, ηs::AbstractMatrix)
-    mean(eachcol(ηs)) do ηᵢ
+function (::StickingTheLandingEntropy)(q, ηs::AbstractMatrix)
+    @ignore_derivatives mean(eachcol(ηs)) do ηᵢ
         -logpdf(q, ηᵢ)
     end
 end
-
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index e8b4be03d..71cf22d51 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -33,8 +33,8 @@ include("models/utils.jl")
                 :NormalFullRank           => normal_fullrank,
             ),
             (objname, objective) ∈ Dict(
-                :ADVIClosedFormEntropy  => (model, b, M) -> ADVI(model, M; b),
-                :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, entropy = StickingTheLandingEntropy()),
+                :ADVIClosedFormEntropy  => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹),
+                :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹, entropy = StickingTheLandingEntropy()),
             ),
             (adbackname, adbackend) ∈ Dict(
                 :ForwarDiff  => AutoForwardDiff(),
diff --git a/test/optimize.jl b/test/optimize.jl
index 3ece467f0..d514d2360 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -33,7 +33,7 @@ include("models/utils.jl")
     μ₀  = zeros(Float64, n_dims)
     L₀  = ones(Float64, n_dims) |> Diagonal
     q₀  = VIMeanFieldGaussian(μ₀, L₀)
-    obj = ADVI(model, 10; b=b⁻¹)
+    obj = ADVI(model, 10; invbij=b⁻¹)
 
     adbackend = AutoForwardDiff()
     optimizer = Optimisers.Adam(1e-2)

From fb519a501585fd279a62bce331ea81b19627ba06 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 18 Aug 2023 03:51:59 +0100
Subject: [PATCH 091/206] update documentation

---
 docs/src/advi.md    | 177 +++++++++++++++++++++++++++++++++++++++++---
 docs/src/started.md |   8 +-
 2 files changed, 170 insertions(+), 15 deletions(-)

diff --git a/docs/src/advi.md b/docs/src/advi.md
index 37b3541bb..3719c89e3 100644
--- a/docs/src/advi.md
+++ b/docs/src/advi.md
@@ -66,34 +66,187 @@ ADVI
 The STL control variate was proposed by Roeder *et al.* (2017).
 By slightly modifying the differentiation path, it implicitly forms a control variate of the form of
 ```math
-\mathrm{CV}_{\mathrm{STL}}\left(z\right) \triangleq \mathbb{H}\left(q_{\lambda}\right) + \log q_{\lambda}\left(z\right),
+\begin{aligned}
+  \mathrm{CV}_{\mathrm{STL}}\left(z\right) 
+  &\triangleq 
+  \nabla_{\lambda} \mathbb{H}\left(q_{\lambda}\right) + \nabla_{\lambda} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) \\
+  &=
+  -\nabla_{\lambda} \mathbb{E}_{z \sim q_{\nu}} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) + \nabla_{\lambda} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right)
+\end{aligned}
 ```
-which has a mean of zero.
+where ``\nu = \lambda`` is set to avoid differentiating through the density of ``q_{\lambda}``.
+We can see that this vector-valued function has a mean of zero and is therefore a valid control variate.
  
 Adding this to the closed-form entropy ELBO estimator yields the STL estimator:
 ```math
 \begin{aligned}
-  \widehat{\mathrm{ELBO}}_{\mathrm{STL}}\left(\lambda\right)
-    &\triangleq \mathbb{E}\left[ \log \pi \left(z\right) \right] - \log q_{\lambda} \left(z\right) \\
-    &= \mathbb{E}\left[ \log \pi\left(z\right) \right] 
-      + \mathbb{H}\left(q_{\lambda}\right) - \mathrm{CV}_{\mathrm{STL}}\left(z\right) \\
-    &= \widehat{\mathrm{ELBO}}\left(\lambda\right)
-      - \mathrm{CV}_{\mathrm{STL}}\left(z\right),
+  \widehat{\nabla \mathrm{ELBO}}_{\mathrm{STL}}\left(\lambda\right)
+    &\triangleq \mathbb{E}_{u \sim \varphi}\left[ 
+	  \nabla_{\lambda} \log \pi \left(z_{\lambda}\left(u\right)\right) 
+	  - 
+	  \nabla_{\lambda} \log q_{\nu} \left(z_{\lambda}\left(u\right)\right)
+	\right] 
+	\\
+    &= 
+	\mathbb{E}\left[ \nabla_{\lambda} \log \pi\left(z_{\lambda}\left(u\right)\right) \right] 
+    + 
+	\nabla_{\lambda} \mathbb{H}\left(q_{\lambda}\right) 
+	- 
+	\mathrm{CV}_{\mathrm{STL}}\left(z\right)
+	\\
+    &= 
+	\widehat{\nabla \mathrm{ELBO}}\left(\lambda\right)
+    - 
+	\mathrm{CV}_{\mathrm{STL}}\left(z\right),
 \end{aligned}
 ```
-which has the same expectation, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``.
+which has the same expectation as the original ADVI estimator, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``.
 The conditions for which the STL estimator results in lower variance is still an active subject for research.
 
+The main downside of the STL estimator is that it needs to evaluate and differentiate the log density of ``q_{\lambda}`` in every iteration.
+Depending on the variational family, this might be computationally inefficient or even numerically unstable.
+For example, if ``q_{\lambda}`` is a Gaussian with a full-rank covariance, a back-substitution must be performed at every step, making the per-iteration complexity ``\mathcal{O}(d^3)`` and reducing numerical stability.
+
+
 The STL control variate can be used by changing the entropy estimator using the following object:
 ```@docs
 StickingTheLandingEntropy
 ```
 
-For example:
-```julia
-ADVI(prob, n_samples; entropy = StickingTheLandingEntropy(), b = bijector)
+```@setup stl
+using LogDensityProblems
+using SimpleUnPack
+using PDMats
+using Bijectors
+using LinearAlgebra
+using Plots
+
+using Optimisers
+using ADTypes, ForwardDiff
+import AdvancedVI as AVI
+
+struct NormalLogNormal{MX,SX,MY,SY}
+    μ_x::MX
+    σ_x::SX
+    μ_y::MY
+    Σ_y::SY
+end
+
+function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
+end
+
+function LogDensityProblems.dimension(model::NormalLogNormal)
+    length(model.μ_y) + 1
+end
+
+function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
+    LogDensityProblems.LogDensityOrder{0}()
+end
+
+n_dims = 10
+μ_x    = randn()
+σ_x    = exp.(randn())
+μ_y    = randn(n_dims)
+σ_y    = exp.(randn(n_dims))
+model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2));
+
+d  = LogDensityProblems.dimension(model);
+μ  = randn(d);
+L  = Diagonal(ones(d));
+q0 = AVI.VIMeanFieldGaussian(μ, L)
+
+model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2));
+
+function Bijectors.bijector(model::NormalLogNormal)
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    Bijectors.Stacked(
+        Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
+        [1:1, 2:1+length(μ_y)])
+end
 ```
 
+Let us come back to the example in [Getting Started](@ref getting_started), where a `LogDensityProblem` is given as `model`.
+In this example, the true posterior is contained within the variational family.
+This setting is known as "perfect variational family specification."
+In this case, the STL estimator is able to converge exponentially fast to the true solution.
+
+Recall that the original ADVI objective with a closed-form entropy (CFE) is given as follows:
+```@example stl
+n_montecarlo = 1;
+b            = Bijectors.bijector(model);
+b⁻¹          = inverse(b)
+
+cfe = AVI.ADVI(model, n_montecarlo; invbij = b⁻¹)
+```
+The STL estimator can instead be created as follows:
+```@example stl
+stl = AVI.ADVI(model, n_montecarlo; entropy = AVI.StickingTheLandingEntropy(), invbij = b⁻¹);
+```
+
+```@setup stl
+n_max_iter = 10^4
+
+idx = [1]
+callback!(; stat, est_state, restructure, λ) = begin
+    if mod(idx[1], 100) == 1
+	    idx[:] .+= 1
+        (elbo_accurate = cfe(restructure(λ); n_samples=10^4),)
+	else
+	    idx[:] .+= 1
+        NamedTuple()
+	end
+end
+
+_, stats_cfe, _ = AVI.optimize(
+    cfe,
+    q0,
+    n_max_iter;
+	show_progress = false,
+	callback!     = callback!,
+    adbackend     = AutoForwardDiff(),
+    optimizer     = Optimisers.Adam(1e-3)
+); 
+
+idx[:] .= 1
+_, stats_stl, _ = AVI.optimize(
+    stl,
+    q0,
+    n_max_iter;
+	show_progress = false,
+	callback!     = callback!,
+    adbackend     = AutoForwardDiff(),
+    optimizer     = Optimisers.Adam(1e-3)
+); 
+
+fmc = AVI.ADVI(model, n_montecarlo; entropy = AVI.MonteCarloEntropy(), invbij = b⁻¹)
+idx[:] .= 1
+_, stats_fmc, _ = AVI.optimize(
+    fmc,
+    q0,
+    n_max_iter;
+	show_progress = false,
+	callback!     = callback!,
+    adbackend     = AutoForwardDiff(),
+    optimizer     = Optimisers.Adam(1e-3)
+); 
+
+t     = [stat.iteration     for stat ∈ stats_cfe[1:100:end]]
+y_cfe = [stat.elbo_accurate for stat ∈ stats_cfe[1:100:end]]
+y_stl = [stat.elbo_accurate for stat ∈ stats_stl[1:100:end]]
+y_fmc = [stat.elbo_accurate for stat ∈ stats_fmc[1:100:end]]
+plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1])
+plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1])
+plot!(t, y_fmc, label="ADVI FMC", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1])
+savefig("advi_stl_elbo.svg")
+nothing
+```
+![](advi_stl_elbo.svg)
+
+We can see that the noise of the STL estimator converges to a more accurate solution compared to the CFE estimator.
+
+
 ## References
 1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research.
 2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR.
diff --git a/docs/src/started.md b/docs/src/started.md
index fec60f1a5..b89a140a7 100644
--- a/docs/src/started.md
+++ b/docs/src/started.md
@@ -90,7 +90,7 @@ We now need to select 1. a variational objective, and 2. a variational family.
 Here, we will use the [`ADVI` objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector.
 ```@example advi
 n_montecaro = 10;
-objective   = AVI.ADVI(model, n_montecaro; b = b⁻¹)
+objective   = AVI.ADVI(model, n_montecaro; invbij = b⁻¹)
 ```
 For the variational family, we will use the classic mean-field Gaussian family.
 ```@example advi
@@ -120,10 +120,12 @@ using Plots
 
 t = [stat.iteration for stat ∈ stats]
 y = [stat.elbo for stat ∈ stats]
-plot(t[1:100:end], y[1:100:end])
-savefig("advi_example_elbo.svg"); nothing
+plot(t, y, label="ADVI", xlabel="Iteration", ylabel="ELBO")
+savefig("advi_example_elbo.svg")
+nothing
 ```
 ![](advi_example_elbo.svg)
+
 Further information can be gathered by defining your own `callback!`.
 
 The final ELBO can be estimated by calling the objective directly with a different number of Monte Carlo samples as follows:

From 8682fd92d7746e3f6741bbcb2f2029b12653ba72 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 18 Aug 2023 04:00:17 +0100
Subject: [PATCH 092/206] update STL documentation

---
 docs/src/advi.md | 42 ++++++++----------------------------------
 1 file changed, 8 insertions(+), 34 deletions(-)

diff --git a/docs/src/advi.md b/docs/src/advi.md
index 3719c89e3..0d5b95683 100644
--- a/docs/src/advi.md
+++ b/docs/src/advi.md
@@ -63,6 +63,7 @@ ADVI
 ```
 
 ## The `StickingTheLanding` Control Variate
+
 The STL control variate was proposed by Roeder *et al.* (2017).
 By slightly modifying the differentiation path, it implicitly forms a control variate of the form of
 ```math
@@ -188,63 +189,36 @@ stl = AVI.ADVI(model, n_montecarlo; entropy = AVI.StickingTheLandingEntropy(), i
 ```@setup stl
 n_max_iter = 10^4
 
-idx = [1]
-callback!(; stat, est_state, restructure, λ) = begin
-    if mod(idx[1], 100) == 1
-	    idx[:] .+= 1
-        (elbo_accurate = cfe(restructure(λ); n_samples=10^4),)
-	else
-	    idx[:] .+= 1
-        NamedTuple()
-	end
-end
-
 _, stats_cfe, _ = AVI.optimize(
     cfe,
     q0,
     n_max_iter;
 	show_progress = false,
-	callback!     = callback!,
     adbackend     = AutoForwardDiff(),
     optimizer     = Optimisers.Adam(1e-3)
 ); 
 
-idx[:] .= 1
 _, stats_stl, _ = AVI.optimize(
     stl,
     q0,
     n_max_iter;
 	show_progress = false,
-	callback!     = callback!,
-    adbackend     = AutoForwardDiff(),
-    optimizer     = Optimisers.Adam(1e-3)
-); 
-
-fmc = AVI.ADVI(model, n_montecarlo; entropy = AVI.MonteCarloEntropy(), invbij = b⁻¹)
-idx[:] .= 1
-_, stats_fmc, _ = AVI.optimize(
-    fmc,
-    q0,
-    n_max_iter;
-	show_progress = false,
-	callback!     = callback!,
     adbackend     = AutoForwardDiff(),
     optimizer     = Optimisers.Adam(1e-3)
 ); 
 
-t     = [stat.iteration     for stat ∈ stats_cfe[1:100:end]]
-y_cfe = [stat.elbo_accurate for stat ∈ stats_cfe[1:100:end]]
-y_stl = [stat.elbo_accurate for stat ∈ stats_stl[1:100:end]]
-y_fmc = [stat.elbo_accurate for stat ∈ stats_fmc[1:100:end]]
-plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1])
-plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1])
-plot!(t, y_fmc, label="ADVI FMC", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1])
+t     = [stat.iteration  for stat ∈ stats_cfe]
+y_cfe = [stat.elbo       for stat ∈ stats_cfe]
+y_stl = [stat.elbo       for stat ∈ stats_stl]
+plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO")
+plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO")
 savefig("advi_stl_elbo.svg")
 nothing
 ```
 ![](advi_stl_elbo.svg)
 
-We can see that the noise of the STL estimator converges to a more accurate solution compared to the CFE estimator.
+We can see that the noise of the STL estimator becomes smaller as VI converges.
+However, the difference in speed of convergence may not always be significant.
 
 
 ## References

From 9a16ee109a8b095e36d700389b18a35dc1355c2c Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 18 Aug 2023 04:01:48 +0100
Subject: [PATCH 093/206] update STL documentation

---
 docs/src/advi.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/docs/src/advi.md b/docs/src/advi.md
index 0d5b95683..afb780cb2 100644
--- a/docs/src/advi.md
+++ b/docs/src/advi.md
@@ -218,7 +218,7 @@ nothing
 ![](advi_stl_elbo.svg)
 
 We can see that the noise of the STL estimator becomes smaller as VI converges.
-However, the difference in speed of convergence may not always be significant.
+However, the speed of convergence may not always be significantly different.
 
 
 ## References

From fc74afaef98e8c31ed04c55abdce20d25a644e4d Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 18 Aug 2023 04:03:33 +0100
Subject: [PATCH 094/206] update location scale documentation

---
 docs/src/locscale.md | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/docs/src/locscale.md b/docs/src/locscale.md
index a4bc2dc1f..63ff5cb4e 100644
--- a/docs/src/locscale.md
+++ b/docs/src/locscale.md
@@ -10,6 +10,7 @@ z \stackrel{d}{=} C u + m;\quad u \sim \varphi
 where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*.
 ``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. 
 The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``.
+
 The probability density is given by
 ```math
   q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m))
@@ -19,6 +20,8 @@ and the entropy is given as
   \mathcal{H}(q_{\lambda}) = \mathcal{H}(\varphi) + \log |C|,
 ```
 where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution.
+Notice the ``\mathcal{H}(\varphi)`` does not depend on ``\log |C|``.
+The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution.
 
 ## Constructors
 

From 4be30a1a44c70b4e9356768fd2d8ac662e7bc461 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 20 Aug 2023 00:10:48 +0100
Subject: [PATCH 095/206] fix README

---
 README.md | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/README.md b/README.md
index e8718e7c5..c43748e50 100644
--- a/README.md
+++ b/README.md
@@ -11,14 +11,14 @@ For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bije
 `AdvancedVI` basically expects a `LogDensityProblem`.
 For example, for the normal-log-normal model:
 $$
-\begin{aligned}
+\begin{align*}
 x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\
 y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right)
-\end{aligned}
-$$ 
+\end{align*}
+$$
 
 A `LogDensityProblem` can be implemented as 
-```
+```julia
 using LogDensityProblems
 
 struct NormalLogNormal{MX,SX,MY,SY}

From c58309dbaea25c986074b69a02f5bc6035dfcde8 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 20 Aug 2023 00:12:15 +0100
Subject: [PATCH 096/206] fix math in README

---
 README.md | 1 +
 1 file changed, 1 insertion(+)

diff --git a/README.md b/README.md
index c43748e50..8def2d983 100644
--- a/README.md
+++ b/README.md
@@ -10,6 +10,7 @@ For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bije
 
 `AdvancedVI` basically expects a `LogDensityProblem`.
 For example, for the normal-log-normal model:
+
 $$
 \begin{align*}
 x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\

From 5b5bd3e9c3f4e90ac0d34f789b17c43c199ebd7d Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sun, 20 Aug 2023 03:08:20 +0100
Subject: [PATCH 097/206] add gradient to arguments of callback!, remove
 `gradient_norm` info

---
 src/objectives/elbo/advi.jl | 2 +-
 src/optimize.jl             | 8 ++++----
 test/optimize.jl            | 2 +-
 3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 788449d18..d8719fa7a 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -57,7 +57,7 @@ function (advi::ADVI)(
 )
     𝔼ℓ = mean(eachcol(ηs)) do ηᵢ
         zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.invbij, ηᵢ)
-        (advi.ℓπ(zᵢ) + logdetjacᵢ)
+        advi.ℓπ(zᵢ) + logdetjacᵢ
     end
     ℍ  = advi.entropy(q_η, ηs)
     𝔼ℓ + ℍ
diff --git a/src/optimize.jl b/src/optimize.jl
index 93e6f7546..43b066895 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -26,7 +26,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie
 # Arguments
 - `objective`: Variational Objective.
 - `λ₀`: Initial value of the variational parameters.
-- `restructure`: Function that reconstructs the variational approximation from the flattened parameters.
+- `restruct`: Function that reconstructs the variational approximation from the flattened parameters.
 - `q`: Initial variational approximation. The variational parameters must be extractable through `Optimisers.destructure`.
 - `n_max_iter`: Maximum number of iterations.
 
@@ -35,7 +35,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie
 - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.)
 - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.)
 - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.)
-- `callback!`: Callback function called after every iteration. The signature is `cb(; t, est_state, stats, restructure, λ)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If `objective` is stateful, `est_state` contains its state. (Default: `nothing`.)
+- `callback!`: Callback function called after every iteration. The signature is `cb(; t, est_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient.
 - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.)
 
 # Returns
@@ -76,11 +76,11 @@ function optimize(
 
         g            = DiffResults.gradient(grad_buf)
         opt_state, λ = Optimisers.update!(opt_state, λ, g)
-        stat′ = (iteration=t, gradient_norm=norm(g))
+        stat′ = (iteration = t,)
         stat = merge(stat, stat′)
 
         if !isnothing(callback!)
-            stat′ = callback!(; est_state, stat, restructure, λ)
+            stat′ = callback!(; est_state, stat, λ, g, restructure)
             stat = !isnothing(stat′) ? merge(stat′, stat) : stat
         end
         
diff --git a/test/optimize.jl b/test/optimize.jl
index d514d2360..920a30709 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -67,7 +67,7 @@ include("models/utils.jl")
         rng = Philox4x(UInt64, seed, 8)
         test_values = rand(rng, T)
 
-        callback!(; stat, est_state, restructure, λ) = begin
+        callback!(; stat, est_state, restructure, λ, g) = begin
             (test_value = test_values[stat.iteration],)
         end
 

From 967021d2a1aa827d9dedda00c2b3eae39638986e Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Mon, 21 Aug 2023 23:43:43 +0100
Subject: [PATCH 098/206] fix math in README.md

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
---
 README.md | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/README.md b/README.md
index 8def2d983..83c2e8bce 100644
--- a/README.md
+++ b/README.md
@@ -12,10 +12,10 @@ For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bije
 For example, for the normal-log-normal model:
 
 $$
-\begin{align*}
-x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\
-y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right)
-\end{align*}
+\begin{aligned}
+x &\sim \operatorname{LogNormal}\left(\mu_x, \sigma_x^2\right) \\
+y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right)
+\end{aligned}
 $$
 
 A `LogDensityProblem` can be implemented as 

From 4dab522ff2583f7a622f7c6d35f829f8daf37cf2 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Mon, 21 Aug 2023 23:44:16 +0100
Subject: [PATCH 099/206] fix type constraint in `ZygoteExt`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
---
 ext/AdvancedVIZygoteExt.jl | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl
index b447d0718..c3d891bba 100644
--- a/ext/AdvancedVIZygoteExt.jl
+++ b/ext/AdvancedVIZygoteExt.jl
@@ -12,10 +12,10 @@ else
 end
 
 function AdvancedVI.value_and_gradient!(
-    ad::ADTypes.AutoZygote, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
-) where {T<:Real}
+    ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
+)
     y, back = Zygote.pullback(f, θ)
-    ∇θ = back(one(T))
+    ∇θ = back(one(y))
     DiffResults.value!(out, y)
     DiffResults.gradient!(out, first(∇θ))
     return out

From 8ab2f19d208d82d720f462107b4949a16bfa3513 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Mon, 21 Aug 2023 23:44:58 +0100
Subject: [PATCH 100/206] fix import of `Random`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
---
 src/AdvancedVI.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index cca220f1a..a314e992f 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -4,7 +4,7 @@ module AdvancedVI
 using SimpleUnPack: @unpack, @pack!
 using Accessors
 
-import Random: AbstractRNG, default_rng
+using Random: AbstractRNG, default_rng
 using Distributions
 import Distributions:
     logpdf, _logpdf, rand, _rand!, _rand!,

From 83dec9fdc25226ed2dff13cc576981f09351a229 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Mon, 21 Aug 2023 23:46:08 +0100
Subject: [PATCH 101/206] refactor `__init__()`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
---
 src/AdvancedVI.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index a314e992f..348a6a305 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -101,8 +101,8 @@ if !isdefined(Base, :get_extension) # check whether :get_extension is defined in
     using Requires
 end
 
-function __init__()
-    @static if !isdefined(Base, :get_extension)
+@static if !isdefined(Base, :get_extension)
+    function __init__()
         @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
             include("../ext/AdvancedVIEnzymeExt.jl")
         end

From a3e563cd43d937602e01f36e87247068f2a0b4ab Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Mon, 21 Aug 2023 23:47:08 +0100
Subject: [PATCH 102/206] fix type constraint in definition of
 `value_and_gradient!`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
---
 src/AdvancedVI.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 348a6a305..42cd0dc52 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -39,9 +39,9 @@ const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0")))
     value_and_gradient!(
         ad::ADTypes.AbstractADType,
         f,
-        θ::AbstractVector{T},
+        θ::AbstractVector{<:Real},
         out::DiffResults.MutableDiffResult
-    ) where {T<:Real}
+    )
 
 Compute the value and gradient of a function `f` at `θ` using the automatic
 differentiation backend `ad`.  The result is stored in `out`. 

From 5553bb950840ea9b8c6aba7794f52d58d3fce910 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Mon, 21 Aug 2023 23:52:56 +0100
Subject: [PATCH 103/206] refactor `ZygoteExt`; use `only` instead of `first`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
---
 ext/AdvancedVIZygoteExt.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl
index c3d891bba..7b8f8817a 100644
--- a/ext/AdvancedVIZygoteExt.jl
+++ b/ext/AdvancedVIZygoteExt.jl
@@ -17,7 +17,7 @@ function AdvancedVI.value_and_gradient!(
     y, back = Zygote.pullback(f, θ)
     ∇θ = back(one(y))
     DiffResults.value!(out, y)
-    DiffResults.gradient!(out, first(∇θ))
+    DiffResults.gradient!(out, only(∇θ))
     return out
 end
 

From 79b455746860f7957a7703d8d99fcbd79e613409 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Mon, 21 Aug 2023 23:53:38 +0100
Subject: [PATCH 104/206] refactor type constraint in `ReverseDiffExt`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
---
 ext/AdvancedVIReverseDiffExt.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl
index fd7fbaabc..520cd9ff1 100644
--- a/ext/AdvancedVIReverseDiffExt.jl
+++ b/ext/AdvancedVIReverseDiffExt.jl
@@ -13,8 +13,8 @@ end
 
 # ReverseDiff without compiled tape
 function AdvancedVI.value_and_gradient!(
-    ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
-) where {T<:Real}
+    ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
+)
     tp = ReverseDiff.GradientTape(f, θ)
     ReverseDiff.gradient!(out, tp, θ)
     return out

From 656b44b03f86ea83cba1d8de3953db956ffbe0ab Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Mon, 21 Aug 2023 23:56:28 +0100
Subject: [PATCH 105/206] refactor remove outdated debug mode macro

---
 src/AdvancedVI.jl | 2 --
 1 file changed, 2 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 42cd0dc52..ae0dc6844 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -32,8 +32,6 @@ using Bijectors
 using StatsBase
 using StatsBase: entropy
 
-const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0")))
-
 # derivatives
 """
     value_and_gradient!(

From c7940636a8e08f5a97740f9872c4cffb4e6bed4d Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 00:10:00 +0100
Subject: [PATCH 106/206] fix remove outdated DEBUG mechanism

---
 src/optimize.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 43b066895..57ee80308 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -84,7 +84,7 @@ function optimize(
             stat = !isnothing(stat′) ? merge(stat′, stat) : stat
         end
         
-        AdvancedVI.DEBUG && @debug "Step $t" stat...
+        @debug "Iteration $t" stat...
 
         pm_next!(prog, stat)
         push!(stats, stat)

From 0c5cc1ce8eacc3451bf360bb3c1b0301415242d4 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 00:13:43 +0100
Subject: [PATCH 107/206] fix LaTeX in README: `operatorname` is currently
 broken

---
 README.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/README.md b/README.md
index 83c2e8bce..b3538ccf5 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,7 @@ For example, for the normal-log-normal model:
 
 $$
 \begin{aligned}
-x &\sim \operatorname{LogNormal}\left(\mu_x, \sigma_x^2\right) \\
+x &\sim \mathrm{Log\text{-}Normal}\left(\mu_x, \sigma_x^2\right) \\
 y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right)
 \end{aligned}
 $$

From 29d7d27ca227413275174e12f9258b13b8276fd0 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 01:04:43 +0100
Subject: [PATCH 108/206] remove `SimpleUnPack` dependency

---
 Project.toml                        |  1 -
 docs/Project.toml                   |  1 -
 docs/src/advi.md                    |  9 +++------
 docs/src/started.md                 | 11 ++++-------
 src/AdvancedVI.jl                   |  1 -
 src/distributions/location_scale.jl | 14 +++++++-------
 test/Project.toml                   |  1 -
 test/advi_locscale.jl               |  3 +--
 test/models/normal.jl               |  2 +-
 test/models/normallognormal.jl      |  7 ++++---
 test/optimize.jl                    |  1 -
 11 files changed, 20 insertions(+), 31 deletions(-)

diff --git a/Project.toml b/Project.toml
index e099308a7..29cc559f7 100644
--- a/Project.toml
+++ b/Project.toml
@@ -19,7 +19,6 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
 ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
-SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
 StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
 
diff --git a/docs/Project.toml b/docs/Project.toml
index 182edd3e6..1f4ba59fd 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -9,7 +9,6 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
 Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
 PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
 Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
-SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 
 [compat]
 ADTypes = "0.1.6"
diff --git a/docs/src/advi.md b/docs/src/advi.md
index afb780cb2..88c11feed 100644
--- a/docs/src/advi.md
+++ b/docs/src/advi.md
@@ -116,7 +116,6 @@ StickingTheLandingEntropy
 
 ```@setup stl
 using LogDensityProblems
-using SimpleUnPack
 using PDMats
 using Bijectors
 using LinearAlgebra
@@ -134,7 +133,7 @@ struct NormalLogNormal{MX,SX,MY,SY}
 end
 
 function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    (; μ_x, σ_x, μ_y, Σ_y) = model
     logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
 end
 
@@ -151,17 +150,15 @@ n_dims = 10
 σ_x    = exp.(randn())
 μ_y    = randn(n_dims)
 σ_y    = exp.(randn(n_dims))
-model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2));
+model  = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2));
 
 d  = LogDensityProblems.dimension(model);
 μ  = randn(d);
 L  = Diagonal(ones(d));
 q0 = AVI.VIMeanFieldGaussian(μ, L)
 
-model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2));
-
 function Bijectors.bijector(model::NormalLogNormal)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    (; μ_x, σ_x, μ_y, Σ_y) = model
     Bijectors.Stacked(
         Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
         [1:1, 2:1+length(μ_y)])
diff --git a/docs/src/started.md b/docs/src/started.md
index b89a140a7..4a1d26ecc 100644
--- a/docs/src/started.md
+++ b/docs/src/started.md
@@ -27,7 +27,6 @@ ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly.
 Using the `LogDensityProblems` interface, we the model can be defined as follows:
 ```@example advi
 using LogDensityProblems
-using SimpleUnPack
 
 struct NormalLogNormal{MX,SX,MY,SY}
     μ_x::MX
@@ -37,7 +36,7 @@ struct NormalLogNormal{MX,SX,MY,SY}
 end
 
 function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    (; μ_x, σ_x, μ_y, Σ_y) = model
     logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
 end
 
@@ -51,14 +50,14 @@ end
 ```
 Let's now instantiate the model
 ```@example advi
-using PDMats
+using LinearAlgebra
 
 n_dims = 10
 μ_x    = randn()
 σ_x    = exp.(randn())
 μ_y    = randn(n_dims)
 σ_y    = exp.(randn(n_dims))
-model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2));
+model  = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2));
 ```
 
 Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
@@ -67,7 +66,7 @@ Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to mat
 using Bijectors
 
 function Bijectors.bijector(model::NormalLogNormal)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    (; μ_x, σ_x, μ_y, Σ_y) = model
     Bijectors.Stacked(
         Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
         [1:1, 2:1+length(μ_y)])
@@ -94,8 +93,6 @@ objective   = AVI.ADVI(model, n_montecaro; invbij = b⁻¹)
 ```
 For the variational family, we will use the classic mean-field Gaussian family.
 ```@example advi
-using LinearAlgebra
-
 d = LogDensityProblems.dimension(model);
 μ = randn(d);
 L = Diagonal(ones(d));
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index ae0dc6844..5d0c3f8d2 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -1,7 +1,6 @@
 
 module AdvancedVI
 
-using SimpleUnPack: @unpack, @pack!
 using Accessors
 
 using Random: AbstractRNG, default_rng
diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index 3113c679e..73be42b9e 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -35,42 +35,42 @@ Base.length(q::VILocationScale) = length(q.location)
 Base.size(q::VILocationScale) = size(q.location)
 
 function StatsBase.entropy(q::VILocationScale)
-    @unpack  location, scale, dist = q
+    (; location, scale, dist) = q
     n_dims = length(location)
     n_dims*entropy(dist) + first(logabsdet(scale))
 end
 
 function logpdf(q::VILocationScale, z::AbstractVector{<:Real})
-    @unpack location, scale, dist = q
+    (; location, scale, dist) = q
     sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale))
 end
 
 function _logpdf(q::VILocationScale, z::AbstractVector{<:Real})
-    @unpack location, scale, dist = q
+    (; location, scale, dist) = q
     sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale))
 end
 
 function rand(q::VILocationScale)
-    @unpack location, scale, dist = q
+    (; location, scale, dist) = q
     n_dims = length(location)
     scale*rand(dist, n_dims) + location
 end
 
 function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) 
-    @unpack location, scale, dist = q
+    (; location, scale, dist) = q
     n_dims = length(location)
     scale*rand(rng, dist, n_dims, num_samples) .+ location
 end
 
 function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real})
-    @unpack location, scale, dist = q
+    (; location, scale, dist) = q
     rand!(rng, dist, x)
     x .= scale*x
     return x += location
 end
 
 function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real})
-    @unpack location, scale, dist = q
+    (; location, scale, dist) = q
     rand!(rng, dist, x)
     x *= scale
     return x += location
diff --git a/test/Project.toml b/test/Project.toml
index 2f38c88fa..277b73c87 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -14,7 +14,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Random123 = "74087812-796a-5b5d-8853-05524746bad3"
 ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"
 ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
-SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
 Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
 Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 71cf22d51..c6aee68b9 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -8,7 +8,6 @@ using Optimisers
 using Distributions
 using PDMats
 using LinearAlgebra
-using SimpleUnPack: @unpack
 
 struct TestModel{M,L,S}
     model::M
@@ -48,7 +47,7 @@ include("models/utils.jl")
 
             T = 10000
             modelstats = modelconstr(realtype; rng)
-            @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+            (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats
 
             b    = Bijectors.bijector(model)
             b⁻¹  = inverse(b)
diff --git a/test/models/normal.jl b/test/models/normal.jl
index f60ad5f38..1dfa653ca 100644
--- a/test/models/normal.jl
+++ b/test/models/normal.jl
@@ -5,7 +5,7 @@ struct TestMvNormal{M,S}
 end
 
 function LogDensityProblems.logdensity(model::TestMvNormal, θ)
-    @unpack μ, Σ = model
+    (; μ, Σ) = model
     logpdf(MvNormal(μ, Σ), θ)
 end
 
diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl
index cab73ccee..49da5bf63 100644
--- a/test/models/normallognormal.jl
+++ b/test/models/normallognormal.jl
@@ -7,7 +7,7 @@ struct NormalLogNormal{MX,SX,MY,SY}
 end
 
 function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    (; μ_x, σ_x, μ_y, Σ_y) = model
     logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
 end
 
@@ -20,7 +20,7 @@ function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
 end
 
 function Bijectors.bijector(model::NormalLogNormal)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    (; μ_x, σ_x, μ_y, Σ_y) = model
     Bijectors.Stacked(
         Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
         [1:1, 2:1+length(μ_y)])
@@ -56,7 +56,8 @@ function normallognormal_meanfield(realtype; rng = default_rng())
     μ_y  = randn(rng, realtype, n_dims)
     σ_y  = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
 
-    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2))
+    #model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2))
+    model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2))
 
     μ = vcat(μ_x, μ_y)
     L = vcat(σ_x, σ_y) |> Diagonal
diff --git a/test/optimize.jl b/test/optimize.jl
index 920a30709..c96fa6cd3 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -6,7 +6,6 @@ using Optimisers
 using Distributions
 using PDMats
 using LinearAlgebra
-using SimpleUnPack: @unpack
 
 struct TestModel{M,L,S}
     model::M

From 75eef445a5daea37d79106851b26af292de2542b Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 01:05:08 +0100
Subject: [PATCH 109/206] fix LaTeX in docs and README

---
 README.md           | 2 +-
 docs/src/started.md | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/README.md b/README.md
index b3538ccf5..d9638bfd0 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,7 @@ For example, for the normal-log-normal model:
 
 $$
 \begin{aligned}
-x &\sim \mathrm{Log\text{-}Normal}\left(\mu_x, \sigma_x^2\right) \\
+x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\
 y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right)
 \end{aligned}
 $$
diff --git a/docs/src/started.md b/docs/src/started.md
index 4a1d26ecc..a129fc46f 100644
--- a/docs/src/started.md
+++ b/docs/src/started.md
@@ -18,8 +18,8 @@ optimize
 In this tutorial, we will work with a basic `normal-log-normal` model.
 ```math
 \begin{aligned}
-x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\
-y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right)
+x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\
+y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right)
 \end{aligned}
 ```
 ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly.

From 40574f46864513ced4051867159e0660b2f4b061 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 01:10:29 +0100
Subject: [PATCH 110/206] add warning about forward-mode AD when using
 `LocationScale`

---
 docs/src/locscale.md | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/docs/src/locscale.md b/docs/src/locscale.md
index 63ff5cb4e..8f14a9ad2 100644
--- a/docs/src/locscale.md
+++ b/docs/src/locscale.md
@@ -23,6 +23,9 @@ where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution.
 Notice the ``\mathcal{H}(\varphi)`` does not depend on ``\log |C|``.
 The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution.
 
+!!! warning
+	`LocationScale` and its specializations such as `VIFullRankGaussian` and `VIMeanFieldGaussian` are inefficient with forward-mode differentiation packages like `ForwardDiff`. Especially, they scale poorly with the number of dimensions. Please use reverse-mode differentation packages such as `ReverseDiff` and `Zygote`.
+
 ## Constructors
 
 ```@docs

From 8738256bd44fc38dd49807a69f70da41fa50448c Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 01:14:04 +0100
Subject: [PATCH 111/206] fix documentation

---
 README.md           | 7 +++----
 docs/src/started.md | 2 +-
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/README.md b/README.md
index d9638bfd0..07407fa96 100644
--- a/README.md
+++ b/README.md
@@ -8,17 +8,16 @@ For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bije
 
 ## Examples
 
-`AdvancedVI` basically expects a `LogDensityProblem`.
+`AdvancedVI` expects a `LogDensityProblem`.
 For example, for the normal-log-normal model:
 
 $$
 \begin{aligned}
 x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\
-y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right)
+y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right),
 \end{aligned}
 $$
-
-A `LogDensityProblem` can be implemented as 
+a `LogDensityProblem` can be implemented as 
 ```julia
 using LogDensityProblems
 
diff --git a/docs/src/started.md b/docs/src/started.md
index a129fc46f..b07a5bd3a 100644
--- a/docs/src/started.md
+++ b/docs/src/started.md
@@ -15,7 +15,7 @@ optimize
 ```
 
 ## `ADVI` Example 
-In this tutorial, we will work with a basic `normal-log-normal` model.
+In this tutorial, we will work with a `normal-log-normal` model.
 ```math
 \begin{aligned}
 x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\

From 817374403e58cb11e4e0e3aaee045c350d5bdfdc Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 01:18:52 +0100
Subject: [PATCH 112/206] fix remove reamining use of `@unpack`

---
 test/optimize.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/optimize.jl b/test/optimize.jl
index c96fa6cd3..969304950 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -25,7 +25,7 @@ include("models/utils.jl")
     T = 1000
     modelstats = normallognormal_meanfield(Float64; rng)
 
-    @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+    (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats
 
     # Global Test Configurations
     b⁻¹ = Bijectors.bijector(model) |> inverse

From e0548aecdc3468aa836d58b55aa3be60124d4782 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 21 Aug 2023 22:22:02 -0400
Subject: [PATCH 113/206] Revert "remove `SimpleUnPack` dependency"

This reverts commit 29d7d27ca227413275174e12f9258b13b8276fd0.
---
 Project.toml                        |  1 +
 docs/Project.toml                   |  1 +
 docs/src/advi.md                    |  9 ++++++---
 docs/src/started.md                 | 11 +++++++----
 src/AdvancedVI.jl                   |  1 +
 src/distributions/location_scale.jl | 14 +++++++-------
 test/Project.toml                   |  1 +
 test/advi_locscale.jl               |  3 ++-
 test/models/normal.jl               |  2 +-
 test/models/normallognormal.jl      |  7 +++----
 test/optimize.jl                    |  1 +
 11 files changed, 31 insertions(+), 20 deletions(-)

diff --git a/Project.toml b/Project.toml
index 29cc559f7..e099308a7 100644
--- a/Project.toml
+++ b/Project.toml
@@ -19,6 +19,7 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
 ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
+SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
 StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
 
diff --git a/docs/Project.toml b/docs/Project.toml
index 1f4ba59fd..182edd3e6 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -9,6 +9,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
 Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
 PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
 Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
+SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 
 [compat]
 ADTypes = "0.1.6"
diff --git a/docs/src/advi.md b/docs/src/advi.md
index 88c11feed..afb780cb2 100644
--- a/docs/src/advi.md
+++ b/docs/src/advi.md
@@ -116,6 +116,7 @@ StickingTheLandingEntropy
 
 ```@setup stl
 using LogDensityProblems
+using SimpleUnPack
 using PDMats
 using Bijectors
 using LinearAlgebra
@@ -133,7 +134,7 @@ struct NormalLogNormal{MX,SX,MY,SY}
 end
 
 function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
-    (; μ_x, σ_x, μ_y, Σ_y) = model
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
     logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
 end
 
@@ -150,15 +151,17 @@ n_dims = 10
 σ_x    = exp.(randn())
 μ_y    = randn(n_dims)
 σ_y    = exp.(randn(n_dims))
-model  = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2));
+model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2));
 
 d  = LogDensityProblems.dimension(model);
 μ  = randn(d);
 L  = Diagonal(ones(d));
 q0 = AVI.VIMeanFieldGaussian(μ, L)
 
+model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2));
+
 function Bijectors.bijector(model::NormalLogNormal)
-    (; μ_x, σ_x, μ_y, Σ_y) = model
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
     Bijectors.Stacked(
         Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
         [1:1, 2:1+length(μ_y)])
diff --git a/docs/src/started.md b/docs/src/started.md
index b07a5bd3a..4e2b43801 100644
--- a/docs/src/started.md
+++ b/docs/src/started.md
@@ -27,6 +27,7 @@ ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly.
 Using the `LogDensityProblems` interface, we the model can be defined as follows:
 ```@example advi
 using LogDensityProblems
+using SimpleUnPack
 
 struct NormalLogNormal{MX,SX,MY,SY}
     μ_x::MX
@@ -36,7 +37,7 @@ struct NormalLogNormal{MX,SX,MY,SY}
 end
 
 function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
-    (; μ_x, σ_x, μ_y, Σ_y) = model
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
     logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
 end
 
@@ -50,14 +51,14 @@ end
 ```
 Let's now instantiate the model
 ```@example advi
-using LinearAlgebra
+using PDMats
 
 n_dims = 10
 μ_x    = randn()
 σ_x    = exp.(randn())
 μ_y    = randn(n_dims)
 σ_y    = exp.(randn(n_dims))
-model  = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2));
+model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2));
 ```
 
 Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
@@ -66,7 +67,7 @@ Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to mat
 using Bijectors
 
 function Bijectors.bijector(model::NormalLogNormal)
-    (; μ_x, σ_x, μ_y, Σ_y) = model
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
     Bijectors.Stacked(
         Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
         [1:1, 2:1+length(μ_y)])
@@ -93,6 +94,8 @@ objective   = AVI.ADVI(model, n_montecaro; invbij = b⁻¹)
 ```
 For the variational family, we will use the classic mean-field Gaussian family.
 ```@example advi
+using LinearAlgebra
+
 d = LogDensityProblems.dimension(model);
 μ = randn(d);
 L = Diagonal(ones(d));
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 5d0c3f8d2..ae0dc6844 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -1,6 +1,7 @@
 
 module AdvancedVI
 
+using SimpleUnPack: @unpack, @pack!
 using Accessors
 
 using Random: AbstractRNG, default_rng
diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index 73be42b9e..3113c679e 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -35,42 +35,42 @@ Base.length(q::VILocationScale) = length(q.location)
 Base.size(q::VILocationScale) = size(q.location)
 
 function StatsBase.entropy(q::VILocationScale)
-    (; location, scale, dist) = q
+    @unpack  location, scale, dist = q
     n_dims = length(location)
     n_dims*entropy(dist) + first(logabsdet(scale))
 end
 
 function logpdf(q::VILocationScale, z::AbstractVector{<:Real})
-    (; location, scale, dist) = q
+    @unpack location, scale, dist = q
     sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale))
 end
 
 function _logpdf(q::VILocationScale, z::AbstractVector{<:Real})
-    (; location, scale, dist) = q
+    @unpack location, scale, dist = q
     sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale))
 end
 
 function rand(q::VILocationScale)
-    (; location, scale, dist) = q
+    @unpack location, scale, dist = q
     n_dims = length(location)
     scale*rand(dist, n_dims) + location
 end
 
 function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) 
-    (; location, scale, dist) = q
+    @unpack location, scale, dist = q
     n_dims = length(location)
     scale*rand(rng, dist, n_dims, num_samples) .+ location
 end
 
 function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real})
-    (; location, scale, dist) = q
+    @unpack location, scale, dist = q
     rand!(rng, dist, x)
     x .= scale*x
     return x += location
 end
 
 function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real})
-    (; location, scale, dist) = q
+    @unpack location, scale, dist = q
     rand!(rng, dist, x)
     x *= scale
     return x += location
diff --git a/test/Project.toml b/test/Project.toml
index 277b73c87..2f38c88fa 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -14,6 +14,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Random123 = "74087812-796a-5b5d-8853-05524746bad3"
 ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"
 ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
+SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
 Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
 Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index c6aee68b9..71cf22d51 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -8,6 +8,7 @@ using Optimisers
 using Distributions
 using PDMats
 using LinearAlgebra
+using SimpleUnPack: @unpack
 
 struct TestModel{M,L,S}
     model::M
@@ -47,7 +48,7 @@ include("models/utils.jl")
 
             T = 10000
             modelstats = modelconstr(realtype; rng)
-            (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats
+            @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
 
             b    = Bijectors.bijector(model)
             b⁻¹  = inverse(b)
diff --git a/test/models/normal.jl b/test/models/normal.jl
index 1dfa653ca..f60ad5f38 100644
--- a/test/models/normal.jl
+++ b/test/models/normal.jl
@@ -5,7 +5,7 @@ struct TestMvNormal{M,S}
 end
 
 function LogDensityProblems.logdensity(model::TestMvNormal, θ)
-    (; μ, Σ) = model
+    @unpack μ, Σ = model
     logpdf(MvNormal(μ, Σ), θ)
 end
 
diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl
index 49da5bf63..cab73ccee 100644
--- a/test/models/normallognormal.jl
+++ b/test/models/normallognormal.jl
@@ -7,7 +7,7 @@ struct NormalLogNormal{MX,SX,MY,SY}
 end
 
 function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
-    (; μ_x, σ_x, μ_y, Σ_y) = model
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
     logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
 end
 
@@ -20,7 +20,7 @@ function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
 end
 
 function Bijectors.bijector(model::NormalLogNormal)
-    (; μ_x, σ_x, μ_y, Σ_y) = model
+    @unpack μ_x, σ_x, μ_y, Σ_y = model
     Bijectors.Stacked(
         Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
         [1:1, 2:1+length(μ_y)])
@@ -56,8 +56,7 @@ function normallognormal_meanfield(realtype; rng = default_rng())
     μ_y  = randn(rng, realtype, n_dims)
     σ_y  = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
 
-    #model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2))
-    model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2))
+    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2))
 
     μ = vcat(μ_x, μ_y)
     L = vcat(σ_x, σ_y) |> Diagonal
diff --git a/test/optimize.jl b/test/optimize.jl
index 969304950..c1a604c13 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -6,6 +6,7 @@ using Optimisers
 using Distributions
 using PDMats
 using LinearAlgebra
+using SimpleUnPack: @unpack
 
 struct TestModel{M,L,S}
     model::M

From 6ab95a096e058d21b9df1bb335d09381ce097705 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 21 Aug 2023 22:23:25 -0400
Subject: [PATCH 114/206] Revert "fix remove reamining use of `@unpack`"

This reverts commit 817374403e58cb11e4e0e3aaee045c350d5bdfdc.
---
 test/optimize.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/optimize.jl b/test/optimize.jl
index c1a604c13..920a30709 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -26,7 +26,7 @@ include("models/utils.jl")
     T = 1000
     modelstats = normallognormal_meanfield(Float64; rng)
 
-    (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats
+    @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
 
     # Global Test Configurations
     b⁻¹ = Bijectors.bijector(model) |> inverse

From f0ec242e615fb9f3f7b4b05ea2a687fa9c0e8b0c Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 18:08:01 +0100
Subject: [PATCH 115/206] fix documentation for `optimize`

---
 src/optimize.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 57ee80308..b18c8581d 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -35,7 +35,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie
 - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.)
 - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.)
 - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.)
-- `callback!`: Callback function called after every iteration. The signature is `cb(; t, est_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient.
+- `callback!`: Callback function called after every iteration. The signature is `cb(; est_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient.
 - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.)
 
 # Returns
@@ -80,7 +80,7 @@ function optimize(
         stat = merge(stat, stat′)
 
         if !isnothing(callback!)
-            stat′ = callback!(; est_state, stat, λ, g, restructure)
+            stat′ = callback!(; est_state, stat, restructure, λ, g)
             stat = !isnothing(stat′) ? merge(stat′, stat) : stat
         end
         

From 1d4c1b6877296a7bdca5ed38c9d34c5be3acc827 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 18:08:13 +0100
Subject: [PATCH 116/206] add specializations of `Optimise.destructure` for
 mean-field

* This fixes the poor performance of `ForwardDiff`
* This prevents the zero elements of the mean-field scale being extracted
---
 docs/src/locscale.md                |  3 ---
 src/distributions/location_scale.jl | 35 ++++++++++++++++++++++++-----
 2 files changed, 30 insertions(+), 8 deletions(-)

diff --git a/docs/src/locscale.md b/docs/src/locscale.md
index 8f14a9ad2..63ff5cb4e 100644
--- a/docs/src/locscale.md
+++ b/docs/src/locscale.md
@@ -23,9 +23,6 @@ where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution.
 Notice the ``\mathcal{H}(\varphi)`` does not depend on ``\log |C|``.
 The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution.
 
-!!! warning
-	`LocationScale` and its specializations such as `VIFullRankGaussian` and `VIMeanFieldGaussian` are inefficient with forward-mode differentiation packages like `ForwardDiff`. Especially, they scale poorly with the number of dimensions. Please use reverse-mode differentation packages such as `ReverseDiff` and `Zygote`.
-
 ## Constructors
 
 ```@docs
diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index 3113c679e..9ae749f2d 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -19,9 +19,8 @@ struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution
     dist    ::D
 
     function VILocationScale(location::AbstractVector{<:Real},
-                             scale::Union{<:AbstractTriangular{<:Real},
-                                      <:Diagonal{<:Real}},
-                             dist::ContinuousUnivariateDistribution)
+                             scale   ::Union{<:AbstractTriangular{<:Real}, <:Diagonal{<:Real}},
+                             dist    ::ContinuousUnivariateDistribution)
         # Restricting all the arguments to have the same types creates problems 
         # with dual-variable-based AD frameworks.
         @assert (length(location) == size(scale,1)) && (length(location) == size(scale,2))
@@ -31,6 +30,32 @@ end
 
 Functors.@functor VILocationScale (location, scale)
 
+# Specialization of `Optimisers.destructure` for mean-field location-scale families.
+# These are necessary because we only want to extract the diagonal elements of 
+# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD
+# is very inefficient.
+# begin
+struct RestructureMeanField{L, S<:Diagonal, D}
+    q::VILocationScale{L, S, D}
+end
+
+function (re::RestructureMeanField)(flat::AbstractVector)
+    n_dims   = div(length(flat), 2)
+    location = first(flat, n_dims)
+    scale    = Diagonal(last(flat, n_dims))
+    VILocationScale(location, scale, re.q.dist)
+end
+
+function Optimisers.destructure(
+    q::VILocationScale{L, <:Diagonal, D}
+) where {L, D}
+    @unpack location, scale, dist = q
+    flat   = vcat(location, diag(scale))
+    n_dims = length(location)
+    flat, RestructureMeanField(q)
+end
+# end
+
 Base.length(q::VILocationScale) = length(q.location)
 Base.size(q::VILocationScale) = size(q.location)
 
@@ -42,12 +67,12 @@ end
 
 function logpdf(q::VILocationScale, z::AbstractVector{<:Real})
     @unpack location, scale, dist = q
-    sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale))
+    sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
 end
 
 function _logpdf(q::VILocationScale, z::AbstractVector{<:Real})
     @unpack location, scale, dist = q
-    sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale))
+    sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
 end
 
 function rand(q::VILocationScale)

From 231835f719f6fce86a4e0cf9935431b53cce75c7 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 20:01:41 +0100
Subject: [PATCH 117/206] add test for `Optimisers.destructure` specializations

---
 test/distributions.jl | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/test/distributions.jl b/test/distributions.jl
index 9b18d0207..dcd20696b 100644
--- a/test/distributions.jl
+++ b/test/distributions.jl
@@ -1,6 +1,7 @@
 
 using ReTest
 using Distributions: _logpdf
+using Optimisers 
 
 @testset "distributions" begin
     @testset "$(string(covtype)) $(basedist) $(realtype)" for
@@ -55,4 +56,15 @@ using Distributions: _logpdf
             @test cov(z_samples, dims=2)                    ≈ Σ       rtol=realtype(1e-2)
         end
     end
+
+    @testset "Diagonal destructure" for
+        n_dims = 10
+        μ      = zeros(n_dims)
+        L      = ones(n_dims)
+        q      = VIMeanFieldGaussian(μ, L |> Diagonal)
+        λ, re  = Optimisers.destructure(q)
+
+        @test length(λ) == 2*n_dims
+        @test q         == re(λ)
+    end
 end

From ea2d426c2c9b96de7d640e9ab0add3b4ae853892 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 21:21:54 +0100
Subject: [PATCH 118/206] add specialization of `rand` for meanfield resulting
 in faster AD

---
 src/distributions/location_scale.jl | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index 9ae749f2d..7eb1f708d 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -87,6 +87,16 @@ function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int)
     scale*rand(rng, dist, n_dims, num_samples) .+ location
 end
 
+# This specialization improves AD performance of the sampling path
+function rand(
+    rng::AbstractRNG, q::VILocationScale{L, <:Diagonal, D}, num_samples::Int
+) where {L, D}
+    @unpack location, scale, dist = q
+    n_dims     = length(location)
+    scale_diag = diag(scale)
+    scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location
+end
+
 function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real})
     @unpack location, scale, dist = q
     rand!(rng, dist, x)

From 3033d75938b9d37408bbe081bf73c7954aff09cf Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 21:42:16 +0100
Subject: [PATCH 119/206] add argument checks for `VIMeanFieldGaussian`,
 `VIFullRankGaussian`

---
 src/distributions/location_scale.jl | 29 ++++++++++++++++++++++++-----
 1 file changed, 24 insertions(+), 5 deletions(-)

diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index 7eb1f708d..a7d9fbe4c 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -6,12 +6,15 @@ The location scale variational family broadly represents various variational
 families using `location` and `scale` variational parameters.
 
 It generally represents any distribution for which the sampling path can be
-represented as the following:
+represented as follows:
 ```julia
   d = length(location)
   u = rand(dist, d)
   z = scale*u + location
 ```
+
+!!! note
+    For stable convergence, the initial scale needs to be sufficiently large.
 """
 struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution
     location::L
@@ -112,21 +115,37 @@ function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real})
 end
 
 """
-    VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T})
+    VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}; check_args = true)
 
 This constructs a multivariate Gaussian distribution with a full rank covariance matrix.
 """
-function VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) where {T <: Real}
+function VIFullRankGaussian(
+    μ::AbstractVector{T},
+    L::AbstractTriangular{T};
+    check_args::Bool = true
+) where {T <: Real}
+    @assert isposdef(L) "Scale must be positive definite"
+    if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
+        @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
+    end
     q_base = Normal{T}(zero(T), one(T))
     VILocationScale(μ, L, q_base)
 end
 
 """
-    VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T})
+    VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}; check_args = true)
 
 This constructs a multivariate Gaussian distribution with a diagonal covariance matrix.
 """
-function VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) where {T <: Real}
+function VIMeanFieldGaussian(
+    μ::AbstractVector{T},
+    L::Diagonal{T};
+    check_args::Bool = true
+) where {T <: Real}
+    @assert isposdef(L) "Scale must be positive definite"
+    if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
+        @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
+    end
     q_base = Normal{T}(zero(T), one(T))
     VILocationScale(μ, L, q_base)
 end

From 0cc36c0eb9f4fc701e73e5ee835e5e0ced0c88d1 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 21:55:02 +0100
Subject: [PATCH 120/206] update documentation

---
 docs/src/advi.md                    | 5 ++---
 docs/src/locscale.md                | 4 ++++
 src/distributions/location_scale.jl | 3 ---
 3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/docs/src/advi.md b/docs/src/advi.md
index afb780cb2..2cf6a7732 100644
--- a/docs/src/advi.md
+++ b/docs/src/advi.md
@@ -210,8 +210,8 @@ _, stats_stl, _ = AVI.optimize(
 t     = [stat.iteration  for stat ∈ stats_cfe]
 y_cfe = [stat.elbo       for stat ∈ stats_cfe]
 y_stl = [stat.elbo       for stat ∈ stats_stl]
-plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO")
-plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO")
+plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO", ylims=(-50, 10))
+plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO", ylims=(-50, 10))
 savefig("advi_stl_elbo.svg")
 nothing
 ```
@@ -220,7 +220,6 @@ nothing
 We can see that the noise of the STL estimator becomes smaller as VI converges.
 However, the speed of convergence may not always be significantly different.
 
-
 ## References
 1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research.
 2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR.
diff --git a/docs/src/locscale.md b/docs/src/locscale.md
index 63ff5cb4e..a5966f44b 100644
--- a/docs/src/locscale.md
+++ b/docs/src/locscale.md
@@ -25,6 +25,10 @@ The derivative of the entropy with respect to ``\lambda`` is thus independent of
 
 ## Constructors
 
+!!! note
+    For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned. 
+	Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities.
+
 ```@docs
 VILocationScale
 ```
diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index a7d9fbe4c..ce14d7249 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -12,9 +12,6 @@ represented as follows:
   u = rand(dist, d)
   z = scale*u + location
 ```
-
-!!! note
-    For stable convergence, the initial scale needs to be sufficiently large.
 """
 struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution
     location::L

From b7d3471fdd81b44a07dac068f1d84a260bb4959a Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 22:54:18 +0100
Subject: [PATCH 121/206] fix type instability, bug in argument check in
 `LocationScale`

---
 src/distributions/location_scale.jl | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index ce14d7249..ab12db842 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -57,12 +57,15 @@ end
 # end
 
 Base.length(q::VILocationScale) = length(q.location)
+
 Base.size(q::VILocationScale) = size(q.location)
 
+Base.eltype(::Type{<:VILocationScale{L, S, D}}) where {L, S, D} = eltype(D)
+
 function StatsBase.entropy(q::VILocationScale)
     @unpack  location, scale, dist = q
     n_dims = length(location)
-    n_dims*entropy(dist) + first(logabsdet(scale))
+    n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale))
 end
 
 function logpdf(q::VILocationScale, z::AbstractVector{<:Real})
@@ -121,7 +124,7 @@ function VIFullRankGaussian(
     L::AbstractTriangular{T};
     check_args::Bool = true
 ) where {T <: Real}
-    @assert isposdef(L) "Scale must be positive definite"
+    @assert eigmin(L) > eps(eltype(L)) "Scale must be positive definite"
     if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
         @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
     end
@@ -139,7 +142,7 @@ function VIMeanFieldGaussian(
     L::Diagonal{T};
     check_args::Bool = true
 ) where {T <: Real}
-    @assert isposdef(L) "Scale must be positive definite"
+    @assert eigmin(L) > eps(eltype(L)) "Scale must be a Cholesky factor"
     if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
         @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
     end

From df50e8346e2d3174c6e57f41812e25f5d9c9751e Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 22:57:24 +0100
Subject: [PATCH 122/206] add missing import bug

---
 src/AdvancedVI.jl | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index ae0dc6844..168075420 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -7,7 +7,7 @@ using Accessors
 using Random: AbstractRNG, default_rng
 using Distributions
 import Distributions:
-    logpdf, _logpdf, rand, _rand!, _rand!,
+    logpdf, _logpdf, rand, rand!, _rand!,
     ContinuousMultivariateDistribution
 
 using Functors
@@ -26,7 +26,6 @@ using ADTypes: AbstractADType
 using ChainRules: @ignore_derivatives 
 
 using FillArrays
-using PDMats
 using Bijectors
 
 using StatsBase

From ae3e9b018518b803ed60b6eaf7c5400cdf040a10 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 22:57:43 +0100
Subject: [PATCH 123/206] refactor test, fix type bug in tests for
 `LocationScale`

---
 test/ad.jl            |  2 --
 test/advi_locscale.jl | 24 +++---------------------
 test/distributions.jl | 27 ++++++++++++---------------
 test/models/utils.jl  |  8 --------
 test/optimize.jl      | 18 ------------------
 test/runtests.jl      | 23 +++++++++++++++++++++++
 6 files changed, 38 insertions(+), 64 deletions(-)
 delete mode 100644 test/models/utils.jl

diff --git a/test/ad.jl b/test/ad.jl
index 2c4f802a1..f575b485b 100644
--- a/test/ad.jl
+++ b/test/ad.jl
@@ -1,7 +1,5 @@
 
 using ReTest
-using ForwardDiff, ReverseDiff, Enzyme, Zygote
-using ADTypes
 
 @testset "ad" begin
     @testset "$(adname)" for (adname, adsymbol) ∈ Dict(
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 71cf22d51..a7dcc98bc 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -2,25 +2,6 @@
 const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
 
 using ReTest
-using Bijectors
-using LogDensityProblems
-using Optimisers
-using Distributions
-using PDMats
-using LinearAlgebra
-using SimpleUnPack: @unpack
-
-struct TestModel{M,L,S}
-    model::M
-    μ_true::L
-    L_true::S
-    n_dims::Int
-    is_meanfield::Bool
-end
-
-include("models/normallognormal.jl")
-include("models/normal.jl")
-include("models/utils.jl")
 
 @testset "advi" begin
     @testset "locscale" begin
@@ -55,10 +36,11 @@ include("models/utils.jl")
 
             μ₀ = zeros(realtype, n_dims)
             L₀ = if is_meanfield
-                ones(realtype, n_dims) |> Diagonal
+                FillArrays.Eye(n_dims) |> Diagonal
             else
-                diagm(ones(realtype, n_dims)) |> LowerTriangular
+                FillArrays.Eye(n_dims) |> LowerTriangular
             end
+
             q₀ = if is_meanfield
                 VIMeanFieldGaussian(μ₀, L₀)
             else
diff --git a/test/distributions.jl b/test/distributions.jl
index dcd20696b..563de12dd 100644
--- a/test/distributions.jl
+++ b/test/distributions.jl
@@ -1,7 +1,6 @@
 
 using ReTest
 using Distributions: _logpdf
-using Optimisers 
 
 @testset "distributions" begin
     @testset "$(string(covtype)) $(basedist) $(realtype)" for
@@ -11,35 +10,33 @@ using Optimisers
 
         seed         = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
         rng          = Philox4x(UInt64, seed, 8)
-        realtype     = Float64
-        ϵ            = 1f-2
         n_dims       = 10
         n_montecarlo = 1000_000
 
-        μ  = randn(rng, realtype, n_dims)
-        L₀ = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular
-        Σ  = if covtype == :fullrank
-            Σ = (L₀*L₀' + ϵ*I) |> Hermitian
+        μ = randn(rng, realtype, n_dims)
+        L = if covtype == :fullrank
+            sample_cholesky(rng, realtype, n_dims)
         else
             Diagonal(log.(exp.(randn(rng, realtype, n_dims)) .+ 1))
         end
+        Σ = L*L'
 
-        L = cholesky(Σ).L
         q = if covtype == :fullrank  && basedist == :gaussian
-            VIFullRankGaussian(μ, L |> LowerTriangular)
+            VIFullRankGaussian(μ, L)
         elseif covtype == :meanfield && basedist == :gaussian
-            VIMeanFieldGaussian(μ, L |> Diagonal)
+            VIMeanFieldGaussian(μ, L)
         end
         q_true = if basedist == :gaussian
             MvNormal(μ, Σ)
         end
 
         @testset "logpdf" begin
-            z = randn(rng, realtype, n_dims)
-            @test logpdf(q, z)  ≈ logpdf(q_true, z)
-            @test _logpdf(q, z) ≈ _logpdf(q_true, z)
-            @test eltype(logpdf(q, z))  == realtype
-            @test eltype(_logpdf(q, z)) == realtype
+            z = rand(rng, q)
+            @test eltype(z)             == realtype
+            @test logpdf(q, z)          ≈  logpdf(q_true, z)  rtol=realtype(1e-2)
+            @test _logpdf(q, z)         ≈  _logpdf(q_true, z) rtol=realtype(1e-2)
+            @test eltype(logpdf(q, z))  == realtype 
+            @test eltype(_logpdf(q, z)) == realtype 
         end
 
         @testset "entropy" begin
diff --git a/test/models/utils.jl b/test/models/utils.jl
deleted file mode 100644
index 3d483c46d..000000000
--- a/test/models/utils.jl
+++ /dev/null
@@ -1,8 +0,0 @@
-
-function sample_cholesky(rng::AbstractRNG, type::Type, n_dims::Int)
-    A   = randn(rng, type, n_dims, n_dims) 
-    L   = tril(A)
-    idx = diagind(L)
-    @. L[idx] = log(exp(L[idx]) + 1)
-    L |> LowerTriangular
-end
diff --git a/test/optimize.jl b/test/optimize.jl
index 920a30709..5686b7246 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -1,23 +1,5 @@
 
 using ReTest
-using Bijectors
-using LogDensityProblems
-using Optimisers
-using Distributions
-using PDMats
-using LinearAlgebra
-using SimpleUnPack: @unpack
-
-struct TestModel{M,L,S}
-    model::M
-    μ_true::L
-    L_true::S
-    n_dims::Int
-    is_meanfield::Bool
-end
-
-include("models/normallognormal.jl")
-include("models/utils.jl")
 
 @testset "optimize" begin
     seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
diff --git a/test/runtests.jl b/test/runtests.jl
index 6bd3bc491..803c11c73 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -8,9 +8,32 @@ using Random123
 using Statistics
 using Distributions
 using LinearAlgebra
+using SimpleUnPack: @unpack
+using PDMats
+
+using Bijectors
+using LogDensityProblems
+using Optimisers
+using ADTypes
+using ForwardDiff, ReverseDiff, Zygote
 
 using AdvancedVI
 
+# Utilities
+include("utils.jl")
+
+struct TestModel{M,L,S}
+    model::M
+    μ_true::L
+    L_true::S
+    n_dims::Int
+    is_meanfield::Bool
+end
+
+include("models/normal.jl")
+include("models/normallognormal.jl")
+
+# Tests
 include("ad.jl")
 include("distributions.jl")
 include("advi_locscale.jl")

From e4002cfeb0f8edd7dd8cf02e6ee68f1eb2bf959a Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 22:58:08 +0100
Subject: [PATCH 124/206] add missing compat entries

---
 Project.toml | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index e099308a7..87aa4aac2 100644
--- a/Project.toml
+++ b/Project.toml
@@ -15,7 +15,6 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
 Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
-PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
 ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
@@ -37,17 +36,21 @@ AdvancedVIZygoteExt = "Zygote"
 
 [compat]
 ADTypes = "0.1"
+Accessors = "0.1.32"
 Bijectors = "0.11, 0.12, 0.13"
 ChainRules = "1.53.0"
 DiffResults = "1"
 Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
 DocStringExtensions = "0.8, 0.9"
+FillArrays = "1.6.0"
 ForwardDiff = "0.10.25"
+Functors = "0.4.5"
 LogDensityProblems = "2.1.1"
 Optimisers = "0.2.16"
 ProgressMeter = "1.0.0"
 Requires = "0.5, 1.0"
 ReverseDiff = "1.14"
+SimpleUnPack = "1.1.0"
 StatsBase = "0.32, 0.33, 0.34"
 StatsFuns = "0.8, 0.9, 1"
 julia = "1.6"

From 8c82569208199480676de7b583cd54ff079ba8c5 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 23:19:26 +0100
Subject: [PATCH 125/206] fix missing package import in test

---
 test/Project.toml | 1 +
 test/runtests.jl  | 1 +
 2 files changed, 2 insertions(+)

diff --git a/test/Project.toml b/test/Project.toml
index 2f38c88fa..663d671dc 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -4,6 +4,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
diff --git a/test/runtests.jl b/test/runtests.jl
index 803c11c73..8a6e486ef 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -10,6 +10,7 @@ using Distributions
 using LinearAlgebra
 using SimpleUnPack: @unpack
 using PDMats
+using FillArrays
 
 using Bijectors
 using LogDensityProblems

From c2e751723a63cd00b5f223a390dc34513b94b946 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 23:19:34 +0100
Subject: [PATCH 126/206] add additional tests for sampling `LocationScale`

---
 test/distributions.jl | 41 +++++++++++++++++++++++++++++++++++------
 1 file changed, 35 insertions(+), 6 deletions(-)

diff --git a/test/distributions.jl b/test/distributions.jl
index 563de12dd..c603421ee 100644
--- a/test/distributions.jl
+++ b/test/distributions.jl
@@ -31,6 +31,9 @@ using Distributions: _logpdf
         end
 
         @testset "logpdf" begin
+            seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
+            rng  = Philox4x(UInt64, seed, 8)
+
             z = rand(rng, q)
             @test eltype(z)             == realtype
             @test logpdf(q, z)          ≈  logpdf(q_true, z)  rtol=realtype(1e-2)
@@ -45,12 +48,38 @@ using Distributions: _logpdf
         end
 
         @testset "sampling" begin
-            z_samples  = rand(rng, q, n_montecarlo)
-            threesigma = L
-            @test eltype(z_samples) == realtype
-            @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ       rtol=realtype(1e-2)
-            @test dropdims(var(z_samples, dims=2),  dims=2) ≈ diag(Σ) rtol=realtype(1e-2)
-            @test cov(z_samples, dims=2)                    ≈ Σ       rtol=realtype(1e-2)
+            @testset "rand" begin
+                seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
+                rng  = Philox4x(UInt64, seed, 8)
+
+                z_samples  = mapreduce(x -> rand(rng, q), hcat, 1:n_montecarlo)
+                @test eltype(z_samples) == realtype
+                @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ       rtol=realtype(1e-2)
+                @test dropdims(var(z_samples, dims=2),  dims=2) ≈ diag(Σ) rtol=realtype(1e-2)
+                @test cov(z_samples, dims=2)                    ≈ Σ       rtol=realtype(1e-2)
+            end
+
+            @testset "rand batch" begin
+                seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
+                rng  = Philox4x(UInt64, seed, 8)
+
+                z_samples  = rand(rng, q, n_montecarlo)
+                @test eltype(z_samples) == realtype
+                @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ       rtol=realtype(1e-2)
+                @test dropdims(var(z_samples, dims=2),  dims=2) ≈ diag(Σ) rtol=realtype(1e-2)
+                @test cov(z_samples, dims=2)                    ≈ Σ       rtol=realtype(1e-2)
+            end
+
+            @testset "rand!" begin
+                seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
+                rng  = Philox4x(UInt64, seed, 8)
+
+                z_samples = Array{realtype}(undef, n_dims, n_montecarlo)
+                rand!(rng, q, z_samples)
+                @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ       rtol=realtype(1e-2)
+                @test dropdims(var(z_samples, dims=2),  dims=2) ≈ diag(Σ) rtol=realtype(1e-2)
+                @test cov(z_samples, dims=2)                    ≈ Σ       rtol=realtype(1e-2)
+            end
         end
     end
 

From 3a6f8bf689af5657d817674d84a886d3496864d6 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 23:19:50 +0100
Subject: [PATCH 127/206] fix bug in batch in-place `rand!` for `LocationScale`

---
 src/distributions/location_scale.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index ab12db842..ecb0b672f 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -110,8 +110,8 @@ end
 function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real})
     @unpack location, scale, dist = q
     rand!(rng, dist, x)
-    x *= scale
-    return x += location
+    x[:] = scale*x
+    return x .+= location
 end
 
 """

From b78ef4bf3afe6649d124320595edde71d3031e02 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Tue, 22 Aug 2023 23:39:16 +0100
Subject: [PATCH 128/206] fix bug in inference test initialization

---
 test/advi_locscale.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index a7dcc98bc..76ae3724b 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -38,7 +38,7 @@ using ReTest
             L₀ = if is_meanfield
                 FillArrays.Eye(n_dims) |> Diagonal
             else
-                FillArrays.Eye(n_dims) |> LowerTriangular
+                FillArrays.Eye(n_dims) |> Matrix |> LowerTriangular
             end
 
             q₀ = if is_meanfield

From a1f7e98a612bc8e7b840c4c341ee8870aac9e29f Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Wed, 23 Aug 2023 01:29:50 +0100
Subject: [PATCH 129/206] add missing file

---
 test/utils.jl | 8 ++++++++
 1 file changed, 8 insertions(+)
 create mode 100644 test/utils.jl

diff --git a/test/utils.jl b/test/utils.jl
new file mode 100644
index 000000000..3d483c46d
--- /dev/null
+++ b/test/utils.jl
@@ -0,0 +1,8 @@
+
+function sample_cholesky(rng::AbstractRNG, type::Type, n_dims::Int)
+    A   = randn(rng, type, n_dims, n_dims) 
+    L   = tril(A)
+    idx = diagind(L)
+    @. L[idx] = log(exp(L[idx]) + 1)
+    L |> LowerTriangular
+end

From 8b783eca14a21cc620f311f9f63417e9f31e5de8 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Tue, 22 Aug 2023 21:46:01 -0400
Subject: [PATCH 130/206] fix remove use of  for 1.6

---
 src/distributions/location_scale.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
index ecb0b672f..91b6768ad 100644
--- a/src/distributions/location_scale.jl
+++ b/src/distributions/location_scale.jl
@@ -124,7 +124,7 @@ function VIFullRankGaussian(
     L::AbstractTriangular{T};
     check_args::Bool = true
 ) where {T <: Real}
-    @assert eigmin(L) > eps(eltype(L)) "Scale must be positive definite"
+    @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite"
     if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
         @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
     end
@@ -142,7 +142,7 @@ function VIMeanFieldGaussian(
     L::Diagonal{T};
     check_args::Bool = true
 ) where {T <: Real}
-    @assert eigmin(L) > eps(eltype(L)) "Scale must be a Cholesky factor"
+    @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor"
     if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
         @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
     end

From 12cd9f22611f3bf1a95ea878ade7c3f151957cd9 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Wed, 23 Aug 2023 21:00:51 +0100
Subject: [PATCH 131/206] refactor adjust inference test hyperparameters to be
 more robust

---
 test/advi_locscale.jl | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 76ae3724b..524dc5e2e 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -27,10 +27,11 @@ using ReTest
             seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
             rng  = Philox4x(UInt64, seed, 8)
 
-            T = 10000
             modelstats = modelconstr(realtype; rng)
             @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
 
+            T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
+
             b    = Bijectors.bijector(model)
             b⁻¹  = inverse(b)
 
@@ -53,7 +54,7 @@ using ReTest
                 Δλ₀         = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
                 q, stats, _ = optimize(
                     obj, q₀, T;
-                    optimizer     = Optimisers.Adam(1e-2),
+                    optimizer     = Optimisers.Adam(realtype(η)),
                     show_progress = PROGRESS,
                     rng           = rng,
                     adbackend     = adbackend,
@@ -72,7 +73,7 @@ using ReTest
                 rng         = Philox4x(UInt64, seed, 8)
                 q, stats, _ = optimize(
                     obj, q₀, T;
-                    optimizer     = Optimisers.Adam(realtype(1e-2)),
+                    optimizer     = Optimisers.Adam(realtype(η)),
                     show_progress = PROGRESS,
                     rng           = rng,
                     adbackend     = adbackend,
@@ -83,7 +84,7 @@ using ReTest
                 rng_repl    = Philox4x(UInt64, seed, 8)
                 q, stats, _ = optimize(
                     obj, q₀, T;
-                    optimizer     = Optimisers.Adam(realtype(1e-2)),
+                    optimizer     = Optimisers.Adam(realtype(η)),
                     show_progress = PROGRESS,
                     rng           = rng_repl,
                     adbackend     = adbackend,

From 837c7296467ae20c66f7c061a6142295ebe50b22 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Thu, 24 Aug 2023 02:43:39 +0100
Subject: [PATCH 132/206] refactor `optimize` to return `obj_state`, add warm
 start kwargs

---
 docs/src/advi.md            |  4 +--
 docs/src/started.md         |  4 +--
 src/AdvancedVI.jl           |  6 -----
 src/objectives/elbo/advi.jl | 49 +++++++++++++++++--------------------
 src/optimize.jl             | 27 +++++++++++++-------
 test/advi_locscale.jl       | 12 ++++-----
 test/optimize.jl            | 14 +++++------
 7 files changed, 57 insertions(+), 59 deletions(-)

diff --git a/docs/src/advi.md b/docs/src/advi.md
index 2cf6a7732..3ac904365 100644
--- a/docs/src/advi.md
+++ b/docs/src/advi.md
@@ -189,7 +189,7 @@ stl = AVI.ADVI(model, n_montecarlo; entropy = AVI.StickingTheLandingEntropy(), i
 ```@setup stl
 n_max_iter = 10^4
 
-_, stats_cfe, _ = AVI.optimize(
+_, stats_cfe, _, _ = AVI.optimize(
     cfe,
     q0,
     n_max_iter;
@@ -198,7 +198,7 @@ _, stats_cfe, _ = AVI.optimize(
     optimizer     = Optimisers.Adam(1e-3)
 ); 
 
-_, stats_stl, _ = AVI.optimize(
+_, stats_stl, _, _ = AVI.optimize(
     stl,
     q0,
     n_max_iter;
diff --git a/docs/src/started.md b/docs/src/started.md
index 4e2b43801..f3ae54b1c 100644
--- a/docs/src/started.md
+++ b/docs/src/started.md
@@ -103,8 +103,8 @@ q = AVI.VIMeanFieldGaussian(μ, L)
 ```
 Passing `objective` and the initial variational approximation `q` to `optimize` performs inference.
 ```@example advi
-n_max_iter  = 10^4
-q, stats, _ = AVI.optimize(
+n_max_iter = 10^4
+q, stats, _, _ = AVI.optimize(
     objective,
     q,
     n_max_iter;
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 168075420..9bc3d3166 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -53,14 +53,8 @@ abstract type AbstractVariationalObjective end
 function init              end
 function estimate_gradient end
 
-init(::Nothing) = nothing
-
 # ADVI-specific interfaces
 abstract type AbstractEntropyEstimator end
-abstract type AbstractControlVariate end
-
-function update end
-update(::Nothing, ::Nothing) = (nothing, nothing)
 
 # entropy.jl must preceed advi.jl
 include("objectives/elbo/entropy.jl")
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index d8719fa7a..f9a61d81b 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -19,18 +19,15 @@ Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017)
 
 Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
 """
-struct ADVI{Tlogπ, B,
-            EntropyEst <: AbstractEntropyEstimator,
-            ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective
-    ℓπ::Tlogπ
-    invbij::B
-    entropy::EntropyEst
-    cv::ControlVar
+struct ADVI{P, B, EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective
+    prob     ::P
+    invbij   ::B
+    entropy  ::EntropyEst
     n_samples::Int
 
-    function ADVI(prob, n_samples::Int;
-                  entropy::AbstractEntropyEstimator = ClosedFormEntropy(),
-                  cv::Union{<:AbstractControlVariate, Nothing} = nothing,
+    function ADVI(prob,
+                  n_samples::Int;
+                  entropy  ::AbstractEntropyEstimator = ClosedFormEntropy(),
                   invbij = Bijectors.identity)
         cap = LogDensityProblems.capabilities(prob)
         if cap === nothing
@@ -40,15 +37,16 @@ struct ADVI{Tlogπ, B,
                 ),
             )
         end
-        ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob)
-        new{typeof(ℓπ), typeof(invbij), typeof(entropy), typeof(cv)}(ℓπ, invbij, entropy, cv, n_samples)
+        new{typeof(prob), typeof(invbij), typeof(entropy)}(
+            prob, invbij, entropy, n_samples
+        )
     end
 end
 
 Base.show(io::IO, advi::ADVI) =
-    print(io, "ADVI(entropy=$(advi.entropy), cv=$(advi.cv), n_samples=$(advi.n_samples))")
+    print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))")
 
-init(advi::ADVI) = init(advi.cv)
+init(rng::AbstractRNG, advi::ADVI, λ::AbstractVector, restructure) = nothing
 
 function (advi::ADVI)(
     rng::AbstractRNG,
@@ -57,7 +55,7 @@ function (advi::ADVI)(
 )
     𝔼ℓ = mean(eachcol(ηs)) do ηᵢ
         zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.invbij, ηᵢ)
-        advi.ℓπ(zᵢ) + logdetjacᵢ
+        LogDensityProblems.logdensity(advi.prob, zᵢ) + logdetjacᵢ
     end
     ℍ  = advi.entropy(q_η, ηs)
     𝔼ℓ + ℍ
@@ -78,22 +76,22 @@ Evaluate the ELBO using the ADVI formulation.
 
 """
 function (advi::ADVI)(
-    q_η::ContinuousMultivariateDistribution;
-    rng::AbstractRNG = default_rng(),
-    n_samples::Int = advi.n_samples
+    q_η      ::ContinuousMultivariateDistribution;
+    rng      ::AbstractRNG = default_rng(),
+    n_samples::Int         = advi.n_samples
 )
     ηs = rand(rng, q_η, n_samples)
     advi(rng, q_η, ηs)
 end
 
 function estimate_gradient(
-    rng::AbstractRNG,
-    adbackend::AbstractADType,
-    advi::ADVI,
+    rng          ::AbstractRNG,
+    adbackend    ::AbstractADType,
+    advi         ::ADVI,
     est_state,
-    λ::Vector{<:Real},
+    λ            ::Vector{<:Real},
     restructure,
-    out::DiffResults.MutableDiffResult
+    out          ::DiffResults.MutableDiffResult
 )
     f(λ′) = begin
         q_η = restructure(λ′)
@@ -105,8 +103,5 @@ function estimate_gradient(
     nelbo = DiffResults.value(out)
     stat  = (elbo=-nelbo,)
 
-    est_state, stat′ = update(advi.cv, est_state)
-    stat = !isnothing(stat′) ? merge(stat′, stat) : stat 
-
-    out, est_state, stat
+    out, nothing, stat
 end
diff --git a/src/optimize.jl b/src/optimize.jl
index b18c8581d..54e7ace09 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -35,13 +35,18 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie
 - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.)
 - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.)
 - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.)
-- `callback!`: Callback function called after every iteration. The signature is `cb(; est_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient.
+- `callback!`: Callback function called after every iteration. The signature is `cb(; obj_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `obj_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient.
 - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.)
 
+When resuming from the state of a previous run, use the following keyword arguments:
+- `opt_state`: Initial state of the optimizer.
+- `obj_state`: Initial state of the objective.
+
 # Returns
 - `λ`: Variational parameters optimizing the variational objective.
 - `stats`: Statistics gathered during inference.
 - `opt_state`: Final state of the optimiser.
+- `obj_state`: Final state of the objective.
 """
 function optimize(
     objective    ::AbstractVariationalObjective,
@@ -52,6 +57,8 @@ function optimize(
     optimizer    ::Optimisers.AbstractRule = Optimisers.Adam(),
     rng          ::AbstractRNG             = default_rng(),
     show_progress::Bool                    = true,
+    opt_state                              = nothing,
+    obj_state                              = nothing,
     callback!                              = nothing,
     prog                                   = ProgressMeter.Progress(
         n_max_iter;
@@ -62,16 +69,16 @@ function optimize(
     )              
 )
     λ         = copy(λ₀)
-    opt_state = Optimisers.setup(optimizer, λ)
-    est_state = init(objective)
+    opt_state = isnothing(opt_state) ? Optimisers.setup(optimizer, λ)       : opt_state
+    obj_state = isnothing(obj_state) ? init(rng, objective, λ, restructure) : obj_state
     grad_buf  = DiffResults.GradientResult(λ)
     stats     = NamedTuple[]
 
     for t = 1:n_max_iter
         stat = (iteration=t,)
 
-        grad_buf, est_state, stat′ = estimate_gradient(
-            rng, adbackend, objective, est_state, λ, restructure, grad_buf)
+        grad_buf, obj_state, stat′ = estimate_gradient(
+            rng, adbackend, objective, obj_state, λ, restructure, grad_buf)
         stat = merge(stat, stat′)
 
         g            = DiffResults.gradient(grad_buf)
@@ -80,7 +87,7 @@ function optimize(
         stat = merge(stat, stat′)
 
         if !isnothing(callback!)
-            stat′ = callback!(; est_state, stat, restructure, λ, g)
+            stat′ = callback!(; obj_state, stat, restructure, λ, g)
             stat = !isnothing(stat′) ? merge(stat′, stat) : stat
         end
         
@@ -89,7 +96,7 @@ function optimize(
         pm_next!(prog, stat)
         push!(stats, stat)
     end
-    λ, map(identity, stats), opt_state
+    λ, map(identity, stats), opt_state, obj_state
 end
 
 function optimize(objective ::AbstractVariationalObjective,
@@ -97,6 +104,8 @@ function optimize(objective ::AbstractVariationalObjective,
                   n_max_iter::Int;
                   kwargs...)
     λ, restructure = Optimisers.destructure(q₀)
-    λ, stats, opt_state = optimize(objective, restructure, λ, n_max_iter; kwargs...)
-    restructure(λ), stats, opt_state
+    λ, stats, opt_state, obj_state = optimize(
+        objective, restructure, λ, n_max_iter; kwargs...
+    )
+    restructure(λ), stats, opt_state, obj_state
 end
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 524dc5e2e..e780b0744 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -51,8 +51,8 @@ using ReTest
             obj = objective(model, b⁻¹, 10)
 
             @testset "convergence" begin
-                Δλ₀         = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
-                q, stats, _ = optimize(
+                Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
+                q, stats, _, _ = optimize(
                     obj, q₀, T;
                     optimizer     = Optimisers.Adam(realtype(η)),
                     show_progress = PROGRESS,
@@ -70,8 +70,8 @@ using ReTest
             end
 
             @testset "determinism" begin
-                rng         = Philox4x(UInt64, seed, 8)
-                q, stats, _ = optimize(
+                rng = Philox4x(UInt64, seed, 8)
+                q, stats, _, _ = optimize(
                     obj, q₀, T;
                     optimizer     = Optimisers.Adam(realtype(η)),
                     show_progress = PROGRESS,
@@ -81,8 +81,8 @@ using ReTest
                 μ  = q.location
                 L  = q.scale
 
-                rng_repl    = Philox4x(UInt64, seed, 8)
-                q, stats, _ = optimize(
+                rng_repl = Philox4x(UInt64, seed, 8)
+                q, stats, _, _ = optimize(
                     obj, q₀, T;
                     optimizer     = Optimisers.Adam(realtype(η)),
                     show_progress = PROGRESS,
diff --git a/test/optimize.jl b/test/optimize.jl
index 5686b7246..2369432c9 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -20,8 +20,8 @@ using ReTest
     adbackend = AutoForwardDiff()
     optimizer = Optimisers.Adam(1e-2)
 
-    rng                 = Philox4x(UInt64, seed, 8)
-    q_ref, stats_ref, _ = optimize(
+    rng = Philox4x(UInt64, seed, 8)
+    q_ref, stats_ref, _, _ = optimize(
         obj, q₀, T;
         optimizer,
         show_progress = false,
@@ -33,8 +33,8 @@ using ReTest
     @testset "restructure" begin
         λ₀, re  = Optimisers.destructure(q₀)
 
-        rng         = Philox4x(UInt64, seed, 8)
-        λ, stats, _ = optimize(
+        rng = Philox4x(UInt64, seed, 8)
+        λ, stats, _, _ = optimize(
             obj, re, λ₀, T;
             optimizer,
             show_progress = false,
@@ -49,12 +49,12 @@ using ReTest
         rng = Philox4x(UInt64, seed, 8)
         test_values = rand(rng, T)
 
-        callback!(; stat, est_state, restructure, λ, g) = begin
+        callback!(; stat, obj_state, restructure, λ, g) = begin
             (test_value = test_values[stat.iteration],)
         end
 
-        rng         = Philox4x(UInt64, seed, 8)
-        _, stats, _ = optimize(
+        rng = Philox4x(UInt64, seed, 8)
+        _, stats, _, _ = optimize(
             obj, q₀, T;
             show_progress = false,
             rng,

From 95629a5471f7e3e94a19b8096cd9df73d8dad523 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Wed, 23 Aug 2023 23:19:09 -0400
Subject: [PATCH 133/206] refactor make tests more robust, reduce amount of
 tests

---
 test/advi_locscale.jl          |  2 --
 test/distributions.jl          |  2 +-
 test/models/normal.jl          | 50 ----------------------------------
 test/models/normallognormal.jl |  2 +-
 test/runtests.jl               |  5 +---
 test/utils.jl                  |  8 ------
 6 files changed, 3 insertions(+), 66 deletions(-)
 delete mode 100644 test/models/normal.jl
 delete mode 100644 test/utils.jl

diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index e780b0744..d5250ce83 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -10,8 +10,6 @@ using ReTest
             (modelname, modelconstr) ∈ Dict(
                 :NormalLogNormalMeanField => normallognormal_meanfield,
                 :NormalLogNormalFullRank  => normallognormal_fullrank,
-                :NormalMeanField          => normal_meanfield,
-                :NormalFullRank           => normal_fullrank,
             ),
             (objname, objective) ∈ Dict(
                 :ADVIClosedFormEntropy  => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹),
diff --git a/test/distributions.jl b/test/distributions.jl
index c603421ee..175cc96b8 100644
--- a/test/distributions.jl
+++ b/test/distributions.jl
@@ -15,7 +15,7 @@ using Distributions: _logpdf
 
         μ = randn(rng, realtype, n_dims)
         L = if covtype == :fullrank
-            sample_cholesky(rng, realtype, n_dims)
+	    tril(I + ones(realtype, n_dims, n_dims)/2) |> LowerTriangular
         else
             Diagonal(log.(exp.(randn(rng, realtype, n_dims)) .+ 1))
         end
diff --git a/test/models/normal.jl b/test/models/normal.jl
deleted file mode 100644
index f60ad5f38..000000000
--- a/test/models/normal.jl
+++ /dev/null
@@ -1,50 +0,0 @@
-
-struct TestMvNormal{M,S}
-    μ::M
-    Σ::S
-end
-
-function LogDensityProblems.logdensity(model::TestMvNormal, θ)
-    @unpack μ, Σ = model
-    logpdf(MvNormal(μ, Σ), θ)
-end
-
-function LogDensityProblems.dimension(model::TestMvNormal)
-    length(model.μ)
-end
-
-function LogDensityProblems.capabilities(::Type{<:TestMvNormal})
-    LogDensityProblems.LogDensityOrder{0}()
-end
-
-function Bijectors.bijector(model::TestMvNormal)
-    identity
-end
-
-function normal_fullrank(realtype; rng = default_rng())
-    n_dims = 5
-
-    μ  = randn(rng, realtype, n_dims)
-    L₀ = sample_cholesky(rng, realtype, n_dims)
-    Σ  = L₀*L₀' |> Hermitian
-
-    Σ_chol = cholesky(Σ)
-    model  = TestMvNormal(μ, PDMats.PDMat(Σ, Σ_chol))
-
-    L = Σ_chol.L |> LowerTriangular
-
-    TestModel(model, μ, L, n_dims, false)
-end
-
-function normal_meanfield(realtype; rng = default_rng())
-    n_dims = 5
-
-    μ = randn(rng, realtype, n_dims)
-    σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
-
-    model = TestMvNormal(μ, PDMats.PDiagMat(σ))
-
-    L = σ |> Diagonal
-
-    TestModel(model, μ, L, n_dims, true)
-end
diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl
index cab73ccee..f8b84a1b7 100644
--- a/test/models/normallognormal.jl
+++ b/test/models/normallognormal.jl
@@ -32,7 +32,7 @@ function normallognormal_fullrank(realtype; rng = default_rng())
     μ_x  = randn(rng, realtype)
     σ_x  = ℯ
     μ_y  = randn(rng, realtype, n_dims)
-    L₀_y = sample_cholesky(rng, realtype, n_dims)
+    L₀_y = tril(I + ones(realtype, n_dims, n_dims))/2 |> LowerTriangular
     Σ_y  = L₀_y*L₀_y' |> Hermitian
 
     model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y))
diff --git a/test/runtests.jl b/test/runtests.jl
index 8a6e486ef..0a2c5e66e 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -20,9 +20,7 @@ using ForwardDiff, ReverseDiff, Zygote
 
 using AdvancedVI
 
-# Utilities
-include("utils.jl")
-
+# Models for Inference Tests
 struct TestModel{M,L,S}
     model::M
     μ_true::L
@@ -31,7 +29,6 @@ struct TestModel{M,L,S}
     is_meanfield::Bool
 end
 
-include("models/normal.jl")
 include("models/normallognormal.jl")
 
 # Tests
diff --git a/test/utils.jl b/test/utils.jl
deleted file mode 100644
index 3d483c46d..000000000
--- a/test/utils.jl
+++ /dev/null
@@ -1,8 +0,0 @@
-
-function sample_cholesky(rng::AbstractRNG, type::Type, n_dims::Int)
-    A   = randn(rng, type, n_dims, n_dims) 
-    L   = tril(A)
-    idx = diagind(L)
-    @. L[idx] = log(exp(L[idx]) + 1)
-    L |> LowerTriangular
-end

From 0b4b865ae9376b35b776afca17baf58cea27b095 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 24 Aug 2023 00:31:09 -0400
Subject: [PATCH 134/206] fix remove a cholesky in test model

---
 test/models/normallognormal.jl | 14 +++++++-------
 test/runtests.jl               |  2 +-
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl
index f8b84a1b7..ec591f2cd 100644
--- a/test/models/normallognormal.jl
+++ b/test/models/normallognormal.jl
@@ -29,13 +29,13 @@ end
 function normallognormal_fullrank(realtype; rng = default_rng())
     n_dims = 5
 
-    μ_x  = randn(rng, realtype)
-    σ_x  = ℯ
-    μ_y  = randn(rng, realtype, n_dims)
-    L₀_y = tril(I + ones(realtype, n_dims, n_dims))/2 |> LowerTriangular
-    Σ_y  = L₀_y*L₀_y' |> Hermitian
+    μ_x = randn(rng, realtype)
+    σ_x = ℯ
+    μ_y = randn(rng, realtype, n_dims)
+    L_y = tril(I + ones(realtype, n_dims, n_dims))/2 |> LowerTriangular
+    Σ_y = L_y*L_y' |> Hermitian
 
-    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y))
+    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y)))
 
     Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1)
     Σ[1,1]         = σ_x^2
@@ -56,7 +56,7 @@ function normallognormal_meanfield(realtype; rng = default_rng())
     μ_y  = randn(rng, realtype, n_dims)
     σ_y  = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
 
-    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2))
+    model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2))
 
     μ = vcat(μ_x, μ_y)
     L = vcat(σ_x, σ_y) |> Diagonal
diff --git a/test/runtests.jl b/test/runtests.jl
index 0a2c5e66e..127503be2 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -9,8 +9,8 @@ using Statistics
 using Distributions
 using LinearAlgebra
 using SimpleUnPack: @unpack
-using PDMats
 using FillArrays
+using PDMats
 
 using Bijectors
 using LogDensityProblems

From b49f4ebc163e2feecba38fba2678e650dfbd788d Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 24 Aug 2023 00:31:34 -0400
Subject: [PATCH 135/206] fix compat bounds, remove unused package

---
 Project.toml      | 28 ++++++++++++++--------------
 src/AdvancedVI.jl |  2 +-
 2 files changed, 15 insertions(+), 15 deletions(-)

diff --git a/Project.toml b/Project.toml
index 87aa4aac2..143e20980 100644
--- a/Project.toml
+++ b/Project.toml
@@ -6,7 +6,7 @@ version = "0.3.0"
 ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
 Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
 Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
-ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
+ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
 DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -20,7 +20,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
 SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
-StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
 
 [weakdeps]
 Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
@@ -36,23 +35,24 @@ AdvancedVIZygoteExt = "Zygote"
 
 [compat]
 ADTypes = "0.1"
-Accessors = "0.1.32"
-Bijectors = "0.11, 0.12, 0.13"
-ChainRules = "1.53.0"
+Accessors = "0.1"
+Bijectors = "0.12, 0.13"
+ChainRulesCore = "1.16"
 DiffResults = "1"
-Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
+Distributions = "0.25.87"
 DocStringExtensions = "0.8, 0.9"
-FillArrays = "1.6.0"
-ForwardDiff = "0.10.25"
-Functors = "0.4.5"
-LogDensityProblems = "2.1.1"
+Enzyme = "0.11.7"
+FillArrays = "1.3"
+ForwardDiff = "0.10.36"
+Functors = "0.4"
+LogDensityProblems = "2"
 Optimisers = "0.2.16"
-ProgressMeter = "1.0.0"
-Requires = "0.5, 1.0"
-ReverseDiff = "1.14"
+ProgressMeter = "1.6"
+Requires = "1.0"
+ReverseDiff = "1.15.1"
 SimpleUnPack = "1.1.0"
 StatsBase = "0.32, 0.33, 0.34"
-StatsFuns = "0.8, 0.9, 1"
+Zygote = "0.6.63"
 julia = "1.6"
 
 [extras]
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 9bc3d3166..7272303a8 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -23,7 +23,7 @@ using LogDensityProblems
 
 using ADTypes, DiffResults
 using ADTypes: AbstractADType
-using ChainRules: @ignore_derivatives 
+using ChainRulesCore: @ignore_derivatives 
 
 using FillArrays
 using Bijectors

From 947a070da945505282711f6a45f6c3723b32b7fd Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 24 Aug 2023 00:32:51 -0400
Subject: [PATCH 136/206] bump compat for ADTypes 0.2

---
 Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index 143e20980..075ae92f3 100644
--- a/Project.toml
+++ b/Project.toml
@@ -34,7 +34,7 @@ AdvancedVIReverseDiffExt = "ReverseDiff"
 AdvancedVIZygoteExt = "Zygote"
 
 [compat]
-ADTypes = "0.1"
+ADTypes = "0.1, 0.2"
 Accessors = "0.1"
 Bijectors = "0.12, 0.13"
 ChainRulesCore = "1.16"

From a9b3f483f4ae3bd4ac2d569d21697c8a786c448c Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 24 Aug 2023 00:35:32 -0400
Subject: [PATCH 137/206] fix broken LaTeX in README

---
 README.md | 1 +
 1 file changed, 1 insertion(+)

diff --git a/README.md b/README.md
index 07407fa96..86a57cb68 100644
--- a/README.md
+++ b/README.md
@@ -17,6 +17,7 @@ x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\
 y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right),
 \end{aligned}
 $$
+
 a `LogDensityProblem` can be implemented as 
 ```julia
 using LogDensityProblems

From 54826eb51c0a64bd7fd85b9363300c28e77381d7 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 24 Aug 2023 00:52:35 -0400
Subject: [PATCH 138/206] remove redundant use of PDMats in docs

---
 README.md           | 9 ++++-----
 docs/Project.toml   | 1 -
 docs/src/advi.md    | 5 +----
 docs/src/started.md | 6 ++----
 4 files changed, 7 insertions(+), 14 deletions(-)

diff --git a/README.md b/README.md
index 86a57cb68..695e9ed98 100644
--- a/README.md
+++ b/README.md
@@ -49,7 +49,7 @@ This corresponds to the automatic differentiation VI (ADVI; Kucukelbir *et al.*,
 using Bijectors
 
 function Bijectors.bijector(model::NormalLogNormal)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
+    (; μ_x, σ_x, μ_y, Σ_y) = model
     Bijectors.Stacked(
         Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
         [1:1, 2:1+length(μ_y)])
@@ -60,19 +60,18 @@ A simpler approach is to use `Turing`, where a `Turing.Model` can be automatical
 
 Let us instantiate a random normal-log-normal model.
 ```julia
-using PDMats
+using LinearAlgebra
 
 n_dims = 10
 μ_x    = randn()
 σ_x    = exp.(randn())
 μ_y    = randn(n_dims)
 σ_y    = exp.(randn(n_dims))
-model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2))
+model  = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2))
 ```
 
 ADVI can be used as follows:
 ```julia
-using LinearAlgebra
 using Optimisers
 using ADTypes, ForwardDiff
 import AdvancedVI as AVI
@@ -81,7 +80,7 @@ b     = Bijectors.bijector(model)
 b⁻¹   = inverse(b)
 
 # ADVI objective 
-objective = AVI.ADVI(model, 10; b=b⁻¹)
+objective = AVI.ADVI(model, 10; invbij=b⁻¹)
 
 # Mean-field Gaussian variational family
 d = LogDensityProblems.dimension(model)
diff --git a/docs/Project.toml b/docs/Project.toml
index 182edd3e6..568be1b61 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -7,7 +7,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
 Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
-PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
 Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
 SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 
diff --git a/docs/src/advi.md b/docs/src/advi.md
index 3ac904365..2773dda7d 100644
--- a/docs/src/advi.md
+++ b/docs/src/advi.md
@@ -117,7 +117,6 @@ StickingTheLandingEntropy
 ```@setup stl
 using LogDensityProblems
 using SimpleUnPack
-using PDMats
 using Bijectors
 using LinearAlgebra
 using Plots
@@ -151,15 +150,13 @@ n_dims = 10
 σ_x    = exp.(randn())
 μ_y    = randn(n_dims)
 σ_y    = exp.(randn(n_dims))
-model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2));
+model  = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2));
 
 d  = LogDensityProblems.dimension(model);
 μ  = randn(d);
 L  = Diagonal(ones(d));
 q0 = AVI.VIMeanFieldGaussian(μ, L)
 
-model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2));
-
 function Bijectors.bijector(model::NormalLogNormal)
     @unpack μ_x, σ_x, μ_y, Σ_y = model
     Bijectors.Stacked(
diff --git a/docs/src/started.md b/docs/src/started.md
index f3ae54b1c..e8392fd7a 100644
--- a/docs/src/started.md
+++ b/docs/src/started.md
@@ -51,14 +51,14 @@ end
 ```
 Let's now instantiate the model
 ```@example advi
-using PDMats
+using LinearAlgebra
 
 n_dims = 10
 μ_x    = randn()
 σ_x    = exp.(randn())
 μ_y    = randn(n_dims)
 σ_y    = exp.(randn(n_dims))
-model  = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2));
+model  = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2));
 ```
 
 Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
@@ -94,8 +94,6 @@ objective   = AVI.ADVI(model, n_montecaro; invbij = b⁻¹)
 ```
 For the variational family, we will use the classic mean-field Gaussian family.
 ```@example advi
-using LinearAlgebra
-
 d = LogDensityProblems.dimension(model);
 μ = randn(d);
 L = Diagonal(ones(d));

From 1d1c8ffd320463b6bd9a552227270bb2837344b0 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 24 Aug 2023 01:09:27 -0400
Subject: [PATCH 139/206] fix use `Cholesky` signature supported in 1.6

---
 test/models/normallognormal.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl
index ec591f2cd..e2b9e8165 100644
--- a/test/models/normallognormal.jl
+++ b/test/models/normallognormal.jl
@@ -32,10 +32,10 @@ function normallognormal_fullrank(realtype; rng = default_rng())
     μ_x = randn(rng, realtype)
     σ_x = ℯ
     μ_y = randn(rng, realtype, n_dims)
-    L_y = tril(I + ones(realtype, n_dims, n_dims))/2 |> LowerTriangular
+    L_y = tril(I + ones(realtype, n_dims, n_dims))/2
     Σ_y = L_y*L_y' |> Hermitian
 
-    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y)))
+    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0)))
 
     Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1)
     Σ[1,1]         = σ_x^2

From 7bac95b1dea4b15df7844602966ebf539ee43fe9 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 24 Aug 2023 01:34:21 -0400
Subject: [PATCH 140/206] revert custom variational families and docs

---
 docs/Project.toml                   |  17 -
 docs/make.jl                        |  22 -
 docs/src/advi.md                    | 227 --------
 docs/src/index.md                   |  14 -
 docs/src/locscale.md                |  85 ---
 docs/src/started.md                 | 132 -----
 src/AdvancedVI.jl                   |   9 -
 src/distributions/location_scale.jl | 151 -----
 test/Manifest.toml                  | 866 ++++++++++++++++++++++++++++
 test/Project.toml                   |   2 +
 test/advi_locscale.jl               |  30 +-
 test/distributions.jl               |  96 ---
 test/optimize.jl                    |   4 +-
 test/runtests.jl                    |   5 +-
 14 files changed, 883 insertions(+), 777 deletions(-)
 delete mode 100644 docs/Project.toml
 delete mode 100644 docs/make.jl
 delete mode 100644 docs/src/advi.md
 delete mode 100644 docs/src/index.md
 delete mode 100644 docs/src/locscale.md
 delete mode 100644 docs/src/started.md
 delete mode 100644 src/distributions/location_scale.jl
 create mode 100644 test/Manifest.toml
 delete mode 100644 test/distributions.jl

diff --git a/docs/Project.toml b/docs/Project.toml
deleted file mode 100644
index 568be1b61..000000000
--- a/docs/Project.toml
+++ /dev/null
@@ -1,17 +0,0 @@
-[deps]
-ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
-AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
-Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
-Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
-Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
-ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
-LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
-Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
-Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
-SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
-
-[compat]
-ADTypes = "0.1.6"
-Bijectors = "0.13.6"
-Documenter = "0.26, 0.27"
-LogDensityProblems = "2.1.1"
diff --git a/docs/make.jl b/docs/make.jl
deleted file mode 100644
index 5d3716089..000000000
--- a/docs/make.jl
+++ /dev/null
@@ -1,22 +0,0 @@
-
-using AdvancedVI
-using Documenter
-
-DocMeta.setdocmeta!(
-    AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true
-)
-
-makedocs(;
-    modules  = [AdvancedVI],
-    sitename = "AdvancedVI.jl",
-    repo     = "https://github.com/TuringLang/AdvancedVI.jl/blob/{commit}{path}#{line}",
-    format   = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"),
-    pages    = ["AdvancedVI"        => "index.md",
-                "Getting Started"   => "started.md",
-                "ELBO Maximization" => [
-                    "Automatic Differentiation VI" => "advi.md",   
-                    "Location Scale Family"        => "locscale.md",
-                ]],
-)
-
-deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true)
diff --git a/docs/src/advi.md b/docs/src/advi.md
deleted file mode 100644
index 2773dda7d..000000000
--- a/docs/src/advi.md
+++ /dev/null
@@ -1,227 +0,0 @@
-
-# [Automatic Differentiation Variational Inference](@id advi)
-
-## Introduction
-
-The automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective is a method for estimating the evidence lower bound between a target posterior distribution ``\pi`` and a variational approximation ``q_{\phi,\lambda}``.
-By maximizing ADVI objective, it is equivalent to solving the problem
-
-```math
-  \mathrm{minimize}_{\lambda \in \Lambda}\quad \mathrm{KL}\left(q_{\phi,\lambda}, \pi\right).
-```
-
-The key aspects of the ADVI objective are the followings:
-1. The use of the reparameterization gradient estimator
-2. Automatically match the support of the target posterior through "bijectors."
-
-Thanks to Item 2, the user is free to choose any unconstrained variational family, for which
-bijectors will automatically match the potentially constrained support of the target.
-
-In particular, ADVI implicitly forms a variational approximation ``q_{\phi,\lambda}``
-from a reparameterizable distribution ``q_{\lambda}`` and a bijector ``\phi`` such that
-```math
-z \sim  q_{\phi,\lambda} \qquad\Leftrightarrow\qquad
-z \stackrel{d}{=} \phi^{-1}\left(\eta\right);\quad \eta \sim q_{\lambda} 
-```
-ADVI provides a principled way to compute the evidence lower bound for ``q_{\phi,\lambda}``.
-
-That is,
-
-```math
-\begin{aligned}
-\mathrm{ADVI}\left(\lambda\right)
-&\triangleq
-\mathbb{E}_{\eta \sim q_{\lambda}}\left[
-  \log \pi\left( \phi^{-1}\left( \eta \right) \right)
-\right]
-+ \mathbb{H}\left(q_{\lambda}\right)
-+ \log \lvert J_{\phi^{-1}}\left(\eta\right) \rvert \\
-&=
-\mathbb{E}_{\eta \sim q_{\lambda}}\left[
-  \log \pi\left( \phi^{-1}\left( \eta \right) \right)
-\right]
-+
-\mathbb{E}_{\eta \sim q_{\lambda}}\left[
-  - \log q_{\lambda}\left( \eta \right) \lvert J_{\phi}\left(\eta\right) \rvert
-\right] \\
-&=
-\mathbb{E}_{z \sim q_{\phi,\lambda}}\left[ \log \pi\left(z\right) \right]
-+
-\mathbb{H}\left(q_{\phi,\lambda}\right)
-\end{aligned}
-```
-
-The idea of using the reparameterization gradient estimator for variational inference was first 
-coined by Titsias and Lázaro-Gredilla (2014).
-Bijectors were generalized by Dillon *et al.* (2017) and later implemented in Julia by
-Fjelde *et al.* (2017).
-
-## The `ADVI` Objective
-
-```@docs
-ADVI
-```
-
-## The `StickingTheLanding` Control Variate
-
-The STL control variate was proposed by Roeder *et al.* (2017).
-By slightly modifying the differentiation path, it implicitly forms a control variate of the form of
-```math
-\begin{aligned}
-  \mathrm{CV}_{\mathrm{STL}}\left(z\right) 
-  &\triangleq 
-  \nabla_{\lambda} \mathbb{H}\left(q_{\lambda}\right) + \nabla_{\lambda} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) \\
-  &=
-  -\nabla_{\lambda} \mathbb{E}_{z \sim q_{\nu}} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) + \nabla_{\lambda} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right)
-\end{aligned}
-```
-where ``\nu = \lambda`` is set to avoid differentiating through the density of ``q_{\lambda}``.
-We can see that this vector-valued function has a mean of zero and is therefore a valid control variate.
- 
-Adding this to the closed-form entropy ELBO estimator yields the STL estimator:
-```math
-\begin{aligned}
-  \widehat{\nabla \mathrm{ELBO}}_{\mathrm{STL}}\left(\lambda\right)
-    &\triangleq \mathbb{E}_{u \sim \varphi}\left[ 
-	  \nabla_{\lambda} \log \pi \left(z_{\lambda}\left(u\right)\right) 
-	  - 
-	  \nabla_{\lambda} \log q_{\nu} \left(z_{\lambda}\left(u\right)\right)
-	\right] 
-	\\
-    &= 
-	\mathbb{E}\left[ \nabla_{\lambda} \log \pi\left(z_{\lambda}\left(u\right)\right) \right] 
-    + 
-	\nabla_{\lambda} \mathbb{H}\left(q_{\lambda}\right) 
-	- 
-	\mathrm{CV}_{\mathrm{STL}}\left(z\right)
-	\\
-    &= 
-	\widehat{\nabla \mathrm{ELBO}}\left(\lambda\right)
-    - 
-	\mathrm{CV}_{\mathrm{STL}}\left(z\right),
-\end{aligned}
-```
-which has the same expectation as the original ADVI estimator, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``.
-The conditions for which the STL estimator results in lower variance is still an active subject for research.
-
-The main downside of the STL estimator is that it needs to evaluate and differentiate the log density of ``q_{\lambda}`` in every iteration.
-Depending on the variational family, this might be computationally inefficient or even numerically unstable.
-For example, if ``q_{\lambda}`` is a Gaussian with a full-rank covariance, a back-substitution must be performed at every step, making the per-iteration complexity ``\mathcal{O}(d^3)`` and reducing numerical stability.
-
-
-The STL control variate can be used by changing the entropy estimator using the following object:
-```@docs
-StickingTheLandingEntropy
-```
-
-```@setup stl
-using LogDensityProblems
-using SimpleUnPack
-using Bijectors
-using LinearAlgebra
-using Plots
-
-using Optimisers
-using ADTypes, ForwardDiff
-import AdvancedVI as AVI
-
-struct NormalLogNormal{MX,SX,MY,SY}
-    μ_x::MX
-    σ_x::SX
-    μ_y::MY
-    Σ_y::SY
-end
-
-function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
-    logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
-end
-
-function LogDensityProblems.dimension(model::NormalLogNormal)
-    length(model.μ_y) + 1
-end
-
-function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
-    LogDensityProblems.LogDensityOrder{0}()
-end
-
-n_dims = 10
-μ_x    = randn()
-σ_x    = exp.(randn())
-μ_y    = randn(n_dims)
-σ_y    = exp.(randn(n_dims))
-model  = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2));
-
-d  = LogDensityProblems.dimension(model);
-μ  = randn(d);
-L  = Diagonal(ones(d));
-q0 = AVI.VIMeanFieldGaussian(μ, L)
-
-function Bijectors.bijector(model::NormalLogNormal)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
-    Bijectors.Stacked(
-        Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
-        [1:1, 2:1+length(μ_y)])
-end
-```
-
-Let us come back to the example in [Getting Started](@ref getting_started), where a `LogDensityProblem` is given as `model`.
-In this example, the true posterior is contained within the variational family.
-This setting is known as "perfect variational family specification."
-In this case, the STL estimator is able to converge exponentially fast to the true solution.
-
-Recall that the original ADVI objective with a closed-form entropy (CFE) is given as follows:
-```@example stl
-n_montecarlo = 1;
-b            = Bijectors.bijector(model);
-b⁻¹          = inverse(b)
-
-cfe = AVI.ADVI(model, n_montecarlo; invbij = b⁻¹)
-```
-The STL estimator can instead be created as follows:
-```@example stl
-stl = AVI.ADVI(model, n_montecarlo; entropy = AVI.StickingTheLandingEntropy(), invbij = b⁻¹);
-```
-
-```@setup stl
-n_max_iter = 10^4
-
-_, stats_cfe, _, _ = AVI.optimize(
-    cfe,
-    q0,
-    n_max_iter;
-	show_progress = false,
-    adbackend     = AutoForwardDiff(),
-    optimizer     = Optimisers.Adam(1e-3)
-); 
-
-_, stats_stl, _, _ = AVI.optimize(
-    stl,
-    q0,
-    n_max_iter;
-	show_progress = false,
-    adbackend     = AutoForwardDiff(),
-    optimizer     = Optimisers.Adam(1e-3)
-); 
-
-t     = [stat.iteration  for stat ∈ stats_cfe]
-y_cfe = [stat.elbo       for stat ∈ stats_cfe]
-y_stl = [stat.elbo       for stat ∈ stats_stl]
-plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO", ylims=(-50, 10))
-plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO", ylims=(-50, 10))
-savefig("advi_stl_elbo.svg")
-nothing
-```
-![](advi_stl_elbo.svg)
-
-We can see that the noise of the STL estimator becomes smaller as VI converges.
-However, the speed of convergence may not always be significantly different.
-
-## References
-1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research.
-2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR.
-3. Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604.
-4. Fjelde, T. E., Xu, K., Tarek, M., Yalburgi, S., & Ge, H. (2020, February). Bijectors. jl: Flexible transformations for probability distributions. In Symposium on Advances in Approximate Bayesian Inference (pp. 1-17). PMLR.
-5. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30.
-
-
diff --git a/docs/src/index.md b/docs/src/index.md
deleted file mode 100644
index dea6d405d..000000000
--- a/docs/src/index.md
+++ /dev/null
@@ -1,14 +0,0 @@
-```@meta
-CurrentModule = AdvancedVI
-```
-
-# AdvancedVI
-
-## Introduction
-[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms.
-VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness.
-`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem.
-
-## Provided Algorithms
-`AdvancedVI` currently provides the following algorithm for evidence lower bound maximization:
-- [Automatic Differentiation Variational Inference](@ref advi)
diff --git a/docs/src/locscale.md b/docs/src/locscale.md
deleted file mode 100644
index a5966f44b..000000000
--- a/docs/src/locscale.md
+++ /dev/null
@@ -1,85 +0,0 @@
-
-# [Location-Scale Variational Family](@id locscale)
-
-## Introduction
-The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as
-```math
-z \sim  q_{\lambda} \qquad\Leftrightarrow\qquad
-z \stackrel{d}{=} C u + m;\quad u \sim \varphi
-```
-where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*.
-``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. 
-The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``.
-
-The probability density is given by
-```math
-  q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m))
-```
-and the entropy is given as
-```math
-  \mathcal{H}(q_{\lambda}) = \mathcal{H}(\varphi) + \log |C|,
-```
-where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution.
-Notice the ``\mathcal{H}(\varphi)`` does not depend on ``\log |C|``.
-The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution.
-
-## Constructors
-
-!!! note
-    For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned. 
-	Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities.
-
-```@docs
-VILocationScale
-```
-
-```@docs
-VIFullRankGaussian
-VIMeanFieldGaussian
-```
-
-## Gaussian Variational Families
-
-Gaussian variational family:
-```julia
-using AdvancedVI, LinearAlgebra, Distributions;
-μ = zeros(2);
-
-L = diagm(ones(2)) |> LowerTriangular;
-q = VIFullRankGaussian(μ, L)
-
-L = ones(2) |> Diagonal;
-q = VIMeanFieldGaussian(μ, L)
-```
-
-## Non-Gaussian Variational Families
-Sudent-T Variational Family:
-
-```julia
-using AdvancedVI, LinearAlgebra, Distributions;
-μ = zeros(2);
-ν = 3;
-
-# Full-Rank 
-L = diagm(ones(2)) |> LowerTriangular;
-q = VILocationScale(μ, L, TDist(ν))
-
-# Mean-Field
-L = ones(2) |> Diagonal;
-q = VILocationScale(μ, L, TDist(ν))
-```
-
-Multivariate Laplace family:
-```julia
-using AdvancedVI, LinearAlgebra, Distributions;
-μ = zeros(2);
-
-# Full-Rank 
-L = diagm(ones(2)) |> LowerTriangular;
-q = VILocationScale(μ, L, Laplace())
-
-# Mean-Field
-L = ones(2) |> Diagonal;
-q = VILocationScale(μ, L, Laplace())
-```
-
diff --git a/docs/src/started.md b/docs/src/started.md
deleted file mode 100644
index e8392fd7a..000000000
--- a/docs/src/started.md
+++ /dev/null
@@ -1,132 +0,0 @@
-
-# [Getting Started with `AdvancedVI`](@id getting_started)
-
-## General Usage
-Each VI algorithm provides the followings:
-1. Variational families supported by each VI algorithm.
-2. A variational objective corresponding to the VI algorithm.
-Note that each variational family is subject to its own constraints.
-Thus, please refer to the documentation of the variational inference algorithm of interest. 
-
-To use `AdvancedVI`, a user needs to select a `variational family`, `variational objective`,  and feed them into `optimize`.
-
-```@docs
-optimize
-```
-
-## `ADVI` Example 
-In this tutorial, we will work with a `normal-log-normal` model.
-```math
-\begin{aligned}
-x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\
-y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right)
-\end{aligned}
-```
-ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly.
-
-Using the `LogDensityProblems` interface, we the model can be defined as follows:
-```@example advi
-using LogDensityProblems
-using SimpleUnPack
-
-struct NormalLogNormal{MX,SX,MY,SY}
-    μ_x::MX
-    σ_x::SX
-    μ_y::MY
-    Σ_y::SY
-end
-
-function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
-    logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
-end
-
-function LogDensityProblems.dimension(model::NormalLogNormal)
-    length(model.μ_y) + 1
-end
-
-function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
-    LogDensityProblems.LogDensityOrder{0}()
-end
-```
-Let's now instantiate the model
-```@example advi
-using LinearAlgebra
-
-n_dims = 10
-μ_x    = randn()
-σ_x    = exp.(randn())
-μ_y    = randn(n_dims)
-σ_y    = exp.(randn(n_dims))
-model  = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2));
-```
-
-Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
-Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation.
-```@example advi
-using Bijectors
-
-function Bijectors.bijector(model::NormalLogNormal)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
-    Bijectors.Stacked(
-        Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
-        [1:1, 2:1+length(μ_y)])
-end
-
-b   = Bijectors.bijector(model);
-b⁻¹ = inverse(b)
-```
-
-Let's now load `AdvancedVI`.
-Since ADVI relies on automatic differentiation (AD), hence the "AD" in "ADVI", we need to load an AD library, *before* loading `AdvancedVI`.
-Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface.
-Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`.
-```@example advi
-using Optimisers
-using ADTypes, ForwardDiff
-import AdvancedVI as AVI
-```
-We now need to select 1. a variational objective, and 2. a variational family.
-Here, we will use the [`ADVI` objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector.
-```@example advi
-n_montecaro = 10;
-objective   = AVI.ADVI(model, n_montecaro; invbij = b⁻¹)
-```
-For the variational family, we will use the classic mean-field Gaussian family.
-```@example advi
-d = LogDensityProblems.dimension(model);
-μ = randn(d);
-L = Diagonal(ones(d));
-q = AVI.VIMeanFieldGaussian(μ, L)
-```
-Passing `objective` and the initial variational approximation `q` to `optimize` performs inference.
-```@example advi
-n_max_iter = 10^4
-q, stats, _, _ = AVI.optimize(
-    objective,
-    q,
-    n_max_iter;
-    adbackend = AutoForwardDiff(),
-    optimizer = Optimisers.Adam(1e-3)
-); 
-```
-
-The selected inference procedure stores per-iteration statistics into `stats`.
-For instance, the ELBO can be ploted as follows:
-```@example advi
-using Plots
-
-t = [stat.iteration for stat ∈ stats]
-y = [stat.elbo for stat ∈ stats]
-plot(t, y, label="ADVI", xlabel="Iteration", ylabel="ELBO")
-savefig("advi_example_elbo.svg")
-nothing
-```
-![](advi_example_elbo.svg)
-
-Further information can be gathered by defining your own `callback!`.
-
-The final ELBO can be estimated by calling the objective directly with a different number of Monte Carlo samples as follows:
-```@example advi
-ELBO = objective(q; n_samples=10^4)
-```
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 7272303a8..da8b05bb3 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -67,15 +67,6 @@ export
     StickingTheLandingEntropy,
     MonteCarloEntropy
 
-# Variational Families
-
-include("distributions/location_scale.jl")
-
-export
-    VILocationScale,
-    VIFullRankGaussian,
-    VIMeanFieldGaussian
-
 # Optimization Routine
 
 function optimize end
diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl
deleted file mode 100644
index 91b6768ad..000000000
--- a/src/distributions/location_scale.jl
+++ /dev/null
@@ -1,151 +0,0 @@
-
-"""
-    VILocationScale(location, scale, dist) <: ContinuousMultivariateDistribution
-
-The location scale variational family broadly represents various variational
-families using `location` and `scale` variational parameters.
-
-It generally represents any distribution for which the sampling path can be
-represented as follows:
-```julia
-  d = length(location)
-  u = rand(dist, d)
-  z = scale*u + location
-```
-"""
-struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution
-    location::L
-    scale   ::S
-    dist    ::D
-
-    function VILocationScale(location::AbstractVector{<:Real},
-                             scale   ::Union{<:AbstractTriangular{<:Real}, <:Diagonal{<:Real}},
-                             dist    ::ContinuousUnivariateDistribution)
-        # Restricting all the arguments to have the same types creates problems 
-        # with dual-variable-based AD frameworks.
-        @assert (length(location) == size(scale,1)) && (length(location) == size(scale,2))
-        new{typeof(location), typeof(scale), typeof(dist)}(location, scale, dist)
-    end
-end
-
-Functors.@functor VILocationScale (location, scale)
-
-# Specialization of `Optimisers.destructure` for mean-field location-scale families.
-# These are necessary because we only want to extract the diagonal elements of 
-# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD
-# is very inefficient.
-# begin
-struct RestructureMeanField{L, S<:Diagonal, D}
-    q::VILocationScale{L, S, D}
-end
-
-function (re::RestructureMeanField)(flat::AbstractVector)
-    n_dims   = div(length(flat), 2)
-    location = first(flat, n_dims)
-    scale    = Diagonal(last(flat, n_dims))
-    VILocationScale(location, scale, re.q.dist)
-end
-
-function Optimisers.destructure(
-    q::VILocationScale{L, <:Diagonal, D}
-) where {L, D}
-    @unpack location, scale, dist = q
-    flat   = vcat(location, diag(scale))
-    n_dims = length(location)
-    flat, RestructureMeanField(q)
-end
-# end
-
-Base.length(q::VILocationScale) = length(q.location)
-
-Base.size(q::VILocationScale) = size(q.location)
-
-Base.eltype(::Type{<:VILocationScale{L, S, D}}) where {L, S, D} = eltype(D)
-
-function StatsBase.entropy(q::VILocationScale)
-    @unpack  location, scale, dist = q
-    n_dims = length(location)
-    n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale))
-end
-
-function logpdf(q::VILocationScale, z::AbstractVector{<:Real})
-    @unpack location, scale, dist = q
-    sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
-end
-
-function _logpdf(q::VILocationScale, z::AbstractVector{<:Real})
-    @unpack location, scale, dist = q
-    sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
-end
-
-function rand(q::VILocationScale)
-    @unpack location, scale, dist = q
-    n_dims = length(location)
-    scale*rand(dist, n_dims) + location
-end
-
-function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) 
-    @unpack location, scale, dist = q
-    n_dims = length(location)
-    scale*rand(rng, dist, n_dims, num_samples) .+ location
-end
-
-# This specialization improves AD performance of the sampling path
-function rand(
-    rng::AbstractRNG, q::VILocationScale{L, <:Diagonal, D}, num_samples::Int
-) where {L, D}
-    @unpack location, scale, dist = q
-    n_dims     = length(location)
-    scale_diag = diag(scale)
-    scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location
-end
-
-function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real})
-    @unpack location, scale, dist = q
-    rand!(rng, dist, x)
-    x .= scale*x
-    return x += location
-end
-
-function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real})
-    @unpack location, scale, dist = q
-    rand!(rng, dist, x)
-    x[:] = scale*x
-    return x .+= location
-end
-
-"""
-    VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}; check_args = true)
-
-This constructs a multivariate Gaussian distribution with a full rank covariance matrix.
-"""
-function VIFullRankGaussian(
-    μ::AbstractVector{T},
-    L::AbstractTriangular{T};
-    check_args::Bool = true
-) where {T <: Real}
-    @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite"
-    if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
-        @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
-    end
-    q_base = Normal{T}(zero(T), one(T))
-    VILocationScale(μ, L, q_base)
-end
-
-"""
-    VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}; check_args = true)
-
-This constructs a multivariate Gaussian distribution with a diagonal covariance matrix.
-"""
-function VIMeanFieldGaussian(
-    μ::AbstractVector{T},
-    L::Diagonal{T};
-    check_args::Bool = true
-) where {T <: Real}
-    @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor"
-    if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
-        @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
-    end
-    q_base = Normal{T}(zero(T), one(T))
-    VILocationScale(μ, L, q_base)
-end
diff --git a/test/Manifest.toml b/test/Manifest.toml
new file mode 100644
index 000000000..220b42bb6
--- /dev/null
+++ b/test/Manifest.toml
@@ -0,0 +1,866 @@
+# This file is machine-generated - editing it directly is not advised
+
+julia_version = "1.9.2"
+manifest_format = "2.0"
+project_hash = "a6495d9f0ea044fd0a55c1c989f1adca1ad5c855"
+
+[[deps.ADTypes]]
+git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a"
+uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
+version = "0.2.1"
+
+[[deps.AbstractFFTs]]
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef"
+uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
+version = "1.5.0"
+weakdeps = ["ChainRulesCore", "Test"]
+
+    [deps.AbstractFFTs.extensions]
+    AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
+    AbstractFFTsTestExt = "Test"
+
+[[deps.Adapt]]
+deps = ["LinearAlgebra", "Requires"]
+git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24"
+uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
+version = "3.6.2"
+weakdeps = ["StaticArrays"]
+
+    [deps.Adapt.extensions]
+    AdaptStaticArraysExt = "StaticArrays"
+
+[[deps.ArgCheck]]
+git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4"
+uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197"
+version = "2.3.0"
+
+[[deps.ArgTools]]
+uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
+version = "1.1.1"
+
+[[deps.Artifacts]]
+uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
+
+[[deps.Atomix]]
+deps = ["UnsafeAtomics"]
+git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be"
+uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
+version = "0.1.0"
+
+[[deps.Base64]]
+uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
+
+[[deps.Bijectors]]
+deps = ["ArgCheck", "ChainRules", "ChainRulesCore", "ChangesOfVariables", "Compat", "Distributions", "Functors", "InverseFunctions", "IrrationalConstants", "LinearAlgebra", "LogExpFunctions", "MappedArrays", "Random", "Reexport", "Requires", "Roots", "SparseArrays", "Statistics"]
+git-tree-sha1 = "af192c7c235264bdc6f67321fd1c57be0dd7ffb5"
+uuid = "76274a88-744f-5084-9051-94815aaf08c4"
+version = "0.13.6"
+
+    [deps.Bijectors.extensions]
+    BijectorsDistributionsADExt = "DistributionsAD"
+    BijectorsForwardDiffExt = "ForwardDiff"
+    BijectorsLazyArraysExt = "LazyArrays"
+    BijectorsReverseDiffExt = "ReverseDiff"
+    BijectorsTrackerExt = "Tracker"
+    BijectorsZygoteExt = "Zygote"
+
+    [deps.Bijectors.weakdeps]
+    DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
+    ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
+    LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
+    ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
+    Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
+    Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
+
+[[deps.CEnum]]
+git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90"
+uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
+version = "0.4.2"
+
+[[deps.Calculus]]
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad"
+uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
+version = "0.5.1"
+
+[[deps.ChainRules]]
+deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"]
+git-tree-sha1 = "f98ae934cd677d51d2941088849f0bf2f59e6f6e"
+uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
+version = "1.53.0"
+
+[[deps.ChainRulesCore]]
+deps = ["Compat", "LinearAlgebra", "SparseArrays"]
+git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644"
+uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+version = "1.16.0"
+
+[[deps.ChangesOfVariables]]
+deps = ["LinearAlgebra", "Test"]
+git-tree-sha1 = "2fba81a302a7be671aefe194f0525ef231104e7f"
+uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
+version = "0.1.8"
+weakdeps = ["InverseFunctions"]
+
+    [deps.ChangesOfVariables.extensions]
+    ChangesOfVariablesInverseFunctionsExt = "InverseFunctions"
+
+[[deps.CommonSolve]]
+git-tree-sha1 = "0eee5eb66b1cf62cd6ad1b460238e60e4b09400c"
+uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
+version = "0.2.4"
+
+[[deps.CommonSubexpressions]]
+deps = ["MacroTools", "Test"]
+git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
+uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
+version = "0.3.0"
+
+[[deps.Comonicon]]
+deps = ["Configurations", "ExproniconLite", "Libdl", "Logging", "Markdown", "OrderedCollections", "PackageCompiler", "Pkg", "Scratch", "TOML", "UUIDs"]
+git-tree-sha1 = "9c360961f23e2fae4c6549bbba58a6f39c9e145c"
+uuid = "863f3e99-da2a-4334-8734-de3dacbe5542"
+version = "1.0.5"
+
+[[deps.Compat]]
+deps = ["UUIDs"]
+git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7"
+uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
+version = "4.9.0"
+weakdeps = ["Dates", "LinearAlgebra"]
+
+    [deps.Compat.extensions]
+    CompatLinearAlgebraExt = "LinearAlgebra"
+
+[[deps.CompilerSupportLibraries_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
+version = "1.0.5+0"
+
+[[deps.Configurations]]
+deps = ["ExproniconLite", "OrderedCollections", "TOML"]
+git-tree-sha1 = "434f446dbf89d08350e83bf57c0fc86f5d3ffd4e"
+uuid = "5218b696-f38b-4ac9-8b61-a12ec717816d"
+version = "0.17.5"
+
+[[deps.ConstructionBase]]
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "fe2838a593b5f776e1597e086dcd47560d94e816"
+uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
+version = "1.5.3"
+
+    [deps.ConstructionBase.extensions]
+    ConstructionBaseIntervalSetsExt = "IntervalSets"
+    ConstructionBaseStaticArraysExt = "StaticArrays"
+
+    [deps.ConstructionBase.weakdeps]
+    IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
+    StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
+
+[[deps.DataAPI]]
+git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c"
+uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
+version = "1.15.0"
+
+[[deps.DataStructures]]
+deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
+git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d"
+uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
+version = "0.18.15"
+
+[[deps.DataValueInterfaces]]
+git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
+uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
+version = "1.0.0"
+
+[[deps.Dates]]
+deps = ["Printf"]
+uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
+
+[[deps.DiffResults]]
+deps = ["StaticArraysCore"]
+git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621"
+uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
+version = "1.1.0"
+
+[[deps.DiffRules]]
+deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"]
+git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272"
+uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
+version = "1.15.1"
+
+[[deps.Distributed]]
+deps = ["Random", "Serialization", "Sockets"]
+uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
+
+[[deps.Distributions]]
+deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"]
+git-tree-sha1 = "938fe2981db009f531b6332e31c58e9584a2f9bd"
+uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
+version = "0.25.100"
+
+    [deps.Distributions.extensions]
+    DistributionsChainRulesCoreExt = "ChainRulesCore"
+    DistributionsDensityInterfaceExt = "DensityInterface"
+
+    [deps.Distributions.weakdeps]
+    ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+    DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
+
+[[deps.DistributionsAD]]
+deps = ["Adapt", "ChainRules", "ChainRulesCore", "Compat", "Distributions", "FillArrays", "LinearAlgebra", "PDMats", "Random", "Requires", "SpecialFunctions", "StaticArrays", "StatsFuns", "ZygoteRules"]
+git-tree-sha1 = "975de103eb2175cf54bf14b15ded2c68625eabdf"
+uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
+version = "0.6.52"
+
+    [deps.DistributionsAD.extensions]
+    DistributionsADForwardDiffExt = "ForwardDiff"
+    DistributionsADLazyArraysExt = "LazyArrays"
+    DistributionsADReverseDiffExt = "ReverseDiff"
+    DistributionsADTrackerExt = "Tracker"
+
+    [deps.DistributionsAD.weakdeps]
+    ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
+    LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
+    ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
+    Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
+
+[[deps.DocStringExtensions]]
+deps = ["LibGit2"]
+git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
+uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
+version = "0.9.3"
+
+[[deps.Downloads]]
+deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
+uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
+version = "1.6.0"
+
+[[deps.DualNumbers]]
+deps = ["Calculus", "NaNMath", "SpecialFunctions"]
+git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566"
+uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
+version = "0.6.8"
+
+[[deps.Enzyme]]
+deps = ["CEnum", "EnzymeCore", "Enzyme_jll", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "ObjectFile", "Preferences", "Printf", "Random"]
+git-tree-sha1 = "1f85bc8a9da6118abb95d134efc68cf4a6957341"
+uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+version = "0.11.7"
+
+[[deps.EnzymeCore]]
+deps = ["Adapt"]
+git-tree-sha1 = "643995502bdfff08bf080212c92430510be01ad5"
+uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
+version = "0.5.2"
+
+[[deps.Enzyme_jll]]
+deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
+git-tree-sha1 = "ffa4926cc857bcc5c256825bd7273a6ac989eb34"
+uuid = "7cc45869-7501-5eee-bdea-0790c847d4ef"
+version = "0.0.80+0"
+
+[[deps.ExprTools]]
+git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec"
+uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
+version = "0.1.10"
+
+[[deps.ExproniconLite]]
+deps = ["Pkg", "TOML"]
+git-tree-sha1 = "d80b5d5990071086edf5de9018c6c69c83937004"
+uuid = "55351af7-c7e9-48d6-89ff-24e801d99491"
+version = "0.10.3"
+
+[[deps.FileWatching]]
+uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
+
+[[deps.FillArrays]]
+deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
+git-tree-sha1 = "048dd3d82558759476cff9cff999219216932a08"
+uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
+version = "1.6.0"
+
+[[deps.ForwardDiff]]
+deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"]
+git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad"
+uuid = "f6369f11-7733-5829-9624-2563aa707210"
+version = "0.10.36"
+weakdeps = ["StaticArrays"]
+
+    [deps.ForwardDiff.extensions]
+    ForwardDiffStaticArraysExt = "StaticArrays"
+
+[[deps.FunctionWrappers]]
+git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e"
+uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
+version = "1.1.3"
+
+[[deps.Functors]]
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "9a68d75d466ccc1218d0552a8e1631151c569545"
+uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
+version = "0.4.5"
+
+[[deps.Future]]
+deps = ["Random"]
+uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
+
+[[deps.GPUArrays]]
+deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"]
+git-tree-sha1 = "2e57b4a4f9cc15e85a24d603256fe08e527f48d1"
+uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
+version = "8.8.1"
+
+[[deps.GPUArraysCore]]
+deps = ["Adapt"]
+git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0"
+uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
+version = "0.1.5"
+
+[[deps.GPUCompiler]]
+deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
+git-tree-sha1 = "72b2e3c2ba583d1a7aa35129e56cf92e07c083e3"
+uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
+version = "0.21.4"
+
+[[deps.HypergeometricFunctions]]
+deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
+git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685"
+uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
+version = "0.3.23"
+
+[[deps.IRTools]]
+deps = ["InteractiveUtils", "MacroTools", "Test"]
+git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5"
+uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
+version = "0.4.10"
+
+[[deps.InlineTest]]
+deps = ["Test"]
+git-tree-sha1 = "daf0743879904f0ad645ca6594e1479685f158a2"
+uuid = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6"
+version = "0.2.0"
+
+[[deps.InteractiveUtils]]
+deps = ["Markdown"]
+uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
+
+[[deps.InverseFunctions]]
+deps = ["Test"]
+git-tree-sha1 = "68772f49f54b479fa88ace904f6127f0a3bb2e46"
+uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
+version = "0.1.12"
+
+[[deps.IrrationalConstants]]
+git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
+uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
+version = "0.2.2"
+
+[[deps.IteratorInterfaceExtensions]]
+git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
+uuid = "82899510-4779-5014-852e-03e436cf321d"
+version = "1.0.0"
+
+[[deps.JLLWrappers]]
+deps = ["Artifacts", "Preferences"]
+git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
+uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
+version = "1.5.0"
+
+[[deps.KernelAbstractions]]
+deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
+git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118"
+uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
+version = "0.9.8"
+weakdeps = ["EnzymeCore"]
+
+    [deps.KernelAbstractions.extensions]
+    EnzymeExt = "EnzymeCore"
+
+[[deps.LLVM]]
+deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
+git-tree-sha1 = "8695a49bfe05a2dc0feeefd06b4ca6361a018729"
+uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
+version = "6.1.0"
+
+[[deps.LLVMExtra_jll]]
+deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
+git-tree-sha1 = "c35203c1e1002747da220ffc3c0762ce7754b08c"
+uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
+version = "0.0.23+0"
+
+[[deps.LazyArtifacts]]
+deps = ["Artifacts", "Pkg"]
+uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
+
+[[deps.LibCURL]]
+deps = ["LibCURL_jll", "MozillaCACerts_jll"]
+uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
+version = "0.6.3"
+
+[[deps.LibCURL_jll]]
+deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
+uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
+version = "7.84.0+0"
+
+[[deps.LibGit2]]
+deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
+uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
+
+[[deps.LibSSH2_jll]]
+deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
+uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
+version = "1.10.2+0"
+
+[[deps.Libdl]]
+uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
+
+[[deps.LinearAlgebra]]
+deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
+uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+
+[[deps.LogDensityProblems]]
+deps = ["ArgCheck", "DocStringExtensions", "Random"]
+git-tree-sha1 = "f9a11237204bc137617194d79d813069838fcf61"
+uuid = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
+version = "2.1.1"
+
+[[deps.LogExpFunctions]]
+deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
+git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa"
+uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
+version = "0.3.26"
+weakdeps = ["ChainRulesCore", "ChangesOfVariables", "InverseFunctions"]
+
+    [deps.LogExpFunctions.extensions]
+    LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
+    LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables"
+    LogExpFunctionsInverseFunctionsExt = "InverseFunctions"
+
+[[deps.Logging]]
+uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
+
+[[deps.MacroTools]]
+deps = ["Markdown", "Random"]
+git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48"
+uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
+version = "0.5.11"
+
+[[deps.MappedArrays]]
+git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e"
+uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
+version = "0.4.2"
+
+[[deps.Markdown]]
+deps = ["Base64"]
+uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
+
+[[deps.MbedTLS_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
+version = "2.28.2+0"
+
+[[deps.Missings]]
+deps = ["DataAPI"]
+git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272"
+uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
+version = "1.1.0"
+
+[[deps.MozillaCACerts_jll]]
+uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
+version = "2022.10.11"
+
+[[deps.NNlib]]
+deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
+git-tree-sha1 = "3d42748c725c3f088bcda47fa2aca89e74d59d22"
+uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
+version = "0.9.4"
+
+    [deps.NNlib.extensions]
+    NNlibAMDGPUExt = "AMDGPU"
+    NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
+    NNlibCUDAExt = "CUDA"
+
+    [deps.NNlib.weakdeps]
+    AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
+    CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
+    cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
+
+[[deps.NaNMath]]
+deps = ["OpenLibm_jll"]
+git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4"
+uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
+version = "1.0.2"
+
+[[deps.NetworkOptions]]
+uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
+version = "1.2.0"
+
+[[deps.ObjectFile]]
+deps = ["Reexport", "StructIO"]
+git-tree-sha1 = "69607899b46e1f8ead70396bc51a4c361478d8f6"
+uuid = "d8793406-e978-5875-9003-1fc021f44a92"
+version = "0.4.0"
+
+[[deps.OpenBLAS_jll]]
+deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
+uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
+version = "0.3.21+4"
+
+[[deps.OpenLibm_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
+version = "0.8.1+0"
+
+[[deps.OpenSpecFun_jll]]
+deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
+git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
+uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
+version = "0.5.5+0"
+
+[[deps.Optimisers]]
+deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
+git-tree-sha1 = "c1fc26bab5df929a5172f296f25d7d08688fd25b"
+uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
+version = "0.2.20"
+
+[[deps.OrderedCollections]]
+git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3"
+uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
+version = "1.6.2"
+
+[[deps.PDMats]]
+deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
+git-tree-sha1 = "67eae2738d63117a196f497d7db789821bce61d1"
+uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
+version = "0.11.17"
+
+[[deps.PackageCompiler]]
+deps = ["Artifacts", "LazyArtifacts", "Libdl", "Pkg", "Printf", "RelocatableFolders", "TOML", "UUIDs"]
+git-tree-sha1 = "1a6a868eb755e8ea9ecd000aa6ad175def0cc85b"
+uuid = "9b87118b-4619-50d2-8e1e-99f35a4d4d9d"
+version = "2.1.7"
+
+[[deps.Pkg]]
+deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
+uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
+version = "1.9.2"
+
+[[deps.PrecompileTools]]
+deps = ["Preferences"]
+git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f"
+uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
+version = "1.2.0"
+
+[[deps.Preferences]]
+deps = ["TOML"]
+git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1"
+uuid = "21216c6a-2e73-6563-6e65-726566657250"
+version = "1.4.0"
+
+[[deps.Printf]]
+deps = ["Unicode"]
+uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
+
+[[deps.QuadGK]]
+deps = ["DataStructures", "LinearAlgebra"]
+git-tree-sha1 = "6ec7ac8412e83d57e313393220879ede1740f9ee"
+uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
+version = "2.8.2"
+
+[[deps.REPL]]
+deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
+uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
+
+[[deps.Random]]
+deps = ["SHA", "Serialization"]
+uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+
+[[deps.Random123]]
+deps = ["Random", "RandomNumbers"]
+git-tree-sha1 = "552f30e847641591ba3f39fd1bed559b9deb0ef3"
+uuid = "74087812-796a-5b5d-8853-05524746bad3"
+version = "1.6.1"
+
+[[deps.RandomNumbers]]
+deps = ["Random", "Requires"]
+git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111"
+uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
+version = "1.5.3"
+
+[[deps.ReTest]]
+deps = ["Distributed", "InlineTest", "Printf", "Random", "Sockets", "Test"]
+git-tree-sha1 = "dd8f6587c0abac44bcec2e42f0aeddb73550c0ec"
+uuid = "e0db7c4e-2690-44b9-bad6-7687da720f89"
+version = "0.3.2"
+
+[[deps.RealDot]]
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9"
+uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
+version = "0.1.0"
+
+[[deps.Reexport]]
+git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
+uuid = "189a3867-3050-52da-a836-e630ba90ab69"
+version = "1.2.2"
+
+[[deps.RelocatableFolders]]
+deps = ["SHA", "Scratch"]
+git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691"
+uuid = "05181044-ff0b-4ac5-8273-598c1e38db00"
+version = "1.0.0"
+
+[[deps.Requires]]
+deps = ["UUIDs"]
+git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
+uuid = "ae029012-a4dd-5104-9daa-d747884805df"
+version = "1.3.0"
+
+[[deps.ReverseDiff]]
+deps = ["ChainRulesCore", "DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"]
+git-tree-sha1 = "d1235bdd57a93bd7504225b792b867e9a7df38d5"
+uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
+version = "1.15.1"
+
+[[deps.Rmath]]
+deps = ["Random", "Rmath_jll"]
+git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b"
+uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
+version = "0.7.1"
+
+[[deps.Rmath_jll]]
+deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
+git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da"
+uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
+version = "0.4.0+0"
+
+[[deps.Roots]]
+deps = ["ChainRulesCore", "CommonSolve", "Printf", "Setfield"]
+git-tree-sha1 = "ff42754a57bb0d6dcfe302fd0d4272853190421f"
+uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
+version = "2.0.19"
+
+    [deps.Roots.extensions]
+    RootsForwardDiffExt = "ForwardDiff"
+    RootsIntervalRootFindingExt = "IntervalRootFinding"
+    RootsSymPyExt = "SymPy"
+
+    [deps.Roots.weakdeps]
+    ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
+    IntervalRootFinding = "d2bf35a9-74e0-55ec-b149-d360ff49b807"
+    SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
+
+[[deps.SHA]]
+uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
+version = "0.7.0"
+
+[[deps.Scratch]]
+deps = ["Dates"]
+git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a"
+uuid = "6c6a2e73-6563-6170-7368-637461726353"
+version = "1.2.0"
+
+[[deps.Serialization]]
+uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
+
+[[deps.Setfield]]
+deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"]
+git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac"
+uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46"
+version = "1.1.1"
+
+[[deps.SimpleUnPack]]
+git-tree-sha1 = "58e6353e72cde29b90a69527e56df1b5c3d8c437"
+uuid = "ce78b400-467f-4804-87d8-8f486da07d0a"
+version = "1.1.0"
+
+[[deps.Sockets]]
+uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
+
+[[deps.SortingAlgorithms]]
+deps = ["DataStructures"]
+git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee"
+uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
+version = "1.1.1"
+
+[[deps.SparseArrays]]
+deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
+uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
+
+[[deps.SpecialFunctions]]
+deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
+git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d"
+uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
+version = "2.3.1"
+weakdeps = ["ChainRulesCore"]
+
+    [deps.SpecialFunctions.extensions]
+    SpecialFunctionsChainRulesCoreExt = "ChainRulesCore"
+
+[[deps.StaticArrays]]
+deps = ["LinearAlgebra", "Random", "StaticArraysCore"]
+git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881"
+uuid = "90137ffa-7385-5640-81b9-e52037218182"
+version = "1.6.2"
+weakdeps = ["Statistics"]
+
+    [deps.StaticArrays.extensions]
+    StaticArraysStatisticsExt = "Statistics"
+
+[[deps.StaticArraysCore]]
+git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d"
+uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
+version = "1.4.2"
+
+[[deps.Statistics]]
+deps = ["LinearAlgebra", "SparseArrays"]
+uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
+version = "1.9.0"
+
+[[deps.StatsAPI]]
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7"
+uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
+version = "1.6.0"
+
+[[deps.StatsBase]]
+deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
+git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4"
+uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
+version = "0.34.0"
+
+[[deps.StatsFuns]]
+deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
+git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a"
+uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
+version = "1.3.0"
+weakdeps = ["ChainRulesCore", "InverseFunctions"]
+
+    [deps.StatsFuns.extensions]
+    StatsFunsChainRulesCoreExt = "ChainRulesCore"
+    StatsFunsInverseFunctionsExt = "InverseFunctions"
+
+[[deps.StructArrays]]
+deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"]
+git-tree-sha1 = "521a0e828e98bb69042fec1809c1b5a680eb7389"
+uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
+version = "0.6.15"
+
+[[deps.StructIO]]
+deps = ["Test"]
+git-tree-sha1 = "010dc73c7146869c042b49adcdb6bf528c12e859"
+uuid = "53d494c1-5632-5724-8f4c-31dff12d585f"
+version = "0.3.0"
+
+[[deps.SuiteSparse]]
+deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
+uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
+
+[[deps.SuiteSparse_jll]]
+deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"]
+uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
+version = "5.10.1+6"
+
+[[deps.TOML]]
+deps = ["Dates"]
+uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
+version = "1.0.3"
+
+[[deps.TableTraits]]
+deps = ["IteratorInterfaceExtensions"]
+git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39"
+uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
+version = "1.0.1"
+
+[[deps.Tables]]
+deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"]
+git-tree-sha1 = "1544b926975372da01227b382066ab70e574a3ec"
+uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
+version = "1.10.1"
+
+[[deps.Tar]]
+deps = ["ArgTools", "SHA"]
+uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
+version = "1.10.0"
+
+[[deps.Test]]
+deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
+uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+
+[[deps.TimerOutputs]]
+deps = ["ExprTools", "Printf"]
+git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7"
+uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
+version = "0.5.23"
+
+[[deps.Tracker]]
+deps = ["Adapt", "DiffRules", "ForwardDiff", "Functors", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Optimisers", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"]
+git-tree-sha1 = "92364c27aa35c0ee36e6e010b704adaade6c409c"
+uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
+version = "0.2.26"
+weakdeps = ["PDMats"]
+
+    [deps.Tracker.extensions]
+    TrackerPDMatsExt = "PDMats"
+
+[[deps.UUIDs]]
+deps = ["Random", "SHA"]
+uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
+
+[[deps.Unicode]]
+uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
+
+[[deps.UnsafeAtomics]]
+git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
+uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
+version = "0.2.1"
+
+[[deps.UnsafeAtomicsLLVM]]
+deps = ["LLVM", "UnsafeAtomics"]
+git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e"
+uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
+version = "0.1.3"
+
+[[deps.Zlib_jll]]
+deps = ["Libdl"]
+uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
+version = "1.2.13+0"
+
+[[deps.Zygote]]
+deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
+git-tree-sha1 = "e2fe78907130b521619bc88408c859a472c4172b"
+uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
+version = "0.6.63"
+
+    [deps.Zygote.extensions]
+    ZygoteColorsExt = "Colors"
+    ZygoteDistancesExt = "Distances"
+    ZygoteTrackerExt = "Tracker"
+
+    [deps.Zygote.weakdeps]
+    Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
+    Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
+    Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
+
+[[deps.ZygoteRules]]
+deps = ["ChainRulesCore", "MacroTools"]
+git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d"
+uuid = "700de1a5-db45-46bc-99cf-38207098b444"
+version = "0.2.3"
+
+[[deps.libblastrampoline_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
+version = "5.8.0+0"
+
+[[deps.nghttp2_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
+version = "1.48.0+0"
+
+[[deps.p7zip_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
+version = "17.4.0+0"
diff --git a/test/Project.toml b/test/Project.toml
index 663d671dc..5ce8fcd88 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -3,9 +3,11 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
 Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
+DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
 Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
 FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
+Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
 Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index d5250ce83..f2ce94a5d 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -9,7 +9,6 @@ using ReTest
             realtype ∈ [Float64], # Currently only tested against Float64
             (modelname, modelconstr) ∈ Dict(
                 :NormalLogNormalMeanField => normallognormal_meanfield,
-                :NormalLogNormalFullRank  => normallognormal_fullrank,
             ),
             (objname, objective) ∈ Dict(
                 :ADVIClosedFormEntropy  => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹),
@@ -17,7 +16,7 @@ using ReTest
             ),
             (adbackname, adbackend) ∈ Dict(
                 :ForwarDiff  => AutoForwardDiff(),
-                :ReverseDiff => AutoReverseDiff(),
+                # :ReverseDiff => AutoReverseDiff(),
                 # :Zygote      => AutoZygote(), 
                 # :Enzyme      => AutoEnzyme(),
             )
@@ -32,19 +31,10 @@ using ReTest
 
             b    = Bijectors.bijector(model)
             b⁻¹  = inverse(b)
+            μ₀   = zeros(realtype, n_dims)
+	    L₀   = Diagonal(ones(realtype, n_dims))
 
-            μ₀ = zeros(realtype, n_dims)
-            L₀ = if is_meanfield
-                FillArrays.Eye(n_dims) |> Diagonal
-            else
-                FillArrays.Eye(n_dims) |> Matrix |> LowerTriangular
-            end
-
-            q₀ = if is_meanfield
-                VIMeanFieldGaussian(μ₀, L₀)
-            else
-                VIFullRankGaussian(μ₀, L₀)
-            end
+	    q₀ = TuringDiagMvNormal(μ₀, diag(L₀))
 
             obj = objective(model, b⁻¹, 10)
 
@@ -58,8 +48,8 @@ using ReTest
                     adbackend     = adbackend,
                 )
 
-                μ  = q.location
-                L  = q.scale
+		μ  = mean(q)
+		L  = sqrt(cov(q))
                 Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
 
                 @test Δλ ≤ Δλ₀/T^(1/4)
@@ -76,8 +66,8 @@ using ReTest
                     rng           = rng,
                     adbackend     = adbackend,
                 )
-                μ  = q.location
-                L  = q.scale
+		μ  = mean(q)
+		L  = sqrt(cov(q))
 
                 rng_repl = Philox4x(UInt64, seed, 8)
                 q, stats, _, _ = optimize(
@@ -87,8 +77,8 @@ using ReTest
                     rng           = rng_repl,
                     adbackend     = adbackend,
                 )
-                μ_repl = q.location
-                L_repl = q.scale
+		μ_repl = mean(q)
+		L_repl = sqrt(cov(q))
                 @test μ == μ_repl
                 @test L == L_repl
             end
diff --git a/test/distributions.jl b/test/distributions.jl
deleted file mode 100644
index 175cc96b8..000000000
--- a/test/distributions.jl
+++ /dev/null
@@ -1,96 +0,0 @@
-
-using ReTest
-using Distributions: _logpdf
-
-@testset "distributions" begin
-    @testset "$(string(covtype)) $(basedist) $(realtype)" for
-        basedist = [:gaussian],
-        covtype  = [:meanfield, :fullrank],
-        realtype = [Float32,     Float64]
-
-        seed         = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
-        rng          = Philox4x(UInt64, seed, 8)
-        n_dims       = 10
-        n_montecarlo = 1000_000
-
-        μ = randn(rng, realtype, n_dims)
-        L = if covtype == :fullrank
-	    tril(I + ones(realtype, n_dims, n_dims)/2) |> LowerTriangular
-        else
-            Diagonal(log.(exp.(randn(rng, realtype, n_dims)) .+ 1))
-        end
-        Σ = L*L'
-
-        q = if covtype == :fullrank  && basedist == :gaussian
-            VIFullRankGaussian(μ, L)
-        elseif covtype == :meanfield && basedist == :gaussian
-            VIMeanFieldGaussian(μ, L)
-        end
-        q_true = if basedist == :gaussian
-            MvNormal(μ, Σ)
-        end
-
-        @testset "logpdf" begin
-            seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
-            rng  = Philox4x(UInt64, seed, 8)
-
-            z = rand(rng, q)
-            @test eltype(z)             == realtype
-            @test logpdf(q, z)          ≈  logpdf(q_true, z)  rtol=realtype(1e-2)
-            @test _logpdf(q, z)         ≈  _logpdf(q_true, z) rtol=realtype(1e-2)
-            @test eltype(logpdf(q, z))  == realtype 
-            @test eltype(_logpdf(q, z)) == realtype 
-        end
-
-        @testset "entropy" begin
-            @test eltype(entropy(q)) == realtype
-            @test entropy(q)         ≈ entropy(q_true)
-        end
-
-        @testset "sampling" begin
-            @testset "rand" begin
-                seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
-                rng  = Philox4x(UInt64, seed, 8)
-
-                z_samples  = mapreduce(x -> rand(rng, q), hcat, 1:n_montecarlo)
-                @test eltype(z_samples) == realtype
-                @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ       rtol=realtype(1e-2)
-                @test dropdims(var(z_samples, dims=2),  dims=2) ≈ diag(Σ) rtol=realtype(1e-2)
-                @test cov(z_samples, dims=2)                    ≈ Σ       rtol=realtype(1e-2)
-            end
-
-            @testset "rand batch" begin
-                seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
-                rng  = Philox4x(UInt64, seed, 8)
-
-                z_samples  = rand(rng, q, n_montecarlo)
-                @test eltype(z_samples) == realtype
-                @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ       rtol=realtype(1e-2)
-                @test dropdims(var(z_samples, dims=2),  dims=2) ≈ diag(Σ) rtol=realtype(1e-2)
-                @test cov(z_samples, dims=2)                    ≈ Σ       rtol=realtype(1e-2)
-            end
-
-            @testset "rand!" begin
-                seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
-                rng  = Philox4x(UInt64, seed, 8)
-
-                z_samples = Array{realtype}(undef, n_dims, n_montecarlo)
-                rand!(rng, q, z_samples)
-                @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ       rtol=realtype(1e-2)
-                @test dropdims(var(z_samples, dims=2),  dims=2) ≈ diag(Σ) rtol=realtype(1e-2)
-                @test cov(z_samples, dims=2)                    ≈ Σ       rtol=realtype(1e-2)
-            end
-        end
-    end
-
-    @testset "Diagonal destructure" for
-        n_dims = 10
-        μ      = zeros(n_dims)
-        L      = ones(n_dims)
-        q      = VIMeanFieldGaussian(μ, L |> Diagonal)
-        λ, re  = Optimisers.destructure(q)
-
-        @test length(λ) == 2*n_dims
-        @test q         == re(λ)
-    end
-end
diff --git a/test/optimize.jl b/test/optimize.jl
index 2369432c9..56ca63c0f 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -12,9 +12,7 @@ using ReTest
 
     # Global Test Configurations
     b⁻¹ = Bijectors.bijector(model) |> inverse
-    μ₀  = zeros(Float64, n_dims)
-    L₀  = ones(Float64, n_dims) |> Diagonal
-    q₀  = VIMeanFieldGaussian(μ₀, L₀)
+    q₀  = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
     obj = ADVI(model, 10; invbij=b⁻¹)
 
     adbackend = AutoForwardDiff()
diff --git a/test/runtests.jl b/test/runtests.jl
index 127503be2..fd68ed794 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -12,6 +12,10 @@ using SimpleUnPack: @unpack
 using FillArrays
 using PDMats
 
+using Functors
+using DistributionsAD
+@functor TuringDiagMvNormal
+
 using Bijectors
 using LogDensityProblems
 using Optimisers
@@ -33,7 +37,6 @@ include("models/normallognormal.jl")
 
 # Tests
 include("ad.jl")
-include("distributions.jl")
 include("advi_locscale.jl")
 include("optimize.jl")
 

From d2ae29fffcbfacad59268b1c6835b43858e138db Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 24 Aug 2023 02:21:56 -0400
Subject: [PATCH 141/206] remove doc action for now

---
 .github/workflows/CI.yml | 27 ---------------------------
 1 file changed, 27 deletions(-)

diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 26f6876f5..7ba573a15 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -61,30 +61,3 @@ jobs:
         with:
           github-token: ${{ secrets.GITHUB_TOKEN }}
           path-to-lcov: lcov.info
-  docs:
-    name: Documentation
-    runs-on: ubuntu-latest
-    permissions:
-      contents: write
-      statuses: write
-    steps:
-      - uses: actions/checkout@v3
-      - uses: julia-actions/setup-julia@v1
-        with:
-          version: '1'
-      - name: Configure doc environment
-        run: |
-          julia --project=docs/ -e '
-            using Pkg
-            Pkg.develop(PackageSpec(path=pwd()))
-            Pkg.instantiate()'
-      - uses: julia-actions/julia-buildpkg@v1
-      - uses: julia-actions/julia-docdeploy@v1
-        env:
-          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-      - run: |
-          julia --project=docs -e '
-            using Documenter: DocMeta, doctest
-            using AdvancedVI
-            DocMeta.setdocmeta!(AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true)
-            doctest(AdvancedVI)'

From fb84e3d3aa0e383c94fe88e0a8b33c845f916cd7 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 24 Aug 2023 02:27:37 -0400
Subject: [PATCH 142/206] revert README for now

---
 README.md | 301 ++++++++++++++++++++++++++++++++++++++++--------------
 1 file changed, 222 insertions(+), 79 deletions(-)

diff --git a/README.md b/README.md
index 695e9ed98..f0bf6cc10 100644
--- a/README.md
+++ b/README.md
@@ -1,108 +1,251 @@
-
 # AdvancedVI.jl
-[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms.
-VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness.
-`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem.
-The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. `Turing`, only need to write a light wrapper for integration. 
-For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bijectors`](https://github.com/TuringLang/Bijectors.jl) by simply converting a `Turing.Model` into a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) and extracting a corresponding `Bijectors.bijector`.
+A library for variational Bayesian inference in Julia.
+
+At the time of writing (05/02/2020), implementations of the variational inference (VI) interface and some algorithms are implemented in [Turing.jl](https://github.com/TuringLang/Turing.jl). The idea is to soon separate the VI functionality in Turing.jl out and into this package.
+
+The purpose of this package will then be to provide a common interface together with implementations of standard algorithms and utilities with the goal of ease of use and the ability for other packages, e.g. Turing.jl, to write a light wrapper around AdvancedVI.jl for integration. 
+
+As an example, in Turing.jl we support automatic differentiation variational inference (ADVI) but really the only piece of code tied into the Turing.jl is the conversion of a `Turing.Model` to a `logjoint(z)` function which computes `z ↦ log p(x, z)`, with `x` denoting the observations embedded in the `Turing.Model`. As long as this `logjoint(z)` method is compatible with some AD framework, e.g. `ForwardDiff.jl` or `Zygote.jl`, this is all we need from Turing.jl to be able to perform ADVI!
+
+## [WIP] Interface
+- `vi`: the main interface to the functionality in this package
+  - `vi(model, alg)`: only used when `alg` has a default variational posterior which it will provide.
+  - `vi(model, alg, q::VariationalPosterior, θ)`: `q` represents the family of variational distributions and `θ` is the initial parameters "indexing" the starting distribution. This assumes that there exists an implementation `Variational.update(q, θ)` which returns the variational posterior corresponding to parameters `θ`.
+  - `vi(model, alg, getq::Function, θ)`: here `getq(θ)` is a function returning a `VariationalPosterior` corresponding to `θ`.
+- `optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())`
+- `grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)`
+  - Different combinations of variational objectives (`vo`), VI methods (`alg`), and variational posteriors (`q`) might use different gradient estimators. `grad!` allows us to specify these different behaviors.
 
 ## Examples
+### Variational Inference
+A very simple generative model is the following
 
-`AdvancedVI` expects a `LogDensityProblem`.
-For example, for the normal-log-normal model:
+    μ ~ 𝒩(0, 1)
+    xᵢ ∼ 𝒩(μ, 1) , ∀i = 1, …, n
 
-$$
-\begin{aligned}
-x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\
-y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right),
-\end{aligned}
-$$
+where μ and xᵢ are some ℝᵈ vectors and 𝒩 denotes a d-dimensional multivariate Normal distribution.
 
-a `LogDensityProblem` can be implemented as 
+Given a set of `n` observations `[x₁, …, xₙ]` we're interested in finding the distribution `p(μ∣x₁, …, xₙ)` over the mean `μ`. We can obtain (an approximation to) this distribution that using AdvancedVI.jl!
+
+First we generate some observations and set up the problem:
 ```julia
-using LogDensityProblems
+julia> using Distributions
 
-struct NormalLogNormal{MX,SX,MY,SY}
-    μ_x::MX
-    σ_x::SX
-    μ_y::MY
-    Σ_y::SY
-end
+julia> d = 2; n = 100;
 
-function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
-    logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
-end
+julia> observations = randn((d, n)); # 100 observations from 2D 𝒩(0, 1)
 
-function LogDensityProblems.dimension(model::NormalLogNormal)
-    length(model.μ_y) + 1
-end
+julia> # Define generative model
+       #    μ ~ 𝒩(0, 1)
+       #    xᵢ ∼ 𝒩(μ, 1) , ∀i = 1, …, n
+       prior(μ) = logpdf(MvNormal(ones(d)), μ)
+prior (generic function with 1 method)
 
-function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
-    LogDensityProblems.LogDensityOrder{0}()
-end
-```
+julia> likelihood(x, μ) = sum(logpdf(MvNormal(μ, ones(d)), x))
+likelihood (generic function with 1 method)
+
+julia> logπ(μ) = likelihood(observations, μ) + prior(μ)
+logπ (generic function with 1 method)
 
-Since the support of `x` is constrained to be $$\mathbb{R}_+$$, and inference is best done in the unconstrained space $$\mathbb{R}_+$$, we need to use a *bijector* to match support.
-This corresponds to the automatic differentiation VI (ADVI; Kucukelbir *et al.*, 2015).
+julia> logπ(randn(2))  # <= just checking that it works
+-311.74132761437653
+```
+Now there are mainly two different ways of specifying the approximate posterior (and its family). The first is by providing a mapping from distribution parameters to the distribution `θ ↦ q(⋅∣θ)`:
 ```julia
-using Bijectors
+julia> using DistributionsAD, AdvancedVI
 
-function Bijectors.bijector(model::NormalLogNormal)
-    (; μ_x, σ_x, μ_y, Σ_y) = model
-    Bijectors.Stacked(
-        Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
-        [1:1, 2:1+length(μ_y)])
-end
+julia> # Using a function z ↦ q(⋅∣z)
+       getq(θ) = TuringDiagMvNormal(θ[1:d], exp.(θ[d + 1:4]))
+getq (generic function with 1 method)
 ```
+Then we make the choice of algorithm, a subtype of `VariationalInference`, 
+```julia
+julia> # Perform VI
+       advi = ADVI(10, 10_000)
+ADVI{AdvancedVI.ForwardDiffAD{40}}(10, 10000)
+```
+And finally we can perform VI! The usual inferface is to call `vi` which behind the scenes takes care of the optimization and returns the resulting variational posterior:
+```julia
+julia> q = vi(logπ, advi, getq, randn(4))
+[ADVI] Optimizing...100% Time: 0:00:01
+TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}}(m=[0.16282745378074515, 0.15789310089462574], σ=[0.09519377533754399, 0.09273176907111745])
+```
+Let's have a look at the resulting ELBO:
+```julia
+julia> AdvancedVI.elbo(advi, q, logπ, 1000)
+-287.7866366886285
+```
+Unfortunately, the *final* value of the ELBO is not always a very good diagnostic, though the ELBO is an important metric to keep an eye on during training since an *increase* in the ELBO means we're going in the right direction. Luckily, this is such a simple problem that we can indeed obtain a closed form solution! Because we're lazy (at least I am), we'll let [ConjugatePriors.jl](https://github.com/JuliaStats/ConjugatePriors.jl) do this for us:
+```julia
+julia> # True posterior
+       using ConjugatePriors
+
+julia> pri = MvNormal(zeros(2), ones(2));
 
-A simpler approach is to use `Turing`, where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated.
+julia> true_posterior = posterior((pri, pri.Σ), MvNormal, observations)
+DiagNormal(
+dim: 2
+μ: [0.1746546592601148, 0.16457110079543008]
+Σ: [0.009900990099009901 0.0; 0.0 0.009900990099009901]
+)
+```
+Comparing to our variational approximation, this looks pretty good! Worth noting that in this particular case the variational posterior seems to overestimate the variance.
 
-Let us instantiate a random normal-log-normal model.
+To conclude, let's make a somewhat pretty picture:
 ```julia
-using LinearAlgebra
-
-n_dims = 10
-μ_x    = randn()
-σ_x    = exp.(randn())
-μ_y    = randn(n_dims)
-σ_y    = exp.(randn(n_dims))
-model  = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2))
+julia> using Plots
+
+julia> p_samples = rand(true_posterior, 10_000); q_samples = rand(q, 10_000);
+
+julia> p1 = histogram(p_samples[1, :], label="p"); histogram!(q_samples[1, :], alpha=0.7, label="q")
+
+julia> title!(raw"$\mu_1$")
+
+julia> p2 = histogram(p_samples[2, :], label="p"); histogram!(q_samples[2, :], alpha=0.7, label="q")
+
+julia> title!(raw"$\mu_2$")
+
+julia> plot(p1, p2)
 ```
+![Histogram](hist.png?raw=true)
+
+### Simple example: using Advanced.jl to directly minimize the KL-divergence between two distributions `p(z)` and `q(z)`
+In VI we aim to approximate the true posterior `p(z ∣ x)` by some approximate variational posterior `q(z)` by maximizing the ELBO:
+
+    ELBO(q) = 𝔼_q[log p(x, z) - log q(z)]
+
+Observe that we can express the ELBO as the negative KL-divergence between `p(x, ⋅)` and `q(⋅)`:
+
+    ELBO(q) = - 𝔼_q[log (q(z) / p(x, z))]
+            = - KL(q(⋅) || p(x, ⋅))
+
+So if we apply VI to something that isn't an actual posterior, i.e. there's no data involved and we write `p(z ∣ x) = p(z)`, we're really just minimizing the KL-divergence between the distributions.
+
+Therefore, we can try out `AdvancedVI.jl` real quick by applying using the interface to minimize the KL-divergence between two distributions:
 
-ADVI can be used as follows:
 ```julia
-using Optimisers
-using ADTypes, ForwardDiff
-import AdvancedVI as AVI
-
-b     = Bijectors.bijector(model)
-b⁻¹   = inverse(b)
-
-# ADVI objective 
-objective = AVI.ADVI(model, 10; invbij=b⁻¹)
-
-# Mean-field Gaussian variational family
-d = LogDensityProblems.dimension(model)
-μ = randn(d)
-L = Diagonal(ones(d))
-q = AVI.VIMeanFieldGaussian(μ, L)
-
-# Run inference
-n_max_iter = 10^4
-q, stats, _ = AVI.optimize(
-    objective,
-    q,
-    n_max_iter;
-    adbackend = ADTypes.AutoForwardDiff(),
-    optimizer = Optimisers.Adam(1e-3)
+julia> using Distributions, DistributionsAD, AdvancedVI
+
+julia> # Target distribution
+       p = MvNormal(ones(2))
+ZeroMeanDiagNormal(
+dim: 2
+μ: [0.0, 0.0]
+Σ: [1.0 0.0; 0.0 1.0]
 )
 
-# Evaluate final ELBO with 10^3 Monte Carlo samples
-objective(q; n_samples=10^3)
+julia> logπ(z) = logpdf(p, z)
+logπ (generic function with 1 method)
+
+julia> # Make a choice of VI algorithm
+       advi = ADVI(10, 1000)
+ADVI{AdvancedVI.ForwardDiffAD{40}}(10, 1000)
+```
+Now there are two different ways of specifying the approximate posterior (and its family); the first is by providing a mapping from parameters to distribution `θ ↦ q(⋅∣θ)`:
+```julia
+julia> # Using a function z ↦ q(⋅∣z)
+       getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4]))
+getq (generic function with 1 method)
+
+julia> # Perform VI
+       q = vi(logπ, advi, getq, randn(4))
+┌ Info: [ADVI] Should only be seen once: optimizer created for θ
+└   objectid(θ) = 0x5ddb564423896704
+[ADVI] Optimizing...100% Time: 0:00:01
+TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}}(m=[-0.012691337868985757, -0.0004442434543332919], σ=[1.0334797673569802, 0.9957355128767893])
+```
+Or we can check the ELBO (which in this case since, as mentioned, doesn't involve data, is the negative KL-divergence):
+```julia
+julia> AdvancedVI.elbo(advi, q, logπ, 1000)  # empirical estimate
+0.08031049170093245
+```
+It's worth noting that the actual value of the ELBO doesn't really tell us too much about the quality of fit. In this particular case, because we're *directly* minimizing the KL-divergence, we can only say something useful if we reach 0, in which case we have obtained the true distribution.
+
+Let's just quickly check the mean-squared error between the `log p(z)` and `log q(z)` for a random set of samples from the target `p`:
+```julia
+julia> zs = rand(p, 100);
+
+julia> mean(abs2, logpdf(q, zs) - logpdf(p, zs))
+0.0014889109427524852
+```
+That doesn't look too bad!
+
+## Implementing your own training loop
+Sometimes it might be convenient to roll your own training loop rather than using `vi(...)`. Here's some psuedo-code for how one would do that when used together with Turing.jl:
+
+```julia
+using Turing, AdvancedVI, DiffResults
+using Turing: Variational
+
+using ProgressMeter
+
+# Assuming you have an instance of a Turing model (`model`)
+
+# 1. Create log-joint needed for ELBO evaluation
+logπ = Variational.make_logjoint(model)
+
+# 2. Define objective
+variational_objective = Variational.ELBO()
+
+# 3. Optimizer
+optimizer = Variational.DecayedADAGrad()
+
+# 4. VI-algorithm
+alg = ADVI(10, 1000)
+
+# 5. Variational distribution
+function getq(θ)
+    # ...
+end
+
+# 6. [OPTIONAL] Implement convergence criterion
+function hasconverged(args...)
+    # ...
+end
+
+# 7. [OPTIONAL] Implement a callback for tracking stats
+function callback(args...)
+    # ...
+end
+
+# 8. Train
+converged = false
+step = 1
+
+prog = ProgressMeter.Progress(num_steps, 1)
+
+diff_results = DiffResults.GradientResult(θ_init)
+
+while (step ≤ num_steps) && !converged
+    # 1. Compute gradient and objective value; results are stored in `diff_results`
+    AdvancedVI.grad!(variational_objective, alg, getq, model, diff_results)
+
+    # 2. Extract gradient from `diff_result`
+    ∇ = DiffResults.gradient(diff_result)
+
+    # 3. Apply optimizer, e.g. multiplying by step-size
+    Δ = apply!(optimizer, θ, ∇)
+
+    # 4. Update parameters
+    @. θ = θ - Δ
+
+    # 5. Do whatever analysis you want
+    callback(args...)
+
+    # 6. Update
+    converged = hasconverged(...) # or something user-defined
+    step += 1
+
+    ProgressMeter.next!(prog)
+end
 ```
 
 
 ## References
 
+- Jordan, Michael I., Zoubin Ghahramani, Tommi S. Jaakkola, and Lawrence K. Saul. "An introduction to variational methods for graphical models." Machine learning 37, no. 2 (1999): 183-233.
+- Blei, David M., Alp Kucukelbir, and Jon D. McAuliffe. "Variational inference: A review for statisticians." Journal of the American statistical Association 112, no. 518 (2017): 859-877.
 - Kucukelbir, Alp, Rajesh Ranganath, Andrew Gelman, and David Blei. "Automatic variational inference in Stan." In Advances in Neural Information Processing Systems, pp. 568-576. 2015.
+- Salimans, Tim, and David A. Knowles. "Fixed-form variational posterior approximation through stochastic linear regression." Bayesian Analysis 8, no. 4 (2013): 837-882.
+- Beal, Matthew James. Variational algorithms for approximate Bayesian inference. 2003.
+

From 0575b23ee90f9677a2d4db9482d9fcb4feeea846 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 25 Aug 2023 19:29:56 +0100
Subject: [PATCH 143/206] refactor remove redundant `rng` argument to `ADVI`,
 improve docs

---
 src/objectives/elbo/advi.jl    | 25 +++++++++++++++++++++----
 src/objectives/elbo/entropy.jl |  3 +++
 2 files changed, 24 insertions(+), 4 deletions(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index f9a61d81b..ef0ac50d1 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -3,6 +3,20 @@
     ADVI(prob, n_samples; kwargs...)
 
 Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective.
+This computes the evidence lower-bound (ELBO) through the ADVI formulation:
+```math
+\\begin{aligned}
+\\mathrm{ADVI}\\left(\\lambda\\right)
+&\\triangleq
+\\mathbb{E}_{\\eta \\sim q_{\\lambda}}\\left[
+  \\log \\pi\\left( \\phi^{-1}\\left( \\eta \\right) \\right)
+  +
+  \\log \\lvert J_{\\phi^{-1}}\\left(\\eta\\right) \\rvert
+\\right]
++ \\mathbb{H}\\left(q_{\\lambda}\\right),
+\\end{aligned}
+```
+where ``\\phi^{-1}`` is an "inverse bijector."
 
 # Arguments
 - `prob`: An object that implements the order `K == 0` `LogDensityProblems` interface.
@@ -11,13 +25,17 @@ Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017)
 # Keyword Arguments
 - `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy())
 - `cv`: A control variate.
-- `invbij`: A bijective mapping the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.)
+- `invbij`: An inverse bijective mapping that matches the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.)
 
 # Requirements
 - ``q_{\\lambda}`` implements `rand`.
 - `logdensity(prob)` must be differentiable by the selected AD backend.
 
 Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
+
+# References
+* Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research.
+* Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR.
 """
 struct ADVI{P, B, EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective
     prob     ::P
@@ -49,7 +67,6 @@ Base.show(io::IO, advi::ADVI) =
 init(rng::AbstractRNG, advi::ADVI, λ::AbstractVector, restructure) = nothing
 
 function (advi::ADVI)(
-    rng::AbstractRNG,
     q_η::ContinuousMultivariateDistribution,
     ηs ::AbstractMatrix
 )
@@ -81,7 +98,7 @@ function (advi::ADVI)(
     n_samples::Int         = advi.n_samples
 )
     ηs = rand(rng, q_η, n_samples)
-    advi(rng, q_η, ηs)
+    advi(q_η, ηs)
 end
 
 function estimate_gradient(
@@ -96,7 +113,7 @@ function estimate_gradient(
     f(λ′) = begin
         q_η = restructure(λ′)
         ηs  = rand(rng, q_η, advi.n_samples)
-        -advi(rng, q_η, ηs)
+        -advi(q_η, ηs)
     end
     value_and_gradient!(adbackend, f, λ, out)
 
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index 97ccda299..e6212c463 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -21,6 +21,9 @@ The "sticking the landing" entropy estimator.
 # Requirements
 - `q` implements `logpdf`.
 - `logpdf(q, η)` must be differentiable by the selected AD framework.
+
+# References
+* Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30.
 """
 struct StickingTheLandingEntropy <: AbstractEntropyEstimator end
 

From ecc52428b8ed79f18525fbc14cbf7d6632f9cac9 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 25 Aug 2023 19:30:23 +0100
Subject: [PATCH 144/206] fix wrong whitespace in tests

---
 test/advi_locscale.jl | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index f2ce94a5d..93ece4124 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -32,9 +32,9 @@ using ReTest
             b    = Bijectors.bijector(model)
             b⁻¹  = inverse(b)
             μ₀   = zeros(realtype, n_dims)
-	    L₀   = Diagonal(ones(realtype, n_dims))
+            L₀   = Diagonal(ones(realtype, n_dims))
 
-	    q₀ = TuringDiagMvNormal(μ₀, diag(L₀))
+            q₀ = TuringDiagMvNormal(μ₀, diag(L₀))
 
             obj = objective(model, b⁻¹, 10)
 
@@ -48,8 +48,8 @@ using ReTest
                     adbackend     = adbackend,
                 )
 
-		μ  = mean(q)
-		L  = sqrt(cov(q))
+                μ  = mean(q)
+                L  = sqrt(cov(q))
                 Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
 
                 @test Δλ ≤ Δλ₀/T^(1/4)
@@ -66,8 +66,8 @@ using ReTest
                     rng           = rng,
                     adbackend     = adbackend,
                 )
-		μ  = mean(q)
-		L  = sqrt(cov(q))
+                μ  = mean(q)
+                L  = sqrt(cov(q))
 
                 rng_repl = Philox4x(UInt64, seed, 8)
                 q, stats, _, _ = optimize(
@@ -77,8 +77,8 @@ using ReTest
                     rng           = rng_repl,
                     adbackend     = adbackend,
                 )
-		μ_repl = mean(q)
-		L_repl = sqrt(cov(q))
+                μ_repl = mean(q)
+                L_repl = sqrt(cov(q))
                 @test μ == μ_repl
                 @test L == L_repl
             end

From 1cff3df3af793b684934107521031a55df222419 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 25 Aug 2023 19:56:56 +0100
Subject: [PATCH 145/206] refactor `estimate_gradient` to `estimate_gradient!`,
 add docs

---
 src/AdvancedVI.jl           | 55 +++++++++++++++++++++++++++++++++----
 src/objectives/elbo/advi.jl | 10 ++-----
 src/optimize.jl             |  2 +-
 3 files changed, 54 insertions(+), 13 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index da8b05bb3..609266b47 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -40,18 +40,63 @@ using StatsBase: entropy
         out::DiffResults.MutableDiffResult
     )
 
-Compute the value and gradient of a function `f` at `θ` using the automatic
-differentiation backend `ad`.  The result is stored in `out`. 
-The function `f` must return a scalar value. The gradient is stored in `out` as a
-vector of the same length as `θ`.
+Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad`.
+The result is stored in `out`. 
+The function `f` must return a scalar value. The gradient is stored in `out` as a vector of the same length as `θ`.
 """
 function value_and_gradient! end
 
 # estimators
+"""
+    abstract type AbstractVariationalObjective end
+
+An VI algorithm supported by `AdvancedVI` should implement a subtype of  `AbstractVariationalObjective`.
+Furthermore, it should implement the functions `init`, `estimate_gradient`.
+"""
 abstract type AbstractVariationalObjective end
 
+"""
+    init(
+        rng::AbstractRNG,
+        obj::AbstractVariationalObjective,
+        λ::AbstractVector,
+        restructure
+    )
+
+Initialize a state of the variational objective `obj` given the initial variational parameters `λ`.
+This is relevant only if `obj` is stateful.
+
+!!! warning
+    This is an internal function. Thus, the signature is subject to change without
+    notice.
+"""
 function init              end
-function estimate_gradient end
+
+"""
+    estimate_gradient!(
+        rng         ::AbstractRNG,
+        adbackend   ::AbstractADType,
+        obj         ::AbstractVariationalObjective,
+        obj_state,
+        λ           ::AbstractVector,
+        restructure,
+        out         ::DiffResults.MutableDiffResult
+    )
+
+Estimate (possibly stochastic) gradients of the objective `obj` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`.
+The estimated objective value and gradient are then stored in `out`.
+If the objective is stateful, `obj_state` is its previous state.
+
+# Returns
+- `out`: The `MutableDiffResult` containing the objective value and gradient estimates.
+- `obj_state`: The updated state of the objective estimator.
+- `stat`: Statistics and logs generated during estimation. (Type: `<: NamedTuple`)
+
+!!! warning
+    This is an internal function. Thus, the signature is subject to change without
+    notice.
+"""
+function estimate_gradient! end
 
 # ADVI-specific interfaces
 abstract type AbstractEntropyEstimator end
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index ef0ac50d1..0e373f9a3 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -85,12 +85,8 @@ end
         n_samples::Int = advi.n_samples
     )
 
-Evaluate the ELBO using the ADVI formulation.
-
-# Arguments
-- `q_η`: Variational approximation before applying a bijector (unconstrained support).
-- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO.
-
+Estimate the ELBO of the variational approximation `q_η`  using the ADVI
+formulation using `n_samples` number of Monte Carlo samples.
 """
 function (advi::ADVI)(
     q_η      ::ContinuousMultivariateDistribution;
@@ -101,7 +97,7 @@ function (advi::ADVI)(
     advi(q_η, ηs)
 end
 
-function estimate_gradient(
+function estimate_gradient!(
     rng          ::AbstractRNG,
     adbackend    ::AbstractADType,
     advi         ::ADVI,
diff --git a/src/optimize.jl b/src/optimize.jl
index 54e7ace09..f21e757a4 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -77,7 +77,7 @@ function optimize(
     for t = 1:n_max_iter
         stat = (iteration=t,)
 
-        grad_buf, obj_state, stat′ = estimate_gradient(
+        grad_buf, obj_state, stat′ = estimate_gradient!(
             rng, adbackend, objective, obj_state, λ, restructure, grad_buf)
         stat = merge(stat, stat′)
 

From 54acd8af483af503d17997d91cb093ed420c0140 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Fri, 25 Aug 2023 20:05:07 +0100
Subject: [PATCH 146/206] refactor add default `init` impl, update docs

---
 src/AdvancedVI.jl           | 17 ++++++++++++-----
 src/objectives/elbo/advi.jl |  2 --
 2 files changed, 12 insertions(+), 7 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 609266b47..db433a678 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -42,7 +42,8 @@ using StatsBase: entropy
 
 Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad`.
 The result is stored in `out`. 
-The function `f` must return a scalar value. The gradient is stored in `out` as a vector of the same length as `θ`.
+The function `f` must return a scalar value.
+The gradient is stored in `out` as a vector of the same length as `θ`.
 """
 function value_and_gradient! end
 
@@ -51,7 +52,8 @@ function value_and_gradient! end
     abstract type AbstractVariationalObjective end
 
 An VI algorithm supported by `AdvancedVI` should implement a subtype of  `AbstractVariationalObjective`.
-Furthermore, it should implement the functions `init`, `estimate_gradient`.
+Furthermore, it should implement the functions `estimate_gradient`.
+If the estimator is stateful, it can implement `init` to initialize the state.
 """
 abstract type AbstractVariationalObjective end
 
@@ -64,13 +66,18 @@ abstract type AbstractVariationalObjective end
     )
 
 Initialize a state of the variational objective `obj` given the initial variational parameters `λ`.
-This is relevant only if `obj` is stateful.
+This function needs to be implemented only if `obj` is stateful.
 
 !!! warning
     This is an internal function. Thus, the signature is subject to change without
     notice.
 """
-function init              end
+init(
+    rng::AbstractRNG,
+    obj::AbstractVariationalObjective,
+    λ::AbstractVector,
+    restructure
+) = nothing
 
 """
     estimate_gradient!(
@@ -85,7 +92,7 @@ function init              end
 
 Estimate (possibly stochastic) gradients of the objective `obj` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`.
 The estimated objective value and gradient are then stored in `out`.
-If the objective is stateful, `obj_state` is its previous state.
+If the objective is stateful, `obj_state` is its previous state, otherwise, it is `nothing`.
 
 # Returns
 - `out`: The `MutableDiffResult` containing the objective value and gradient estimates.
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 0e373f9a3..5a3ce96ea 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -64,8 +64,6 @@ end
 Base.show(io::IO, advi::ADVI) =
     print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))")
 
-init(rng::AbstractRNG, advi::ADVI, λ::AbstractVector, restructure) = nothing
-
 function (advi::ADVI)(
     q_η::ContinuousMultivariateDistribution,
     ηs ::AbstractMatrix

From 61a2272cfb01d3052595f23fabb4cf85ba81b320 Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 26 Aug 2023 21:24:55 +0100
Subject: [PATCH 147/206] merge (manually) commit
 ff32ac642d6aa3a08d371ed895aa6b4026b06b92

---
 src/optimize.jl  | 64 +++++++++++++++++++++++++-----------------------
 test/optimize.jl | 31 ++++++++++++++++++++---
 2 files changed, 61 insertions(+), 34 deletions(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index f21e757a4..ea2fd5a12 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -8,7 +8,8 @@ end
         objective    ::AbstractVariationalObjective,
         restructure,
         λ₀           ::AbstractVector{<:Real},
-        n_max_iter   ::Int;
+        n_max_iter   ::Int,
+        objargs...;
         kwargs...
     )              
 
@@ -17,7 +18,8 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie
     optimize(
         objective ::AbstractVariationalObjective,
         q,
-        n_max_iter::Int;
+        n_max_iter::Int,
+        objargs...;
         kwargs...
     )              
 
@@ -29,36 +31,34 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie
 - `restruct`: Function that reconstructs the variational approximation from the flattened parameters.
 - `q`: Initial variational approximation. The variational parameters must be extractable through `Optimisers.destructure`.
 - `n_max_iter`: Maximum number of iterations.
+- `objargs...`: Arguments to be passed to `objective`.
+- `kwargs...`: Additional keywoard arguments. (See below.)
 
 # Keyword Arguments
 - `adbackend`: Automatic differentiation backend. (Type: `<: ADtypes.AbstractADType`.)
 - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.)
 - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.)
 - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.)
-- `callback!`: Callback function called after every iteration. The signature is `cb(; obj_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `obj_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient.
+- `callback!`: Callback function called after every iteration. The signature is `cb(; stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`, `g` is the stochastic estimate of the gradient. (Default: `nothing`.)
 - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.)
-
-When resuming from the state of a previous run, use the following keyword arguments:
-- `opt_state`: Initial state of the optimizer.
-- `obj_state`: Initial state of the objective.
+- `state`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) (Type: `<: NamedTuple`.)
 
 # Returns
 - `λ`: Variational parameters optimizing the variational objective.
-- `stats`: Statistics gathered during inference.
-- `opt_state`: Final state of the optimiser.
-- `obj_state`: Final state of the objective.
+- `logstats`: Statistics and logs gathered during optimization.
+- `states`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run.
 """
 function optimize(
     objective    ::AbstractVariationalObjective,
     restructure,
     λ₀           ::AbstractVector{<:Real},
-    n_max_iter   ::Int;
-    adbackend::AbstractADType, 
+    n_max_iter   ::Int,
+    objargs...;
+    adbackend    ::AbstractADType, 
     optimizer    ::Optimisers.AbstractRule = Optimisers.Adam(),
     rng          ::AbstractRNG             = default_rng(),
     show_progress::Bool                    = true,
-    opt_state                              = nothing,
-    obj_state                              = nothing,
+    state        ::NamedTuple              = NamedTuple(),
     callback!                              = nothing,
     prog                                   = ProgressMeter.Progress(
         n_max_iter;
@@ -66,37 +66,39 @@ function optimize(
         barlen    = 31,
         showspeed = true,
         enabled   = show_progress
-    )              
+    )
 )
-    λ         = copy(λ₀)
-    opt_state = isnothing(opt_state) ? Optimisers.setup(optimizer, λ)       : opt_state
-    obj_state = isnothing(obj_state) ? init(rng, objective, λ, restructure) : obj_state
-    grad_buf  = DiffResults.GradientResult(λ)
-    stats     = NamedTuple[]
+    λ        = copy(λ₀)
+    opt_st   = haskey(state, :opt) ? state.opt : Optimisers.setup(optimizer, λ)
+    obj_st   = haskey(state, :obj) ? state.obj : init(rng, objective, λ, restructure)
+    grad_buf = DiffResults.DiffResult(zero(eltype(λ)), similar(λ))
+    logstats = NamedTuple[]
 
     for t = 1:n_max_iter
         stat = (iteration=t,)
 
-        grad_buf, obj_state, stat′ = estimate_gradient!(
-            rng, adbackend, objective, obj_state, λ, restructure, grad_buf)
+        grad_buf, obj_st, stat′ = estimate_gradient(
+            rng, adbackend, objective, obj_st,
+            λ, restructure, grad_buf; objargs...
+        )
         stat = merge(stat, stat′)
 
-        g            = DiffResults.gradient(grad_buf)
-        opt_state, λ = Optimisers.update!(opt_state, λ, g)
-        stat′ = (iteration = t,)
-        stat = merge(stat, stat′)
+        g         = DiffResults.gradient(grad_buf)
+        opt_st, λ = Optimisers.update!(opt_st, λ, g)
 
         if !isnothing(callback!)
-            stat′ = callback!(; obj_state, stat, restructure, λ, g)
+            stat′ = callback!(; stat, restructure, λ, g)
             stat = !isnothing(stat′) ? merge(stat′, stat) : stat
         end
         
         @debug "Iteration $t" stat...
 
         pm_next!(prog, stat)
-        push!(stats, stat)
+        push!(logstats, stat)
     end
-    λ, map(identity, stats), opt_state, obj_state
+    state    = (opt=opt_st, obj=obj_st)
+    logstats = map(identity, logstats)
+    λ, logstats, state
 end
 
 function optimize(objective ::AbstractVariationalObjective,
@@ -104,8 +106,8 @@ function optimize(objective ::AbstractVariationalObjective,
                   n_max_iter::Int;
                   kwargs...)
     λ, restructure = Optimisers.destructure(q₀)
-    λ, stats, opt_state, obj_state = optimize(
+    λ, logstats, state = optimize(
         objective, restructure, λ, n_max_iter; kwargs...
     )
-    restructure(λ), stats, opt_state, obj_state
+    restructure(λ), logstats, state
 end
diff --git a/test/optimize.jl b/test/optimize.jl
index 56ca63c0f..78d07d001 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -19,7 +19,7 @@ using ReTest
     optimizer = Optimisers.Adam(1e-2)
 
     rng = Philox4x(UInt64, seed, 8)
-    q_ref, stats_ref, _, _ = optimize(
+    q_ref, stats_ref, _ = optimize(
         obj, q₀, T;
         optimizer,
         show_progress = false,
@@ -32,7 +32,7 @@ using ReTest
         λ₀, re  = Optimisers.destructure(q₀)
 
         rng = Philox4x(UInt64, seed, 8)
-        λ, stats, _, _ = optimize(
+        λ, stats, _ = optimize(
             obj, re, λ₀, T;
             optimizer,
             show_progress = false,
@@ -52,7 +52,7 @@ using ReTest
         end
 
         rng = Philox4x(UInt64, seed, 8)
-        _, stats, _, _ = optimize(
+        _, stats, _ = optimize(
             obj, q₀, T;
             show_progress = false,
             rng,
@@ -61,4 +61,29 @@ using ReTest
         )
         @test [stat.test_value for stat ∈ stats] == test_values
     end
+
+    @testset "warm start" begin
+        rng = Philox4x(UInt64, seed, 8)
+
+        T_first = div(T,2)
+        T_last  = T - T_first
+
+        q_first, _, state = optimize(
+            obj, q₀, T_first;
+            optimizer,
+            show_progress = false,
+            rng,
+            adbackend
+        )
+
+        q, stats, _ = optimize(
+            obj, q_first, T_last;
+            optimizer,
+            show_progress = false,
+            state,
+            rng,
+            adbackend
+        )
+        @test q == q_ref
+    end
 end

From c56d29ef1c2954673b7941fce6c3d8d664fe020c Mon Sep 17 00:00:00 2001
From: Ray Kim <msca8h@naver.com>
Date: Sat, 26 Aug 2023 22:03:36 +0100
Subject: [PATCH 148/206] fix test for new interface, change interface for
 `optimize`, `advi`

---
 src/AdvancedVI.jl           |  3 +-
 src/objectives/elbo/advi.jl | 81 ++++++++++++++++++++-----------------
 src/optimize.jl             | 21 ++++++----
 test/advi_locscale.jl       | 33 ++++++++-------
 test/optimize.jl            | 21 +++++-----
 5 files changed, 87 insertions(+), 72 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index db433a678..91f714e48 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -82,6 +82,7 @@ init(
 """
     estimate_gradient!(
         rng         ::AbstractRNG,
+        prob,
         adbackend   ::AbstractADType,
         obj         ::AbstractVariationalObjective,
         obj_state,
@@ -90,7 +91,7 @@ init(
         out         ::DiffResults.MutableDiffResult
     )
 
-Estimate (possibly stochastic) gradients of the objective `obj` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`.
+Estimate (possibly stochastic) gradients of the objective `obj` targeting `prob` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`.
 The estimated objective value and gradient are then stored in `out`.
 If the objective is stateful, `obj_state` is its previous state, otherwise, it is `nothing`.
 
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 5a3ce96ea..1ce573717 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -1,6 +1,6 @@
 
 """
-    ADVI(prob, n_samples; kwargs...)
+    ADVI(n_samples; kwargs...)
 
 Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective.
 This computes the evidence lower-bound (ELBO) through the ADVI formulation:
@@ -19,17 +19,14 @@ This computes the evidence lower-bound (ELBO) through the ADVI formulation:
 where ``\\phi^{-1}`` is an "inverse bijector."
 
 # Arguments
-- `prob`: An object that implements the order `K == 0` `LogDensityProblems` interface.
 - `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. (Type `<: Int`.)
 
 # Keyword Arguments
 - `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy())
-- `cv`: A control variate.
-- `invbij`: An inverse bijective mapping that matches the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.)
 
 # Requirements
 - ``q_{\\lambda}`` implements `rand`.
-- `logdensity(prob)` must be differentiable by the selected AD backend.
+- The target `logdensity(prob)` must be differentiable by the selected AD backend.
 
 Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
 
@@ -37,27 +34,12 @@ Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
 * Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research.
 * Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR.
 """
-struct ADVI{P, B, EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective
-    prob     ::P
-    invbij   ::B
+struct ADVI{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective
     entropy  ::EntropyEst
     n_samples::Int
 
-    function ADVI(prob,
-                  n_samples::Int;
-                  entropy  ::AbstractEntropyEstimator = ClosedFormEntropy(),
-                  invbij = Bijectors.identity)
-        cap = LogDensityProblems.capabilities(prob)
-        if cap === nothing
-            throw(
-                ArgumentError(
-                    "The log density function does not support the LogDensityProblems.jl interface",
-                ),
-            )
-        end
-        new{typeof(prob), typeof(invbij), typeof(entropy)}(
-            prob, invbij, entropy, n_samples
-        )
+    function ADVI(n_samples::Int; entropy::AbstractEntropyEstimator = ClosedFormEntropy())
+        new{typeof(entropy)}(entropy, n_samples)
     end
 end
 
@@ -65,38 +47,64 @@ Base.show(io::IO, advi::ADVI) =
     print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))")
 
 function (advi::ADVI)(
-    q_η::ContinuousMultivariateDistribution,
+    prob,
+    q ::ContinuousMultivariateDistribution,
+    zs::AbstractMatrix
+)
+    𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(zs))
+    ℍ  = advi.entropy(q, zs)
+    𝔼ℓ + ℍ
+end
+
+function (advi::ADVI)(
+    prob,
+    q_trans::Bijectors.TransformedDistribution,
     ηs ::AbstractMatrix
 )
+    @unpack dist, transform = q_trans
+    q   = dist
+    b⁻¹ = transform
     𝔼ℓ = mean(eachcol(ηs)) do ηᵢ
-        zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.invbij, ηᵢ)
-        LogDensityProblems.logdensity(advi.prob, zᵢ) + logdetjacᵢ
+        zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(b⁻¹, ηᵢ)
+        LogDensityProblems.logdensity(prob, zᵢ) + logdetjacᵢ
     end
-    ℍ  = advi.entropy(q_η, ηs)
+    ℍ  = advi.entropy(q, ηs)
     𝔼ℓ + ℍ
 end
 
 """
     (advi::ADVI)(
-        q_η::ContinuousMultivariateDistribution;
+        prob, q;
         rng::AbstractRNG = Random.default_rng(),
         n_samples::Int = advi.n_samples
     )
 
-Estimate the ELBO of the variational approximation `q_η`  using the ADVI
-formulation using `n_samples` number of Monte Carlo samples.
+Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation using `n_samples` number of Monte Carlo samples.
 """
 function (advi::ADVI)(
-    q_η      ::ContinuousMultivariateDistribution;
+    prob,
+    q        ::ContinuousMultivariateDistribution;
+    rng      ::AbstractRNG = default_rng(),
+    n_samples::Int         = advi.n_samples
+)
+    zs = rand(rng, q, n_samples)
+    advi(q, zs)
+end
+
+function (advi::ADVI)(
+    prob,
+    q_trans  ::Bijectors.TransformedDistribution;
     rng      ::AbstractRNG = default_rng(),
     n_samples::Int         = advi.n_samples
 )
-    ηs = rand(rng, q_η, n_samples)
-    advi(q_η, ηs)
+    q  = q_trans.dist
+    ηs = rand(rng, q, n_samples)
+    advi(q_trans, ηs)
 end
 
 function estimate_gradient!(
     rng          ::AbstractRNG,
+    prob,
     adbackend    ::AbstractADType,
     advi         ::ADVI,
     est_state,
@@ -105,9 +113,10 @@ function estimate_gradient!(
     out          ::DiffResults.MutableDiffResult
 )
     f(λ′) = begin
-        q_η = restructure(λ′)
-        ηs  = rand(rng, q_η, advi.n_samples)
-        -advi(q_η, ηs)
+        q_trans = restructure(λ′)
+        q       = q_trans.dist
+        ηs      = rand(rng, q, advi.n_samples)
+        -advi(prob, q_trans, ηs)
     end
     value_and_gradient!(adbackend, f, λ, out)
 
diff --git a/src/optimize.jl b/src/optimize.jl
index ea2fd5a12..5425d938b 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -5,6 +5,7 @@ end
 
 """
     optimize(
+        prob,
         objective    ::AbstractVariationalObjective,
         restructure,
         λ₀           ::AbstractVector{<:Real},
@@ -13,9 +14,10 @@ end
         kwargs...
     )              
 
-Optimize the variational objective `objective` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `λ₀` to the function `restructure`.
+Optimize the variational objective `objective` targeting `prob` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `λ₀` to the function `restructure`.
 
     optimize(
+        prob,
         objective ::AbstractVariationalObjective,
         q,
         n_max_iter::Int,
@@ -23,7 +25,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie
         kwargs...
     )              
 
-Optimize the variational objective `objective` by estimating (stochastic) gradients, where the initial variational approximation `q₀` supports the `Optimisers.destructure` interface.
+Optimize the variational objective `objective` targeting `prob` by estimating (stochastic) gradients, where the initial variational approximation `q₀` supports the `Optimisers.destructure` interface.
 
 # Arguments
 - `objective`: Variational Objective.
@@ -49,6 +51,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie
 - `states`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run.
 """
 function optimize(
+    prob,
     objective    ::AbstractVariationalObjective,
     restructure,
     λ₀           ::AbstractVector{<:Real},
@@ -77,9 +80,9 @@ function optimize(
     for t = 1:n_max_iter
         stat = (iteration=t,)
 
-        grad_buf, obj_st, stat′ = estimate_gradient(
-            rng, adbackend, objective, obj_st,
-            λ, restructure, grad_buf; objargs...
+        grad_buf, obj_st, stat′ = estimate_gradient!(
+            rng, prob, adbackend, objective, obj_st,
+            λ, restructure, grad_buf, objargs...
         )
         stat = merge(stat, stat′)
 
@@ -101,13 +104,15 @@ function optimize(
     λ, logstats, state
 end
 
-function optimize(objective ::AbstractVariationalObjective,
+function optimize(prob,
+                  objective ::AbstractVariationalObjective,
                   q₀,
-                  n_max_iter::Int;
+                  n_max_iter::Int,
+                  objargs...;
                   kwargs...)
     λ, restructure = Optimisers.destructure(q₀)
     λ, logstats, state = optimize(
-        objective, restructure, λ, n_max_iter; kwargs...
+        prob, objective, restructure, λ, n_max_iter, objargs...; kwargs...
     )
     restructure(λ), logstats, state
 end
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 93ece4124..85cfea71e 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -11,8 +11,8 @@ using ReTest
                 :NormalLogNormalMeanField => normallognormal_meanfield,
             ),
             (objname, objective) ∈ Dict(
-                :ADVIClosedFormEntropy  => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹),
-                :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹, entropy = StickingTheLandingEntropy()),
+                :ADVIClosedFormEntropy  => ADVI(10),
+                :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()),
             ),
             (adbackname, adbackend) ∈ Dict(
                 :ForwarDiff  => AutoForwardDiff(),
@@ -34,22 +34,21 @@ using ReTest
             μ₀   = zeros(realtype, n_dims)
             L₀   = Diagonal(ones(realtype, n_dims))
 
-            q₀ = TuringDiagMvNormal(μ₀, diag(L₀))
-
-            obj = objective(model, b⁻¹, 10)
+            q₀_η = TuringDiagMvNormal(μ₀, diag(L₀))
+            q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
 
             @testset "convergence" begin
                 Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
-                q, stats, _, _ = optimize(
-                    obj, q₀, T;
+                q, stats, _ = optimize(
+                    model, objective, q₀_z, T;
                     optimizer     = Optimisers.Adam(realtype(η)),
                     show_progress = PROGRESS,
                     rng           = rng,
                     adbackend     = adbackend,
                 )
 
-                μ  = mean(q)
-                L  = sqrt(cov(q))
+                μ  = mean(q.dist)
+                L  = sqrt(cov(q.dist))
                 Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
 
                 @test Δλ ≤ Δλ₀/T^(1/4)
@@ -59,26 +58,26 @@ using ReTest
 
             @testset "determinism" begin
                 rng = Philox4x(UInt64, seed, 8)
-                q, stats, _, _ = optimize(
-                    obj, q₀, T;
+                q, stats, _ = optimize(
+                    model, objective, q₀_z, T;
                     optimizer     = Optimisers.Adam(realtype(η)),
                     show_progress = PROGRESS,
                     rng           = rng,
                     adbackend     = adbackend,
                 )
-                μ  = mean(q)
-                L  = sqrt(cov(q))
+                μ  = mean(q.dist)
+                L  = sqrt(cov(q.dist))
 
                 rng_repl = Philox4x(UInt64, seed, 8)
-                q, stats, _, _ = optimize(
-                    obj, q₀, T;
+                q, stats, _ = optimize(
+                    model, objective, q₀_z, T;
                     optimizer     = Optimisers.Adam(realtype(η)),
                     show_progress = PROGRESS,
                     rng           = rng_repl,
                     adbackend     = adbackend,
                 )
-                μ_repl = mean(q)
-                L_repl = sqrt(cov(q))
+                μ_repl = mean(q.dist)
+                L_repl = sqrt(cov(q.dist))
                 @test μ == μ_repl
                 @test L == L_repl
             end
diff --git a/test/optimize.jl b/test/optimize.jl
index 78d07d001..2af56c1fa 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -11,16 +11,17 @@ using ReTest
     @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
 
     # Global Test Configurations
-    b⁻¹ = Bijectors.bijector(model) |> inverse
-    q₀  = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
-    obj = ADVI(model, 10; invbij=b⁻¹)
+    b⁻¹  = Bijectors.bijector(model) |> inverse
+    q₀_η = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
+    q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
+    obj  = ADVI(10)
 
     adbackend = AutoForwardDiff()
     optimizer = Optimisers.Adam(1e-2)
 
     rng = Philox4x(UInt64, seed, 8)
     q_ref, stats_ref, _ = optimize(
-        obj, q₀, T;
+        model, obj, q₀_z, T;
         optimizer,
         show_progress = false,
         rng,
@@ -29,11 +30,11 @@ using ReTest
     λ_ref, _ = Optimisers.destructure(q_ref)
 
     @testset "restructure" begin
-        λ₀, re  = Optimisers.destructure(q₀)
+        λ₀, re  = Optimisers.destructure(q₀_z)
 
         rng = Philox4x(UInt64, seed, 8)
         λ, stats, _ = optimize(
-            obj, re, λ₀, T;
+            model, obj, re, λ₀, T;
             optimizer,
             show_progress = false,
             rng,
@@ -47,13 +48,13 @@ using ReTest
         rng = Philox4x(UInt64, seed, 8)
         test_values = rand(rng, T)
 
-        callback!(; stat, obj_state, restructure, λ, g) = begin
+        callback!(; stat, restructure, λ, g) = begin
             (test_value = test_values[stat.iteration],)
         end
 
         rng = Philox4x(UInt64, seed, 8)
         _, stats, _ = optimize(
-            obj, q₀, T;
+            model, obj, q₀_z, T;
             show_progress = false,
             rng,
             adbackend,
@@ -69,7 +70,7 @@ using ReTest
         T_last  = T - T_first
 
         q_first, _, state = optimize(
-            obj, q₀, T_first;
+            model, obj, q₀_z, T_first;
             optimizer,
             show_progress = false,
             rng,
@@ -77,7 +78,7 @@ using ReTest
         )
 
         q, stats, _ = optimize(
-            obj, q_first, T_last;
+            model, obj, q_first, T_last;
             optimizer,
             show_progress = false,
             state,

From 913b46953fbaaf81150bb308daee6c06d7bfa47d Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Fri, 1 Sep 2023 00:27:14 -0400
Subject: [PATCH 149/206] fix integer subtype error in documentation of advi

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
---
 src/objectives/elbo/advi.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 1ce573717..77e1c750f 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -19,7 +19,7 @@ This computes the evidence lower-bound (ELBO) through the ADVI formulation:
 where ``\\phi^{-1}`` is an "inverse bijector."
 
 # Arguments
-- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. (Type `<: Int`.)
+- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO.
 
 # Keyword Arguments
 - `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy())

From 385a653c2d1c37e6fe088c6ecc08b86647f16159 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Fri, 1 Sep 2023 00:28:31 -0400
Subject: [PATCH 150/206] fix remove redundant argument for `advi`

---
 src/objectives/elbo/advi.jl | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index f9a61d81b..b70c32993 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -49,7 +49,6 @@ Base.show(io::IO, advi::ADVI) =
 init(rng::AbstractRNG, advi::ADVI, λ::AbstractVector, restructure) = nothing
 
 function (advi::ADVI)(
-    rng::AbstractRNG,
     q_η::ContinuousMultivariateDistribution,
     ηs ::AbstractMatrix
 )
@@ -81,7 +80,7 @@ function (advi::ADVI)(
     n_samples::Int         = advi.n_samples
 )
     ηs = rand(rng, q_η, n_samples)
-    advi(rng, q_η, ηs)
+    advi(q_η, ηs)
 end
 
 function estimate_gradient(
@@ -96,7 +95,7 @@ function estimate_gradient(
     f(λ′) = begin
         q_η = restructure(λ′)
         ηs  = rand(rng, q_η, advi.n_samples)
-        -advi(rng, q_η, ηs)
+        -advi(q_η, ηs)
     end
     value_and_gradient!(adbackend, f, λ, out)
 

From c9df90e72a842b15c2aa6c41d32ecb14331a7c6b Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Fri, 1 Sep 2023 00:31:48 -0400
Subject: [PATCH 151/206] remove manifest

---
 README.md          |   1 -
 test/Manifest.toml | 866 ---------------------------------------------
 2 files changed, 867 deletions(-)
 delete mode 100644 test/Manifest.toml

diff --git a/README.md b/README.md
index f0bf6cc10..18ba63e50 100644
--- a/README.md
+++ b/README.md
@@ -248,4 +248,3 @@ end
 - Kucukelbir, Alp, Rajesh Ranganath, Andrew Gelman, and David Blei. "Automatic variational inference in Stan." In Advances in Neural Information Processing Systems, pp. 568-576. 2015.
 - Salimans, Tim, and David A. Knowles. "Fixed-form variational posterior approximation through stochastic linear regression." Bayesian Analysis 8, no. 4 (2013): 837-882.
 - Beal, Matthew James. Variational algorithms for approximate Bayesian inference. 2003.
-
diff --git a/test/Manifest.toml b/test/Manifest.toml
deleted file mode 100644
index 220b42bb6..000000000
--- a/test/Manifest.toml
+++ /dev/null
@@ -1,866 +0,0 @@
-# This file is machine-generated - editing it directly is not advised
-
-julia_version = "1.9.2"
-manifest_format = "2.0"
-project_hash = "a6495d9f0ea044fd0a55c1c989f1adca1ad5c855"
-
-[[deps.ADTypes]]
-git-tree-sha1 = "a4c8e0f8c09d4aa708289c1a5fc23e2d1970017a"
-uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
-version = "0.2.1"
-
-[[deps.AbstractFFTs]]
-deps = ["LinearAlgebra"]
-git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef"
-uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
-version = "1.5.0"
-weakdeps = ["ChainRulesCore", "Test"]
-
-    [deps.AbstractFFTs.extensions]
-    AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
-    AbstractFFTsTestExt = "Test"
-
-[[deps.Adapt]]
-deps = ["LinearAlgebra", "Requires"]
-git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24"
-uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
-version = "3.6.2"
-weakdeps = ["StaticArrays"]
-
-    [deps.Adapt.extensions]
-    AdaptStaticArraysExt = "StaticArrays"
-
-[[deps.ArgCheck]]
-git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4"
-uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197"
-version = "2.3.0"
-
-[[deps.ArgTools]]
-uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
-version = "1.1.1"
-
-[[deps.Artifacts]]
-uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
-
-[[deps.Atomix]]
-deps = ["UnsafeAtomics"]
-git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be"
-uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
-version = "0.1.0"
-
-[[deps.Base64]]
-uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
-
-[[deps.Bijectors]]
-deps = ["ArgCheck", "ChainRules", "ChainRulesCore", "ChangesOfVariables", "Compat", "Distributions", "Functors", "InverseFunctions", "IrrationalConstants", "LinearAlgebra", "LogExpFunctions", "MappedArrays", "Random", "Reexport", "Requires", "Roots", "SparseArrays", "Statistics"]
-git-tree-sha1 = "af192c7c235264bdc6f67321fd1c57be0dd7ffb5"
-uuid = "76274a88-744f-5084-9051-94815aaf08c4"
-version = "0.13.6"
-
-    [deps.Bijectors.extensions]
-    BijectorsDistributionsADExt = "DistributionsAD"
-    BijectorsForwardDiffExt = "ForwardDiff"
-    BijectorsLazyArraysExt = "LazyArrays"
-    BijectorsReverseDiffExt = "ReverseDiff"
-    BijectorsTrackerExt = "Tracker"
-    BijectorsZygoteExt = "Zygote"
-
-    [deps.Bijectors.weakdeps]
-    DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
-    ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
-    LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
-    ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
-    Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
-    Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
-
-[[deps.CEnum]]
-git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90"
-uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
-version = "0.4.2"
-
-[[deps.Calculus]]
-deps = ["LinearAlgebra"]
-git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad"
-uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
-version = "0.5.1"
-
-[[deps.ChainRules]]
-deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"]
-git-tree-sha1 = "f98ae934cd677d51d2941088849f0bf2f59e6f6e"
-uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
-version = "1.53.0"
-
-[[deps.ChainRulesCore]]
-deps = ["Compat", "LinearAlgebra", "SparseArrays"]
-git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644"
-uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
-version = "1.16.0"
-
-[[deps.ChangesOfVariables]]
-deps = ["LinearAlgebra", "Test"]
-git-tree-sha1 = "2fba81a302a7be671aefe194f0525ef231104e7f"
-uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
-version = "0.1.8"
-weakdeps = ["InverseFunctions"]
-
-    [deps.ChangesOfVariables.extensions]
-    ChangesOfVariablesInverseFunctionsExt = "InverseFunctions"
-
-[[deps.CommonSolve]]
-git-tree-sha1 = "0eee5eb66b1cf62cd6ad1b460238e60e4b09400c"
-uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
-version = "0.2.4"
-
-[[deps.CommonSubexpressions]]
-deps = ["MacroTools", "Test"]
-git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
-uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
-version = "0.3.0"
-
-[[deps.Comonicon]]
-deps = ["Configurations", "ExproniconLite", "Libdl", "Logging", "Markdown", "OrderedCollections", "PackageCompiler", "Pkg", "Scratch", "TOML", "UUIDs"]
-git-tree-sha1 = "9c360961f23e2fae4c6549bbba58a6f39c9e145c"
-uuid = "863f3e99-da2a-4334-8734-de3dacbe5542"
-version = "1.0.5"
-
-[[deps.Compat]]
-deps = ["UUIDs"]
-git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7"
-uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
-version = "4.9.0"
-weakdeps = ["Dates", "LinearAlgebra"]
-
-    [deps.Compat.extensions]
-    CompatLinearAlgebraExt = "LinearAlgebra"
-
-[[deps.CompilerSupportLibraries_jll]]
-deps = ["Artifacts", "Libdl"]
-uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
-version = "1.0.5+0"
-
-[[deps.Configurations]]
-deps = ["ExproniconLite", "OrderedCollections", "TOML"]
-git-tree-sha1 = "434f446dbf89d08350e83bf57c0fc86f5d3ffd4e"
-uuid = "5218b696-f38b-4ac9-8b61-a12ec717816d"
-version = "0.17.5"
-
-[[deps.ConstructionBase]]
-deps = ["LinearAlgebra"]
-git-tree-sha1 = "fe2838a593b5f776e1597e086dcd47560d94e816"
-uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
-version = "1.5.3"
-
-    [deps.ConstructionBase.extensions]
-    ConstructionBaseIntervalSetsExt = "IntervalSets"
-    ConstructionBaseStaticArraysExt = "StaticArrays"
-
-    [deps.ConstructionBase.weakdeps]
-    IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
-    StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
-
-[[deps.DataAPI]]
-git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c"
-uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
-version = "1.15.0"
-
-[[deps.DataStructures]]
-deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
-git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d"
-uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
-version = "0.18.15"
-
-[[deps.DataValueInterfaces]]
-git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
-uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
-version = "1.0.0"
-
-[[deps.Dates]]
-deps = ["Printf"]
-uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
-
-[[deps.DiffResults]]
-deps = ["StaticArraysCore"]
-git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621"
-uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
-version = "1.1.0"
-
-[[deps.DiffRules]]
-deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"]
-git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272"
-uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
-version = "1.15.1"
-
-[[deps.Distributed]]
-deps = ["Random", "Serialization", "Sockets"]
-uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
-
-[[deps.Distributions]]
-deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"]
-git-tree-sha1 = "938fe2981db009f531b6332e31c58e9584a2f9bd"
-uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
-version = "0.25.100"
-
-    [deps.Distributions.extensions]
-    DistributionsChainRulesCoreExt = "ChainRulesCore"
-    DistributionsDensityInterfaceExt = "DensityInterface"
-
-    [deps.Distributions.weakdeps]
-    ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
-    DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
-
-[[deps.DistributionsAD]]
-deps = ["Adapt", "ChainRules", "ChainRulesCore", "Compat", "Distributions", "FillArrays", "LinearAlgebra", "PDMats", "Random", "Requires", "SpecialFunctions", "StaticArrays", "StatsFuns", "ZygoteRules"]
-git-tree-sha1 = "975de103eb2175cf54bf14b15ded2c68625eabdf"
-uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
-version = "0.6.52"
-
-    [deps.DistributionsAD.extensions]
-    DistributionsADForwardDiffExt = "ForwardDiff"
-    DistributionsADLazyArraysExt = "LazyArrays"
-    DistributionsADReverseDiffExt = "ReverseDiff"
-    DistributionsADTrackerExt = "Tracker"
-
-    [deps.DistributionsAD.weakdeps]
-    ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
-    LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
-    ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
-    Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
-
-[[deps.DocStringExtensions]]
-deps = ["LibGit2"]
-git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
-uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
-version = "0.9.3"
-
-[[deps.Downloads]]
-deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
-uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
-version = "1.6.0"
-
-[[deps.DualNumbers]]
-deps = ["Calculus", "NaNMath", "SpecialFunctions"]
-git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566"
-uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
-version = "0.6.8"
-
-[[deps.Enzyme]]
-deps = ["CEnum", "EnzymeCore", "Enzyme_jll", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "ObjectFile", "Preferences", "Printf", "Random"]
-git-tree-sha1 = "1f85bc8a9da6118abb95d134efc68cf4a6957341"
-uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
-version = "0.11.7"
-
-[[deps.EnzymeCore]]
-deps = ["Adapt"]
-git-tree-sha1 = "643995502bdfff08bf080212c92430510be01ad5"
-uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
-version = "0.5.2"
-
-[[deps.Enzyme_jll]]
-deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
-git-tree-sha1 = "ffa4926cc857bcc5c256825bd7273a6ac989eb34"
-uuid = "7cc45869-7501-5eee-bdea-0790c847d4ef"
-version = "0.0.80+0"
-
-[[deps.ExprTools]]
-git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec"
-uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
-version = "0.1.10"
-
-[[deps.ExproniconLite]]
-deps = ["Pkg", "TOML"]
-git-tree-sha1 = "d80b5d5990071086edf5de9018c6c69c83937004"
-uuid = "55351af7-c7e9-48d6-89ff-24e801d99491"
-version = "0.10.3"
-
-[[deps.FileWatching]]
-uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
-
-[[deps.FillArrays]]
-deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
-git-tree-sha1 = "048dd3d82558759476cff9cff999219216932a08"
-uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
-version = "1.6.0"
-
-[[deps.ForwardDiff]]
-deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"]
-git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad"
-uuid = "f6369f11-7733-5829-9624-2563aa707210"
-version = "0.10.36"
-weakdeps = ["StaticArrays"]
-
-    [deps.ForwardDiff.extensions]
-    ForwardDiffStaticArraysExt = "StaticArrays"
-
-[[deps.FunctionWrappers]]
-git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e"
-uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
-version = "1.1.3"
-
-[[deps.Functors]]
-deps = ["LinearAlgebra"]
-git-tree-sha1 = "9a68d75d466ccc1218d0552a8e1631151c569545"
-uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
-version = "0.4.5"
-
-[[deps.Future]]
-deps = ["Random"]
-uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
-
-[[deps.GPUArrays]]
-deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"]
-git-tree-sha1 = "2e57b4a4f9cc15e85a24d603256fe08e527f48d1"
-uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
-version = "8.8.1"
-
-[[deps.GPUArraysCore]]
-deps = ["Adapt"]
-git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0"
-uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
-version = "0.1.5"
-
-[[deps.GPUCompiler]]
-deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
-git-tree-sha1 = "72b2e3c2ba583d1a7aa35129e56cf92e07c083e3"
-uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
-version = "0.21.4"
-
-[[deps.HypergeometricFunctions]]
-deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
-git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685"
-uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
-version = "0.3.23"
-
-[[deps.IRTools]]
-deps = ["InteractiveUtils", "MacroTools", "Test"]
-git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5"
-uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
-version = "0.4.10"
-
-[[deps.InlineTest]]
-deps = ["Test"]
-git-tree-sha1 = "daf0743879904f0ad645ca6594e1479685f158a2"
-uuid = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6"
-version = "0.2.0"
-
-[[deps.InteractiveUtils]]
-deps = ["Markdown"]
-uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
-
-[[deps.InverseFunctions]]
-deps = ["Test"]
-git-tree-sha1 = "68772f49f54b479fa88ace904f6127f0a3bb2e46"
-uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
-version = "0.1.12"
-
-[[deps.IrrationalConstants]]
-git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
-uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
-version = "0.2.2"
-
-[[deps.IteratorInterfaceExtensions]]
-git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
-uuid = "82899510-4779-5014-852e-03e436cf321d"
-version = "1.0.0"
-
-[[deps.JLLWrappers]]
-deps = ["Artifacts", "Preferences"]
-git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
-uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
-version = "1.5.0"
-
-[[deps.KernelAbstractions]]
-deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
-git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118"
-uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
-version = "0.9.8"
-weakdeps = ["EnzymeCore"]
-
-    [deps.KernelAbstractions.extensions]
-    EnzymeExt = "EnzymeCore"
-
-[[deps.LLVM]]
-deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
-git-tree-sha1 = "8695a49bfe05a2dc0feeefd06b4ca6361a018729"
-uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
-version = "6.1.0"
-
-[[deps.LLVMExtra_jll]]
-deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
-git-tree-sha1 = "c35203c1e1002747da220ffc3c0762ce7754b08c"
-uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
-version = "0.0.23+0"
-
-[[deps.LazyArtifacts]]
-deps = ["Artifacts", "Pkg"]
-uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
-
-[[deps.LibCURL]]
-deps = ["LibCURL_jll", "MozillaCACerts_jll"]
-uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
-version = "0.6.3"
-
-[[deps.LibCURL_jll]]
-deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
-uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
-version = "7.84.0+0"
-
-[[deps.LibGit2]]
-deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
-uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
-
-[[deps.LibSSH2_jll]]
-deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
-uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
-version = "1.10.2+0"
-
-[[deps.Libdl]]
-uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
-
-[[deps.LinearAlgebra]]
-deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
-uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
-
-[[deps.LogDensityProblems]]
-deps = ["ArgCheck", "DocStringExtensions", "Random"]
-git-tree-sha1 = "f9a11237204bc137617194d79d813069838fcf61"
-uuid = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
-version = "2.1.1"
-
-[[deps.LogExpFunctions]]
-deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
-git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa"
-uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
-version = "0.3.26"
-weakdeps = ["ChainRulesCore", "ChangesOfVariables", "InverseFunctions"]
-
-    [deps.LogExpFunctions.extensions]
-    LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
-    LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables"
-    LogExpFunctionsInverseFunctionsExt = "InverseFunctions"
-
-[[deps.Logging]]
-uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
-
-[[deps.MacroTools]]
-deps = ["Markdown", "Random"]
-git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48"
-uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
-version = "0.5.11"
-
-[[deps.MappedArrays]]
-git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e"
-uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
-version = "0.4.2"
-
-[[deps.Markdown]]
-deps = ["Base64"]
-uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
-
-[[deps.MbedTLS_jll]]
-deps = ["Artifacts", "Libdl"]
-uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
-version = "2.28.2+0"
-
-[[deps.Missings]]
-deps = ["DataAPI"]
-git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272"
-uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
-version = "1.1.0"
-
-[[deps.MozillaCACerts_jll]]
-uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
-version = "2022.10.11"
-
-[[deps.NNlib]]
-deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
-git-tree-sha1 = "3d42748c725c3f088bcda47fa2aca89e74d59d22"
-uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
-version = "0.9.4"
-
-    [deps.NNlib.extensions]
-    NNlibAMDGPUExt = "AMDGPU"
-    NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
-    NNlibCUDAExt = "CUDA"
-
-    [deps.NNlib.weakdeps]
-    AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
-    CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
-    cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
-
-[[deps.NaNMath]]
-deps = ["OpenLibm_jll"]
-git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4"
-uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
-version = "1.0.2"
-
-[[deps.NetworkOptions]]
-uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
-version = "1.2.0"
-
-[[deps.ObjectFile]]
-deps = ["Reexport", "StructIO"]
-git-tree-sha1 = "69607899b46e1f8ead70396bc51a4c361478d8f6"
-uuid = "d8793406-e978-5875-9003-1fc021f44a92"
-version = "0.4.0"
-
-[[deps.OpenBLAS_jll]]
-deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
-uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
-version = "0.3.21+4"
-
-[[deps.OpenLibm_jll]]
-deps = ["Artifacts", "Libdl"]
-uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
-version = "0.8.1+0"
-
-[[deps.OpenSpecFun_jll]]
-deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
-uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
-version = "0.5.5+0"
-
-[[deps.Optimisers]]
-deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
-git-tree-sha1 = "c1fc26bab5df929a5172f296f25d7d08688fd25b"
-uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
-version = "0.2.20"
-
-[[deps.OrderedCollections]]
-git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3"
-uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
-version = "1.6.2"
-
-[[deps.PDMats]]
-deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
-git-tree-sha1 = "67eae2738d63117a196f497d7db789821bce61d1"
-uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
-version = "0.11.17"
-
-[[deps.PackageCompiler]]
-deps = ["Artifacts", "LazyArtifacts", "Libdl", "Pkg", "Printf", "RelocatableFolders", "TOML", "UUIDs"]
-git-tree-sha1 = "1a6a868eb755e8ea9ecd000aa6ad175def0cc85b"
-uuid = "9b87118b-4619-50d2-8e1e-99f35a4d4d9d"
-version = "2.1.7"
-
-[[deps.Pkg]]
-deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
-uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
-version = "1.9.2"
-
-[[deps.PrecompileTools]]
-deps = ["Preferences"]
-git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f"
-uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
-version = "1.2.0"
-
-[[deps.Preferences]]
-deps = ["TOML"]
-git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1"
-uuid = "21216c6a-2e73-6563-6e65-726566657250"
-version = "1.4.0"
-
-[[deps.Printf]]
-deps = ["Unicode"]
-uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
-
-[[deps.QuadGK]]
-deps = ["DataStructures", "LinearAlgebra"]
-git-tree-sha1 = "6ec7ac8412e83d57e313393220879ede1740f9ee"
-uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
-version = "2.8.2"
-
-[[deps.REPL]]
-deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
-uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
-
-[[deps.Random]]
-deps = ["SHA", "Serialization"]
-uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
-
-[[deps.Random123]]
-deps = ["Random", "RandomNumbers"]
-git-tree-sha1 = "552f30e847641591ba3f39fd1bed559b9deb0ef3"
-uuid = "74087812-796a-5b5d-8853-05524746bad3"
-version = "1.6.1"
-
-[[deps.RandomNumbers]]
-deps = ["Random", "Requires"]
-git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111"
-uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
-version = "1.5.3"
-
-[[deps.ReTest]]
-deps = ["Distributed", "InlineTest", "Printf", "Random", "Sockets", "Test"]
-git-tree-sha1 = "dd8f6587c0abac44bcec2e42f0aeddb73550c0ec"
-uuid = "e0db7c4e-2690-44b9-bad6-7687da720f89"
-version = "0.3.2"
-
-[[deps.RealDot]]
-deps = ["LinearAlgebra"]
-git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9"
-uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
-version = "0.1.0"
-
-[[deps.Reexport]]
-git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
-uuid = "189a3867-3050-52da-a836-e630ba90ab69"
-version = "1.2.2"
-
-[[deps.RelocatableFolders]]
-deps = ["SHA", "Scratch"]
-git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691"
-uuid = "05181044-ff0b-4ac5-8273-598c1e38db00"
-version = "1.0.0"
-
-[[deps.Requires]]
-deps = ["UUIDs"]
-git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
-uuid = "ae029012-a4dd-5104-9daa-d747884805df"
-version = "1.3.0"
-
-[[deps.ReverseDiff]]
-deps = ["ChainRulesCore", "DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"]
-git-tree-sha1 = "d1235bdd57a93bd7504225b792b867e9a7df38d5"
-uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
-version = "1.15.1"
-
-[[deps.Rmath]]
-deps = ["Random", "Rmath_jll"]
-git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b"
-uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
-version = "0.7.1"
-
-[[deps.Rmath_jll]]
-deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da"
-uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f"
-version = "0.4.0+0"
-
-[[deps.Roots]]
-deps = ["ChainRulesCore", "CommonSolve", "Printf", "Setfield"]
-git-tree-sha1 = "ff42754a57bb0d6dcfe302fd0d4272853190421f"
-uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
-version = "2.0.19"
-
-    [deps.Roots.extensions]
-    RootsForwardDiffExt = "ForwardDiff"
-    RootsIntervalRootFindingExt = "IntervalRootFinding"
-    RootsSymPyExt = "SymPy"
-
-    [deps.Roots.weakdeps]
-    ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
-    IntervalRootFinding = "d2bf35a9-74e0-55ec-b149-d360ff49b807"
-    SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
-
-[[deps.SHA]]
-uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
-version = "0.7.0"
-
-[[deps.Scratch]]
-deps = ["Dates"]
-git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a"
-uuid = "6c6a2e73-6563-6170-7368-637461726353"
-version = "1.2.0"
-
-[[deps.Serialization]]
-uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
-
-[[deps.Setfield]]
-deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"]
-git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac"
-uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46"
-version = "1.1.1"
-
-[[deps.SimpleUnPack]]
-git-tree-sha1 = "58e6353e72cde29b90a69527e56df1b5c3d8c437"
-uuid = "ce78b400-467f-4804-87d8-8f486da07d0a"
-version = "1.1.0"
-
-[[deps.Sockets]]
-uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
-
-[[deps.SortingAlgorithms]]
-deps = ["DataStructures"]
-git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee"
-uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
-version = "1.1.1"
-
-[[deps.SparseArrays]]
-deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
-uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
-
-[[deps.SpecialFunctions]]
-deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
-git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d"
-uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
-version = "2.3.1"
-weakdeps = ["ChainRulesCore"]
-
-    [deps.SpecialFunctions.extensions]
-    SpecialFunctionsChainRulesCoreExt = "ChainRulesCore"
-
-[[deps.StaticArrays]]
-deps = ["LinearAlgebra", "Random", "StaticArraysCore"]
-git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881"
-uuid = "90137ffa-7385-5640-81b9-e52037218182"
-version = "1.6.2"
-weakdeps = ["Statistics"]
-
-    [deps.StaticArrays.extensions]
-    StaticArraysStatisticsExt = "Statistics"
-
-[[deps.StaticArraysCore]]
-git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d"
-uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
-version = "1.4.2"
-
-[[deps.Statistics]]
-deps = ["LinearAlgebra", "SparseArrays"]
-uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
-version = "1.9.0"
-
-[[deps.StatsAPI]]
-deps = ["LinearAlgebra"]
-git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7"
-uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
-version = "1.6.0"
-
-[[deps.StatsBase]]
-deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
-git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4"
-uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
-version = "0.34.0"
-
-[[deps.StatsFuns]]
-deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
-git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a"
-uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
-version = "1.3.0"
-weakdeps = ["ChainRulesCore", "InverseFunctions"]
-
-    [deps.StatsFuns.extensions]
-    StatsFunsChainRulesCoreExt = "ChainRulesCore"
-    StatsFunsInverseFunctionsExt = "InverseFunctions"
-
-[[deps.StructArrays]]
-deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"]
-git-tree-sha1 = "521a0e828e98bb69042fec1809c1b5a680eb7389"
-uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
-version = "0.6.15"
-
-[[deps.StructIO]]
-deps = ["Test"]
-git-tree-sha1 = "010dc73c7146869c042b49adcdb6bf528c12e859"
-uuid = "53d494c1-5632-5724-8f4c-31dff12d585f"
-version = "0.3.0"
-
-[[deps.SuiteSparse]]
-deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
-uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
-
-[[deps.SuiteSparse_jll]]
-deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"]
-uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
-version = "5.10.1+6"
-
-[[deps.TOML]]
-deps = ["Dates"]
-uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
-version = "1.0.3"
-
-[[deps.TableTraits]]
-deps = ["IteratorInterfaceExtensions"]
-git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39"
-uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
-version = "1.0.1"
-
-[[deps.Tables]]
-deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"]
-git-tree-sha1 = "1544b926975372da01227b382066ab70e574a3ec"
-uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
-version = "1.10.1"
-
-[[deps.Tar]]
-deps = ["ArgTools", "SHA"]
-uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
-version = "1.10.0"
-
-[[deps.Test]]
-deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
-uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
-
-[[deps.TimerOutputs]]
-deps = ["ExprTools", "Printf"]
-git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7"
-uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
-version = "0.5.23"
-
-[[deps.Tracker]]
-deps = ["Adapt", "DiffRules", "ForwardDiff", "Functors", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Optimisers", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"]
-git-tree-sha1 = "92364c27aa35c0ee36e6e010b704adaade6c409c"
-uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
-version = "0.2.26"
-weakdeps = ["PDMats"]
-
-    [deps.Tracker.extensions]
-    TrackerPDMatsExt = "PDMats"
-
-[[deps.UUIDs]]
-deps = ["Random", "SHA"]
-uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
-
-[[deps.Unicode]]
-uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
-
-[[deps.UnsafeAtomics]]
-git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
-uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
-version = "0.2.1"
-
-[[deps.UnsafeAtomicsLLVM]]
-deps = ["LLVM", "UnsafeAtomics"]
-git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e"
-uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
-version = "0.1.3"
-
-[[deps.Zlib_jll]]
-deps = ["Libdl"]
-uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
-version = "1.2.13+0"
-
-[[deps.Zygote]]
-deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
-git-tree-sha1 = "e2fe78907130b521619bc88408c859a472c4172b"
-uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.63"
-
-    [deps.Zygote.extensions]
-    ZygoteColorsExt = "Colors"
-    ZygoteDistancesExt = "Distances"
-    ZygoteTrackerExt = "Tracker"
-
-    [deps.Zygote.weakdeps]
-    Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
-    Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
-    Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
-
-[[deps.ZygoteRules]]
-deps = ["ChainRulesCore", "MacroTools"]
-git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d"
-uuid = "700de1a5-db45-46bc-99cf-38207098b444"
-version = "0.2.3"
-
-[[deps.libblastrampoline_jll]]
-deps = ["Artifacts", "Libdl"]
-uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
-version = "5.8.0+0"
-
-[[deps.nghttp2_jll]]
-deps = ["Artifacts", "Libdl"]
-uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
-version = "1.48.0+0"
-
-[[deps.p7zip_jll]]
-deps = ["Artifacts", "Libdl"]
-uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
-version = "17.4.0+0"

From 19d11d141a788ce6f476172f02c004854a0d892d Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Fri, 1 Sep 2023 00:44:21 -0400
Subject: [PATCH 152/206] refactor remove imports and use fully qualified names

---
 src/AdvancedVI.jl              | 22 ++++++----------------
 src/objectives/elbo/advi.jl    | 16 ++++++++--------
 src/objectives/elbo/entropy.jl |  2 +-
 src/optimize.jl                |  4 ++--
 src/utils.jl                   |  0
 5 files changed, 17 insertions(+), 27 deletions(-)
 delete mode 100644 src/utils.jl

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 91f714e48..1c662cfbd 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -4,11 +4,8 @@ module AdvancedVI
 using SimpleUnPack: @unpack, @pack!
 using Accessors
 
-using Random: AbstractRNG, default_rng
+using Random
 using Distributions
-import Distributions:
-    logpdf, _logpdf, rand, rand!, _rand!,
-    ContinuousMultivariateDistribution
 
 using Functors
 using Optimisers
@@ -17,19 +14,16 @@ using DocStringExtensions
 
 using ProgressMeter
 using LinearAlgebra
-using LinearAlgebra: AbstractTriangular
 
 using LogDensityProblems
 
 using ADTypes, DiffResults
-using ADTypes: AbstractADType
-using ChainRulesCore: @ignore_derivatives 
+using ChainRulesCore
 
 using FillArrays
 using Bijectors
 
 using StatsBase
-using StatsBase: entropy
 
 # derivatives
 """
@@ -59,7 +53,7 @@ abstract type AbstractVariationalObjective end
 
 """
     init(
-        rng::AbstractRNG,
+        rng::Random.AbstractRNG,
         obj::AbstractVariationalObjective,
         λ::AbstractVector,
         restructure
@@ -73,7 +67,7 @@ This function needs to be implemented only if `obj` is stateful.
     notice.
 """
 init(
-    rng::AbstractRNG,
+    rng::Random.AbstractRNG,
     obj::AbstractVariationalObjective,
     λ::AbstractVector,
     restructure
@@ -81,9 +75,9 @@ init(
 
 """
     estimate_gradient!(
-        rng         ::AbstractRNG,
+        rng         ::Random.AbstractRNG,
         prob,
-        adbackend   ::AbstractADType,
+        adbackend   ::ADTypes.AbstractADType,
         obj         ::AbstractVariationalObjective,
         obj_state,
         λ           ::AbstractVector,
@@ -114,7 +108,6 @@ include("objectives/elbo/entropy.jl")
 include("objectives/elbo/advi.jl")
 
 export
-    ELBO,
     ADVI,
     ClosedFormEntropy,
     StickingTheLandingEntropy,
@@ -128,9 +121,6 @@ include("optimize.jl")
 
 export optimize
 
-include("utils.jl")
-
-
 # optional dependencies 
 if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base
     using Requires
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 77e1c750f..97a08b95d 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -48,7 +48,7 @@ Base.show(io::IO, advi::ADVI) =
 
 function (advi::ADVI)(
     prob,
-    q ::ContinuousMultivariateDistribution,
+    q ::Distributions.ContinuousMultivariateDistribution,
     zs::AbstractMatrix
 )
     𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(zs))
@@ -59,7 +59,7 @@ end
 function (advi::ADVI)(
     prob,
     q_trans::Bijectors.TransformedDistribution,
-    ηs ::AbstractMatrix
+    ηs     ::AbstractMatrix
 )
     @unpack dist, transform = q_trans
     q   = dist
@@ -84,8 +84,8 @@ Estimate the ELBO of the variational approximation `q` of the target `prob` usin
 function (advi::ADVI)(
     prob,
     q        ::ContinuousMultivariateDistribution;
-    rng      ::AbstractRNG = default_rng(),
-    n_samples::Int         = advi.n_samples
+    rng      ::Random.AbstractRNG = Random.default_rng(),
+    n_samples::Int                = advi.n_samples
 )
     zs = rand(rng, q, n_samples)
     advi(q, zs)
@@ -94,8 +94,8 @@ end
 function (advi::ADVI)(
     prob,
     q_trans  ::Bijectors.TransformedDistribution;
-    rng      ::AbstractRNG = default_rng(),
-    n_samples::Int         = advi.n_samples
+    rng      ::Random.AbstractRNG = Random.default_rng(),
+    n_samples::Int                = advi.n_samples
 )
     q  = q_trans.dist
     ηs = rand(rng, q, n_samples)
@@ -103,9 +103,9 @@ function (advi::ADVI)(
 end
 
 function estimate_gradient!(
-    rng          ::AbstractRNG,
+    rng          ::Random.AbstractRNG,
     prob,
-    adbackend    ::AbstractADType,
+    adbackend    ::ADTypes.AbstractADType,
     advi         ::ADVI,
     est_state,
     λ            ::Vector{<:Real},
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index e6212c463..63854ec0f 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -28,7 +28,7 @@ The "sticking the landing" entropy estimator.
 struct StickingTheLandingEntropy <: AbstractEntropyEstimator end
 
 function (::StickingTheLandingEntropy)(q, ηs::AbstractMatrix)
-    @ignore_derivatives mean(eachcol(ηs)) do ηᵢ
+    ChainRulesCore.@ignore_derivatives mean(eachcol(ηs)) do ηᵢ
         -logpdf(q, ηᵢ)
     end
 end
diff --git a/src/optimize.jl b/src/optimize.jl
index 5425d938b..85cac75e9 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -57,9 +57,9 @@ function optimize(
     λ₀           ::AbstractVector{<:Real},
     n_max_iter   ::Int,
     objargs...;
-    adbackend    ::AbstractADType, 
+    adbackend    ::ADTypes.AbstractADType, 
     optimizer    ::Optimisers.AbstractRule = Optimisers.Adam(),
-    rng          ::AbstractRNG             = default_rng(),
+    rng          ::Random.AbstractRNG      = Random.default_rng(),
     show_progress::Bool                    = true,
     state        ::NamedTuple              = NamedTuple(),
     callback!                              = nothing,
diff --git a/src/utils.jl b/src/utils.jl
deleted file mode 100644
index e69de29bb..000000000

From 59bd4f848bc034b8428408381e951fead53d6f74 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Fri, 1 Sep 2023 00:46:24 -0400
Subject: [PATCH 153/206] update documentation for
 `AbstractVariationalObjective`

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
---
 src/AdvancedVI.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 1c662cfbd..8abf3da9f 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -43,7 +43,7 @@ function value_and_gradient! end
 
 # estimators
 """
-    abstract type AbstractVariationalObjective end
+    AbstractVariationalObjective
 
 An VI algorithm supported by `AdvancedVI` should implement a subtype of  `AbstractVariationalObjective`.
 Furthermore, it should implement the functions `estimate_gradient`.

From dedc5cf1a99bf7771d68380126e480129ed050af Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Fri, 1 Sep 2023 01:01:07 -0400
Subject: [PATCH 154/206] refactor use StableRNG instead of Random123

---
 test/Project.toml     |  2 +-
 test/advi_locscale.jl |  8 ++++----
 test/optimize.jl      | 14 +++++++-------
 test/runtests.jl      |  3 +--
 4 files changed, 13 insertions(+), 14 deletions(-)

diff --git a/test/Project.toml b/test/Project.toml
index 5ce8fcd88..2c06aa538 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -14,10 +14,10 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
 PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
 Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
-Random123 = "74087812-796a-5b5d-8853-05524746bad3"
 ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"
 ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
 SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
+StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
 Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
 Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
 Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 85cfea71e..7c60188ce 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -21,8 +21,8 @@ using ReTest
                 # :Enzyme      => AutoEnzyme(),
             )
 
-            seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
-            rng  = Philox4x(UInt64, seed, 8)
+            seed = (0x38bef07cf9cc549d)
+            rng  = StableRNG(seed)
 
             modelstats = modelconstr(realtype; rng)
             @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
@@ -57,7 +57,7 @@ using ReTest
             end
 
             @testset "determinism" begin
-                rng = Philox4x(UInt64, seed, 8)
+                rng = StableRNG(seed)
                 q, stats, _ = optimize(
                     model, objective, q₀_z, T;
                     optimizer     = Optimisers.Adam(realtype(η)),
@@ -68,7 +68,7 @@ using ReTest
                 μ  = mean(q.dist)
                 L  = sqrt(cov(q.dist))
 
-                rng_repl = Philox4x(UInt64, seed, 8)
+                rng_repl = StableRNG(seed)
                 q, stats, _ = optimize(
                     model, objective, q₀_z, T;
                     optimizer     = Optimisers.Adam(realtype(η)),
diff --git a/test/optimize.jl b/test/optimize.jl
index 2af56c1fa..c7173f512 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -2,8 +2,8 @@
 using ReTest
 
 @testset "optimize" begin
-    seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797)
-    rng  = Philox4x(UInt64, seed, 8)
+    seed = (0x38bef07cf9cc549d)
+    rng  = StableRNG(seed)
 
     T = 1000
     modelstats = normallognormal_meanfield(Float64; rng)
@@ -19,7 +19,7 @@ using ReTest
     adbackend = AutoForwardDiff()
     optimizer = Optimisers.Adam(1e-2)
 
-    rng = Philox4x(UInt64, seed, 8)
+    rng  = StableRNG(seed)
     q_ref, stats_ref, _ = optimize(
         model, obj, q₀_z, T;
         optimizer,
@@ -32,7 +32,7 @@ using ReTest
     @testset "restructure" begin
         λ₀, re  = Optimisers.destructure(q₀_z)
 
-        rng = Philox4x(UInt64, seed, 8)
+        rng  = StableRNG(seed)
         λ, stats, _ = optimize(
             model, obj, re, λ₀, T;
             optimizer,
@@ -45,14 +45,14 @@ using ReTest
     end
 
     @testset "callback" begin
-        rng = Philox4x(UInt64, seed, 8)
+        rng  = StableRNG(seed)
         test_values = rand(rng, T)
 
         callback!(; stat, restructure, λ, g) = begin
             (test_value = test_values[stat.iteration],)
         end
 
-        rng = Philox4x(UInt64, seed, 8)
+        rng  = StableRNG(seed)
         _, stats, _ = optimize(
             model, obj, q₀_z, T;
             show_progress = false,
@@ -64,7 +64,7 @@ using ReTest
     end
 
     @testset "warm start" begin
-        rng = Philox4x(UInt64, seed, 8)
+        rng  = StableRNG(seed)
 
         T_first = div(T,2)
         T_last  = T - T_first
diff --git a/test/runtests.jl b/test/runtests.jl
index fd68ed794..ef85f16b8 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -3,8 +3,7 @@ using ReTest
 using ReTest: @testset, @test
 
 using Comonicon
-using Random
-using Random123
+using Random, StableRNGs
 using Statistics
 using Distributions
 using LinearAlgebra

From e35dc67f24f1d5c60e4a5c5959b0e085f19506a9 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Fri, 1 Sep 2023 01:08:18 -0400
Subject: [PATCH 155/206] refactor migrate to Test, re-enable x86 tests

---
 .github/workflows/CI.yml | 2 +-
 test/Project.toml        | 2 +-
 test/ad.jl               | 2 +-
 test/advi_locscale.jl    | 2 +-
 test/optimize.jl         | 2 +-
 test/runtests.jl         | 9 ++-------
 6 files changed, 7 insertions(+), 12 deletions(-)

diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 7ba573a15..9731f20c2 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -20,7 +20,7 @@ jobs:
           - windows-latest
         arch:
           - x64
-          # - x86 # Uncomment after https://github.com/JuliaTesting/ReTest.jl/pull/52 is merged
+          - x86
         exclude:
           - os: macOS-latest
             arch: x86
diff --git a/test/Project.toml b/test/Project.toml
index 2c06aa538..0e81ec08b 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -14,10 +14,10 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
 PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
 Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
-ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"
 ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
 SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
 Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
+Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
 Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
diff --git a/test/ad.jl b/test/ad.jl
index f575b485b..b716ca2f2 100644
--- a/test/ad.jl
+++ b/test/ad.jl
@@ -1,5 +1,5 @@
 
-using ReTest
+using Test
 
 @testset "ad" begin
     @testset "$(adname)" for (adname, adsymbol) ∈ Dict(
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 7c60188ce..db2338a3b 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -1,7 +1,7 @@
 
 const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
 
-using ReTest
+using Test
 
 @testset "advi" begin
     @testset "locscale" begin
diff --git a/test/optimize.jl b/test/optimize.jl
index c7173f512..6f7986d00 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -1,5 +1,5 @@
 
-using ReTest
+using Test
 
 @testset "optimize" begin
     seed = (0x38bef07cf9cc549d)
diff --git a/test/runtests.jl b/test/runtests.jl
index ef85f16b8..a4220f986 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,6 +1,6 @@
 
-using ReTest
-using ReTest: @testset, @test
+using Test
+using Test: @testset, @test
 
 using Comonicon
 using Random, StableRNGs
@@ -38,8 +38,3 @@ include("models/normallognormal.jl")
 include("ad.jl")
 include("advi_locscale.jl")
 include("optimize.jl")
-
-@main function runtests(patterns...; dry::Bool = false)
-    retest(patterns...; dry = dry, verbose = Inf)
-end
-

From 641318331387e3de818cf2c159ae7bc41e313abe Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Tue, 5 Sep 2023 18:30:50 +0100
Subject: [PATCH 156/206] refactor remove inner constructor for `ADVI`

---
 src/objectives/elbo/advi.jl | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 97a08b95d..a7d655d36 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -37,12 +37,10 @@ Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
 struct ADVI{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective
     entropy  ::EntropyEst
     n_samples::Int
-
-    function ADVI(n_samples::Int; entropy::AbstractEntropyEstimator = ClosedFormEntropy())
-        new{typeof(entropy)}(entropy, n_samples)
-    end
 end
 
+ADVI(n_samples::Int; entropy::AbstractEntropyEstimator = ClosedFormEntropy()) = ADVI(entropy, n_samples)
+
 Base.show(io::IO, advi::ADVI) =
     print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))")
 

From 1668bae6ee3532fcd751037022f1d6da4dc7257c Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Wed, 6 Sep 2023 21:56:38 -0400
Subject: [PATCH 157/206] fix swap `export`s and `include`s

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
---
 src/AdvancedVI.jl | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 8abf3da9f..d4d776f22 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -103,15 +103,14 @@ function estimate_gradient! end
 # ADVI-specific interfaces
 abstract type AbstractEntropyEstimator end
 
-# entropy.jl must preceed advi.jl
-include("objectives/elbo/entropy.jl")
-include("objectives/elbo/advi.jl")
-
 export
     ADVI,
     ClosedFormEntropy,
     StickingTheLandingEntropy,
     MonteCarloEntropy
+# entropy.jl must preceed advi.jl
+include("objectives/elbo/entropy.jl")
+include("objectives/elbo/advi.jl")
 
 # Optimization Routine
 

From a8f12541c32cc1c74408ffce84257ce0a4eab526 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Wed, 6 Sep 2023 21:57:24 -0400
Subject: [PATCH 158/206] fix doscs for `ADVI`

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
---
 src/objectives/elbo/advi.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index a7d655d36..5ba0ef340 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -26,7 +26,7 @@ where ``\\phi^{-1}`` is an "inverse bijector."
 
 # Requirements
 - ``q_{\\lambda}`` implements `rand`.
-- The target `logdensity(prob)` must be differentiable by the selected AD backend.
+- The target `logdensity(prob, x)` must be differentiable wrt. `x` by the selected AD backend.
 
 Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
 

From 7b368c12b1dd198ba73607dc565465f1f16c5890 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Wed, 6 Sep 2023 21:59:13 -0400
Subject: [PATCH 159/206] fix use `FillArrays` in the test problems

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
---
 test/advi_locscale.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index db2338a3b..033736dfd 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -31,8 +31,8 @@ using Test
 
             b    = Bijectors.bijector(model)
             b⁻¹  = inverse(b)
-            μ₀   = zeros(realtype, n_dims)
-            L₀   = Diagonal(ones(realtype, n_dims))
+            μ₀   = Zeros(realtype, n_dims)
+            L₀   = Diagonal(Ones(realtype, n_dims))
 
             q₀_η = TuringDiagMvNormal(μ₀, diag(L₀))
             q₀_z = Bijectors.transformed(q₀_η, b⁻¹)

From f216b376f50fdc22e639406eee654d4a0b1922d6 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Wed, 6 Sep 2023 22:19:39 -0400
Subject: [PATCH 160/206] fix `optimize` docs

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
---
 src/optimize.jl | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 85cac75e9..2d8f57f6d 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -37,13 +37,13 @@ Optimize the variational objective `objective` targeting `prob` by estimating (s
 - `kwargs...`: Additional keywoard arguments. (See below.)
 
 # Keyword Arguments
-- `adbackend`: Automatic differentiation backend. (Type: `<: ADtypes.AbstractADType`.)
-- `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.)
-- `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.)
-- `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.)
+- `adbackend::ADtypes.AbstractADType`: Automatic differentiation backend. 
+- `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.)
+- `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.)
+- `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.)
 - `callback!`: Callback function called after every iteration. The signature is `cb(; stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`, `g` is the stochastic estimate of the gradient. (Default: `nothing`.)
 - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.)
-- `state`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) (Type: `<: NamedTuple`.)
+- `state::NamedTuple`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.)
 
 # Returns
 - `λ`: Variational parameters optimizing the variational objective.

From 9e0338db1601ec02a63ea04470cfbdfaab430d6c Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 7 Sep 2023 00:06:14 -0400
Subject: [PATCH 161/206] fix improve argument names and docs for `optimize`

---
 src/optimize.jl | 95 ++++++++++++++++++++++++++++++-------------------
 1 file changed, 58 insertions(+), 37 deletions(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 2d8f57f6d..44617f859 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -5,34 +5,34 @@ end
 
 """
     optimize(
-        prob,
-        objective    ::AbstractVariationalObjective,
+        problem,
+        objective   ::AbstractVariationalObjective,
         restructure,
-        λ₀           ::AbstractVector{<:Real},
-        n_max_iter   ::Int,
+        param_init  ::AbstractVector{<:Real},
+        max_iter    ::Int,
         objargs...;
         kwargs...
     )              
 
-Optimize the variational objective `objective` targeting `prob` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `λ₀` to the function `restructure`.
+Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `param_init` to the function `restructure`.
 
     optimize(
-        prob,
-        objective ::AbstractVariationalObjective,
-        q,
-        n_max_iter::Int,
+        problem,
+        objective             ::AbstractVariationalObjective,
+        variational_dist_init,
+        max_iter              ::Int,
         objargs...;
         kwargs...
     )              
 
-Optimize the variational objective `objective` targeting `prob` by estimating (stochastic) gradients, where the initial variational approximation `q₀` supports the `Optimisers.destructure` interface.
+Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients, where the initial variational approximation `variational_dist_init` supports the `Optimisers.destructure` interface.
 
 # Arguments
 - `objective`: Variational Objective.
-- `λ₀`: Initial value of the variational parameters.
+- `param_init`: Initial value of the variational parameters.
 - `restruct`: Function that reconstructs the variational approximation from the flattened parameters.
-- `q`: Initial variational approximation. The variational parameters must be extractable through `Optimisers.destructure`.
-- `n_max_iter`: Maximum number of iterations.
+- `variational_dist_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`.
+- `max_iter`: Maximum number of iterations.
 - `objargs...`: Arguments to be passed to `objective`.
 - `kwargs...`: Additional keywoard arguments. (See below.)
 
@@ -41,47 +41,64 @@ Optimize the variational objective `objective` targeting `prob` by estimating (s
 - `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.)
 - `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.)
 - `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.)
-- `callback!`: Callback function called after every iteration. The signature is `cb(; stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`, `g` is the stochastic estimate of the gradient. (Default: `nothing`.)
+- `callback!`: Callback function called after every iteration. See further information below. (Default: `nothing`.)
 - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.)
 - `state::NamedTuple`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.)
 
 # Returns
-- `λ`: Variational parameters optimizing the variational objective.
-- `logstats`: Statistics and logs gathered during optimization.
-- `states`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run.
+- `params`: Variational parameters optimizing the variational objective.
+- `stats`: Statistics gathered during optimization.
+- `state`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run.
+
+# Callback
+The callback function `callback!` has a signature of
+
+    cb(; stat, state, param, restructure, gradient)
+
+The arguments are as follows:
+- `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`.
+- `state`: Collection of the internal states used for optimization.
+- `param`: Variational parameters.
+- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation. 
+- `gradient`: The estimated (possibly stochastic) gradient.
+
+`cb` can return a `NamedTuple` containing some additional information computed within `cb`.
+This will be appended to the statistic of the current corresponding iteration.
+Otherwise, just return `nothing`.
+
 """
 function optimize(
-    prob,
+    problem,
     objective    ::AbstractVariationalObjective,
     restructure,
-    λ₀           ::AbstractVector{<:Real},
-    n_max_iter   ::Int,
+    params_init  ::AbstractVector{<:Real},
+    max_iter     ::Int,
     objargs...;
     adbackend    ::ADTypes.AbstractADType, 
     optimizer    ::Optimisers.AbstractRule = Optimisers.Adam(),
     rng          ::Random.AbstractRNG      = Random.default_rng(),
     show_progress::Bool                    = true,
-    state        ::NamedTuple              = NamedTuple(),
+    state_init   ::NamedTuple              = NamedTuple(),
     callback!                              = nothing,
     prog                                   = ProgressMeter.Progress(
-        n_max_iter;
+        max_iter;
         desc      = "Optimizing",
         barlen    = 31,
         showspeed = true,
         enabled   = show_progress
     )
 )
-    λ        = copy(λ₀)
-    opt_st   = haskey(state, :opt) ? state.opt : Optimisers.setup(optimizer, λ)
-    obj_st   = haskey(state, :obj) ? state.obj : init(rng, objective, λ, restructure)
+    λ        = copy(params_init)
+    opt_st   = haskey(state_init, :opt) ? state_init.opt : Optimisers.setup(optimizer, λ)
+    obj_st   = haskey(state_init, :obj) ? state_init.obj : init(rng, objective, λ, restructure)
     grad_buf = DiffResults.DiffResult(zero(eltype(λ)), similar(λ))
-    logstats = NamedTuple[]
+    stats    = NamedTuple[]
 
-    for t = 1:n_max_iter
+    for t = 1:max_iter
         stat = (iteration=t,)
 
         grad_buf, obj_st, stat′ = estimate_gradient!(
-            rng, prob, adbackend, objective, obj_st,
+            rng, problem, adbackend, objective, obj_st,
             λ, restructure, grad_buf, objargs...
         )
         stat = merge(stat, stat′)
@@ -90,29 +107,33 @@ function optimize(
         opt_st, λ = Optimisers.update!(opt_st, λ, g)
 
         if !isnothing(callback!)
-            stat′ = callback!(; stat, restructure, λ, g)
+            stat′ = callback!(
+                ; stat, restructure, params=λ, gradient=g,
+                state=(optimizer=opt_st, objective=obj_st)
+            )
             stat = !isnothing(stat′) ? merge(stat′, stat) : stat
         end
         
         @debug "Iteration $t" stat...
 
         pm_next!(prog, stat)
-        push!(logstats, stat)
+        push!(stats, stat)
     end
-    state    = (opt=opt_st, obj=obj_st)
-    logstats = map(identity, logstats)
-    λ, logstats, state
+    state  = (optimizer=opt_st, objective=obj_st)
+    stats  = map(identity, stats)
+    params = λ
+    params, stats, state
 end
 
-function optimize(prob,
+function optimize(problem,
                   objective ::AbstractVariationalObjective,
-                  q₀,
+                  variational_dist_init,
                   n_max_iter::Int,
                   objargs...;
                   kwargs...)
-    λ, restructure = Optimisers.destructure(q₀)
+    λ, restructure = Optimisers.destructure(variational_dist_init)
     λ, logstats, state = optimize(
-        prob, objective, restructure, λ, n_max_iter, objargs...; kwargs...
+        problem, objective, restructure, λ, n_max_iter, objargs...; kwargs...
     )
     restructure(λ), logstats, state
 end

From d6fcaf6ec86b9d2d06e9e26506d97cbf316661e2 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 7 Sep 2023 00:19:11 -0400
Subject: [PATCH 162/206] fix tests to match new interface of `optimize`

---
 test/optimize.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/test/optimize.jl b/test/optimize.jl
index 6f7986d00..21718f521 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -48,7 +48,7 @@ using Test
         rng  = StableRNG(seed)
         test_values = rand(rng, T)
 
-        callback!(; stat, restructure, λ, g) = begin
+        callback!(; stat, args...) = begin
             (test_value = test_values[stat.iteration],)
         end
 
@@ -81,7 +81,7 @@ using Test
             model, obj, q_first, T_last;
             optimizer,
             show_progress = false,
-            state,
+            state_init    = state,
             rng,
             adbackend
         )

From 5799f1e246a9c1edd86a307d0644ba09369e9ae3 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 7 Sep 2023 00:19:33 -0400
Subject: [PATCH 163/206] refactor move utility functions to new file

---
 src/AdvancedVI.jl | 13 ++++++++-----
 src/optimize.jl   |  8 ++------
 src/utils.jl      | 23 +++++++++++++++++++++++
 3 files changed, 33 insertions(+), 11 deletions(-)
 create mode 100644 src/utils.jl

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index d4d776f22..35b493c8d 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -67,10 +67,10 @@ This function needs to be implemented only if `obj` is stateful.
     notice.
 """
 init(
-    rng::Random.AbstractRNG,
-    obj::AbstractVariationalObjective,
-    λ::AbstractVector,
-    restructure
+    ::Random.AbstractRNG,
+    ::AbstractVariationalObjective,
+    ::AbstractVector,
+    ::Any
 ) = nothing
 
 """
@@ -108,6 +108,7 @@ export
     ClosedFormEntropy,
     StickingTheLandingEntropy,
     MonteCarloEntropy
+
 # entropy.jl must preceed advi.jl
 include("objectives/elbo/entropy.jl")
 include("objectives/elbo/advi.jl")
@@ -116,9 +117,11 @@ include("objectives/elbo/advi.jl")
 
 function optimize end
 
+export optimize
+
+include("utils.jl")
 include("optimize.jl")
 
-export optimize
 
 # optional dependencies 
 if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base
diff --git a/src/optimize.jl b/src/optimize.jl
index 44617f859..5f257c428 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -1,8 +1,4 @@
 
-function pm_next!(pm, stats::NamedTuple)
-    ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)])
-end
-
 """
     optimize(
         problem,
@@ -89,8 +85,8 @@ function optimize(
     )
 )
     λ        = copy(params_init)
-    opt_st   = haskey(state_init, :opt) ? state_init.opt : Optimisers.setup(optimizer, λ)
-    obj_st   = haskey(state_init, :obj) ? state_init.obj : init(rng, objective, λ, restructure)
+    opt_st   = maybe_init_optimizer(state_init, optimizer, λ)
+    obj_st   = maybe_init_objective(state_init, rng, objective, λ, restructure)
     grad_buf = DiffResults.DiffResult(zero(eltype(λ)), similar(λ))
     stats    = NamedTuple[]
 
diff --git a/src/utils.jl b/src/utils.jl
new file mode 100644
index 000000000..ce11d0be7
--- /dev/null
+++ b/src/utils.jl
@@ -0,0 +1,23 @@
+
+function pm_next!(pm, stats::NamedTuple)
+    ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)])
+end
+
+function maybe_init_optimizer(
+    state_init::Union{Nothing, NamedTuple},
+    optimizer ::Optimisers.AbstractRule,
+    λ         ::AbstractVector
+)
+    haskey(state_init, :optimizer) ? state_init.optimizer : Optimisers.setup(optimizer, λ)
+end
+
+function maybe_init_objective(
+    state_init::Union{Nothing, NamedTuple},
+    rng       ::Random.AbstractRNG,
+    objective ::AbstractVariationalObjective,
+    λ         ::AbstractVector,
+    restructure
+)
+    haskey(state_init, :objective) ? state_init.objective : init(rng, objective, λ, restructure)
+end
+

From 2229d61d229e130c557f2797ed8ac7affffec193 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Thu, 7 Sep 2023 00:21:21 -0400
Subject: [PATCH 164/206] fix docs for `optimize`

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
---
 src/optimize.jl | 1 -
 1 file changed, 1 deletion(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 5f257c428..97146319d 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -30,7 +30,6 @@ Optimize the variational objective `objective` targeting the problem `problem` b
 - `variational_dist_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`.
 - `max_iter`: Maximum number of iterations.
 - `objargs...`: Arguments to be passed to `objective`.
-- `kwargs...`: Additional keywoard arguments. (See below.)
 
 # Keyword Arguments
 - `adbackend::ADtypes.AbstractADType`: Automatic differentiation backend. 

From bc48e14b31eb8c35387f5dbadf21e4d519689c4c Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Thu, 7 Sep 2023 00:36:09 -0400
Subject: [PATCH 165/206] refactor advi internal objective

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
---
 src/objectives/elbo/advi.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 5ba0ef340..37b53a47c 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -110,7 +110,7 @@ function estimate_gradient!(
     restructure,
     out          ::DiffResults.MutableDiffResult
 )
-    f(λ′) = begin
+    function f(λ′)
         q_trans = restructure(λ′)
         q       = q_trans.dist
         ηs      = rand(rng, q, advi.n_samples)

From 9949a04bff27a1ba69523a8e8eb96bcbeadea09c Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 7 Sep 2023 00:38:25 -0400
Subject: [PATCH 166/206] refactor move `rng` to be an optional first argument

---
 src/optimize.jl       | 74 ++++++++++++++++++++++++++++++++++++++++---
 test/advi_locscale.jl |  9 ++----
 test/optimize.jl      | 32 +++++++++++++------
 test/runtests.jl      |  2 +-
 4 files changed, 95 insertions(+), 22 deletions(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 97146319d..47fd102f9 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -62,7 +62,30 @@ This will be appended to the statistic of the current corresponding iteration.
 Otherwise, just return `nothing`.
 
 """
+
 function optimize(
+    problem,
+    objective    ::AbstractVariationalObjective,
+    restructure,
+    params_init  ::AbstractVector{<:Real},
+    max_iter     ::Int,
+    objargs...;
+    kwargs...
+)
+    optimize(
+        Random.default_rng(),
+        problem,
+        objective,
+        restructure,
+        params_init,
+        max_iter,
+        objargs...;
+        kwargs...
+    )
+end
+
+function optimize(
+    rng          ::Random.AbstractRNG,
     problem,
     objective    ::AbstractVariationalObjective,
     restructure,
@@ -71,7 +94,6 @@ function optimize(
     objargs...;
     adbackend    ::ADTypes.AbstractADType, 
     optimizer    ::Optimisers.AbstractRule = Optimisers.Adam(),
-    rng          ::Random.AbstractRNG      = Random.default_rng(),
     show_progress::Bool                    = true,
     state_init   ::NamedTuple              = NamedTuple(),
     callback!                              = nothing,
@@ -120,15 +142,57 @@ function optimize(
     params, stats, state
 end
 
-function optimize(problem,
-                  objective ::AbstractVariationalObjective,
+function optimize(
+    problem,
+    objective    ::AbstractVariationalObjective,
+    restructure,
+    params_init  ::AbstractVector{<:Real},
+    max_iter     ::Int,
+    objargs...;
+    kwargs...
+)
+    optimize(
+        Random.default_rng(),
+        problem,
+        objective,
+        restructure,
+        params_init,
+        max_iter,
+        objargs...;
+        kwargs...
+    )
+end
+
+function optimize(rng                   ::Random.AbstractRNG,
+                  problem,
+                  objective             ::AbstractVariationalObjective,
                   variational_dist_init,
-                  n_max_iter::Int,
+                  n_max_iter            ::Int,
                   objargs...;
                   kwargs...)
     λ, restructure = Optimisers.destructure(variational_dist_init)
     λ, logstats, state = optimize(
-        problem, objective, restructure, λ, n_max_iter, objargs...; kwargs...
+        rng, problem, objective, restructure, λ, n_max_iter, objargs...; kwargs...
     )
     restructure(λ), logstats, state
 end
+
+
+function optimize(
+    problem,
+    objective              ::AbstractVariationalObjective,
+    variational_dist_init,
+    max_iter               ::Int,
+    objargs...;
+    kwargs...
+)
+    optimize(
+        Random.default_rng(),
+        problem,
+        objective,
+        variational_dist_init,
+        max_iter,
+        objargs...;
+        kwargs...
+    )
+end
diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl
index 033736dfd..dab8f5601 100644
--- a/test/advi_locscale.jl
+++ b/test/advi_locscale.jl
@@ -40,10 +40,9 @@ using Test
             @testset "convergence" begin
                 Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
                 q, stats, _ = optimize(
-                    model, objective, q₀_z, T;
+                    rng, model, objective, q₀_z, T;
                     optimizer     = Optimisers.Adam(realtype(η)),
                     show_progress = PROGRESS,
-                    rng           = rng,
                     adbackend     = adbackend,
                 )
 
@@ -59,10 +58,9 @@ using Test
             @testset "determinism" begin
                 rng = StableRNG(seed)
                 q, stats, _ = optimize(
-                    model, objective, q₀_z, T;
+                    rng, model, objective, q₀_z, T;
                     optimizer     = Optimisers.Adam(realtype(η)),
                     show_progress = PROGRESS,
-                    rng           = rng,
                     adbackend     = adbackend,
                 )
                 μ  = mean(q.dist)
@@ -70,10 +68,9 @@ using Test
 
                 rng_repl = StableRNG(seed)
                 q, stats, _ = optimize(
-                    model, objective, q₀_z, T;
+                    rng_repl, model, objective, q₀_z, T;
                     optimizer     = Optimisers.Adam(realtype(η)),
                     show_progress = PROGRESS,
-                    rng           = rng_repl,
                     adbackend     = adbackend,
                 )
                 μ_repl = mean(q.dist)
diff --git a/test/optimize.jl b/test/optimize.jl
index 21718f521..3de3cdc36 100644
--- a/test/optimize.jl
+++ b/test/optimize.jl
@@ -21,23 +21,38 @@ using Test
 
     rng  = StableRNG(seed)
     q_ref, stats_ref, _ = optimize(
-        model, obj, q₀_z, T;
+        rng, model, obj, q₀_z, T;
         optimizer,
         show_progress = false,
-        rng,
         adbackend,
     )
     λ_ref, _ = Optimisers.destructure(q_ref)
 
+    @testset "default_rng" begin
+        optimize(
+            model, obj, q₀_z, T;
+            optimizer,
+            show_progress = false,
+            adbackend,
+        )
+
+        λ₀, re  = Optimisers.destructure(q₀_z)
+        optimize(
+            model, obj, re, λ₀, T;
+            optimizer,
+            show_progress = false,
+            adbackend,
+        )
+    end
+
     @testset "restructure" begin
         λ₀, re  = Optimisers.destructure(q₀_z)
 
         rng  = StableRNG(seed)
         λ, stats, _ = optimize(
-            model, obj, re, λ₀, T;
+            rng, model, obj, re, λ₀, T;
             optimizer,
             show_progress = false,
-            rng,
             adbackend,
         )
         @test λ     == λ_ref
@@ -54,9 +69,8 @@ using Test
 
         rng  = StableRNG(seed)
         _, stats, _ = optimize(
-            model, obj, q₀_z, T;
+            rng, model, obj, q₀_z, T;
             show_progress = false,
-            rng,
             adbackend,
             callback!
         )
@@ -70,19 +84,17 @@ using Test
         T_last  = T - T_first
 
         q_first, _, state = optimize(
-            model, obj, q₀_z, T_first;
+            rng, model, obj, q₀_z, T_first;
             optimizer,
             show_progress = false,
-            rng,
             adbackend
         )
 
         q, stats, _ = optimize(
-            model, obj, q_first, T_last;
+            rng, model, obj, q_first, T_last;
             optimizer,
             show_progress = false,
             state_init    = state,
-            rng,
             adbackend
         )
         @test q == q_ref
diff --git a/test/runtests.jl b/test/runtests.jl
index a4220f986..5f8fab41f 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -36,5 +36,5 @@ include("models/normallognormal.jl")
 
 # Tests
 include("ad.jl")
-include("advi_locscale.jl")
 include("optimize.jl")
+include("advi_locscale.jl")

From 92cf3547da9c58a6da3d97065dd6993669845f61 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 7 Sep 2023 00:49:01 -0400
Subject: [PATCH 167/206] fix docs for optimize

---
 src/optimize.jl | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 47fd102f9..5beef7126 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -9,9 +9,6 @@
         objargs...;
         kwargs...
     )              
-
-Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `param_init` to the function `restructure`.
-
     optimize(
         problem,
         objective             ::AbstractVariationalObjective,
@@ -21,7 +18,7 @@ Optimize the variational objective `objective` targeting the problem `problem` b
         kwargs...
     )              
 
-Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients, where the initial variational approximation `variational_dist_init` supports the `Optimisers.destructure` interface.
+Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`.
 
 # Arguments
 - `objective`: Variational Objective.

From d75fd3cf96c3e708757708c94674654da77a81d1 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 7 Sep 2023 22:23:06 -0400
Subject: [PATCH 168/206] add compat bounds to test dependencies

---
 test/Project.toml | 21 ++++++++++++++++++++-
 test/runtests.jl  |  1 -
 2 files changed, 20 insertions(+), 2 deletions(-)

diff --git a/test/Project.toml b/test/Project.toml
index 0e81ec08b..56aa3dfff 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -1,7 +1,6 @@
 [deps]
 ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
 Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
-Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
 Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
@@ -21,3 +20,23 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
 Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
+
+[compat]
+ADTypes = "0.2.1"
+Bijectors = "0.13.6"
+Distributions = "0.25.100"
+DistributionsAD = "0.6.45"
+Enzyme = "0.11.7"
+FillArrays = "1.6.1"
+ForwardDiff = "0.10.36"
+Functors = "0.4.5"
+LogDensityProblems = "2.1.1"
+Optimisers = "0.3.0"
+PDMats = "0.11.7"
+Pkg = "1.9.2"
+ReverseDiff = "1.15.1"
+SimpleUnPack = "1.1.0"
+StableRNGs = "1.0.0"
+Tracker = "0.2.20"
+Zygote = "0.6.63"
+julia = "1.6"
diff --git a/test/runtests.jl b/test/runtests.jl
index 5f8fab41f..9b48bc371 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -2,7 +2,6 @@
 using Test
 using Test: @testset, @test
 
-using Comonicon
 using Random, StableRNGs
 using Statistics
 using Distributions

From faa91ce33dbb48124dafb76bd46253fc72fe00b2 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 7 Sep 2023 22:25:37 -0400
Subject: [PATCH 169/206] update compat bound for `Optimisers`

---
 Project.toml      | 2 +-
 test/Project.toml | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/Project.toml b/Project.toml
index 075ae92f3..658c168c4 100644
--- a/Project.toml
+++ b/Project.toml
@@ -46,7 +46,7 @@ FillArrays = "1.3"
 ForwardDiff = "0.10.36"
 Functors = "0.4"
 LogDensityProblems = "2"
-Optimisers = "0.2.16"
+Optimisers = "0.2.16, 0.3"
 ProgressMeter = "1.6"
 Requires = "1.0"
 ReverseDiff = "1.15.1"
diff --git a/test/Project.toml b/test/Project.toml
index 56aa3dfff..e61061db8 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -31,7 +31,7 @@ FillArrays = "1.6.1"
 ForwardDiff = "0.10.36"
 Functors = "0.4.5"
 LogDensityProblems = "2.1.1"
-Optimisers = "0.3.0"
+Optimisers = "0.2.16, 0.3"
 PDMats = "0.11.7"
 Pkg = "1.9.2"
 ReverseDiff = "1.15.1"

From 6dc0bb745dea89cc40228e6a7d13a315dbbf2e33 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 7 Sep 2023 22:32:27 -0400
Subject: [PATCH 170/206] fix test compat

---
 test/Project.toml | 1 -
 1 file changed, 1 deletion(-)

diff --git a/test/Project.toml b/test/Project.toml
index e61061db8..89c4f77b2 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -33,7 +33,6 @@ Functors = "0.4.5"
 LogDensityProblems = "2.1.1"
 Optimisers = "0.2.16, 0.3"
 PDMats = "0.11.7"
-Pkg = "1.9.2"
 ReverseDiff = "1.15.1"
 SimpleUnPack = "1.1.0"
 StableRNGs = "1.0.0"

From e941ad4b922e40c170fd6d6ed25c026cfc093cc7 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Mon, 23 Oct 2023 00:23:08 -0400
Subject: [PATCH 171/206] fix remove `!` in callback

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
---
 src/optimize.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 5beef7126..72db40e7b 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -45,7 +45,7 @@ Optimize the variational objective `objective` targeting the problem `problem` b
 # Callback
 The callback function `callback!` has a signature of
 
-    cb(; stat, state, param, restructure, gradient)
+    callback!(; stat, state, param, restructure, gradient)
 
 The arguments are as follows:
 - `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`.

From 15e05534f102430e09b781e7fad916bd84254bd8 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 00:38:44 -0400
Subject: [PATCH 172/206] fix rng argument position in `advi`

---
 src/objectives/elbo/advi.jl | 21 +++++++++++----------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 37b53a47c..a14fced51 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -72,34 +72,35 @@ end
 
 """
     (advi::ADVI)(
-        prob, q;
-        rng::AbstractRNG = Random.default_rng(),
-        n_samples::Int = advi.n_samples
+        [rng], prob, q; n_samples::Int = advi.n_samples
     )
 
 Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation using `n_samples` number of Monte Carlo samples.
 """
 function (advi::ADVI)(
+    rng      ::Random.AbstractRNG,
     prob,
-    q        ::ContinuousMultivariateDistribution;
-    rng      ::Random.AbstractRNG = Random.default_rng(),
-    n_samples::Int                = advi.n_samples
+    q        ::ContinuousDistribution;
+    n_samples::Int = advi.n_samples
 )
     zs = rand(rng, q, n_samples)
-    advi(q, zs)
+    advi(prob, q, zs)
 end
 
 function (advi::ADVI)(
+    rng      ::Random.AbstractRNG,
     prob,
     q_trans  ::Bijectors.TransformedDistribution;
-    rng      ::Random.AbstractRNG = Random.default_rng(),
-    n_samples::Int                = advi.n_samples
+    n_samples::Int  = advi.n_samples
 )
     q  = q_trans.dist
     ηs = rand(rng, q, n_samples)
-    advi(q_trans, ηs)
+    advi(prob, q_trans, ηs)
 end
 
+(advi::ADVI)(prob, q::Distribution; n_samples::Int = advi.n_samples) =
+    advi(Random.default_rng(), prob, q; n_samples)
+
 function estimate_gradient!(
     rng          ::Random.AbstractRNG,
     prob,

From a643cf28cedbb719c5f0276beecf9676723b45a5 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 00:39:30 -0400
Subject: [PATCH 173/206] fix callback signature in `optimize`

---
 src/optimize.jl | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 72db40e7b..17f31689f 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -43,9 +43,9 @@ Optimize the variational objective `objective` targeting the problem `problem` b
 - `state`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run.
 
 # Callback
-The callback function `callback!` has a signature of
+The callback function `callback` has a signature of
 
-    callback!(; stat, state, param, restructure, gradient)
+    callback(; stat, state, param, restructure, gradient)
 
 The arguments are as follows:
 - `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`.
@@ -93,7 +93,7 @@ function optimize(
     optimizer    ::Optimisers.AbstractRule = Optimisers.Adam(),
     show_progress::Bool                    = true,
     state_init   ::NamedTuple              = NamedTuple(),
-    callback!                              = nothing,
+    callback                               = nothing,
     prog                                   = ProgressMeter.Progress(
         max_iter;
         desc      = "Optimizing",
@@ -120,8 +120,8 @@ function optimize(
         g         = DiffResults.gradient(grad_buf)
         opt_st, λ = Optimisers.update!(opt_st, λ, g)
 
-        if !isnothing(callback!)
-            stat′ = callback!(
+        if !isnothing(callback)
+            stat′ = callback(
                 ; stat, restructure, params=λ, gradient=g,
                 state=(optimizer=opt_st, objective=obj_st)
             )

From ffa69a33adb08678c7618d8e25c0402d6b10a649 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 00:39:55 -0400
Subject: [PATCH 174/206] refactor reorganize test files and naming

---
 .../advi_distributionsad.jl}                  |  0
 test/{ => interface}/ad.jl                    |  0
 test/interface/advi.jl                        | 55 +++++++++++++++++++
 test/{ => interface}/optimize.jl              |  4 +-
 test/models/normal.jl                         | 43 +++++++++++++++
 test/runtests.jl                              |  9 ++-
 6 files changed, 106 insertions(+), 5 deletions(-)
 rename test/{advi_locscale.jl => inference/advi_distributionsad.jl} (100%)
 rename test/{ => interface}/ad.jl (100%)
 create mode 100644 test/interface/advi.jl
 rename test/{ => interface}/optimize.jl (97%)
 create mode 100644 test/models/normal.jl

diff --git a/test/advi_locscale.jl b/test/inference/advi_distributionsad.jl
similarity index 100%
rename from test/advi_locscale.jl
rename to test/inference/advi_distributionsad.jl
diff --git a/test/ad.jl b/test/interface/ad.jl
similarity index 100%
rename from test/ad.jl
rename to test/interface/ad.jl
diff --git a/test/interface/advi.jl b/test/interface/advi.jl
new file mode 100644
index 000000000..904305ace
--- /dev/null
+++ b/test/interface/advi.jl
@@ -0,0 +1,55 @@
+
+using Test
+
+@testset "advi" begin
+    seed = (0x38bef07cf9cc549d)
+    rng  = StableRNG(seed)
+
+    @testset "with bijector"  begin
+        modelstats = normallognormal_meanfield(Float64; rng)
+
+        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+        b⁻¹  = Bijectors.bijector(model) |> inverse
+        q₀_η = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
+        q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
+        obj  = ADVI(10)
+
+        rng      = StableRNG(seed)
+        elbo_ref = obj(rng, model, q₀_z; n_samples=1024)
+
+        @testset "determinism" begin
+            rng  = StableRNG(seed)
+            elbo = obj(rng, model, q₀_z; n_samples=1024)
+            @test elbo == elbo_ref
+        end
+
+        @testset "default_rng" begin
+            elbo = obj(model, q₀_z; n_samples=1024)
+            @test elbo ≈ elbo_ref rtol=0.1
+        end
+    end
+
+    @testset "without bijector"  begin
+        modelstats = normal_meanfield(Float64; rng)
+
+        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+        q₀_z = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
+
+        obj      = ADVI(10)
+        rng      = StableRNG(seed)
+        elbo_ref = obj(rng, model, q₀_z; n_samples=1024)
+
+        @testset "determinism" begin
+            rng  = StableRNG(seed)
+            elbo = obj(rng, model, q₀_z; n_samples=1024)
+            @test elbo == elbo_ref
+        end
+
+        @testset "default_rng" begin
+            elbo = obj(model, q₀_z; n_samples=1024)
+            @test elbo ≈ elbo_ref rtol=0.1
+        end
+    end
+end
diff --git a/test/optimize.jl b/test/interface/optimize.jl
similarity index 97%
rename from test/optimize.jl
rename to test/interface/optimize.jl
index 3de3cdc36..bbbb49987 100644
--- a/test/optimize.jl
+++ b/test/interface/optimize.jl
@@ -63,7 +63,7 @@ using Test
         rng  = StableRNG(seed)
         test_values = rand(rng, T)
 
-        callback!(; stat, args...) = begin
+        callback(; stat, args...) = begin
             (test_value = test_values[stat.iteration],)
         end
 
@@ -72,7 +72,7 @@ using Test
             rng, model, obj, q₀_z, T;
             show_progress = false,
             adbackend,
-            callback!
+            callback
         )
         @test [stat.test_value for stat ∈ stats] == test_values
     end
diff --git a/test/models/normal.jl b/test/models/normal.jl
new file mode 100644
index 000000000..3efa75243
--- /dev/null
+++ b/test/models/normal.jl
@@ -0,0 +1,43 @@
+
+struct TestNormal{M,S}
+    μ::M
+    Σ::S
+end
+
+function LogDensityProblems.logdensity(model::TestNormal, θ)
+    @unpack μ, Σ = model
+    logpdf(MvNormal(μ, Σ), θ)
+end
+
+function LogDensityProblems.dimension(model::TestNormal)
+    length(model.μ)
+end
+
+function LogDensityProblems.capabilities(::Type{<:TestNormal})
+    LogDensityProblems.LogDensityOrder{0}()
+end
+
+function normal_fullrank(realtype; rng = default_rng())
+    n_dims = 5
+
+    μ = randn(rng, realtype, n_dims)
+    L = tril(I + ones(realtype, n_dims, n_dims))/2
+    Σ = L*L' |> Hermitian
+
+    model = TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0)))
+
+    TestModel(model, μ, L, n_dims, false)
+end
+
+function normal_meanfield(realtype; rng = default_rng())
+    n_dims = 5
+
+    μ = randn(rng, realtype, n_dims)
+    σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
+
+    model = TestNormal(μ, Diagonal(σ.^2))
+
+    L = σ |> Diagonal
+
+    TestModel(model, μ, L, n_dims, true)
+end
diff --git a/test/runtests.jl b/test/runtests.jl
index 9b48bc371..6fda0be81 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -32,8 +32,11 @@ struct TestModel{M,L,S}
 end
 
 include("models/normallognormal.jl")
+include("models/normal.jl")
 
 # Tests
-include("ad.jl")
-include("optimize.jl")
-include("advi_locscale.jl")
+include("interface/ad.jl")
+include("interface/optimize.jl")
+include("interface/advi.jl")
+
+include("inference/advi_distributionsad.jl")

From d5026e14aa441c32b12fa7050c583ec0011d0063 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Mon, 23 Oct 2023 00:41:19 -0400
Subject: [PATCH 175/206] fix simplify description for `optimize`

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
---
 src/optimize.jl | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 17f31689f..7e62e2558 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -18,7 +18,9 @@
         kwargs...
     )              
 
-Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`.
+Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients.
+
+The variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`.
 
 # Arguments
 - `objective`: Variational Objective.

From 764406b2a33687c83efb6969f9aeb16be735db81 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 00:51:02 -0400
Subject: [PATCH 176/206] fix remove redundant `Nothing` type signature for
 `maybe_init`

---
 src/utils.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/utils.jl b/src/utils.jl
index ce11d0be7..8dd7c37bf 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -4,7 +4,7 @@ function pm_next!(pm, stats::NamedTuple)
 end
 
 function maybe_init_optimizer(
-    state_init::Union{Nothing, NamedTuple},
+    state_init::NamedTuple,
     optimizer ::Optimisers.AbstractRule,
     λ         ::AbstractVector
 )
@@ -12,7 +12,7 @@ function maybe_init_optimizer(
 end
 
 function maybe_init_objective(
-    state_init::Union{Nothing, NamedTuple},
+    state_init::NamedTuple,
     rng       ::Random.AbstractRNG,
     objective ::AbstractVariationalObjective,
     λ         ::AbstractVector,

From 65006cb23cf258d0dd7acfa41d025a6c9aead4f4 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 00:51:41 -0400
Subject: [PATCH 177/206] fix remove "internal use" warning in documentation

---
 src/AdvancedVI.jl | 8 --------
 1 file changed, 8 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 35b493c8d..203f5ae1b 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -61,10 +61,6 @@ abstract type AbstractVariationalObjective end
 
 Initialize a state of the variational objective `obj` given the initial variational parameters `λ`.
 This function needs to be implemented only if `obj` is stateful.
-
-!!! warning
-    This is an internal function. Thus, the signature is subject to change without
-    notice.
 """
 init(
     ::Random.AbstractRNG,
@@ -93,10 +89,6 @@ If the objective is stateful, `obj_state` is its previous state, otherwise, it i
 - `out`: The `MutableDiffResult` containing the objective value and gradient estimates.
 - `obj_state`: The updated state of the objective estimator.
 - `stat`: Statistics and logs generated during estimation. (Type: `<: NamedTuple`)
-
-!!! warning
-    This is an internal function. Thus, the signature is subject to change without
-    notice.
 """
 function estimate_gradient! end
 

From b23a610f1fdf3865827522d7c7b44a2badace2cd Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 00:51:59 -0400
Subject: [PATCH 178/206] refactor change `estimate_gradient!` signature to be
 type stable

---
 src/objectives/elbo/advi.jl | 12 ++++++------
 src/optimize.jl             |  4 ++--
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index a14fced51..7640f4867 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -102,14 +102,14 @@ end
     advi(Random.default_rng(), prob, q; n_samples)
 
 function estimate_gradient!(
-    rng          ::Random.AbstractRNG,
+    rng       ::Random.AbstractRNG,
+    advi      ::ADVI,
+    adbackend ::ADTypes.AbstractADType,
+    out       ::DiffResults.MutableDiffResult,
     prob,
-    adbackend    ::ADTypes.AbstractADType,
-    advi         ::ADVI,
-    est_state,
-    λ            ::Vector{<:Real},
+    λ,
     restructure,
-    out          ::DiffResults.MutableDiffResult
+    est_state,
 )
     function f(λ′)
         q_trans = restructure(λ′)
diff --git a/src/optimize.jl b/src/optimize.jl
index 17f31689f..cfe6179e6 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -112,8 +112,8 @@ function optimize(
         stat = (iteration=t,)
 
         grad_buf, obj_st, stat′ = estimate_gradient!(
-            rng, problem, adbackend, objective, obj_st,
-            λ, restructure, grad_buf, objargs...
+            rng, objective, adbackend, grad_buf, problem,
+            λ, restructure,  obj_st, objargs...
         )
         stat = merge(stat, stat′)
 

From 9c242a53dbd1f0b6f32be200695f94e6dd8b4201 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 00:57:10 -0400
Subject: [PATCH 179/206] add signature for computing `advi` over a fixed set
 of samples

---
 src/objectives/elbo/advi.jl | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 7640f4867..90e84b564 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -44,6 +44,13 @@ ADVI(n_samples::Int; entropy::AbstractEntropyEstimator = ClosedFormEntropy()) =
 Base.show(io::IO, advi::ADVI) =
     print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))")
 
+"""
+    (advi::ADVI)(
+        [rng], prob, q, zs::AbstractMatrix
+    )
+
+Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation over the Monte Carlo samples `zs` (each column is a sample).
+"""
 function (advi::ADVI)(
     prob,
     q ::Distributions.ContinuousMultivariateDistribution,

From e0148637882b370717e128ce419508b1c1d88dd3 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 01:10:44 -0400
Subject: [PATCH 180/206] fix change test tolerance

---
 test/interface/advi.jl | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/test/interface/advi.jl b/test/interface/advi.jl
index 904305ace..409a0ca64 100644
--- a/test/interface/advi.jl
+++ b/test/interface/advi.jl
@@ -16,16 +16,16 @@ using Test
         obj  = ADVI(10)
 
         rng      = StableRNG(seed)
-        elbo_ref = obj(rng, model, q₀_z; n_samples=1024)
+        elbo_ref = obj(rng, model, q₀_z; n_samples=10^4)
 
         @testset "determinism" begin
             rng  = StableRNG(seed)
-            elbo = obj(rng, model, q₀_z; n_samples=1024)
+            elbo = obj(rng, model, q₀_z; n_samples=10^4)
             @test elbo == elbo_ref
         end
 
         @testset "default_rng" begin
-            elbo = obj(model, q₀_z; n_samples=1024)
+            elbo = obj(model, q₀_z; n_samples=10^4)
             @test elbo ≈ elbo_ref rtol=0.1
         end
     end
@@ -39,16 +39,16 @@ using Test
 
         obj      = ADVI(10)
         rng      = StableRNG(seed)
-        elbo_ref = obj(rng, model, q₀_z; n_samples=1024)
+        elbo_ref = obj(rng, model, q₀_z; n_samples=10^4)
 
         @testset "determinism" begin
             rng  = StableRNG(seed)
-            elbo = obj(rng, model, q₀_z; n_samples=1024)
+            elbo = obj(rng, model, q₀_z; n_samples=10^4)
             @test elbo == elbo_ref
         end
 
         @testset "default_rng" begin
-            elbo = obj(model, q₀_z; n_samples=1024)
+            elbo = obj(model, q₀_z; n_samples=10^4)
             @test elbo ≈ elbo_ref rtol=0.1
         end
     end

From 71184fa4af2b46e02649fd939052fa2701c7c862 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 01:20:43 -0400
Subject: [PATCH 181/206] fix update documentation for `estimate_gradient!`

---
 src/AdvancedVI.jl | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 203f5ae1b..dd7f10ae4 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -72,13 +72,13 @@ init(
 """
     estimate_gradient!(
         rng         ::Random.AbstractRNG,
-        prob,
-        adbackend   ::ADTypes.AbstractADType,
         obj         ::AbstractVariationalObjective,
-        obj_state,
-        λ           ::AbstractVector,
-        restructure,
+        adbackend   ::ADTypes.AbstractADType,
         out         ::DiffResults.MutableDiffResult
+        prob,
+        λ,
+        restructure,
+        obj_state,
     )
 
 Estimate (possibly stochastic) gradients of the objective `obj` targeting `prob` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`.

From 9f6d6634e9a91c0129eb8642af27cd0c544b74c9 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 01:20:55 -0400
Subject: [PATCH 182/206] refactor remove type constraint for variational
 parameters

---
 src/optimize.jl | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index df09355d0..0868044f6 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -4,7 +4,7 @@
         problem,
         objective   ::AbstractVariationalObjective,
         restructure,
-        param_init  ::AbstractVector{<:Real},
+        param_init,
         max_iter    ::Int,
         objargs...;
         kwargs...
@@ -66,7 +66,7 @@ function optimize(
     problem,
     objective    ::AbstractVariationalObjective,
     restructure,
-    params_init  ::AbstractVector{<:Real},
+    params_init,
     max_iter     ::Int,
     objargs...;
     kwargs...
@@ -88,7 +88,7 @@ function optimize(
     problem,
     objective    ::AbstractVariationalObjective,
     restructure,
-    params_init  ::AbstractVector{<:Real},
+    params_init,
     max_iter     ::Int,
     objargs...;
     adbackend    ::ADTypes.AbstractADType, 
@@ -145,7 +145,7 @@ function optimize(
     problem,
     objective    ::AbstractVariationalObjective,
     restructure,
-    params_init  ::AbstractVector{<:Real},
+    params_init,
     max_iter     ::Int,
     objargs...;
     kwargs...

From a673520f7510660e632ad95ffbe1a6f574c10bef Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 01:22:45 -0400
Subject: [PATCH 183/206] fix remove dead code

---
 src/optimize.jl | 21 ---------------------
 1 file changed, 21 deletions(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 0868044f6..208ffabed 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -62,27 +62,6 @@ Otherwise, just return `nothing`.
 
 """
 
-function optimize(
-    problem,
-    objective    ::AbstractVariationalObjective,
-    restructure,
-    params_init,
-    max_iter     ::Int,
-    objargs...;
-    kwargs...
-)
-    optimize(
-        Random.default_rng(),
-        problem,
-        objective,
-        restructure,
-        params_init,
-        max_iter,
-        objargs...;
-        kwargs...
-    )
-end
-
 function optimize(
     rng          ::Random.AbstractRNG,
     problem,

From a3f98867545c374c51b2d2956ded59384e05e57a Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 01:27:48 -0400
Subject: [PATCH 184/206] add compat entry for stdlib

---
 Project.toml | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/Project.toml b/Project.toml
index 658c168c4..700415614 100644
--- a/Project.toml
+++ b/Project.toml
@@ -45,9 +45,11 @@ Enzyme = "0.11.7"
 FillArrays = "1.3"
 ForwardDiff = "0.10.36"
 Functors = "0.4"
+LinearAlgebra = "1"
 LogDensityProblems = "2"
 Optimisers = "0.2.16, 0.3"
 ProgressMeter = "1.6"
+Random = "1"
 Requires = "1.0"
 ReverseDiff = "1.15.1"
 SimpleUnPack = "1.1.0"

From 7a92708950cf5a69957863b7925dbd3a49b294b3 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 01:31:31 -0400
Subject: [PATCH 185/206] add compat entry for stdlib in `test/`

---
 test/Project.toml | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/test/Project.toml b/test/Project.toml
index 89c4f77b2..eb4d9d594 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -30,12 +30,16 @@ Enzyme = "0.11.7"
 FillArrays = "1.6.1"
 ForwardDiff = "0.10.36"
 Functors = "0.4.5"
+LinearAlgebra = "1"
 LogDensityProblems = "2.1.1"
 Optimisers = "0.2.16, 0.3"
 PDMats = "0.11.7"
+Random = "1"
 ReverseDiff = "1.15.1"
 SimpleUnPack = "1.1.0"
 StableRNGs = "1.0.0"
+Statistics = "1.9"
+Test = "1"
 Tracker = "0.2.20"
 Zygote = "0.6.63"
 julia = "1.6"

From 5dd434d9dcffb9c7b96b2d5bfa8f564c42ba4e42 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 23:26:59 -0400
Subject: [PATCH 186/206] fix rng argument position in tests

---
 test/inference/advi_distributionsad.jl | 2 +-
 test/interface/advi.jl                 | 4 ++--
 test/interface/optimize.jl             | 2 +-
 test/models/normal.jl                  | 4 ++--
 test/models/normallognormal.jl         | 6 +++---
 5 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/test/inference/advi_distributionsad.jl b/test/inference/advi_distributionsad.jl
index dab8f5601..2d60a16a4 100644
--- a/test/inference/advi_distributionsad.jl
+++ b/test/inference/advi_distributionsad.jl
@@ -24,7 +24,7 @@ using Test
             seed = (0x38bef07cf9cc549d)
             rng  = StableRNG(seed)
 
-            modelstats = modelconstr(realtype; rng)
+            modelstats = modelconstr(rng, realtype)
             @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
 
             T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
diff --git a/test/interface/advi.jl b/test/interface/advi.jl
index 409a0ca64..16db09ca5 100644
--- a/test/interface/advi.jl
+++ b/test/interface/advi.jl
@@ -6,7 +6,7 @@ using Test
     rng  = StableRNG(seed)
 
     @testset "with bijector"  begin
-        modelstats = normallognormal_meanfield(Float64; rng)
+        modelstats = normallognormal_meanfield(rng, Float64)
 
         @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
 
@@ -31,7 +31,7 @@ using Test
     end
 
     @testset "without bijector"  begin
-        modelstats = normal_meanfield(Float64; rng)
+        modelstats = normal_meanfield(rng, Float64)
 
         @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
 
diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl
index bbbb49987..1384a4b4e 100644
--- a/test/interface/optimize.jl
+++ b/test/interface/optimize.jl
@@ -6,7 +6,7 @@ using Test
     rng  = StableRNG(seed)
 
     T = 1000
-    modelstats = normallognormal_meanfield(Float64; rng)
+    modelstats = normallognormal_meanfield(rng, Float64)
 
     @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
 
diff --git a/test/models/normal.jl b/test/models/normal.jl
index 3efa75243..3f305e1a0 100644
--- a/test/models/normal.jl
+++ b/test/models/normal.jl
@@ -17,7 +17,7 @@ function LogDensityProblems.capabilities(::Type{<:TestNormal})
     LogDensityProblems.LogDensityOrder{0}()
 end
 
-function normal_fullrank(realtype; rng = default_rng())
+function normal_fullrank(rng::Random.AbstractRNG, realtype::Type)
     n_dims = 5
 
     μ = randn(rng, realtype, n_dims)
@@ -29,7 +29,7 @@ function normal_fullrank(realtype; rng = default_rng())
     TestModel(model, μ, L, n_dims, false)
 end
 
-function normal_meanfield(realtype; rng = default_rng())
+function normal_meanfield(rng::Random.AbstractRNG, realtype::Type)
     n_dims = 5
 
     μ = randn(rng, realtype, n_dims)
diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl
index e2b9e8165..c2cb2b0e9 100644
--- a/test/models/normallognormal.jl
+++ b/test/models/normallognormal.jl
@@ -26,7 +26,7 @@ function Bijectors.bijector(model::NormalLogNormal)
         [1:1, 2:1+length(μ_y)])
 end
 
-function normallognormal_fullrank(realtype; rng = default_rng())
+function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type)
     n_dims = 5
 
     μ_x = randn(rng, realtype)
@@ -43,12 +43,12 @@ function normallognormal_fullrank(realtype; rng = default_rng())
     Σ = Σ |> Hermitian
 
     μ = vcat(μ_x, μ_y)
-    L = cholesky(Σ).L |> LowerTriangular
+    L = cholesky(Σ).L
 
     TestModel(model, μ, L, n_dims+1, false)
 end
 
-function normallognormal_meanfield(realtype; rng = default_rng())
+function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type)
     n_dims = 5
 
     μ_x  = randn(rng, realtype)

From a764d9baf6826e519faeb5762f2b7195437d16ec Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 23:30:59 -0400
Subject: [PATCH 187/206] refactor change name of inference test

---
 test/inference/advi_distributionsad.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/test/inference/advi_distributionsad.jl b/test/inference/advi_distributionsad.jl
index 2d60a16a4..01c7a96e6 100644
--- a/test/inference/advi_distributionsad.jl
+++ b/test/inference/advi_distributionsad.jl
@@ -3,8 +3,8 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
 
 using Test
 
-@testset "advi" begin
-    @testset "locscale" begin
+@testset "inference_advi" begin
+    @testset "distributionsad" begin
         @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
             realtype ∈ [Float64], # Currently only tested against Float64
             (modelname, modelconstr) ∈ Dict(

From 8af8a5f6ede44600156398546f58097b10b8bdd5 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Mon, 23 Oct 2023 23:48:20 -0400
Subject: [PATCH 188/206] fix documentation for `optimize`

---
 src/optimize.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 208ffabed..1a39cb419 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -35,7 +35,7 @@ The variational approximation can be constructed by passing the variational para
 - `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.)
 - `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.)
 - `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.)
-- `callback!`: Callback function called after every iteration. See further information below. (Default: `nothing`.)
+- `callback`: Callback function called after every iteration. See further information below. (Default: `nothing`.)
 - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.)
 - `state::NamedTuple`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.)
 

From 5f1fb52b5be0c46ea295087f9f3644396f239c66 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Tue, 24 Oct 2023 00:06:50 -0400
Subject: [PATCH 189/206] refactor rewrite the documentation for the global
 interfaces

---
 src/AdvancedVI.jl | 73 ++++++++++++++++++++++++-----------------------
 1 file changed, 37 insertions(+), 36 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index dd7f10ae4..54c2b1eb0 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -27,17 +27,15 @@ using StatsBase
 
 # derivatives
 """
-    value_and_gradient!(
-        ad::ADTypes.AbstractADType,
-        f,
-        θ::AbstractVector{<:Real},
-        out::DiffResults.MutableDiffResult
-    )
-
-Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad`.
-The result is stored in `out`. 
-The function `f` must return a scalar value.
-The gradient is stored in `out` as a vector of the same length as `θ`.
+    value_and_gradient!(ad, f, θ, out)
+
+Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad` and store the result in `out`.
+
+# Arguments
+- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. 
+- `f`: Function subject to differentiation.
+- `θ`: The point to evaluate the gradient.
+- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value.
 """
 function value_and_gradient! end
 
@@ -45,22 +43,26 @@ function value_and_gradient! end
 """
     AbstractVariationalObjective
 
-An VI algorithm supported by `AdvancedVI` should implement a subtype of  `AbstractVariationalObjective`.
-Furthermore, it should implement the functions `estimate_gradient`.
+Abstract type for the VI algorithms supported by `AdvancedVI`.
+
+# Implementations
+To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective`.
+Also, it should provide gradients by implementing the function `estimate_gradient!`.
 If the estimator is stateful, it can implement `init` to initialize the state.
 """
 abstract type AbstractVariationalObjective end
 
 """
-    init(
-        rng::Random.AbstractRNG,
-        obj::AbstractVariationalObjective,
-        λ::AbstractVector,
-        restructure
-    )
+    init(rng, obj, λ, restructure)
 
 Initialize a state of the variational objective `obj` given the initial variational parameters `λ`.
 This function needs to be implemented only if `obj` is stateful.
+
+# Arguments
+- `rng::Random.AbstractRNG`: Random number generator.
+- `obj::AbstractVariationalObjective`: Variational objective.
+- `λ`: Initial variational parameters.
+- `restructure`: Function that reconstructs the variational approximation from `λ`.
 """
 init(
     ::Random.AbstractRNG,
@@ -70,25 +72,24 @@ init(
 ) = nothing
 
 """
-    estimate_gradient!(
-        rng         ::Random.AbstractRNG,
-        obj         ::AbstractVariationalObjective,
-        adbackend   ::ADTypes.AbstractADType,
-        out         ::DiffResults.MutableDiffResult
-        prob,
-        λ,
-        restructure,
-        obj_state,
-    )
-
-Estimate (possibly stochastic) gradients of the objective `obj` targeting `prob` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`.
-The estimated objective value and gradient are then stored in `out`.
-If the objective is stateful, `obj_state` is its previous state, otherwise, it is `nothing`.
+    estimate_gradient!(rng, obj, adbackend, out, prob, λ, restructure, obj_state)
+
+Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ`
+
+# Arguments
+- `rng::Random.AbstractRNG`: Random number generator.
+- `obj::AbstractVariationalObjective`: Variational objective.
+- `adbackend::ADTypes.AbstractADType`: Automatic differentiation backend. 
+- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. 
+- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
+- `λ`: Variational parameters to evaluate the gradient on.
+- `restructure`: Function that reconstructs the variational approximation from `λ`.
+- `obj_state`: Previous state of the objective.
 
 # Returns
-- `out`: The `MutableDiffResult` containing the objective value and gradient estimates.
-- `obj_state`: The updated state of the objective estimator.
-- `stat`: Statistics and logs generated during estimation. (Type: `<: NamedTuple`)
+- `out::MutableDiffResult`: Buffer containing the objective value and gradient estimates.
+- `obj_state`: The updated state of the objective.
+- `stat::NamedTuple`: Statistics and logs generated during estimation.
 """
 function estimate_gradient! end
 

From 2491c64ac825e5b3b59309b0eacb63363dab00ce Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Tue, 24 Oct 2023 00:11:54 -0400
Subject: [PATCH 190/206] fix compat error

---
 test/Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/Project.toml b/test/Project.toml
index eb4d9d594..490782cb5 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -38,7 +38,7 @@ Random = "1"
 ReverseDiff = "1.15.1"
 SimpleUnPack = "1.1.0"
 StableRNGs = "1.0.0"
-Statistics = "1.9"
+Statistics = "1"
 Test = "1"
 Tracker = "0.2.20"
 Zygote = "0.6.63"

From 92d148988e1fcd9ca9e67ff084fa11d5c5960e69 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Tue, 24 Oct 2023 00:12:12 -0400
Subject: [PATCH 191/206] fix documentation for `optimize` to be single line

---
 src/optimize.jl | 23 ++++-------------------
 1 file changed, 4 insertions(+), 19 deletions(-)

diff --git a/src/optimize.jl b/src/optimize.jl
index 1a39cb419..7e0032dce 100644
--- a/src/optimize.jl
+++ b/src/optimize.jl
@@ -1,33 +1,18 @@
 
 """
-    optimize(
-        problem,
-        objective   ::AbstractVariationalObjective,
-        restructure,
-        param_init,
-        max_iter    ::Int,
-        objargs...;
-        kwargs...
-    )              
-    optimize(
-        problem,
-        objective             ::AbstractVariationalObjective,
-        variational_dist_init,
-        max_iter              ::Int,
-        objargs...;
-        kwargs...
-    )              
+    optimize(problem, objective, restructure, param_init, max_iter, objargs...; kwargs...)              
+    optimize(problem, objective, variational_dist_init, max_iter, objargs...; kwargs...)              
 
 Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients.
 
 The variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`.
 
 # Arguments
-- `objective`: Variational Objective.
+- `objective::AbstractVariationalObjective`: Variational Objective.
 - `param_init`: Initial value of the variational parameters.
 - `restruct`: Function that reconstructs the variational approximation from the flattened parameters.
 - `variational_dist_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`.
-- `max_iter`: Maximum number of iterations.
+- `max_iter::Int`: Maximum number of iterations.
 - `objargs...`: Arguments to be passed to `objective`.
 
 # Keyword Arguments

From a03e955245c20800580e541d0402a70d4588235a Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Tue, 24 Oct 2023 00:15:41 -0400
Subject: [PATCH 192/206] refactor remove begin end for one-liner

---
 test/interface/optimize.jl | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl
index 1384a4b4e..3459b4c35 100644
--- a/test/interface/optimize.jl
+++ b/test/interface/optimize.jl
@@ -63,9 +63,7 @@ using Test
         rng  = StableRNG(seed)
         test_values = rand(rng, T)
 
-        callback(; stat, args...) = begin
-            (test_value = test_values[stat.iteration],)
-        end
+        callback(; stat, args...) = (test_value = test_values[stat.iteration],)
 
         rng  = StableRNG(seed)
         _, stats, _ = optimize(

From ff83c036a3c2ee5f8d1d33f62d1894592a00497b Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Fri, 10 Nov 2023 02:38:04 -0500
Subject: [PATCH 193/206] refactor create unified interface for estimating
 objectives

---
 src/AdvancedVI.jl           | 24 +++++++++++++++++++++-
 src/objectives/elbo/advi.jl | 40 ++++++++++++++++++++-----------------
 test/interface/advi.jl      | 12 +++++------
 3 files changed, 51 insertions(+), 25 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index 54c2b1eb0..b1decc4a7 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -46,7 +46,7 @@ function value_and_gradient! end
 Abstract type for the VI algorithms supported by `AdvancedVI`.
 
 # Implementations
-To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective`.
+To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective` and `estimate_objective`.
 Also, it should provide gradients by implementing the function `estimate_gradient!`.
 If the estimator is stateful, it can implement `init` to initialize the state.
 """
@@ -71,6 +71,28 @@ init(
     ::Any
 ) = nothing
 
+"""
+    estimate_objective([rng,] obj, q, prob, kwargs...)
+
+Estimate the variational objective `obj` targeting `prob` with respect to the variational approximation `q`.
+
+# Arguments
+- `rng::Random.AbstractRNG`: Random number generator.
+- `obj::AbstractVariationalObjective`: Variational objective.
+- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
+- `q`: Variational approximation.
+
+# Keyword Arguments
+For the keywword arguments, refer to the respective documentation for each variational objective.
+
+# Returns
+- `obj_est`: Estimate of the objective value.
+"""
+function estimate_objective end
+
+export estimate_objective
+
+
 """
     estimate_gradient!(rng, obj, adbackend, out, prob, λ, restructure, obj_state)
 
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 90e84b564..ef339cfd8 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -51,19 +51,21 @@ Base.show(io::IO, advi::ADVI) =
 
 Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation over the Monte Carlo samples `zs` (each column is a sample).
 """
-function (advi::ADVI)(
+function estimate_objective_with_samples(
+    advi::ADVI,
+    q   ::Distributions.ContinuousMultivariateDistribution,
     prob,
-    q ::Distributions.ContinuousMultivariateDistribution,
-    zs::AbstractMatrix
+    zs  ::AbstractMatrix
 )
     𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(zs))
     ℍ  = advi.entropy(q, zs)
     𝔼ℓ + ℍ
 end
 
-function (advi::ADVI)(
-    prob,
+function estimate_objective_with_samples(
+    advi   ::ADVI,
     q_trans::Bijectors.TransformedDistribution,
+    prob,
     ηs     ::AbstractMatrix
 )
     @unpack dist, transform = q_trans
@@ -78,35 +80,37 @@ function (advi::ADVI)(
 end
 
 """
-    (advi::ADVI)(
-        [rng], prob, q; n_samples::Int = advi.n_samples
+    estimate_objective(
+        advi::ADVI, [rng], prob, q; n_samples::Int = advi.n_samples
     )
 
 Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation using `n_samples` number of Monte Carlo samples.
 """
-function (advi::ADVI)(
+function estimate_objective(
     rng      ::Random.AbstractRNG,
-    prob,
-    q        ::ContinuousDistribution;
+    advi     ::ADVI,
+    q        ::ContinuousDistribution,
+    prob;
     n_samples::Int = advi.n_samples
 )
     zs = rand(rng, q, n_samples)
-    advi(prob, q, zs)
+    estimate_objective_with_samples(advi, q, prob, zs)
 end
 
-function (advi::ADVI)(
+function estimate_objective(
     rng      ::Random.AbstractRNG,
-    prob,
-    q_trans  ::Bijectors.TransformedDistribution;
+    advi     ::ADVI,
+    q_trans  ::Bijectors.TransformedDistribution,
+    prob;
     n_samples::Int  = advi.n_samples
 )
     q  = q_trans.dist
     ηs = rand(rng, q, n_samples)
-    advi(prob, q_trans, ηs)
+    estimate_objective_with_samples(advi, q_trans, prob, ηs)
 end
 
-(advi::ADVI)(prob, q::Distribution; n_samples::Int = advi.n_samples) =
-    advi(Random.default_rng(), prob, q; n_samples)
+estimate_objective(advi::ADVI, q::Distribution, prob; n_samples::Int = advi.n_samples) =
+    estimate_objective(Random.default_rng(), advi, q, prob; n_samples)
 
 function estimate_gradient!(
     rng       ::Random.AbstractRNG,
@@ -122,7 +126,7 @@ function estimate_gradient!(
         q_trans = restructure(λ′)
         q       = q_trans.dist
         ηs      = rand(rng, q, advi.n_samples)
-        -advi(prob, q_trans, ηs)
+        -estimate_objective_with_samples(advi, q_trans, prob, ηs)
     end
     value_and_gradient!(adbackend, f, λ, out)
 
diff --git a/test/interface/advi.jl b/test/interface/advi.jl
index 16db09ca5..1df396e4d 100644
--- a/test/interface/advi.jl
+++ b/test/interface/advi.jl
@@ -16,16 +16,16 @@ using Test
         obj  = ADVI(10)
 
         rng      = StableRNG(seed)
-        elbo_ref = obj(rng, model, q₀_z; n_samples=10^4)
+        elbo_ref = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4)
 
         @testset "determinism" begin
             rng  = StableRNG(seed)
-            elbo = obj(rng, model, q₀_z; n_samples=10^4)
+            elbo = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4)
             @test elbo == elbo_ref
         end
 
         @testset "default_rng" begin
-            elbo = obj(model, q₀_z; n_samples=10^4)
+            elbo = estimate_objective(obj, q₀_z, model; n_samples=10^4)
             @test elbo ≈ elbo_ref rtol=0.1
         end
     end
@@ -39,16 +39,16 @@ using Test
 
         obj      = ADVI(10)
         rng      = StableRNG(seed)
-        elbo_ref = obj(rng, model, q₀_z; n_samples=10^4)
+        elbo_ref = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4)
 
         @testset "determinism" begin
             rng  = StableRNG(seed)
-            elbo = obj(rng, model, q₀_z; n_samples=10^4)
+            elbo = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4)
             @test elbo == elbo_ref
         end
 
         @testset "default_rng" begin
-            elbo = obj(model, q₀_z; n_samples=10^4)
+            elbo = estimate_objective(obj, q₀_z, model; n_samples=10^4)
             @test elbo ≈ elbo_ref rtol=0.1
         end
     end

From aecc655dcfb87b6f09d0278e405935afe162db60 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Fri, 10 Nov 2023 03:01:45 -0500
Subject: [PATCH 194/206] refactor unify interface for entropy estimator, fix
 advi docs

---
 src/objectives/elbo/advi.jl    | 62 ++++++++++++++++++++++------------
 src/objectives/elbo/entropy.jl | 43 +++++++++++++++++++----
 2 files changed, 76 insertions(+), 29 deletions(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index ef339cfd8..54024db89 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -45,20 +45,28 @@ Base.show(io::IO, advi::ADVI) =
     print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))")
 
 """
-    (advi::ADVI)(
-        [rng], prob, q, zs::AbstractMatrix
-    )
+    estimate_objective_with_samples(obj, prob, q, zs)
+
+Estimate the ELBO using the ADVI formulation over a set of given Monte Carlo samples.
+
+# Arguments
+- `advi::ADVI`: ADVI objective.
+- `q`: Variational approximation
+- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
+- `mc_samples::AbstractMatrix`: Samples to be used to estimate the energy. (Each column is a single sample.)
+
+# Returns
+- `obj_est`: Estimate of the objective value.
 
-Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation over the Monte Carlo samples `zs` (each column is a sample).
 """
 function estimate_objective_with_samples(
-    advi::ADVI,
-    q   ::Distributions.ContinuousMultivariateDistribution,
+    advi      ::ADVI,
+    q         ::Distributions.ContinuousMultivariateDistribution,
     prob,
-    zs  ::AbstractMatrix
+    mc_samples::AbstractMatrix
 )
-    𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(zs))
-    ℍ  = advi.entropy(q, zs)
+    𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(mc_samples))
+    ℍ  = estimate_entropy(advi.entropy, mc_samples, q)
     𝔼ℓ + ℍ
 end
 
@@ -66,25 +74,34 @@ function estimate_objective_with_samples(
     advi   ::ADVI,
     q_trans::Bijectors.TransformedDistribution,
     prob,
-    ηs     ::AbstractMatrix
+    mc_samples_unconstr::AbstractMatrix
 )
     @unpack dist, transform = q_trans
     q   = dist
     b⁻¹ = transform
-    𝔼ℓ = mean(eachcol(ηs)) do ηᵢ
-        zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(b⁻¹, ηᵢ)
-        LogDensityProblems.logdensity(prob, zᵢ) + logdetjacᵢ
+    𝔼ℓ = mean(eachcol(mc_samples_unconstr)) do mc_sample_unconstr
+        mc_sample, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(b⁻¹, mc_sample_unconstr)
+        LogDensityProblems.logdensity(prob, mc_sample) + logdetjacᵢ
     end
-    ℍ  = advi.entropy(q, ηs)
+    ℍ  = estimate_entropy(advi.entropy, mc_samples_unconstr, q)
     𝔼ℓ + ℍ
 end
 
 """
-    estimate_objective(
-        advi::ADVI, [rng], prob, q; n_samples::Int = advi.n_samples
-    )
+    estimate_objective([rng,] advi, q, prob; n_samples)
+
+Estimate the ELBO using the ADVI formulation.
+
+# Arguments
+- `advi::ADVI`: ADVI objective.
+- `q`: Variational approximation
+- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
+
+# Keyword Arguments
+- `n_samples::Int = advi.n_samples`: Number of samples to be used to estimate the objective.
 
-Estimate the ELBO of the variational approximation `q` of the target `prob` using the ADVI formulation using `n_samples` number of Monte Carlo samples.
+# Returns
+- `obj_est`: Estimate of the objective value.
 """
 function estimate_objective(
     rng      ::Random.AbstractRNG,
@@ -93,8 +110,8 @@ function estimate_objective(
     prob;
     n_samples::Int = advi.n_samples
 )
-    zs = rand(rng, q, n_samples)
-    estimate_objective_with_samples(advi, q, prob, zs)
+    mc_samples = rand(rng, q, n_samples)
+    estimate_objective_with_samples(advi, q, prob, mc_samples)
 end
 
 function estimate_objective(
@@ -105,8 +122,8 @@ function estimate_objective(
     n_samples::Int  = advi.n_samples
 )
     q  = q_trans.dist
-    ηs = rand(rng, q, n_samples)
-    estimate_objective_with_samples(advi, q_trans, prob, ηs)
+    mc_unconstr_samples = rand(rng, q, n_samples)
+    estimate_objective_with_samples(advi, q_trans, prob, mc_unconstr_samples)
 end
 
 estimate_objective(advi::ADVI, q::Distribution, prob; n_samples::Int = advi.n_samples) =
@@ -122,6 +139,7 @@ function estimate_gradient!(
     restructure,
     est_state,
 )
+    q_trans_stop = restructure(λ)
     function f(λ′)
         q_trans = restructure(λ′)
         q       = q_trans.dist
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index 63854ec0f..48dad2756 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -1,15 +1,44 @@
 
+"""
+    estimate_entropy(entropy_estimator, mc_samples, q)
+
+Estimate the entropy of `q`.
+
+# Arguments
+- `entropy_estimator`: Entropy estimation strategy.
+- `q`: Variational approximation.
+- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.)
+
+# Returns
+- `obj_est`: Estimate of the objective value.
+"""
+
+function estimate_entropy end
+
+
+"""
+    ClosedFormEntropy()
+
+Use closed-form expression of entropy.
+
+# Requirements
+- `q` implements `entropy`.
+
+# References
+* Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR.
+* Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research.
+"""
 struct ClosedFormEntropy <: AbstractEntropyEstimator end
 
-function (::ClosedFormEntropy)(q, ::AbstractMatrix)
+function estimate_entropy(::ClosedFormEntropy, ::Any, q)
     entropy(q)
 end
 
 struct MonteCarloEntropy <: AbstractEntropyEstimator end
 
-function (::MonteCarloEntropy)(q, ηs::AbstractMatrix)
-    mean(eachcol(ηs)) do ηᵢ
-        -logpdf(q, ηᵢ)
+function estimate_entropy(::MonteCarloEntropy, mc_samples::AbstractMatrix, q)
+    mean(eachcol(mc_samples)) do mc_sample
+        -logpdf(q, mc_sample)
     end
 end
 
@@ -27,8 +56,8 @@ The "sticking the landing" entropy estimator.
 """
 struct StickingTheLandingEntropy <: AbstractEntropyEstimator end
 
-function (::StickingTheLandingEntropy)(q, ηs::AbstractMatrix)
-    ChainRulesCore.@ignore_derivatives mean(eachcol(ηs)) do ηᵢ
-        -logpdf(q, ηᵢ)
+function estimate_entropy(::StickingTheLandingEntropy, mc_samples::AbstractMatrix, q)
+    ChainRulesCore.@ignore_derivatives mean(eachcol(mc_samples)) do mc_sample
+        -logpdf(q, mc_sample)
     end
 end

From a8d532ae33de54c295f611ae4c00bf53ab764e07 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Fri, 10 Nov 2023 03:25:00 -0500
Subject: [PATCH 195/206] fix STL estimator to use manually stopped gradients
 instead

---
 src/objectives/elbo/advi.jl    | 73 +++++++++++++++++++++-------------
 src/objectives/elbo/entropy.jl | 13 +++---
 2 files changed, 52 insertions(+), 34 deletions(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 54024db89..8b7a2771f 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -51,7 +51,7 @@ Estimate the ELBO using the ADVI formulation over a set of given Monte Carlo sam
 
 # Arguments
 - `advi::ADVI`: ADVI objective.
-- `q`: Variational approximation
+- `q`: Variational approximation.
 - `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
 - `mc_samples::AbstractMatrix`: Samples to be used to estimate the energy. (Each column is a single sample.)
 
@@ -59,34 +59,64 @@ Estimate the ELBO using the ADVI formulation over a set of given Monte Carlo sam
 - `obj_est`: Estimate of the objective value.
 
 """
+function estimate_objective_with_samples(
+    advi      ::ADVI,
+    q         ::Union{Distributions.ContinuousMultivariateDistribution,
+                      Bijectors.TransformedDistribution},
+    prob,
+    mc_samples::AbstractMatrix
+)
+    estimate_objective_with_samples(advi, q, q, prob, mc_samples)
+end
+
+
 function estimate_objective_with_samples(
     advi      ::ADVI,
     q         ::Distributions.ContinuousMultivariateDistribution,
+    q_stop    ::Distributions.ContinuousMultivariateDistribution,
     prob,
     mc_samples::AbstractMatrix
 )
     𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(mc_samples))
-    ℍ  = estimate_entropy(advi.entropy, mc_samples, q)
+    ℍ  = estimate_entropy(advi.entropy, mc_samples, q, q_stop)
     𝔼ℓ + ℍ
 end
 
 function estimate_objective_with_samples(
-    advi   ::ADVI,
-    q_trans::Bijectors.TransformedDistribution,
+    advi        ::ADVI,
+    q_trans     ::Bijectors.TransformedDistribution,
+    q_trans_stop::Bijectors.TransformedDistribution,
     prob,
     mc_samples_unconstr::AbstractMatrix
 )
     @unpack dist, transform = q_trans
-    q   = dist
-    b⁻¹ = transform
-    𝔼ℓ = mean(eachcol(mc_samples_unconstr)) do mc_sample_unconstr
+    q      = dist
+    q_stop = q_trans_stop.dist
+    b⁻¹    = transform
+    𝔼ℓ     = mean(eachcol(mc_samples_unconstr)) do mc_sample_unconstr
         mc_sample, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(b⁻¹, mc_sample_unconstr)
         LogDensityProblems.logdensity(prob, mc_sample) + logdetjacᵢ
     end
-    ℍ  = estimate_entropy(advi.entropy, mc_samples_unconstr, q)
+    ℍ  = estimate_entropy(advi.entropy, mc_samples_unconstr, q, q_stop)
     𝔼ℓ + ℍ
 end
 
+function rand_uncontrained_samples(
+    rng      ::Random.AbstractRNG,
+    q        ::ContinuousDistribution,
+    n_samples::Int,
+)
+    rand(rng, q, n_samples)
+end
+
+function rand_uncontrained_samples(
+    rng      ::Random.AbstractRNG,
+    q_trans  ::Bijectors.TransformedDistribution,
+    n_samples::Int,
+)
+    rand(rng, q_trans.dist, n_samples)
+end
+
 """
     estimate_objective([rng,] advi, q, prob; n_samples)
 
@@ -106,24 +136,12 @@ Estimate the ELBO using the ADVI formulation.
 function estimate_objective(
     rng      ::Random.AbstractRNG,
     advi     ::ADVI,
-    q        ::ContinuousDistribution,
+    q,
     prob;
     n_samples::Int = advi.n_samples
 )
-    mc_samples = rand(rng, q, n_samples)
-    estimate_objective_with_samples(advi, q, prob, mc_samples)
-end
-
-function estimate_objective(
-    rng      ::Random.AbstractRNG,
-    advi     ::ADVI,
-    q_trans  ::Bijectors.TransformedDistribution,
-    prob;
-    n_samples::Int  = advi.n_samples
-)
-    q  = q_trans.dist
-    mc_unconstr_samples = rand(rng, q, n_samples)
-    estimate_objective_with_samples(advi, q_trans, prob, mc_unconstr_samples)
+    mc_samples_unconstr = rand_uncontrained_samples(rng, q, n_samples)
+    estimate_objective_with_samples(advi, q, prob, mc_samples_unconstr)
 end
 
 estimate_objective(advi::ADVI, q::Distribution, prob; n_samples::Int = advi.n_samples) =
@@ -139,12 +157,11 @@ function estimate_gradient!(
     restructure,
     est_state,
 )
-    q_trans_stop = restructure(λ)
+    q_stop = restructure(λ)
     function f(λ′)
-        q_trans = restructure(λ′)
-        q       = q_trans.dist
-        ηs      = rand(rng, q, advi.n_samples)
-        -estimate_objective_with_samples(advi, q_trans, prob, ηs)
+        q = restructure(λ′)
+        mc_samples_unconstr = rand_uncontrained_samples(rng, q, advi.n_samples)
+        -estimate_objective_with_samples(advi, q, q_stop, prob, mc_samples_unconstr)
     end
     value_and_gradient!(adbackend, f, λ, out)
 
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index 48dad2756..461aa030e 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -1,12 +1,13 @@
 
 """
-    estimate_entropy(entropy_estimator, mc_samples, q)
+    estimate_entropy(entropy_estimator, mc_samples, q, q_stop)
 
 Estimate the entropy of `q`.
 
 # Arguments
 - `entropy_estimator`: Entropy estimation strategy.
 - `q`: Variational approximation.
+- `q_stop`: Variational approximation with "stopped gradients".
 - `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.)
 
 # Returns
@@ -30,13 +31,13 @@ Use closed-form expression of entropy.
 """
 struct ClosedFormEntropy <: AbstractEntropyEstimator end
 
-function estimate_entropy(::ClosedFormEntropy, ::Any, q)
+function estimate_entropy(::ClosedFormEntropy, ::Any, q, ::Any)
     entropy(q)
 end
 
 struct MonteCarloEntropy <: AbstractEntropyEstimator end
 
-function estimate_entropy(::MonteCarloEntropy, mc_samples::AbstractMatrix, q)
+function estimate_entropy(::MonteCarloEntropy, mc_samples::AbstractMatrix, q, ::Any)
     mean(eachcol(mc_samples)) do mc_sample
         -logpdf(q, mc_sample)
     end
@@ -56,8 +57,8 @@ The "sticking the landing" entropy estimator.
 """
 struct StickingTheLandingEntropy <: AbstractEntropyEstimator end
 
-function estimate_entropy(::StickingTheLandingEntropy, mc_samples::AbstractMatrix, q)
-    ChainRulesCore.@ignore_derivatives mean(eachcol(mc_samples)) do mc_sample
-        -logpdf(q, mc_sample)
+function estimate_entropy(::StickingTheLandingEntropy, mc_samples::AbstractMatrix, ::Any, q_stop)
+    mean(eachcol(mc_samples)) do mc_sample
+        -logpdf(q_stop, mc_sample)
     end
 end

From 65e9b126a6c0038bac9f6904e8d851b211b528b3 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Fri, 10 Nov 2023 03:25:15 -0500
Subject: [PATCH 196/206] add inference test for a non-bijector model

---
 test/inference/advi_distributionsad.jl | 81 +++++++++++++++++++++++++-
 1 file changed, 78 insertions(+), 3 deletions(-)

diff --git a/test/inference/advi_distributionsad.jl b/test/inference/advi_distributionsad.jl
index 01c7a96e6..9919ce2b3 100644
--- a/test/inference/advi_distributionsad.jl
+++ b/test/inference/advi_distributionsad.jl
@@ -4,6 +4,81 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
 using Test
 
 @testset "inference_advi" begin
+    @testset "distributionsad" begin
+        @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
+            realtype ∈ [Float64], # Currently only tested against Float64
+            (modelname, modelconstr) ∈ Dict(
+                :Normal=> normal_meanfield,
+            ),
+            (objname, objective) ∈ Dict(
+                :ADVIClosedFormEntropy  => ADVI(10),
+                :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()),
+            ),
+            (adbackname, adbackend) ∈ Dict(
+                :ForwarDiff  => AutoForwardDiff(),
+                #:ReverseDiff => AutoReverseDiff(),
+                #:Zygote      => AutoZygote(), 
+                #:Enzyme      => AutoEnzyme(),
+            )
+
+            seed = (0x38bef07cf9cc549d)
+            rng  = StableRNG(seed)
+
+            modelstats = modelconstr(rng, realtype)
+            @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+            T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
+
+            μ₀   = Zeros(realtype, n_dims)
+            L₀   = Diagonal(Ones(realtype, n_dims))
+            q₀_z = TuringDiagMvNormal(μ₀, diag(L₀))
+
+            @testset "convergence" begin
+                Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
+                q, stats, _ = optimize(
+                    rng, model, objective, q₀_z, T;
+                    optimizer     = Optimisers.Adam(realtype(η)),
+                    show_progress = PROGRESS,
+                    adbackend     = adbackend,
+                )
+
+                μ  = mean(q)
+                L  = sqrt(cov(q))
+                Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
+
+                @test Δλ ≤ Δλ₀/T^(1/4)
+                @test eltype(μ) == eltype(μ_true)
+                @test eltype(L) == eltype(L_true)
+            end
+
+            @testset "determinism" begin
+                rng = StableRNG(seed)
+                q, stats, _ = optimize(
+                    rng, model, objective, q₀_z, T;
+                    optimizer     = Optimisers.Adam(realtype(η)),
+                    show_progress = PROGRESS,
+                    adbackend     = adbackend,
+                )
+                μ  = mean(q)
+                L  = sqrt(cov(q))
+
+                rng_repl = StableRNG(seed)
+                q, stats, _ = optimize(
+                    rng_repl, model, objective, q₀_z, T;
+                    optimizer     = Optimisers.Adam(realtype(η)),
+                    show_progress = PROGRESS,
+                    adbackend     = adbackend,
+                )
+                μ_repl = mean(q)
+                L_repl = sqrt(cov(q))
+                @test μ == μ_repl
+                @test L == L_repl
+            end
+        end
+    end
+end
+
+@testset "inference_bijectors_advi" begin
     @testset "distributionsad" begin
         @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
             realtype ∈ [Float64], # Currently only tested against Float64
@@ -16,9 +91,9 @@ using Test
             ),
             (adbackname, adbackend) ∈ Dict(
                 :ForwarDiff  => AutoForwardDiff(),
-                # :ReverseDiff => AutoReverseDiff(),
-                # :Zygote      => AutoZygote(), 
-                # :Enzyme      => AutoEnzyme(),
+                #:ReverseDiff => AutoReverseDiff(),
+                #:Zygote      => AutoZygote(), 
+                #:Enzyme      => AutoEnzyme(),
             )
 
             seed = (0x38bef07cf9cc549d)

From 3691f160a7923d7a569cb93eb51719e8f4065beb Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Sat, 11 Nov 2023 00:24:08 -0500
Subject: [PATCH 197/206] refactor add indirections to handle STL and bijectors
 in ADVI

---
 src/objectives/elbo/advi.jl    | 106 ++++++++++++++++-----------------
 src/objectives/elbo/entropy.jl |  24 ++++----
 2 files changed, 60 insertions(+), 70 deletions(-)

diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
index 8b7a2771f..98f8ae993 100644
--- a/src/objectives/elbo/advi.jl
+++ b/src/objectives/elbo/advi.jl
@@ -39,82 +39,76 @@ struct ADVI{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObject
     n_samples::Int
 end
 
-ADVI(n_samples::Int; entropy::AbstractEntropyEstimator = ClosedFormEntropy()) = ADVI(entropy, n_samples)
+ADVI(
+    n_samples::Int;
+    entropy  ::AbstractEntropyEstimator = ClosedFormEntropy()
+) = ADVI(entropy, n_samples)
 
 Base.show(io::IO, advi::ADVI) =
     print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))")
 
-"""
-    estimate_objective_with_samples(obj, prob, q, zs)
+maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop
 
-Estimate the ELBO using the ADVI formulation over a set of given Monte Carlo samples.
+maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q
 
-# Arguments
-- `advi::ADVI`: ADVI objective.
-- `q`: Variational approximation.
-- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
-- `mc_samples::AbstractMatrix`: Samples to be used to estimate the energy. (Each column is a single sample.)
-
-# Returns
-- `obj_est`: Estimate of the objective value.
+function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, mc_samples, q, q_stop)
+    q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
+    estimate_entropy(entropy_estimator, mc_samples, q_maybe_stop)
+end
 
-"""
-function estimate_objective_with_samples(
-    advi      ::ADVI,
-    q         ::Union{Distributions.ContinuousMultivariateDistribution,
-                      Bijectors.TransformedDistribution},
-    prob,
-    mc_samples::AbstractMatrix
-)
-    estimate_objective_with_samples(advi, q, q, prob, mc_samples)
+function estimate_energy_with_samples(::ADVI, mc_samples::AbstractMatrix, prob)
+    mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(mc_samples))
 end
 
+function estimate_energy_with_samples_bijector(::ADVI, mc_samples::AbstractMatrix, invbij, prob)
+    mean(eachcol(mc_samples)) do mc_sample
+        mc_sample, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(invbij, mc_sample)
+        LogDensityProblems.logdensity(prob, mc_sample) + logdetjacᵢ
+    end
+end
 
-function estimate_objective_with_samples(
+function estimate_advi_maybe_stl_with_samples(
     advi      ::ADVI,
-    q         ::Distributions.ContinuousMultivariateDistribution,
-    q_stop    ::Distributions.ContinuousMultivariateDistribution,
-    prob,
-    mc_samples::AbstractMatrix
+    q         ::ContinuousDistribution,
+    q_stop    ::ContinuousDistribution,
+    mc_samples::AbstractMatrix,
+    prob
 )
-    𝔼ℓ = mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(mc_samples))
-    ℍ  = estimate_entropy(advi.entropy, mc_samples, q, q_stop)
-    𝔼ℓ + ℍ
+    energy  = estimate_energy_with_samples(advi, mc_samples, prob)
+    entropy = estimate_entropy_maybe_stl(advi.entropy, mc_samples, q, q_stop)
+    energy + entropy
 end
 
-function estimate_objective_with_samples(
+function estimate_advi_maybe_stl_with_samples(
     advi        ::ADVI,
     q_trans     ::Bijectors.TransformedDistribution,
     q_trans_stop::Bijectors.TransformedDistribution,
-    prob,
-    mc_samples_unconstr::AbstractMatrix
+    mc_samples  ::AbstractMatrix,
+    prob
 )
-    @unpack dist, transform = q_trans
-    q      = dist
-    q_stop = q_trans_stop.dist
-    b⁻¹    = transform
-    𝔼ℓ     = mean(eachcol(mc_samples_unconstr)) do mc_sample_unconstr
-        mc_sample, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(b⁻¹, mc_sample_unconstr)
-        LogDensityProblems.logdensity(prob, mc_sample) + logdetjacᵢ
-    end
-    ℍ  = estimate_entropy(advi.entropy, mc_samples_unconstr, q, q_stop)
-    𝔼ℓ + ℍ
+    q       = q_trans.dist
+    invbij  = q_trans.transform
+    q_stop  = q_trans_stop.dist
+    energy  = estimate_energy_with_samples_bijector(advi, mc_samples, invbij, prob)
+    entropy = estimate_entropy_maybe_stl(advi.entropy, mc_samples, q, q_stop)
+    energy + entropy
 end
 
-function rand_uncontrained_samples(
+rand_unconstrained(
     rng      ::Random.AbstractRNG,
     q        ::ContinuousDistribution,
-    n_samples::Int,
-)
-    rand(rng, q, n_samples)
-end
+    n_samples::Int
+) = rand(rng, q, n_samples)
 
-function rand_uncontrained_samples(
+rand_unconstrained(
     rng      ::Random.AbstractRNG,
-    q_trans  ::Bijectors.TransformedDistribution,
-    n_samples::Int,
-)
-    rand(rng, q_trans.dist, n_samples)
+    q        ::Bijectors.TransformedDistribution,
+    n_samples::Int
+) = rand(rng, q.dist, n_samples)
+
+function estimate_advi_maybe_stl(rng::Random.AbstractRNG, advi::ADVI, q, q_stop, prob)
+    mc_samples = rand_unconstrained(rng, q, advi.n_samples)
+    estimate_advi_maybe_stl_with_samples(advi, q, q_stop, mc_samples, prob)
 end
 
 """
@@ -140,8 +134,8 @@ function estimate_objective(
     prob;
     n_samples::Int = advi.n_samples
 )
-    mc_samples_unconstr = rand_uncontrained_samples(rng, q, n_samples)
-    estimate_objective_with_samples(advi, q, prob, mc_samples_unconstr)
+    mc_samples = rand_unconstrained(rng, q, n_samples)
+    estimate_advi_maybe_stl_with_samples(advi, q, q, mc_samples, prob)
 end
 
 estimate_objective(advi::ADVI, q::Distribution, prob; n_samples::Int = advi.n_samples) =
@@ -160,8 +154,8 @@ function estimate_gradient!(
     q_stop = restructure(λ)
     function f(λ′)
         q = restructure(λ′)
-        mc_samples_unconstr = rand_uncontrained_samples(rng, q, advi.n_samples)
-        -estimate_objective_with_samples(advi, q, q_stop, prob, mc_samples_unconstr)
+        elbo = estimate_advi_maybe_stl(rng, advi, q, q_stop, prob)
+        -elbo
     end
     value_and_gradient!(adbackend, f, λ, out)
 
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index 461aa030e..6fa3095e8 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -1,13 +1,12 @@
 
 """
-    estimate_entropy(entropy_estimator, mc_samples, q, q_stop)
+    estimate_entropy(entropy_estimator, mc_samples, q)
 
 Estimate the entropy of `q`.
 
 # Arguments
 - `entropy_estimator`: Entropy estimation strategy.
 - `q`: Variational approximation.
-- `q_stop`: Variational approximation with "stopped gradients".
 - `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.)
 
 # Returns
@@ -16,7 +15,6 @@ Estimate the entropy of `q`.
 
 function estimate_entropy end
 
-
 """
     ClosedFormEntropy()
 
@@ -31,18 +29,10 @@ Use closed-form expression of entropy.
 """
 struct ClosedFormEntropy <: AbstractEntropyEstimator end
 
-function estimate_entropy(::ClosedFormEntropy, ::Any, q, ::Any)
+function estimate_entropy(::ClosedFormEntropy, ::Any, q)
     entropy(q)
 end
 
-struct MonteCarloEntropy <: AbstractEntropyEstimator end
-
-function estimate_entropy(::MonteCarloEntropy, mc_samples::AbstractMatrix, q, ::Any)
-    mean(eachcol(mc_samples)) do mc_sample
-        -logpdf(q, mc_sample)
-    end
-end
-
 """
     StickingTheLandingEntropy()
 
@@ -57,8 +47,14 @@ The "sticking the landing" entropy estimator.
 """
 struct StickingTheLandingEntropy <: AbstractEntropyEstimator end
 
-function estimate_entropy(::StickingTheLandingEntropy, mc_samples::AbstractMatrix, ::Any, q_stop)
+struct MonteCarloEntropy <: AbstractEntropyEstimator end
+
+function estimate_entropy(
+    ::Union{MonteCarloEntropy, StickingTheLandingEntropy},
+    mc_samples::AbstractMatrix,
+    q
+)
     mean(eachcol(mc_samples)) do mc_sample
-        -logpdf(q_stop, mc_sample)
+        -logpdf(q, mc_sample)
     end
 end

From a063583c8e5b9cb83efd006110f662591e68c0ba Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Sat, 11 Nov 2023 00:44:07 -0500
Subject: [PATCH 198/206] refactor split inference tests for
 advi+distributionsad

---
 test/inference/advi_distributionsad.jl        | 208 ++++++------------
 .../advi_distributionsad_bijectors.jl         |  81 +++++++
 test/runtests.jl                              |   1 +
 3 files changed, 146 insertions(+), 144 deletions(-)
 create mode 100644 test/inference/advi_distributionsad_bijectors.jl

diff --git a/test/inference/advi_distributionsad.jl b/test/inference/advi_distributionsad.jl
index 9919ce2b3..e82a9ec0e 100644
--- a/test/inference/advi_distributionsad.jl
+++ b/test/inference/advi_distributionsad.jl
@@ -3,156 +3,76 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
 
 using Test
 
-@testset "inference_advi" begin
-    @testset "distributionsad" begin
-        @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
-            realtype ∈ [Float64], # Currently only tested against Float64
-            (modelname, modelconstr) ∈ Dict(
-                :Normal=> normal_meanfield,
-            ),
-            (objname, objective) ∈ Dict(
-                :ADVIClosedFormEntropy  => ADVI(10),
-                :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()),
-            ),
-            (adbackname, adbackend) ∈ Dict(
-                :ForwarDiff  => AutoForwardDiff(),
-                #:ReverseDiff => AutoReverseDiff(),
-                #:Zygote      => AutoZygote(), 
-                #:Enzyme      => AutoEnzyme(),
+@testset "inference_advi_distributionsad" begin
+    @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
+        realtype ∈ [Float64, Float32],
+        (modelname, modelconstr) ∈ Dict(
+            :Normal=> normal_meanfield,
+        ),
+        (objname, objective) ∈ Dict(
+            :ADVIClosedFormEntropy  => ADVI(10),
+            :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()),
+        ),
+        (adbackname, adbackend) ∈ Dict(
+            :ForwarDiff  => AutoForwardDiff(),
+            #:ReverseDiff => AutoReverseDiff(),
+            #:Zygote      => AutoZygote(), 
+            #:Enzyme      => AutoEnzyme(),
+        )
+
+        seed = (0x38bef07cf9cc549d)
+        rng  = StableRNG(seed)
+
+        modelstats = modelconstr(rng, realtype)
+        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+        T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
+
+        μ₀   = Zeros(realtype, n_dims)
+        L₀   = Diagonal(Ones(realtype, n_dims))
+        q₀_z = TuringDiagMvNormal(μ₀, diag(L₀))
+
+        @testset "convergence" begin
+            Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
+            q, stats, _ = optimize(
+                rng, model, objective, q₀_z, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
             )
 
-            seed = (0x38bef07cf9cc549d)
-            rng  = StableRNG(seed)
+            μ  = mean(q)
+            L  = sqrt(cov(q))
+            Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
 
-            modelstats = modelconstr(rng, realtype)
-            @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
-
-            T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
-
-            μ₀   = Zeros(realtype, n_dims)
-            L₀   = Diagonal(Ones(realtype, n_dims))
-            q₀_z = TuringDiagMvNormal(μ₀, diag(L₀))
-
-            @testset "convergence" begin
-                Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
-                q, stats, _ = optimize(
-                    rng, model, objective, q₀_z, T;
-                    optimizer     = Optimisers.Adam(realtype(η)),
-                    show_progress = PROGRESS,
-                    adbackend     = adbackend,
-                )
-
-                μ  = mean(q)
-                L  = sqrt(cov(q))
-                Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
-
-                @test Δλ ≤ Δλ₀/T^(1/4)
-                @test eltype(μ) == eltype(μ_true)
-                @test eltype(L) == eltype(L_true)
-            end
-
-            @testset "determinism" begin
-                rng = StableRNG(seed)
-                q, stats, _ = optimize(
-                    rng, model, objective, q₀_z, T;
-                    optimizer     = Optimisers.Adam(realtype(η)),
-                    show_progress = PROGRESS,
-                    adbackend     = adbackend,
-                )
-                μ  = mean(q)
-                L  = sqrt(cov(q))
-
-                rng_repl = StableRNG(seed)
-                q, stats, _ = optimize(
-                    rng_repl, model, objective, q₀_z, T;
-                    optimizer     = Optimisers.Adam(realtype(η)),
-                    show_progress = PROGRESS,
-                    adbackend     = adbackend,
-                )
-                μ_repl = mean(q)
-                L_repl = sqrt(cov(q))
-                @test μ == μ_repl
-                @test L == L_repl
-            end
+            @test Δλ ≤ Δλ₀/T^(1/4)
+            @test eltype(μ) == eltype(μ_true)
+            @test eltype(L) == eltype(L_true)
         end
-    end
-end
 
-@testset "inference_bijectors_advi" begin
-    @testset "distributionsad" begin
-        @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
-            realtype ∈ [Float64], # Currently only tested against Float64
-            (modelname, modelconstr) ∈ Dict(
-                :NormalLogNormalMeanField => normallognormal_meanfield,
-            ),
-            (objname, objective) ∈ Dict(
-                :ADVIClosedFormEntropy  => ADVI(10),
-                :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()),
-            ),
-            (adbackname, adbackend) ∈ Dict(
-                :ForwarDiff  => AutoForwardDiff(),
-                #:ReverseDiff => AutoReverseDiff(),
-                #:Zygote      => AutoZygote(), 
-                #:Enzyme      => AutoEnzyme(),
+        @testset "determinism" begin
+            rng = StableRNG(seed)
+            q, stats, _ = optimize(
+                rng, model, objective, q₀_z, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
             )
-
-            seed = (0x38bef07cf9cc549d)
-            rng  = StableRNG(seed)
-
-            modelstats = modelconstr(rng, realtype)
-            @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
-
-            T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
-
-            b    = Bijectors.bijector(model)
-            b⁻¹  = inverse(b)
-            μ₀   = Zeros(realtype, n_dims)
-            L₀   = Diagonal(Ones(realtype, n_dims))
-
-            q₀_η = TuringDiagMvNormal(μ₀, diag(L₀))
-            q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
-
-            @testset "convergence" begin
-                Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
-                q, stats, _ = optimize(
-                    rng, model, objective, q₀_z, T;
-                    optimizer     = Optimisers.Adam(realtype(η)),
-                    show_progress = PROGRESS,
-                    adbackend     = adbackend,
-                )
-
-                μ  = mean(q.dist)
-                L  = sqrt(cov(q.dist))
-                Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
-
-                @test Δλ ≤ Δλ₀/T^(1/4)
-                @test eltype(μ) == eltype(μ_true)
-                @test eltype(L) == eltype(L_true)
-            end
-
-            @testset "determinism" begin
-                rng = StableRNG(seed)
-                q, stats, _ = optimize(
-                    rng, model, objective, q₀_z, T;
-                    optimizer     = Optimisers.Adam(realtype(η)),
-                    show_progress = PROGRESS,
-                    adbackend     = adbackend,
-                )
-                μ  = mean(q.dist)
-                L  = sqrt(cov(q.dist))
-
-                rng_repl = StableRNG(seed)
-                q, stats, _ = optimize(
-                    rng_repl, model, objective, q₀_z, T;
-                    optimizer     = Optimisers.Adam(realtype(η)),
-                    show_progress = PROGRESS,
-                    adbackend     = adbackend,
-                )
-                μ_repl = mean(q.dist)
-                L_repl = sqrt(cov(q.dist))
-                @test μ == μ_repl
-                @test L == L_repl
-            end
+            μ  = mean(q)
+            L  = sqrt(cov(q))
+
+            rng_repl = StableRNG(seed)
+            q, stats, _ = optimize(
+                rng_repl, model, objective, q₀_z, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
+            )
+            μ_repl = mean(q)
+            L_repl = sqrt(cov(q))
+            @test μ == μ_repl
+            @test L == L_repl
         end
     end
 end
+
diff --git a/test/inference/advi_distributionsad_bijectors.jl b/test/inference/advi_distributionsad_bijectors.jl
new file mode 100644
index 000000000..29602fe76
--- /dev/null
+++ b/test/inference/advi_distributionsad_bijectors.jl
@@ -0,0 +1,81 @@
+
+const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
+
+using Test
+
+@testset "inference_advi_distributionsad_bijectors" begin
+    @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
+        realtype ∈ [Float64, Float32],
+        (modelname, modelconstr) ∈ Dict(
+            :NormalLogNormalMeanField => normallognormal_meanfield,
+        ),
+        (objname, objective) ∈ Dict(
+            :ADVIClosedFormEntropy  => ADVI(10),
+            :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()),
+        ),
+        (adbackname, adbackend) ∈ Dict(
+            :ForwarDiff  => AutoForwardDiff(),
+            #:ReverseDiff => AutoReverseDiff(),
+            #:Zygote      => AutoZygote(), 
+            #:Enzyme      => AutoEnzyme(),
+        )
+
+        seed = (0x38bef07cf9cc549d)
+        rng  = StableRNG(seed)
+
+        modelstats = modelconstr(rng, realtype)
+        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+        T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
+
+        b    = Bijectors.bijector(model)
+        b⁻¹  = inverse(b)
+        μ₀   = Zeros(realtype, n_dims)
+        L₀   = Diagonal(Ones(realtype, n_dims))
+
+        q₀_η = TuringDiagMvNormal(μ₀, diag(L₀))
+        q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
+
+        @testset "convergence" begin
+            Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
+            q, stats, _ = optimize(
+                rng, model, objective, q₀_z, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
+            )
+
+            μ  = mean(q.dist)
+            L  = sqrt(cov(q.dist))
+            Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
+
+            @test Δλ ≤ Δλ₀/T^(1/4)
+            @test eltype(μ) == eltype(μ_true)
+            @test eltype(L) == eltype(L_true)
+        end
+
+        @testset "determinism" begin
+            rng = StableRNG(seed)
+            q, stats, _ = optimize(
+                rng, model, objective, q₀_z, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
+            )
+            μ  = mean(q.dist)
+            L  = sqrt(cov(q.dist))
+            
+            rng_repl = StableRNG(seed)
+            q, stats, _ = optimize(
+                rng_repl, model, objective, q₀_z, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
+            )
+            μ_repl = mean(q.dist)
+            L_repl = sqrt(cov(q.dist))
+            @test μ == μ_repl
+            @test L == L_repl
+        end
+    end
+end
diff --git a/test/runtests.jl b/test/runtests.jl
index 6fda0be81..757a931d8 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -40,3 +40,4 @@ include("interface/optimize.jl")
 include("interface/advi.jl")
 
 include("inference/advi_distributionsad.jl")
+include("inference/advi_distributionsad_bijectors.jl")

From 316b629eb965a591019b7149bbcf7fc72e613b9b Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Tue, 21 Nov 2023 01:50:41 -0500
Subject: [PATCH 199/206] refactor rename advi to repgradelbo and not use
 bijectors directly

---
 Project.toml                                  |   2 -
 src/AdvancedVI.jl                             |   9 +-
 src/objectives/elbo/advi.jl                   | 166 ------------------
 src/objectives/elbo/repgradelbo.jl            | 126 +++++++++++++
 src/utils.jl                                  |   3 +
 test/Project.toml                             |   2 -
 .../advi_distributionsad_bijectors.jl         |  81 ---------
 ...nsad.jl => repgradelbo_distributionsad.jl} |  20 +--
 test/interface/advi.jl                        |  55 ------
 test/interface/optimize.jl                    |  22 ++-
 test/interface/repgradelbo.jl                 |  28 +++
 test/models/normallognormal.jl                |  65 -------
 test/runtests.jl                              |   8 +-
 13 files changed, 182 insertions(+), 405 deletions(-)
 delete mode 100644 src/objectives/elbo/advi.jl
 create mode 100644 src/objectives/elbo/repgradelbo.jl
 delete mode 100644 test/inference/advi_distributionsad_bijectors.jl
 rename test/inference/{advi_distributionsad.jl => repgradelbo_distributionsad.jl} (78%)
 delete mode 100644 test/interface/advi.jl
 create mode 100644 test/interface/repgradelbo.jl
 delete mode 100644 test/models/normallognormal.jl

diff --git a/Project.toml b/Project.toml
index 700415614..7799d5057 100644
--- a/Project.toml
+++ b/Project.toml
@@ -5,7 +5,6 @@ version = "0.3.0"
 [deps]
 ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
 Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
-Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
 DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -36,7 +35,6 @@ AdvancedVIZygoteExt = "Zygote"
 [compat]
 ADTypes = "0.1, 0.2"
 Accessors = "0.1"
-Bijectors = "0.12, 0.13"
 ChainRulesCore = "1.16"
 DiffResults = "1"
 Distributions = "0.25.87"
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index b1decc4a7..bb5b6e856 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -11,7 +11,6 @@ using Functors
 using Optimisers
 
 using DocStringExtensions
-
 using ProgressMeter
 using LinearAlgebra
 
@@ -21,7 +20,6 @@ using ADTypes, DiffResults
 using ChainRulesCore
 
 using FillArrays
-using Bijectors
 
 using StatsBase
 
@@ -115,18 +113,17 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ
 """
 function estimate_gradient! end
 
-# ADVI-specific interfaces
+# ELBO-specific interfaces
 abstract type AbstractEntropyEstimator end
 
 export
-    ADVI,
+    RepGradELBO,
     ClosedFormEntropy,
     StickingTheLandingEntropy,
     MonteCarloEntropy
 
-# entropy.jl must preceed advi.jl
 include("objectives/elbo/entropy.jl")
-include("objectives/elbo/advi.jl")
+include("objectives/elbo/repgradelbo.jl")
 
 # Optimization Routine
 
diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl
deleted file mode 100644
index 98f8ae993..000000000
--- a/src/objectives/elbo/advi.jl
+++ /dev/null
@@ -1,166 +0,0 @@
-
-"""
-    ADVI(n_samples; kwargs...)
-
-Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective.
-This computes the evidence lower-bound (ELBO) through the ADVI formulation:
-```math
-\\begin{aligned}
-\\mathrm{ADVI}\\left(\\lambda\\right)
-&\\triangleq
-\\mathbb{E}_{\\eta \\sim q_{\\lambda}}\\left[
-  \\log \\pi\\left( \\phi^{-1}\\left( \\eta \\right) \\right)
-  +
-  \\log \\lvert J_{\\phi^{-1}}\\left(\\eta\\right) \\rvert
-\\right]
-+ \\mathbb{H}\\left(q_{\\lambda}\\right),
-\\end{aligned}
-```
-where ``\\phi^{-1}`` is an "inverse bijector."
-
-# Arguments
-- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO.
-
-# Keyword Arguments
-- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy())
-
-# Requirements
-- ``q_{\\lambda}`` implements `rand`.
-- The target `logdensity(prob, x)` must be differentiable wrt. `x` by the selected AD backend.
-
-Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
-
-# References
-* Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research.
-* Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR.
-"""
-struct ADVI{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective
-    entropy  ::EntropyEst
-    n_samples::Int
-end
-
-ADVI(
-    n_samples::Int;
-    entropy  ::AbstractEntropyEstimator = ClosedFormEntropy()
-) = ADVI(entropy, n_samples)
-
-Base.show(io::IO, advi::ADVI) =
-    print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))")
-
-maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop
-
-maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q
-
-function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, mc_samples, q, q_stop)
-    q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
-    estimate_entropy(entropy_estimator, mc_samples, q_maybe_stop)
-end
-
-function estimate_energy_with_samples(::ADVI, mc_samples::AbstractMatrix, prob)
-    mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(mc_samples))
-end
-
-function estimate_energy_with_samples_bijector(::ADVI, mc_samples::AbstractMatrix, invbij, prob)
-    mean(eachcol(mc_samples)) do mc_sample
-        mc_sample, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(invbij, mc_sample)
-        LogDensityProblems.logdensity(prob, mc_sample) + logdetjacᵢ
-    end
-end
-
-function estimate_advi_maybe_stl_with_samples(
-    advi      ::ADVI,
-    q         ::ContinuousDistribution,
-    q_stop    ::ContinuousDistribution,
-    mc_samples::AbstractMatrix,
-    prob
-)
-    energy  = estimate_energy_with_samples(advi, mc_samples, prob)
-    entropy = estimate_entropy_maybe_stl(advi.entropy, mc_samples, q, q_stop)
-    energy + entropy
-end
-
-function estimate_advi_maybe_stl_with_samples(
-    advi        ::ADVI,
-    q_trans     ::Bijectors.TransformedDistribution,
-    q_trans_stop::Bijectors.TransformedDistribution,
-    mc_samples  ::AbstractMatrix,
-    prob
-)
-    q       = q_trans.dist
-    invbij  = q_trans.transform
-    q_stop  = q_trans_stop.dist
-    energy  = estimate_energy_with_samples_bijector(advi, mc_samples, invbij, prob)
-    entropy = estimate_entropy_maybe_stl(advi.entropy, mc_samples, q, q_stop)
-    energy + entropy
-end
-
-rand_unconstrained(
-    rng      ::Random.AbstractRNG,
-    q        ::ContinuousDistribution,
-    n_samples::Int
-) = rand(rng, q, n_samples)
-
-rand_unconstrained(
-    rng      ::Random.AbstractRNG,
-    q        ::Bijectors.TransformedDistribution,
-    n_samples::Int
-) = rand(rng, q.dist, n_samples)
-
-function estimate_advi_maybe_stl(rng::Random.AbstractRNG, advi::ADVI, q, q_stop, prob)
-    mc_samples = rand_unconstrained(rng, q, advi.n_samples)
-    estimate_advi_maybe_stl_with_samples(advi, q, q_stop, mc_samples, prob)
-end
-
-"""
-    estimate_objective([rng,] advi, q, prob; n_samples)
-
-Estimate the ELBO using the ADVI formulation.
-
-# Arguments
-- `advi::ADVI`: ADVI objective.
-- `q`: Variational approximation
-- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
-
-# Keyword Arguments
-- `n_samples::Int = advi.n_samples`: Number of samples to be used to estimate the objective.
-
-# Returns
-- `obj_est`: Estimate of the objective value.
-"""
-function estimate_objective(
-    rng      ::Random.AbstractRNG,
-    advi     ::ADVI,
-    q,
-    prob;
-    n_samples::Int = advi.n_samples
-)
-    mc_samples = rand_unconstrained(rng, q, n_samples)
-    estimate_advi_maybe_stl_with_samples(advi, q, q, mc_samples, prob)
-end
-
-estimate_objective(advi::ADVI, q::Distribution, prob; n_samples::Int = advi.n_samples) =
-    estimate_objective(Random.default_rng(), advi, q, prob; n_samples)
-
-function estimate_gradient!(
-    rng       ::Random.AbstractRNG,
-    advi      ::ADVI,
-    adbackend ::ADTypes.AbstractADType,
-    out       ::DiffResults.MutableDiffResult,
-    prob,
-    λ,
-    restructure,
-    est_state,
-)
-    q_stop = restructure(λ)
-    function f(λ′)
-        q = restructure(λ′)
-        elbo = estimate_advi_maybe_stl(rng, advi, q, q_stop, prob)
-        -elbo
-    end
-    value_and_gradient!(adbackend, f, λ, out)
-
-    nelbo = DiffResults.value(out)
-    stat  = (elbo=-nelbo,)
-
-    out, nothing, stat
-end
diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl
new file mode 100644
index 000000000..09ba1a793
--- /dev/null
+++ b/src/objectives/elbo/repgradelbo.jl
@@ -0,0 +1,126 @@
+
+"""
+    RepGradELBO(n_samples; kwargs...)
+
+Evidence lower-bound objective with the reparameterization gradient formulation[^TL2014][^RMW2014][^KW2014].
+This computes the evidence lower-bound (ELBO) through the formulation:
+```math
+\\begin{aligned}
+\\mathrm{ELBO}\\left(\\lambda\\right)
+&\\triangleq
+\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[
+  \\log \\pi\\left(z\\right)
+\\right]
++ \\mathbb{H}\\left(q_{\\lambda}\\right),
+\\end{aligned}
+```
+
+# Arguments
+- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO.
+
+# Keyword Arguments
+- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy())
+
+# Requirements
+- ``q_{\\lambda}`` implements `rand`.
+- The target `logdensity(prob, x)` must be differentiable wrt. `x` by the selected AD backend.
+
+Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
+
+# References
+[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In ICML.
+[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In ICML.
+[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In ICLR.
+"""
+struct RepGradELBO{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective
+    entropy  ::EntropyEst
+    n_samples::Int
+end
+
+RepGradELBO(
+    n_samples::Int;
+    entropy  ::AbstractEntropyEstimator = ClosedFormEntropy()
+) = RepGradELBO(entropy, n_samples)
+
+Base.show(io::IO, obj::RepGradELBO) =
+    print(io, "RepGradELBO(entropy=$(obj.entropy), n_samples=$(obj.n_samples))")
+
+maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop
+
+maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q
+
+function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop)
+    q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
+    estimate_entropy(entropy_estimator, samples, q_maybe_stop)
+end
+
+function estimate_energy_with_samples(::RepGradELBO, samples, prob)
+    mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
+end
+
+function estimate_repgradelbo_maybe_stl_with_samples(
+    obj::RepGradELBO, q, q_stop, samples::AbstractMatrix, prob
+)
+    energy  = estimate_energy_with_samples(obj, samples, prob)
+    entropy = estimate_entropy_maybe_stl(obj.entropy, samples, q, q_stop)
+    energy + entropy
+end
+
+function estimate_repgradelbo_maybe_stl(rng::Random.AbstractRNG, obj::RepGradELBO, q, q_stop, prob)
+    samples = rand(rng, q, obj.n_samples)
+    estimate_repgradelbo_maybe_stl_with_samples(obj, q, q_stop, samples, prob)
+end
+
+"""
+    estimate_objective([rng,] obj, q, prob; n_samples)
+
+Estimate the ELBO using the reparameterization gradient formulation.
+
+# Arguments
+- `obj::RepGradELBO`: The ELBO objective.
+- `q`: Variational approximation
+- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
+
+# Keyword Arguments
+- `n_samples::Int = obj.n_samples`: Number of samples to be used to estimate the objective.
+
+# Returns
+- `obj_est`: Estimate of the objective value.
+"""
+function estimate_objective(
+    rng::Random.AbstractRNG,
+    obj::RepGradELBO,
+    q,
+    prob;
+    n_samples::Int = obj.n_samples
+)
+    samples = rand(rng, q, n_samples)
+    estimate_repgradelbo_maybe_stl_with_samples(obj, q, q, samples, prob)
+end
+
+estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) =
+    estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
+
+function estimate_gradient!(
+    rng      ::Random.AbstractRNG,
+    obj      ::RepGradELBO,
+    adbackend::ADTypes.AbstractADType,
+    out      ::DiffResults.MutableDiffResult,
+    prob,
+    λ,
+    restructure,
+    est_state,
+)
+    q_stop = restructure(λ)
+    function f(λ′)
+        q = restructure(λ′)
+        elbo = estimate_repgradelbo_maybe_stl(rng, obj, q, q_stop, prob)
+        -elbo
+    end
+    value_and_gradient!(adbackend, f, λ, out)
+
+    nelbo = DiffResults.value(out)
+    stat  = (elbo=-nelbo,)
+
+    out, nothing, stat
+end
diff --git a/src/utils.jl b/src/utils.jl
index 8dd7c37bf..76637fa3c 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -21,3 +21,6 @@ function maybe_init_objective(
     haskey(state_init, :objective) ? state_init.objective : init(rng, objective, λ, restructure)
 end
 
+eachsample(samples::AbstractMatrix) = eachcol(samples)
+
+eachsample(samples::AbstractVector) = samples
diff --git a/test/Project.toml b/test/Project.toml
index 490782cb5..7d0bf2d2f 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -1,6 +1,5 @@
 [deps]
 ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
-Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
 Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
@@ -23,7 +22,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
 
 [compat]
 ADTypes = "0.2.1"
-Bijectors = "0.13.6"
 Distributions = "0.25.100"
 DistributionsAD = "0.6.45"
 Enzyme = "0.11.7"
diff --git a/test/inference/advi_distributionsad_bijectors.jl b/test/inference/advi_distributionsad_bijectors.jl
deleted file mode 100644
index 29602fe76..000000000
--- a/test/inference/advi_distributionsad_bijectors.jl
+++ /dev/null
@@ -1,81 +0,0 @@
-
-const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
-
-using Test
-
-@testset "inference_advi_distributionsad_bijectors" begin
-    @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
-        realtype ∈ [Float64, Float32],
-        (modelname, modelconstr) ∈ Dict(
-            :NormalLogNormalMeanField => normallognormal_meanfield,
-        ),
-        (objname, objective) ∈ Dict(
-            :ADVIClosedFormEntropy  => ADVI(10),
-            :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()),
-        ),
-        (adbackname, adbackend) ∈ Dict(
-            :ForwarDiff  => AutoForwardDiff(),
-            #:ReverseDiff => AutoReverseDiff(),
-            #:Zygote      => AutoZygote(), 
-            #:Enzyme      => AutoEnzyme(),
-        )
-
-        seed = (0x38bef07cf9cc549d)
-        rng  = StableRNG(seed)
-
-        modelstats = modelconstr(rng, realtype)
-        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
-
-        T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
-
-        b    = Bijectors.bijector(model)
-        b⁻¹  = inverse(b)
-        μ₀   = Zeros(realtype, n_dims)
-        L₀   = Diagonal(Ones(realtype, n_dims))
-
-        q₀_η = TuringDiagMvNormal(μ₀, diag(L₀))
-        q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
-
-        @testset "convergence" begin
-            Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
-            q, stats, _ = optimize(
-                rng, model, objective, q₀_z, T;
-                optimizer     = Optimisers.Adam(realtype(η)),
-                show_progress = PROGRESS,
-                adbackend     = adbackend,
-            )
-
-            μ  = mean(q.dist)
-            L  = sqrt(cov(q.dist))
-            Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
-
-            @test Δλ ≤ Δλ₀/T^(1/4)
-            @test eltype(μ) == eltype(μ_true)
-            @test eltype(L) == eltype(L_true)
-        end
-
-        @testset "determinism" begin
-            rng = StableRNG(seed)
-            q, stats, _ = optimize(
-                rng, model, objective, q₀_z, T;
-                optimizer     = Optimisers.Adam(realtype(η)),
-                show_progress = PROGRESS,
-                adbackend     = adbackend,
-            )
-            μ  = mean(q.dist)
-            L  = sqrt(cov(q.dist))
-            
-            rng_repl = StableRNG(seed)
-            q, stats, _ = optimize(
-                rng_repl, model, objective, q₀_z, T;
-                optimizer     = Optimisers.Adam(realtype(η)),
-                show_progress = PROGRESS,
-                adbackend     = adbackend,
-            )
-            μ_repl = mean(q.dist)
-            L_repl = sqrt(cov(q.dist))
-            @test μ == μ_repl
-            @test L == L_repl
-        end
-    end
-end
diff --git a/test/inference/advi_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl
similarity index 78%
rename from test/inference/advi_distributionsad.jl
rename to test/inference/repgradelbo_distributionsad.jl
index e82a9ec0e..29cb2d834 100644
--- a/test/inference/advi_distributionsad.jl
+++ b/test/inference/repgradelbo_distributionsad.jl
@@ -3,15 +3,15 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
 
 using Test
 
-@testset "inference_advi_distributionsad" begin
+@testset "inference RepGradELBO DistributionsAD" begin
     @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
         realtype ∈ [Float64, Float32],
         (modelname, modelconstr) ∈ Dict(
             :Normal=> normal_meanfield,
         ),
         (objname, objective) ∈ Dict(
-            :ADVIClosedFormEntropy  => ADVI(10),
-            :ADVIStickingTheLanding => ADVI(10, entropy = StickingTheLandingEntropy()),
+            :RepGradELBOClosedFormEntropy  => RepGradELBO(10),
+            :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
         ),
         (adbackname, adbackend) ∈ Dict(
             :ForwarDiff  => AutoForwardDiff(),
@@ -28,14 +28,14 @@ using Test
 
         T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
 
-        μ₀   = Zeros(realtype, n_dims)
-        L₀   = Diagonal(Ones(realtype, n_dims))
-        q₀_z = TuringDiagMvNormal(μ₀, diag(L₀))
+        μ0 = Zeros(realtype, n_dims)
+        L0 = Diagonal(Ones(realtype, n_dims))
+        q0 = TuringDiagMvNormal(μ0, diag(L0))
 
         @testset "convergence" begin
-            Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
+            Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
             q, stats, _ = optimize(
-                rng, model, objective, q₀_z, T;
+                rng, model, objective, q0, T;
                 optimizer     = Optimisers.Adam(realtype(η)),
                 show_progress = PROGRESS,
                 adbackend     = adbackend,
@@ -53,7 +53,7 @@ using Test
         @testset "determinism" begin
             rng = StableRNG(seed)
             q, stats, _ = optimize(
-                rng, model, objective, q₀_z, T;
+                rng, model, objective, q0, T;
                 optimizer     = Optimisers.Adam(realtype(η)),
                 show_progress = PROGRESS,
                 adbackend     = adbackend,
@@ -63,7 +63,7 @@ using Test
 
             rng_repl = StableRNG(seed)
             q, stats, _ = optimize(
-                rng_repl, model, objective, q₀_z, T;
+                rng_repl, model, objective, q0, T;
                 optimizer     = Optimisers.Adam(realtype(η)),
                 show_progress = PROGRESS,
                 adbackend     = adbackend,
diff --git a/test/interface/advi.jl b/test/interface/advi.jl
deleted file mode 100644
index 1df396e4d..000000000
--- a/test/interface/advi.jl
+++ /dev/null
@@ -1,55 +0,0 @@
-
-using Test
-
-@testset "advi" begin
-    seed = (0x38bef07cf9cc549d)
-    rng  = StableRNG(seed)
-
-    @testset "with bijector"  begin
-        modelstats = normallognormal_meanfield(rng, Float64)
-
-        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
-
-        b⁻¹  = Bijectors.bijector(model) |> inverse
-        q₀_η = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
-        q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
-        obj  = ADVI(10)
-
-        rng      = StableRNG(seed)
-        elbo_ref = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4)
-
-        @testset "determinism" begin
-            rng  = StableRNG(seed)
-            elbo = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4)
-            @test elbo == elbo_ref
-        end
-
-        @testset "default_rng" begin
-            elbo = estimate_objective(obj, q₀_z, model; n_samples=10^4)
-            @test elbo ≈ elbo_ref rtol=0.1
-        end
-    end
-
-    @testset "without bijector"  begin
-        modelstats = normal_meanfield(rng, Float64)
-
-        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
-
-        q₀_z = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
-
-        obj      = ADVI(10)
-        rng      = StableRNG(seed)
-        elbo_ref = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4)
-
-        @testset "determinism" begin
-            rng  = StableRNG(seed)
-            elbo = estimate_objective(rng, obj, q₀_z, model; n_samples=10^4)
-            @test elbo == elbo_ref
-        end
-
-        @testset "default_rng" begin
-            elbo = estimate_objective(obj, q₀_z, model; n_samples=10^4)
-            @test elbo ≈ elbo_ref rtol=0.1
-        end
-    end
-end
diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl
index 3459b4c35..6e69616bd 100644
--- a/test/interface/optimize.jl
+++ b/test/interface/optimize.jl
@@ -1,27 +1,25 @@
 
 using Test
 
-@testset "optimize" begin
+@testset "interface optimize" begin
     seed = (0x38bef07cf9cc549d)
     rng  = StableRNG(seed)
 
     T = 1000
-    modelstats = normallognormal_meanfield(rng, Float64)
+    modelstats = normal_meanfield(rng, Float64)
 
     @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
 
     # Global Test Configurations
-    b⁻¹  = Bijectors.bijector(model) |> inverse
-    q₀_η = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
-    q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
-    obj  = ADVI(10)
+    q0  = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
+    obj = RepGradELBO(10)
 
     adbackend = AutoForwardDiff()
     optimizer = Optimisers.Adam(1e-2)
 
     rng  = StableRNG(seed)
     q_ref, stats_ref, _ = optimize(
-        rng, model, obj, q₀_z, T;
+        rng, model, obj, q0, T;
         optimizer,
         show_progress = false,
         adbackend,
@@ -30,13 +28,13 @@ using Test
 
     @testset "default_rng" begin
         optimize(
-            model, obj, q₀_z, T;
+            model, obj, q0, T;
             optimizer,
             show_progress = false,
             adbackend,
         )
 
-        λ₀, re  = Optimisers.destructure(q₀_z)
+        λ₀, re  = Optimisers.destructure(q0)
         optimize(
             model, obj, re, λ₀, T;
             optimizer,
@@ -46,7 +44,7 @@ using Test
     end
 
     @testset "restructure" begin
-        λ₀, re  = Optimisers.destructure(q₀_z)
+        λ₀, re  = Optimisers.destructure(q0)
 
         rng  = StableRNG(seed)
         λ, stats, _ = optimize(
@@ -67,7 +65,7 @@ using Test
 
         rng  = StableRNG(seed)
         _, stats, _ = optimize(
-            rng, model, obj, q₀_z, T;
+            rng, model, obj, q0, T;
             show_progress = false,
             adbackend,
             callback
@@ -82,7 +80,7 @@ using Test
         T_last  = T - T_first
 
         q_first, _, state = optimize(
-            rng, model, obj, q₀_z, T_first;
+            rng, model, obj, q0, T_first;
             optimizer,
             show_progress = false,
             adbackend
diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl
new file mode 100644
index 000000000..61ff0111c
--- /dev/null
+++ b/test/interface/repgradelbo.jl
@@ -0,0 +1,28 @@
+
+using Test
+
+@testset "interface RepGradELBO" begin
+    seed = (0x38bef07cf9cc549d)
+    rng  = StableRNG(seed)
+
+    modelstats = normal_meanfield(rng, Float64)
+
+    @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+    q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
+
+    obj      = RepGradELBO(10)
+    rng      = StableRNG(seed)
+    elbo_ref = estimate_objective(rng, obj, q0, model; n_samples=10^4)
+
+    @testset "determinism" begin
+        rng  = StableRNG(seed)
+        elbo = estimate_objective(rng, obj, q0, model; n_samples=10^4)
+        @test elbo == elbo_ref
+    end
+
+    @testset "default_rng" begin
+        elbo = estimate_objective(obj, q0, model; n_samples=10^4)
+        @test elbo ≈ elbo_ref rtol=0.1
+    end
+end
diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl
deleted file mode 100644
index c2cb2b0e9..000000000
--- a/test/models/normallognormal.jl
+++ /dev/null
@@ -1,65 +0,0 @@
-
-struct NormalLogNormal{MX,SX,MY,SY}
-    μ_x::MX
-    σ_x::SX
-    μ_y::MY
-    Σ_y::SY
-end
-
-function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
-    logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
-end
-
-function LogDensityProblems.dimension(model::NormalLogNormal)
-    length(model.μ_y) + 1
-end
-
-function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
-    LogDensityProblems.LogDensityOrder{0}()
-end
-
-function Bijectors.bijector(model::NormalLogNormal)
-    @unpack μ_x, σ_x, μ_y, Σ_y = model
-    Bijectors.Stacked(
-        Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),
-        [1:1, 2:1+length(μ_y)])
-end
-
-function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type)
-    n_dims = 5
-
-    μ_x = randn(rng, realtype)
-    σ_x = ℯ
-    μ_y = randn(rng, realtype, n_dims)
-    L_y = tril(I + ones(realtype, n_dims, n_dims))/2
-    Σ_y = L_y*L_y' |> Hermitian
-
-    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0)))
-
-    Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1)
-    Σ[1,1]         = σ_x^2
-    Σ[2:end,2:end] = Σ_y
-    Σ = Σ |> Hermitian
-
-    μ = vcat(μ_x, μ_y)
-    L = cholesky(Σ).L
-
-    TestModel(model, μ, L, n_dims+1, false)
-end
-
-function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type)
-    n_dims = 5
-
-    μ_x  = randn(rng, realtype)
-    σ_x  = ℯ
-    μ_y  = randn(rng, realtype, n_dims)
-    σ_y  = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
-
-    model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2))
-
-    μ = vcat(μ_x, μ_y)
-    L = vcat(σ_x, σ_y) |> Diagonal
-
-    TestModel(model, μ, L, n_dims+1, true)
-end
diff --git a/test/runtests.jl b/test/runtests.jl
index 757a931d8..a855541cf 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -14,7 +14,6 @@ using Functors
 using DistributionsAD
 @functor TuringDiagMvNormal
 
-using Bijectors
 using LogDensityProblems
 using Optimisers
 using ADTypes
@@ -30,14 +29,11 @@ struct TestModel{M,L,S}
     n_dims::Int
     is_meanfield::Bool
 end
-
-include("models/normallognormal.jl")
 include("models/normal.jl")
 
 # Tests
 include("interface/ad.jl")
 include("interface/optimize.jl")
-include("interface/advi.jl")
+include("interface/repgradelbo.jl")
 
-include("inference/advi_distributionsad.jl")
-include("inference/advi_distributionsad_bijectors.jl")
+include("inference/repgradelbo_distributionsad.jl")

From 13b208868dc2b3d3d6d5fe9b38228fa2879684cd Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 23 Nov 2023 00:44:57 -0500
Subject: [PATCH 200/206] fix documentation for estimate_objective

---
 src/AdvancedVI.jl                  | 20 ++++++++++++++++++--
 src/objectives/elbo/repgradelbo.jl | 20 --------------------
 2 files changed, 18 insertions(+), 22 deletions(-)

diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index bb5b6e856..d17a088ca 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -70,7 +70,7 @@ init(
 ) = nothing
 
 """
-    estimate_objective([rng,] obj, q, prob, kwargs...)
+    estimate_objective([rng,] obj, q, prob; kwargs...)
 
 Estimate the variational objective `obj` targeting `prob` with respect to the variational approximation `q`.
 
@@ -81,7 +81,8 @@ Estimate the variational objective `obj` targeting `prob` with respect to the va
 - `q`: Variational approximation.
 
 # Keyword Arguments
-For the keywword arguments, refer to the respective documentation for each variational objective.
+Depending on the objective, additional keyword arguments may apply.
+Please refer to the respective documentation of each variational objective for more info.
 
 # Returns
 - `obj_est`: Estimate of the objective value.
@@ -116,6 +117,21 @@ function estimate_gradient! end
 # ELBO-specific interfaces
 abstract type AbstractEntropyEstimator end
 
+"""
+    estimate_entropy(entropy_estimator, mc_samples, q)
+
+Estimate the entropy of `q`.
+
+# Arguments
+- `entropy_estimator`: Entropy estimation strategy.
+- `q`: Variational approximation.
+- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.)
+
+# Returns
+- `obj_est`: Estimate of the objective value.
+"""
+function estimate_entropy end
+
 export
     RepGradELBO,
     ClosedFormEntropy,
diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl
index 09ba1a793..48a5461f8 100644
--- a/src/objectives/elbo/repgradelbo.jl
+++ b/src/objectives/elbo/repgradelbo.jl
@@ -45,10 +45,6 @@ RepGradELBO(
 Base.show(io::IO, obj::RepGradELBO) =
     print(io, "RepGradELBO(entropy=$(obj.entropy), n_samples=$(obj.n_samples))")
 
-maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop
-
-maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q
-
 function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop)
     q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
     estimate_entropy(entropy_estimator, samples, q_maybe_stop)
@@ -71,22 +67,6 @@ function estimate_repgradelbo_maybe_stl(rng::Random.AbstractRNG, obj::RepGradELB
     estimate_repgradelbo_maybe_stl_with_samples(obj, q, q_stop, samples, prob)
 end
 
-"""
-    estimate_objective([rng,] obj, q, prob; n_samples)
-
-Estimate the ELBO using the reparameterization gradient formulation.
-
-# Arguments
-- `obj::RepGradELBO`: The ELBO objective.
-- `q`: Variational approximation
-- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
-
-# Keyword Arguments
-- `n_samples::Int = obj.n_samples`: Number of samples to be used to estimate the objective.
-
-# Returns
-- `obj_est`: Estimate of the objective value.
-"""
 function estimate_objective(
     rng::Random.AbstractRNG,
     obj::RepGradELBO,

From b0e1be14bed230f307d704641bdc911728367613 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 23 Nov 2023 02:05:48 -0500
Subject: [PATCH 201/206] refactor add indirection in repgradelbo for
 interacting with `q`

---
 src/objectives/elbo/entropy.jl     | 20 +++-----------
 src/objectives/elbo/repgradelbo.jl | 44 ++++++++++++++++++++----------
 2 files changed, 33 insertions(+), 31 deletions(-)

diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index 6fa3095e8..231b16523 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -1,20 +1,4 @@
 
-"""
-    estimate_entropy(entropy_estimator, mc_samples, q)
-
-Estimate the entropy of `q`.
-
-# Arguments
-- `entropy_estimator`: Entropy estimation strategy.
-- `q`: Variational approximation.
-- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.)
-
-# Returns
-- `obj_est`: Estimate of the objective value.
-"""
-
-function estimate_entropy end
-
 """
     ClosedFormEntropy()
 
@@ -29,6 +13,8 @@ Use closed-form expression of entropy.
 """
 struct ClosedFormEntropy <: AbstractEntropyEstimator end
 
+maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q
+
 function estimate_entropy(::ClosedFormEntropy, ::Any, q)
     entropy(q)
 end
@@ -49,6 +35,8 @@ struct StickingTheLandingEntropy <: AbstractEntropyEstimator end
 
 struct MonteCarloEntropy <: AbstractEntropyEstimator end
 
+maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop
+
 function estimate_entropy(
     ::Union{MonteCarloEntropy, StickingTheLandingEntropy},
     mc_samples::AbstractMatrix,
diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl
index 48a5461f8..28bd681fc 100644
--- a/src/objectives/elbo/repgradelbo.jl
+++ b/src/objectives/elbo/repgradelbo.jl
@@ -50,21 +50,32 @@ function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator,
     estimate_entropy(entropy_estimator, samples, q_maybe_stop)
 end
 
-function estimate_energy_with_samples(::RepGradELBO, samples, prob)
+function estimate_energy_with_samples(prob, samples)
     mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
 end
 
-function estimate_repgradelbo_maybe_stl_with_samples(
-    obj::RepGradELBO, q, q_stop, samples::AbstractMatrix, prob
-)
-    energy  = estimate_energy_with_samples(obj, samples, prob)
-    entropy = estimate_entropy_maybe_stl(obj.entropy, samples, q, q_stop)
-    energy + entropy
-end
+"""
+    reparam_with_entropy(rng, n_samples, q, q_stop, ent_est)
+
+Draw `n_samples` from `q` and compute its entropy.
 
-function estimate_repgradelbo_maybe_stl(rng::Random.AbstractRNG, obj::RepGradELBO, q, q_stop, prob)
-    samples = rand(rng, q, obj.n_samples)
-    estimate_repgradelbo_maybe_stl_with_samples(obj, q, q_stop, samples, prob)
+# Arguments
+- `rng::Random.AbstractRNG`: Random number generator.
+- `n_samples::Int`: Number of Monte Carlo samples 
+- `q`: Variational approximation.
+- `q_stop`: `q` but with its gradient stopped.
+- `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.)
+
+# Returns
+- `samples`: Monte Carlo samples generated through reparameterization. Their support matches that of the target distribution.
+- `entropy`: An estimate (or exact value) of the differential entropy of `q`.
+"""
+function reparam_with_entropy(
+    rng::Random.AbstractRNG, n_samples::Int, q, q_stop, ent_est
+)
+    samples = rand(rng, q, n_samples)
+    entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop)
+    samples, entropy
 end
 
 function estimate_objective(
@@ -74,8 +85,9 @@ function estimate_objective(
     prob;
     n_samples::Int = obj.n_samples
 )
-    samples = rand(rng, q, n_samples)
-    estimate_repgradelbo_maybe_stl_with_samples(obj, q, q, samples, prob)
+    samples, entropy =  reparam_with_entropy(rng, n_samples, q, q, obj.entropy)
+    energy = estimate_energy_with_samples(prob, samples)
+    energy + entropy
 end
 
 estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) =
@@ -89,12 +101,14 @@ function estimate_gradient!(
     prob,
     λ,
     restructure,
-    est_state,
+    state,
 )
     q_stop = restructure(λ)
     function f(λ′)
         q = restructure(λ′)
-        elbo = estimate_repgradelbo_maybe_stl(rng, obj, q, q_stop, prob)
+        samples, entropy = reparam_with_entropy(rng, obj.n_samples, q, q_stop, obj.entropy)
+        energy = estimate_energy_with_samples(prob, samples)
+        elbo = energy + entropy
         -elbo
     end
     value_and_gradient!(adbackend, f, λ, out)

From 7361ed4d5abafa46a3b16b74fec0be612d859d7e Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Thu, 23 Nov 2023 02:37:42 -0500
Subject: [PATCH 202/206] add TransformedDistribution support as extension

---
 Project.toml                                  |  4 +
 ext/AdvancedVIBijectorsExt.jl                 | 37 +++++++++
 src/AdvancedVI.jl                             |  3 +
 test/Project.toml                             |  2 +
 .../repgradelbo_distributionsad_bijectors.jl  | 81 +++++++++++++++++++
 test/models/normallognormal.jl                | 65 +++++++++++++++
 test/runtests.jl                              |  3 +
 7 files changed, 195 insertions(+)
 create mode 100644 ext/AdvancedVIBijectorsExt.jl
 create mode 100644 test/inference/repgradelbo_distributionsad_bijectors.jl
 create mode 100644 test/models/normallognormal.jl

diff --git a/Project.toml b/Project.toml
index 7799d5057..f4ea1bccb 100644
--- a/Project.toml
+++ b/Project.toml
@@ -21,6 +21,7 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
 
 [weakdeps]
+Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
@@ -31,10 +32,12 @@ AdvancedVIEnzymeExt = "Enzyme"
 AdvancedVIForwardDiffExt = "ForwardDiff"
 AdvancedVIReverseDiffExt = "ReverseDiff"
 AdvancedVIZygoteExt = "Zygote"
+AdvancedVIBijectorsExt = "Bijectors"
 
 [compat]
 ADTypes = "0.1, 0.2"
 Accessors = "0.1"
+Bijectors = "0.13"
 ChainRulesCore = "1.16"
 DiffResults = "1"
 Distributions = "0.25.87"
@@ -56,6 +59,7 @@ Zygote = "0.6.63"
 julia = "1.6"
 
 [extras]
+Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl
new file mode 100644
index 000000000..5d9dc774f
--- /dev/null
+++ b/ext/AdvancedVIBijectorsExt.jl
@@ -0,0 +1,37 @@
+
+module AdvancedVIBijectorsExt
+
+if isdefined(Base, :get_extension)
+    using AdvancedVI
+    using Bijectors
+    using Random
+else
+    using ..AdvancedVI
+    using ..Bijectors
+    using ..Random
+end
+
+function AdvancedVI.reparam_with_entropy(
+    rng      ::Random.AbstractRNG,
+    n_samples::Int,
+    q        ::Bijectors.TransformedDistribution,
+    q_stop   ::Bijectors.TransformedDistribution,
+    ent_est
+)
+    transform     = q.transform
+    q_base        = q.dist
+    q_base_stop   = q_stop.dist
+    ∑logabsdetjac = 0.0
+    base_samples  = rand(rng, q_base, n_samples)
+    samples       = mapreduce(hcat, eachcol(base_samples)) do base_sample
+        sample, logabsdetjac = with_logabsdet_jacobian(transform, base_sample)
+        ∑logabsdetjac       += logabsdetjac
+        sample
+    end
+    entropy_base = AdvancedVI.estimate_entropy_maybe_stl(
+        ent_est, base_samples, q_base, q_base_stop
+    )
+    entropy      = entropy_base + ∑logabsdetjac/n_samples
+    samples, entropy
+end
+end
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index d17a088ca..89f866963 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -158,6 +158,9 @@ end
 
 @static if !isdefined(Base, :get_extension)
     function __init__()
+        @require Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" begin
+            include("../ext/AdvancedVIBijectorsExt.jl")
+        end
         @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
             include("../ext/AdvancedVIEnzymeExt.jl")
         end
diff --git a/test/Project.toml b/test/Project.toml
index 7d0bf2d2f..a751b89d9 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -1,5 +1,6 @@
 [deps]
 ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
+Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
 Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
@@ -22,6 +23,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
 
 [compat]
 ADTypes = "0.2.1"
+Bijectors = "0.13"
 Distributions = "0.25.100"
 DistributionsAD = "0.6.45"
 Enzyme = "0.11.7"
diff --git a/test/inference/repgradelbo_distributionsad_bijectors.jl b/test/inference/repgradelbo_distributionsad_bijectors.jl
new file mode 100644
index 000000000..9f1e3cc4a
--- /dev/null
+++ b/test/inference/repgradelbo_distributionsad_bijectors.jl
@@ -0,0 +1,81 @@
+
+const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
+
+using Test
+
+@testset "inference RepGradELBO DistributionsAD Bijectors" begin
+    @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
+        realtype ∈ [Float64, Float32],
+        (modelname, modelconstr) ∈ Dict(
+            :NormalLogNormalMeanField => normallognormal_meanfield,
+        ),
+        (objname, objective) ∈ Dict(
+            :RepGradELBOClosedFormEntropy  => RepGradELBO(10),
+            :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
+        ),
+        (adbackname, adbackend) ∈ Dict(
+            :ForwarDiff  => AutoForwardDiff(),
+            #:ReverseDiff => AutoReverseDiff(),
+            #:Zygote      => AutoZygote(), 
+            #:Enzyme      => AutoEnzyme(),
+        )
+
+        seed = (0x38bef07cf9cc549d)
+        rng  = StableRNG(seed)
+
+        modelstats = modelconstr(rng, realtype)
+        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+        T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
+
+        b    = Bijectors.bijector(model)
+        b⁻¹  = inverse(b)
+        μ₀   = Zeros(realtype, n_dims)
+        L₀   = Diagonal(Ones(realtype, n_dims))
+
+        q₀_η = TuringDiagMvNormal(μ₀, diag(L₀))
+        q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
+
+        @testset "convergence" begin
+            Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
+            q, stats, _ = optimize(
+                rng, model, objective, q₀_z, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
+            )
+
+            μ  = mean(q.dist)
+            L  = sqrt(cov(q.dist))
+            Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
+
+            @test Δλ ≤ Δλ₀/T^(1/4)
+            @test eltype(μ) == eltype(μ_true)
+            @test eltype(L) == eltype(L_true)
+        end
+
+        @testset "determinism" begin
+            rng = StableRNG(seed)
+            q, stats, _ = optimize(
+                rng, model, objective, q₀_z, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
+            )
+            μ  = mean(q.dist)
+            L  = sqrt(cov(q.dist))
+
+            rng_repl = StableRNG(seed)
+            q, stats, _ = optimize(
+                rng_repl, model, objective, q₀_z, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
+            )
+            μ_repl = mean(q.dist)
+            L_repl = sqrt(cov(q.dist))
+            @test μ == μ_repl
+            @test L == L_repl
+        end
+    end
+end
diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl
new file mode 100644
index 000000000..6615084b9
--- /dev/null
+++ b/test/models/normallognormal.jl
@@ -0,0 +1,65 @@
+
+struct NormalLogNormal{MX,SX,MY,SY}	
+    μ_x::MX	
+    σ_x::SX	
+    μ_y::MY	
+    Σ_y::SY	
+end	
+
+function LogDensityProblems.logdensity(model::NormalLogNormal, θ)	
+    @unpack μ_x, σ_x, μ_y, Σ_y = model	
+    logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])	
+end	
+
+function LogDensityProblems.dimension(model::NormalLogNormal)	
+    length(model.μ_y) + 1	
+end	
+
+function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})	
+    LogDensityProblems.LogDensityOrder{0}()	
+end	
+
+function Bijectors.bijector(model::NormalLogNormal)	
+    @unpack μ_x, σ_x, μ_y, Σ_y = model	
+    Bijectors.Stacked(	
+        Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),	
+        [1:1, 2:1+length(μ_y)])	
+end	
+
+function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type)	
+    n_dims = 5	
+
+    μ_x = randn(rng, realtype)	
+    σ_x = ℯ	
+    μ_y = randn(rng, realtype, n_dims)	
+    L_y = tril(I + ones(realtype, n_dims, n_dims))/2	
+    Σ_y = L_y*L_y' |> Hermitian	
+
+    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0)))	
+
+    Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1)	
+    Σ[1,1]         = σ_x^2	
+    Σ[2:end,2:end] = Σ_y	
+    Σ = Σ |> Hermitian	
+
+    μ = vcat(μ_x, μ_y)	
+    L = cholesky(Σ).L	
+
+    TestModel(model, μ, L, n_dims+1, false)	
+end	
+
+function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type)	
+    n_dims = 5	
+
+    μ_x  = randn(rng, realtype)	
+    σ_x  = ℯ	
+    μ_y  = randn(rng, realtype, n_dims)	
+    σ_y  = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)	
+
+    model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2))	
+
+    μ = vcat(μ_x, μ_y)	
+    L = vcat(σ_x, σ_y) |> Diagonal	
+
+    TestModel(model, μ, L, n_dims+1, true)	
+end	
diff --git a/test/runtests.jl b/test/runtests.jl
index a855541cf..b14b8b2ed 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -2,6 +2,7 @@
 using Test
 using Test: @testset, @test
 
+using Bijectors
 using Random, StableRNGs
 using Statistics
 using Distributions
@@ -30,6 +31,7 @@ struct TestModel{M,L,S}
     is_meanfield::Bool
 end
 include("models/normal.jl")
+include("models/normallognormal.jl")
 
 # Tests
 include("interface/ad.jl")
@@ -37,3 +39,4 @@ include("interface/optimize.jl")
 include("interface/repgradelbo.jl")
 
 include("inference/repgradelbo_distributionsad.jl")
+include("inference/repgradelbo_distributionsad_bijectors.jl")

From d2e76143f18cfbe0631816e8af03b6b531256067 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <msca8h@naver.com>
Date: Fri, 8 Dec 2023 02:24:30 -0500
Subject: [PATCH 203/206] Update src/objectives/elbo/repgradelbo.jl

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
---
 src/objectives/elbo/repgradelbo.jl | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl
index 28bd681fc..7e0931352 100644
--- a/src/objectives/elbo/repgradelbo.jl
+++ b/src/objectives/elbo/repgradelbo.jl
@@ -43,7 +43,11 @@ RepGradELBO(
 ) = RepGradELBO(entropy, n_samples)
 
 Base.show(io::IO, obj::RepGradELBO) =
-    print(io, "RepGradELBO(entropy=$(obj.entropy), n_samples=$(obj.n_samples))")
+    print(io, "RepGradELBO(entropy=")
+    print(io, obj.entropy)
+    print(io, ", n_samples=")
+    print(io, obj.n_samples)
+    print(io, ")")
 
 function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop)
     q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)

From 77686b5c776de4e42637932b0310c36b6e4a8d86 Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Fri, 8 Dec 2023 02:31:21 -0500
Subject: [PATCH 204/206] fix docstring for entropy estimator

---
 src/objectives/elbo/entropy.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
index 231b16523..6c5b4739d 100644
--- a/src/objectives/elbo/entropy.jl
+++ b/src/objectives/elbo/entropy.jl
@@ -5,7 +5,7 @@
 Use closed-form expression of entropy.
 
 # Requirements
-- `q` implements `entropy`.
+- The variational approximation implements `entropy`.
 
 # References
 * Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR.
@@ -25,7 +25,7 @@ end
 The "sticking the landing" entropy estimator.
 
 # Requirements
-- `q` implements `logpdf`.
+- The variational approximation `q` implements `logpdf`.
 - `logpdf(q, η)` must be differentiable by the selected AD framework.
 
 # References

From 8461b43c821980bd433631fd3ce9e369db56b01c Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Fri, 8 Dec 2023 03:12:52 -0500
Subject: [PATCH 205/206] fix `reparam_with_entropy` specialization for
 bijectors

---
 ext/AdvancedVIBijectorsExt.jl      | 32 +++++++++++++++++++-----------
 src/objectives/elbo/repgradelbo.jl | 19 +++++++++++-------
 src/utils.jl                       | 10 ++++++++++
 3 files changed, 42 insertions(+), 19 deletions(-)

diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl
index 5d9dc774f..1b200ac5d 100644
--- a/ext/AdvancedVIBijectorsExt.jl
+++ b/ext/AdvancedVIBijectorsExt.jl
@@ -13,25 +13,33 @@ end
 
 function AdvancedVI.reparam_with_entropy(
     rng      ::Random.AbstractRNG,
-    n_samples::Int,
     q        ::Bijectors.TransformedDistribution,
     q_stop   ::Bijectors.TransformedDistribution,
-    ent_est
+    n_samples::Int,
+    ent_est  ::AdvancedVI.AbstractEntropyEstimator
 )
-    transform     = q.transform
-    q_base        = q.dist
-    q_base_stop   = q_stop.dist
-    ∑logabsdetjac = 0.0
-    base_samples  = rand(rng, q_base, n_samples)
-    samples       = mapreduce(hcat, eachcol(base_samples)) do base_sample
-        sample, logabsdetjac = with_logabsdet_jacobian(transform, base_sample)
-        ∑logabsdetjac       += logabsdetjac
-        sample
+    transform    = q.transform
+    q_base       = q.dist
+    q_base_stop  = q_stop.dist
+    base_samples = rand(rng, q_base, n_samples)
+    it           = AdvancedVI.eachsample(base_samples)
+    sample_init  = first(it)
+
+    samples_and_logjac = mapreduce(
+        AdvancedVI.catsamples_and_acc,
+        Iterators.drop(it, 1);
+        init=with_logabsdet_jacobian(transform, sample_init)
+    ) do sample
+        with_logabsdet_jacobian(transform, sample)
     end
+    samples = first(samples_and_logjac)
+    logjac  = last(samples_and_logjac)
+
     entropy_base = AdvancedVI.estimate_entropy_maybe_stl(
         ent_est, base_samples, q_base, q_base_stop
     )
-    entropy      = entropy_base + ∑logabsdetjac/n_samples
+
+    entropy = entropy_base + logjac/n_samples
     samples, entropy
 end
 end
diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl
index 28bd681fc..04f353200 100644
--- a/src/objectives/elbo/repgradelbo.jl
+++ b/src/objectives/elbo/repgradelbo.jl
@@ -42,8 +42,13 @@ RepGradELBO(
     entropy  ::AbstractEntropyEstimator = ClosedFormEntropy()
 ) = RepGradELBO(entropy, n_samples)
 
-Base.show(io::IO, obj::RepGradELBO) =
-    print(io, "RepGradELBO(entropy=$(obj.entropy), n_samples=$(obj.n_samples))")
+function Base.show(io::IO, obj::RepGradELBO)
+    print(io, "RepGradELBO(entropy=")
+    print(io, obj.entropy)
+    print(io, ", n_samples=")
+    print(io, obj.n_samples)
+    print(io, ")")
+end
 
 function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop)
     q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
@@ -55,15 +60,15 @@ function estimate_energy_with_samples(prob, samples)
 end
 
 """
-    reparam_with_entropy(rng, n_samples, q, q_stop, ent_est)
+    reparam_with_entropy(rng, q, q_stop, n_samples, ent_est)
 
 Draw `n_samples` from `q` and compute its entropy.
 
 # Arguments
 - `rng::Random.AbstractRNG`: Random number generator.
-- `n_samples::Int`: Number of Monte Carlo samples 
 - `q`: Variational approximation.
 - `q_stop`: `q` but with its gradient stopped.
+- `n_samples::Int`: Number of Monte Carlo samples 
 - `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.)
 
 # Returns
@@ -71,7 +76,7 @@ Draw `n_samples` from `q` and compute its entropy.
 - `entropy`: An estimate (or exact value) of the differential entropy of `q`.
 """
 function reparam_with_entropy(
-    rng::Random.AbstractRNG, n_samples::Int, q, q_stop, ent_est
+    rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator
 )
     samples = rand(rng, q, n_samples)
     entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop)
@@ -85,7 +90,7 @@ function estimate_objective(
     prob;
     n_samples::Int = obj.n_samples
 )
-    samples, entropy =  reparam_with_entropy(rng, n_samples, q, q, obj.entropy)
+    samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy)
     energy = estimate_energy_with_samples(prob, samples)
     energy + entropy
 end
@@ -106,7 +111,7 @@ function estimate_gradient!(
     q_stop = restructure(λ)
     function f(λ′)
         q = restructure(λ′)
-        samples, entropy = reparam_with_entropy(rng, obj.n_samples, q, q_stop, obj.entropy)
+        samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
         energy = estimate_energy_with_samples(prob, samples)
         elbo = energy + entropy
         -elbo
diff --git a/src/utils.jl b/src/utils.jl
index 76637fa3c..8e67ff1a3 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -24,3 +24,13 @@ end
 eachsample(samples::AbstractMatrix) = eachcol(samples)
 
 eachsample(samples::AbstractVector) = samples
+
+function catsamples_and_acc(
+    state_curr::Tuple{<:AbstractArray,  <:Real},
+    state_new ::Tuple{<:AbstractVector, <:Real}
+)
+    x  = hcat(first(state_curr), first(state_new))
+    ∑y = last(state_curr) + last(state_new)
+    return (x, ∑y)
+end
+

From bd925cce08473c9f5698beac05d810c146ecc56c Mon Sep 17 00:00:00 2001
From: Kyurae Kim <kyrkim@seas.upenn.edu>
Date: Fri, 8 Dec 2023 03:19:58 -0500
Subject: [PATCH 206/206] enable Zygote for non-bijector tests

---
 test/inference/repgradelbo_distributionsad.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl
index 29cb2d834..b6db22a62 100644
--- a/test/inference/repgradelbo_distributionsad.jl
+++ b/test/inference/repgradelbo_distributionsad.jl
@@ -16,7 +16,7 @@ using Test
         (adbackname, adbackend) ∈ Dict(
             :ForwarDiff  => AutoForwardDiff(),
             #:ReverseDiff => AutoReverseDiff(),
-            #:Zygote      => AutoZygote(), 
+            :Zygote      => AutoZygote(), 
             #:Enzyme      => AutoEnzyme(),
         )