diff --git a/Project.toml b/Project.toml
index 800fa34f3..f4ea1bccb 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,37 +1,71 @@
 name = "AdvancedVI"
 uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
-version = "0.2.4"
+version = "0.3.0"
 
 [deps]
-Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
+ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
+Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
+ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
-DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
 DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
-ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
+FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
+Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
+Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
 ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Requires = "ae029012-a4dd-5104-9daa-d747884805df"
+SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
-StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
-Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
+
+[weakdeps]
+Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
+Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
+ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
+Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
+
+[extensions]
+AdvancedVIEnzymeExt = "Enzyme"
+AdvancedVIForwardDiffExt = "ForwardDiff"
+AdvancedVIReverseDiffExt = "ReverseDiff"
+AdvancedVIZygoteExt = "Zygote"
+AdvancedVIBijectorsExt = "Bijectors"
 
 [compat]
-Bijectors = "0.11, 0.12, 0.13"
-Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
-DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6"
+ADTypes = "0.1, 0.2"
+Accessors = "0.1"
+Bijectors = "0.13"
+ChainRulesCore = "1.16"
+DiffResults = "1"
+Distributions = "0.25.87"
 DocStringExtensions = "0.8, 0.9"
-ForwardDiff = "0.10.3"
-ProgressMeter = "1.0.0"
-Requires = "0.5, 1.0"
+Enzyme = "0.11.7"
+FillArrays = "1.3"
+ForwardDiff = "0.10.36"
+Functors = "0.4"
+LinearAlgebra = "1"
+LogDensityProblems = "2"
+Optimisers = "0.2.16, 0.3"
+ProgressMeter = "1.6"
+Random = "1"
+Requires = "1.0"
+ReverseDiff = "1.15.1"
+SimpleUnPack = "1.1.0"
 StatsBase = "0.32, 0.33, 0.34"
-StatsFuns = "0.8, 0.9, 1"
-Tracker = "0.2.3"
+Zygote = "0.6.63"
 julia = "1.6"
 
 [extras]
+Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
+Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
+ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
 
 [targets]
 test = ["Pkg", "Test"]
diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl
new file mode 100644
index 000000000..1b200ac5d
--- /dev/null
+++ b/ext/AdvancedVIBijectorsExt.jl
@@ -0,0 +1,45 @@
+
+module AdvancedVIBijectorsExt
+
+if isdefined(Base, :get_extension)
+    using AdvancedVI
+    using Bijectors
+    using Random
+else
+    using ..AdvancedVI
+    using ..Bijectors
+    using ..Random
+end
+
+function AdvancedVI.reparam_with_entropy(
+    rng      ::Random.AbstractRNG,
+    q        ::Bijectors.TransformedDistribution,
+    q_stop   ::Bijectors.TransformedDistribution,
+    n_samples::Int,
+    ent_est  ::AdvancedVI.AbstractEntropyEstimator
+)
+    transform    = q.transform
+    q_base       = q.dist
+    q_base_stop  = q_stop.dist
+    base_samples = rand(rng, q_base, n_samples)
+    it           = AdvancedVI.eachsample(base_samples)
+    sample_init  = first(it)
+
+    samples_and_logjac = mapreduce(
+        AdvancedVI.catsamples_and_acc,
+        Iterators.drop(it, 1);
+        init=with_logabsdet_jacobian(transform, sample_init)
+    ) do sample
+        with_logabsdet_jacobian(transform, sample)
+    end
+    samples = first(samples_and_logjac)
+    logjac  = last(samples_and_logjac)
+
+    entropy_base = AdvancedVI.estimate_entropy_maybe_stl(
+        ent_est, base_samples, q_base, q_base_stop
+    )
+
+    entropy = entropy_base + logjac/n_samples
+    samples, entropy
+end
+end
diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl
new file mode 100644
index 000000000..8333299f0
--- /dev/null
+++ b/ext/AdvancedVIEnzymeExt.jl
@@ -0,0 +1,26 @@
+
+module AdvancedVIEnzymeExt
+
+if isdefined(Base, :get_extension)
+    using Enzyme
+    using AdvancedVI
+    using AdvancedVI: ADTypes, DiffResults
+else
+    using ..Enzyme
+    using ..AdvancedVI
+    using ..AdvancedVI: ADTypes, DiffResults
+end
+
+# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916)
+function AdvancedVI.value_and_gradient!(
+    ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
+) where {T<:Real}
+    y = f(θ)
+    DiffResults.value!(out, y)
+    ∇θ = DiffResults.gradient(out)
+    fill!(∇θ, zero(T))
+    Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ))
+    return out
+end
+
+end
diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl
new file mode 100644
index 000000000..5949bdf81
--- /dev/null
+++ b/ext/AdvancedVIForwardDiffExt.jl
@@ -0,0 +1,29 @@
+
+module AdvancedVIForwardDiffExt
+
+if isdefined(Base, :get_extension)
+    using ForwardDiff
+    using AdvancedVI
+    using AdvancedVI: ADTypes, DiffResults
+else
+    using ..ForwardDiff
+    using ..AdvancedVI
+    using ..AdvancedVI: ADTypes, DiffResults
+end
+
+getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize
+
+function AdvancedVI.value_and_gradient!(
+    ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
+) where {T<:Real}
+    chunk_size = getchunksize(ad)
+    config = if isnothing(chunk_size)
+        ForwardDiff.GradientConfig(f, θ)
+    else
+        ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size))
+    end
+    ForwardDiff.gradient!(out, f, θ, config)
+    return out
+end
+
+end
diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl
new file mode 100644
index 000000000..520cd9ff1
--- /dev/null
+++ b/ext/AdvancedVIReverseDiffExt.jl
@@ -0,0 +1,23 @@
+
+module AdvancedVIReverseDiffExt
+
+if isdefined(Base, :get_extension)
+    using AdvancedVI
+    using AdvancedVI: ADTypes, DiffResults
+    using ReverseDiff
+else
+    using ..AdvancedVI
+    using ..AdvancedVI: ADTypes, DiffResults
+    using ..ReverseDiff
+end
+
+# ReverseDiff without compiled tape
+function AdvancedVI.value_and_gradient!(
+    ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
+)
+    tp = ReverseDiff.GradientTape(f, θ)
+    ReverseDiff.gradient!(out, tp, θ)
+    return out
+end
+
+end
diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl
new file mode 100644
index 000000000..7b8f8817a
--- /dev/null
+++ b/ext/AdvancedVIZygoteExt.jl
@@ -0,0 +1,24 @@
+
+module AdvancedVIZygoteExt
+
+if isdefined(Base, :get_extension)
+    using AdvancedVI
+    using AdvancedVI: ADTypes, DiffResults
+    using Zygote
+else
+    using ..AdvancedVI
+    using ..AdvancedVI: ADTypes, DiffResults
+    using ..Zygote
+end
+
+function AdvancedVI.value_and_gradient!(
+    ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
+)
+    y, back = Zygote.pullback(f, θ)
+    ∇θ = back(one(y))
+    DiffResults.value!(out, y)
+    DiffResults.gradient!(out, only(∇θ))
+    return out
+end
+
+end
diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl
index e203a13ca..89f866963 100644
--- a/src/AdvancedVI.jl
+++ b/src/AdvancedVI.jl
@@ -1,270 +1,180 @@
+
 module AdvancedVI
 
-using Random: Random
+using SimpleUnPack: @unpack, @pack!
+using Accessors
 
-using Distributions, DistributionsAD, Bijectors
-using DocStringExtensions
+using Random
+using Distributions
 
-using ProgressMeter, LinearAlgebra
+using Functors
+using Optimisers
 
-using ForwardDiff
-using Tracker
+using DocStringExtensions
+using ProgressMeter
+using LinearAlgebra
 
-const PROGRESS = Ref(true)
-function turnprogress(switch::Bool)
-    @info("[AdvancedVI]: global PROGRESS is set as $switch")
-    PROGRESS[] = switch
-end
+using LogDensityProblems
 
-const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0")))
+using ADTypes, DiffResults
+using ChainRulesCore
 
-include("ad.jl")
-include("utils.jl")
+using FillArrays
 
-using Requires
-function __init__()
-    @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
-        apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ)
-        Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ)
-        Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ)
-    end
-    @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
-        include("compat/zygote.jl")
-        export ZygoteAD
-
-        function AdvancedVI.grad!(
-            vo,
-            alg::VariationalInference{<:AdvancedVI.ZygoteAD},
-            q,
-            model,
-            θ::AbstractVector{<:Real},
-            out::DiffResults.MutableDiffResult,
-            args...
-        )
-            f(θ) = if (q isa Distribution)
-                - vo(alg, update(q, θ), model, args...)
-            else
-                - vo(alg, q(θ), model, args...)
-            end
-            y, back = Zygote.pullback(f, θ)
-            dy = first(back(1.0))
-            DiffResults.value!(out, y)
-            DiffResults.gradient!(out, dy)
-            return out
-        end
-    end
-    @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
-        include("compat/reversediff.jl")
-        export ReverseDiffAD
-
-        function AdvancedVI.grad!(
-            vo,
-            alg::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}},
-            q,
-            model,
-            θ::AbstractVector{<:Real},
-            out::DiffResults.MutableDiffResult,
-            args...
-        )
-            f(θ) = if (q isa Distribution)
-                - vo(alg, update(q, θ), model, args...)
-            else
-                - vo(alg, q(θ), model, args...)
-            end
-            tp = AdvancedVI.tape(f, θ)
-            ReverseDiff.gradient!(out, tp, θ)
-            return out
-        end
-    end
-    @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
-        include("compat/enzyme.jl")
-        export EnzymeAD
-
-        function AdvancedVI.grad!(
-            vo,
-            alg::VariationalInference{<:AdvancedVI.EnzymeAD},
-            q,
-            model,
-            θ::AbstractVector{<:Real},
-            out::DiffResults.MutableDiffResult,
-            args...
-        )
-            f(θ) = if (q isa Distribution)
-                - vo(alg, update(q, θ), model, args...)
-            else
-                - vo(alg, q(θ), model, args...)
-            end
-            # Use `Enzyme.ReverseWithPrimal` once it is released:
-            # https://github.com/EnzymeAD/Enzyme.jl/pull/598
-            y = f(θ)
-            DiffResults.value!(out, y)
-            dy = DiffResults.gradient(out)
-            fill!(dy, 0)
-            Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy))
-            return out
-        end
-    end
-end
+using StatsBase
 
-export
-    vi,
-    ADVI,
-    ELBO,
-    elbo,
-    TruncatedADAGrad,
-    DecayedADAGrad,
-    VariationalInference
+# derivatives
+"""
+    value_and_gradient!(ad, f, θ, out)
 
-abstract type VariationalInference{AD} end
+Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad` and store the result in `out`.
 
-getchunksize(::Type{<:VariationalInference{AD}}) where AD = getchunksize(AD)
-getADtype(::VariationalInference{AD}) where AD = AD
+# Arguments
+- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. 
+- `f`: Function subject to differentiation.
+- `θ`: The point to evaluate the gradient.
+- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value.
+"""
+function value_and_gradient! end
 
-abstract type VariationalObjective end
+# estimators
+"""
+    AbstractVariationalObjective
 
-const VariationalPosterior = Distribution{Multivariate, Continuous}
+Abstract type for the VI algorithms supported by `AdvancedVI`.
 
+# Implementations
+To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective` and `estimate_objective`.
+Also, it should provide gradients by implementing the function `estimate_gradient!`.
+If the estimator is stateful, it can implement `init` to initialize the state.
+"""
+abstract type AbstractVariationalObjective end
 
 """
-    grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)
+    init(rng, obj, λ, restructure)
 
-Computes the gradients used in `optimize!`. Default implementation is provided for 
-`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`.
-This implicitly also gives a default implementation of `optimize!`.
+Initialize a state of the variational objective `obj` given the initial variational parameters `λ`.
+This function needs to be implemented only if `obj` is stateful.
 
-Variance reduction techniques, e.g. control variates, should be implemented in this function.
+# Arguments
+- `rng::Random.AbstractRNG`: Random number generator.
+- `obj::AbstractVariationalObjective`: Variational objective.
+- `λ`: Initial variational parameters.
+- `restructure`: Function that reconstructs the variational approximation from `λ`.
 """
-function grad! end
+init(
+    ::Random.AbstractRNG,
+    ::AbstractVariationalObjective,
+    ::AbstractVector,
+    ::Any
+) = nothing
 
 """
-    vi(model, alg::VariationalInference)
-    vi(model, alg::VariationalInference, q::VariationalPosterior)
-    vi(model, alg::VariationalInference, getq::Function, θ::AbstractArray)
+    estimate_objective([rng,] obj, q, prob; kwargs...)
 
-Constructs the variational posterior from the `model` and performs the optimization
-following the configuration of the given `VariationalInference` instance.
+Estimate the variational objective `obj` targeting `prob` with respect to the variational approximation `q`.
 
 # Arguments
-- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations
-- `alg`: the VI algorithm used
-- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists.
-- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior`
-- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior
+- `rng::Random.AbstractRNG`: Random number generator.
+- `obj::AbstractVariationalObjective`: Variational objective.
+- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
+- `q`: Variational approximation.
+
+# Keyword Arguments
+Depending on the objective, additional keyword arguments may apply.
+Please refer to the respective documentation of each variational objective for more info.
+
+# Returns
+- `obj_est`: Estimate of the objective value.
 """
-function vi end
-
-function update end
-
-# default implementations
-function grad!(
-    vo,
-    alg::VariationalInference{<:ForwardDiffAD},
-    q,
-    model,
-    θ::AbstractVector{<:Real},
-    out::DiffResults.MutableDiffResult,
-    args...
-)
-    f(θ_) = if (q isa Distribution)
-        - vo(alg, update(q, θ_), model, args...)
-    else
-        - vo(alg, q(θ_), model, args...)
-    end
+function estimate_objective end
 
-    # Set chunk size and do ForwardMode.
-    chunk_size = getchunksize(typeof(alg))
-    config = if chunk_size == 0
-        ForwardDiff.GradientConfig(f, θ)
-    else
-        ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size))
-    end
-    ForwardDiff.gradient!(out, f, θ, config)
-end
+export estimate_objective
 
-function grad!(
-    vo,
-    alg::VariationalInference{<:TrackerAD},
-    q,
-    model,
-    θ::AbstractVector{<:Real},
-    out::DiffResults.MutableDiffResult,
-    args...
-)
-    θ_tracked = Tracker.param(θ)
-    y = if (q isa Distribution)
-        - vo(alg, update(q, θ_tracked), model, args...)
-    else
-        - vo(alg, q(θ_tracked), model, args...)
-    end
-    Tracker.back!(y, 1.0)
 
-    DiffResults.value!(out, Tracker.data(y))
-    DiffResults.gradient!(out, Tracker.grad(θ_tracked))
-end
+"""
+    estimate_gradient!(rng, obj, adbackend, out, prob, λ, restructure, obj_state)
 
+Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ`
 
+# Arguments
+- `rng::Random.AbstractRNG`: Random number generator.
+- `obj::AbstractVariationalObjective`: Variational objective.
+- `adbackend::ADTypes.AbstractADType`: Automatic differentiation backend. 
+- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. 
+- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
+- `λ`: Variational parameters to evaluate the gradient on.
+- `restructure`: Function that reconstructs the variational approximation from `λ`.
+- `obj_state`: Previous state of the objective.
+
+# Returns
+- `out::MutableDiffResult`: Buffer containing the objective value and gradient estimates.
+- `obj_state`: The updated state of the objective.
+- `stat::NamedTuple`: Statistics and logs generated during estimation.
 """
-    optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())
+function estimate_gradient! end
+
+# ELBO-specific interfaces
+abstract type AbstractEntropyEstimator end
 
-Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute
-the steps.
 """
-function optimize!(
-    vo,
-    alg::VariationalInference,
-    q,
-    model,
-    θ::AbstractVector{<:Real};
-    optimizer = TruncatedADAGrad()
-)
-    # TODO: should we always assume `samples_per_step` and `max_iters` for all algos?
-    alg_name = alg_str(alg)
-    samples_per_step = alg.samples_per_step
-    max_iters = alg.max_iters
-    
-    num_params = length(θ)
-
-    # TODO: really need a better way to warn the user about potentially
-    # not using the correct accumulator
-    if (optimizer isa TruncatedADAGrad) && (θ ∉ keys(optimizer.acc))
-        # this message should only occurr once in the optimization process
-        @info "[$alg_name] Should only be seen once: optimizer created for θ" objectid(θ)
-    end
+    estimate_entropy(entropy_estimator, mc_samples, q)
 
-    diff_result = DiffResults.GradientResult(θ)
+Estimate the entropy of `q`.
 
-    i = 0
-    prog = if PROGRESS[]
-        ProgressMeter.Progress(max_iters, 1, "[$alg_name] Optimizing...", 0)
-    else
-        0
-    end
+# Arguments
+- `entropy_estimator`: Entropy estimation strategy.
+- `q`: Variational approximation.
+- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.)
 
-    # add criterion? A running mean maybe?
-    time_elapsed = @elapsed while (i < max_iters) # & converged
-        grad!(vo, alg, q, model, θ, diff_result, samples_per_step)
+# Returns
+- `obj_est`: Estimate of the objective value.
+"""
+function estimate_entropy end
 
-        # apply update rule
-        Δ = DiffResults.gradient(diff_result)
-        Δ = apply!(optimizer, θ, Δ)
-        @. θ = θ - Δ
-        
-        AdvancedVI.DEBUG && @debug "Step $i" Δ DiffResults.value(diff_result)
-        PROGRESS[] && (ProgressMeter.next!(prog))
+export
+    RepGradELBO,
+    ClosedFormEntropy,
+    StickingTheLandingEntropy,
+    MonteCarloEntropy
 
-        i += 1
-    end
+include("objectives/elbo/entropy.jl")
+include("objectives/elbo/repgradelbo.jl")
 
-    return θ
-end
+# Optimization Routine
+
+function optimize end
 
-# objectives
-include("objectives.jl")
+export optimize
 
-# optimisers
-include("optimisers.jl")
+include("utils.jl")
+include("optimize.jl")
 
-# VI algorithms
-include("advi.jl")
 
-end # module
+# optional dependencies 
+if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base
+    using Requires
+end
+
+@static if !isdefined(Base, :get_extension)
+    function __init__()
+        @require Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" begin
+            include("../ext/AdvancedVIBijectorsExt.jl")
+        end
+        @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
+            include("../ext/AdvancedVIEnzymeExt.jl")
+        end
+        @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
+            include("../ext/AdvancedVIForwardDiffExt.jl")
+        end
+        @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
+            include("../ext/AdvancedVIReverseDiffExt.jl")
+        end
+        @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
+            include("../ext/AdvancedVIZygoteExt.jl")
+        end
+    end
+end
+
+end
+
diff --git a/src/ad.jl b/src/ad.jl
deleted file mode 100644
index 62e785e1b..000000000
--- a/src/ad.jl
+++ /dev/null
@@ -1,46 +0,0 @@
-##############################
-# Global variables/constants #
-##############################
-const ADBACKEND = Ref(:forwarddiff)
-setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym))
-function setadbackend(::Val{:forward_diff})
-    Base.depwarn("`AdvancedVI.setadbackend(:forward_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend)
-    setadbackend(Val(:forwarddiff))
-end
-function setadbackend(::Val{:forwarddiff})
-    ADBACKEND[] = :forwarddiff
-end
-
-function setadbackend(::Val{:reverse_diff})
-    Base.depwarn("`AdvancedVI.setadbackend(:reverse_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:tracker)` to use `Tracker` or `AdvancedVI.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.",  :setadbackend)
-    setadbackend(Val(:tracker))
-end
-function setadbackend(::Val{:tracker})
-    ADBACKEND[] = :tracker
-end
-
-const ADSAFE = Ref(false)
-function setadsafe(switch::Bool)
-    @info("[AdvancedVI]: global ADSAFE is set as $switch")
-    ADSAFE[] = switch
-end
-
-const CHUNKSIZE = Ref(0) # 0 means letting ForwardDiff set it automatically
-
-function setchunksize(chunk_size::Int)
-    @info("[AdvancedVI]: AD chunk size is set as $chunk_size")
-    CHUNKSIZE[] = chunk_size
-end
-
-abstract type ADBackend end
-struct ForwardDiffAD{chunk} <: ADBackend end
-getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk
-
-struct TrackerAD <: ADBackend end
-
-ADBackend() = ADBackend(ADBACKEND[])
-ADBackend(T::Symbol) = ADBackend(Val(T))
-
-ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]}
-ADBackend(::Val{:tracker}) = TrackerAD
-ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")
diff --git a/src/advi.jl b/src/advi.jl
deleted file mode 100644
index 7f9e73460..000000000
--- a/src/advi.jl
+++ /dev/null
@@ -1,99 +0,0 @@
-using StatsFuns
-using DistributionsAD
-using Bijectors
-using Bijectors: TransformedDistribution
-
-
-"""
-$(TYPEDEF)
-
-Automatic Differentiation Variational Inference (ADVI) with automatic differentiation
-backend `AD`.
-
-# Fields
-
-$(TYPEDFIELDS)
-"""
-struct ADVI{AD} <: VariationalInference{AD}
-    "Number of samples used to estimate the ELBO in each optimization step."
-    samples_per_step::Int
-    "Maximum number of gradient steps."
-    max_iters::Int
-end
-
-function ADVI(samples_per_step::Int=1, max_iters::Int=1000)
-    return ADVI{ADBackend()}(samples_per_step, max_iters)
-end
-
-alg_str(::ADVI) = "ADVI"
-
-function vi(model, alg::ADVI, q, θ_init; optimizer = TruncatedADAGrad())
-    θ = copy(θ_init)
-    optimize!(elbo, alg, q, model, θ; optimizer = optimizer)
-
-    # If `q` is a mean-field approx we use the specialized `update` function
-    if q isa Distribution
-        return update(q, θ)
-    else
-        # Otherwise we assume it's a mapping θ → q
-        return q(θ)
-    end
-end
-
-
-function optimize(elbo::ELBO, alg::ADVI, q, model, θ_init; optimizer = TruncatedADAGrad())
-    θ = copy(θ_init)
-
-    # `model` assumed to be callable z ↦ p(x, z)
-    optimize!(elbo, alg, q, model, θ; optimizer = optimizer)
-
-    return θ
-end
-
-# WITHOUT updating parameters inside ELBO
-function (elbo::ELBO)(
-    rng::Random.AbstractRNG,
-    alg::ADVI,
-    q::VariationalPosterior,
-    logπ::Function,
-    num_samples
-)
-    #   𝔼_q(z)[log p(xᵢ, z)]
-    # = ∫ log p(xᵢ, z) q(z) dz
-    # = ∫ log p(xᵢ, f(ϕ)) q(f(ϕ)) |det J_f(ϕ)| dϕ   (since change of variables)
-    # = ∫ log p(xᵢ, f(ϕ)) q̃(ϕ) dϕ                   (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ))
-    # = 𝔼_q̃(ϕ)[log p(xᵢ, z)]
-
-    #   𝔼_q(z)[log q(z)]
-    # = ∫ q(f(ϕ)) log (q(f(ϕ))) |det J_f(ϕ)| dϕ     (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ))
-    # = 𝔼_q̃(ϕ) [log q(f(ϕ))]
-    # = 𝔼_q̃(ϕ) [log q̃(ϕ) - log |det J_f(ϕ)|]
-    # = 𝔼_q̃(ϕ) [log q̃(ϕ)] - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|]
-    # = - ℍ(q̃(ϕ)) - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|]
-
-    # Finally, the ELBO is given by
-    # ELBO = 𝔼_q(z)[log p(xᵢ, z)] - 𝔼_q(z)[log q(z)]
-    #      = 𝔼_q̃(ϕ)[log p(xᵢ, z)] + 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] + ℍ(q̃(ϕ))
-
-    # If f: supp(p(z | x)) → ℝ then
-    # ELBO = 𝔼[log p(x, z) - log q(z)]
-    #      = 𝔼[log p(x, f⁻¹(z̃)) + logabsdet(J(f⁻¹(z̃)))] + ℍ(q̃(z̃))
-    #      = 𝔼[log p(x, z) - logabsdetjac(J(f(z)))] + ℍ(q̃(z̃))
-
-    # But our `rand_and_logjac(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac`
-    z, logjac = rand_and_logjac(rng, q)
-    res = (logπ(z) + logjac) / num_samples
-
-    if q isa TransformedDistribution
-        res += entropy(q.dist)
-    else
-        res += entropy(q)
-    end
-    
-    for i = 2:num_samples
-        z, logjac = rand_and_logjac(rng, q)
-        res += (logπ(z) + logjac) / num_samples
-    end
-
-    return res
-end
diff --git a/src/compat/enzyme.jl b/src/compat/enzyme.jl
deleted file mode 100644
index c6bb9ac39..000000000
--- a/src/compat/enzyme.jl
+++ /dev/null
@@ -1,5 +0,0 @@
-struct EnzymeAD <: ADBackend end
-ADBackend(::Val{:enzyme}) = EnzymeAD
-function setadbackend(::Val{:enzyme})
-    ADBACKEND[] = :enzyme
-end
diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl
deleted file mode 100644
index 721d03618..000000000
--- a/src/compat/reversediff.jl
+++ /dev/null
@@ -1,16 +0,0 @@
-using .ReverseDiff: compile, GradientTape
-using .ReverseDiff.DiffResults: GradientResult
-
-struct ReverseDiffAD{cache} <: ADBackend end
-const RDCache = Ref(false)
-setcache(b::Bool) = RDCache[] = b
-getcache() = RDCache[]
-ADBackend(::Val{:reversediff}) = ReverseDiffAD{getcache()}
-function setadbackend(::Val{:reversediff})
-    ADBACKEND[] = :reversediff
-end
-
-tape(f, x) = GradientTape(f, x)
-function taperesult(f, x)
-    return tape(f, x), GradientResult(x)
-end
diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl
deleted file mode 100644
index 40022e215..000000000
--- a/src/compat/zygote.jl
+++ /dev/null
@@ -1,5 +0,0 @@
-struct ZygoteAD <: ADBackend end
-ADBackend(::Val{:zygote}) = ZygoteAD
-function setadbackend(::Val{:zygote})
-    ADBACKEND[] = :zygote
-end
diff --git a/src/objectives.jl b/src/objectives.jl
deleted file mode 100644
index 5a6b61b0c..000000000
--- a/src/objectives.jl
+++ /dev/null
@@ -1,7 +0,0 @@
-struct ELBO <: VariationalObjective end
-
-function (elbo::ELBO)(alg, q, logπ, num_samples; kwargs...)
-    return elbo(Random.default_rng(), alg, q, logπ, num_samples; kwargs...)
-end
-
-const elbo = ELBO()
diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl
new file mode 100644
index 000000000..6c5b4739d
--- /dev/null
+++ b/src/objectives/elbo/entropy.jl
@@ -0,0 +1,48 @@
+
+"""
+    ClosedFormEntropy()
+
+Use closed-form expression of entropy.
+
+# Requirements
+- The variational approximation implements `entropy`.
+
+# References
+* Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR.
+* Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research.
+"""
+struct ClosedFormEntropy <: AbstractEntropyEstimator end
+
+maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q
+
+function estimate_entropy(::ClosedFormEntropy, ::Any, q)
+    entropy(q)
+end
+
+"""
+    StickingTheLandingEntropy()
+
+The "sticking the landing" entropy estimator.
+
+# Requirements
+- The variational approximation `q` implements `logpdf`.
+- `logpdf(q, η)` must be differentiable by the selected AD framework.
+
+# References
+* Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30.
+"""
+struct StickingTheLandingEntropy <: AbstractEntropyEstimator end
+
+struct MonteCarloEntropy <: AbstractEntropyEstimator end
+
+maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop
+
+function estimate_entropy(
+    ::Union{MonteCarloEntropy, StickingTheLandingEntropy},
+    mc_samples::AbstractMatrix,
+    q
+)
+    mean(eachcol(mc_samples)) do mc_sample
+        -logpdf(q, mc_sample)
+    end
+end
diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl
new file mode 100644
index 000000000..04f353200
--- /dev/null
+++ b/src/objectives/elbo/repgradelbo.jl
@@ -0,0 +1,125 @@
+
+"""
+    RepGradELBO(n_samples; kwargs...)
+
+Evidence lower-bound objective with the reparameterization gradient formulation[^TL2014][^RMW2014][^KW2014].
+This computes the evidence lower-bound (ELBO) through the formulation:
+```math
+\\begin{aligned}
+\\mathrm{ELBO}\\left(\\lambda\\right)
+&\\triangleq
+\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[
+  \\log \\pi\\left(z\\right)
+\\right]
++ \\mathbb{H}\\left(q_{\\lambda}\\right),
+\\end{aligned}
+```
+
+# Arguments
+- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO.
+
+# Keyword Arguments
+- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy())
+
+# Requirements
+- ``q_{\\lambda}`` implements `rand`.
+- The target `logdensity(prob, x)` must be differentiable wrt. `x` by the selected AD backend.
+
+Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
+
+# References
+[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In ICML.
+[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In ICML.
+[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In ICLR.
+"""
+struct RepGradELBO{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective
+    entropy  ::EntropyEst
+    n_samples::Int
+end
+
+RepGradELBO(
+    n_samples::Int;
+    entropy  ::AbstractEntropyEstimator = ClosedFormEntropy()
+) = RepGradELBO(entropy, n_samples)
+
+function Base.show(io::IO, obj::RepGradELBO)
+    print(io, "RepGradELBO(entropy=")
+    print(io, obj.entropy)
+    print(io, ", n_samples=")
+    print(io, obj.n_samples)
+    print(io, ")")
+end
+
+function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop)
+    q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
+    estimate_entropy(entropy_estimator, samples, q_maybe_stop)
+end
+
+function estimate_energy_with_samples(prob, samples)
+    mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
+end
+
+"""
+    reparam_with_entropy(rng, q, q_stop, n_samples, ent_est)
+
+Draw `n_samples` from `q` and compute its entropy.
+
+# Arguments
+- `rng::Random.AbstractRNG`: Random number generator.
+- `q`: Variational approximation.
+- `q_stop`: `q` but with its gradient stopped.
+- `n_samples::Int`: Number of Monte Carlo samples 
+- `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.)
+
+# Returns
+- `samples`: Monte Carlo samples generated through reparameterization. Their support matches that of the target distribution.
+- `entropy`: An estimate (or exact value) of the differential entropy of `q`.
+"""
+function reparam_with_entropy(
+    rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator
+)
+    samples = rand(rng, q, n_samples)
+    entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop)
+    samples, entropy
+end
+
+function estimate_objective(
+    rng::Random.AbstractRNG,
+    obj::RepGradELBO,
+    q,
+    prob;
+    n_samples::Int = obj.n_samples
+)
+    samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy)
+    energy = estimate_energy_with_samples(prob, samples)
+    energy + entropy
+end
+
+estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) =
+    estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
+
+function estimate_gradient!(
+    rng      ::Random.AbstractRNG,
+    obj      ::RepGradELBO,
+    adbackend::ADTypes.AbstractADType,
+    out      ::DiffResults.MutableDiffResult,
+    prob,
+    λ,
+    restructure,
+    state,
+)
+    q_stop = restructure(λ)
+    function f(λ′)
+        q = restructure(λ′)
+        samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
+        energy = estimate_energy_with_samples(prob, samples)
+        elbo = energy + entropy
+        -elbo
+    end
+    value_and_gradient!(adbackend, f, λ, out)
+
+    nelbo = DiffResults.value(out)
+    stat  = (elbo=-nelbo,)
+
+    out, nothing, stat
+end
diff --git a/src/optimisers.jl b/src/optimisers.jl
deleted file mode 100644
index 8077f98cb..000000000
--- a/src/optimisers.jl
+++ /dev/null
@@ -1,94 +0,0 @@
-const ϵ = 1e-8
-
-"""
-    TruncatedADAGrad(η=0.1, τ=1.0, n=100)
-
-Implements a truncated version of AdaGrad in the sense that only the `n` previous gradient norms are used to compute the scaling rather than *all* previous. It has parameter specific learning rates based on how frequently it is updated.
-
-## Parameters
-  - η: learning rate
-  - τ: constant scale factor
-  - n: number of previous gradient norms to use in the scaling.
-```
-## References
-[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
-Parameters don't need tuning.
-
-[TruncatedADAGrad](https://arxiv.org/abs/1506.03431v2) (Appendix E).
-"""
-mutable struct TruncatedADAGrad
-    eta::Float64
-    tau::Float64
-    n::Int
-    
-    iters::IdDict
-    acc::IdDict
-end
-
-function TruncatedADAGrad(η = 0.1, τ = 1.0, n = 100)
-    TruncatedADAGrad(η, τ, n, IdDict(), IdDict())
-end
-
-function apply!(o::TruncatedADAGrad, x, Δ)
-    T = eltype(Tracker.data(Δ))
-    
-    η = o.eta
-    τ = o.tau
-
-    g² = get!(
-        o.acc,
-        x,
-        [zeros(T, size(x)) for j = 1:o.n]
-    )::Array{typeof(Tracker.data(Δ)), 1}
-    i = get!(o.iters, x, 1)::Int
-
-    # Example: suppose i = 12 and o.n = 10
-    idx = mod(i - 1, o.n) + 1 # => idx = 2
-
-    # set the current
-    @inbounds @. g²[idx] = Δ^2 # => g²[2] = Δ^2 where Δ is the (o.n + 2)-th Δ
-
-    # TODO: make more efficient and stable
-    s = sum(g²)
-    
-    # increment
-    o.iters[x] += 1
-    
-    # TODO: increment (but "truncate")
-    # o.iters[x] = i > o.n ? o.n + mod(i, o.n) : i + 1
-
-    @. Δ *= η / (τ + sqrt(s) + ϵ)
-end
-
-"""
-    DecayedADAGrad(η=0.1, pre=1.0, post=0.9)
-
-Implements a decayed version of AdaGrad. It has parameter specific learning rates based on how frequently it is updated.
-
-## Parameters
-  - η: learning rate
-  - pre: weight of new gradient norm
-  - post: weight of histroy of gradient norms
-```
-## References
-[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser.
-Parameters don't need tuning.
-"""
-mutable struct DecayedADAGrad
-    eta::Float64
-    pre::Float64
-    post::Float64
-
-    acc::IdDict
-end
-
-DecayedADAGrad(η = 0.1, pre = 1.0, post = 0.9) = DecayedADAGrad(η, pre, post, IdDict())
-
-function apply!(o::DecayedADAGrad, x, Δ)
-    T = eltype(Tracker.data(Δ))
-    
-    η = o.eta
-    acc = get!(o.acc, x, fill(T(ϵ), size(x)))::typeof(Tracker.data(x))
-    @. acc = o.post * acc + o.pre * Δ^2
-    @. Δ *= η / (√acc + ϵ)
-end
diff --git a/src/optimize.jl b/src/optimize.jl
new file mode 100644
index 000000000..7e0032dce
--- /dev/null
+++ b/src/optimize.jl
@@ -0,0 +1,161 @@
+
+"""
+    optimize(problem, objective, restructure, param_init, max_iter, objargs...; kwargs...)              
+    optimize(problem, objective, variational_dist_init, max_iter, objargs...; kwargs...)              
+
+Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients.
+
+The variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`.
+
+# Arguments
+- `objective::AbstractVariationalObjective`: Variational Objective.
+- `param_init`: Initial value of the variational parameters.
+- `restruct`: Function that reconstructs the variational approximation from the flattened parameters.
+- `variational_dist_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`.
+- `max_iter::Int`: Maximum number of iterations.
+- `objargs...`: Arguments to be passed to `objective`.
+
+# Keyword Arguments
+- `adbackend::ADtypes.AbstractADType`: Automatic differentiation backend. 
+- `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.)
+- `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.)
+- `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.)
+- `callback`: Callback function called after every iteration. See further information below. (Default: `nothing`.)
+- `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.)
+- `state::NamedTuple`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.)
+
+# Returns
+- `params`: Variational parameters optimizing the variational objective.
+- `stats`: Statistics gathered during optimization.
+- `state`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run.
+
+# Callback
+The callback function `callback` has a signature of
+
+    callback(; stat, state, param, restructure, gradient)
+
+The arguments are as follows:
+- `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`.
+- `state`: Collection of the internal states used for optimization.
+- `param`: Variational parameters.
+- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation. 
+- `gradient`: The estimated (possibly stochastic) gradient.
+
+`cb` can return a `NamedTuple` containing some additional information computed within `cb`.
+This will be appended to the statistic of the current corresponding iteration.
+Otherwise, just return `nothing`.
+
+"""
+
+function optimize(
+    rng          ::Random.AbstractRNG,
+    problem,
+    objective    ::AbstractVariationalObjective,
+    restructure,
+    params_init,
+    max_iter     ::Int,
+    objargs...;
+    adbackend    ::ADTypes.AbstractADType, 
+    optimizer    ::Optimisers.AbstractRule = Optimisers.Adam(),
+    show_progress::Bool                    = true,
+    state_init   ::NamedTuple              = NamedTuple(),
+    callback                               = nothing,
+    prog                                   = ProgressMeter.Progress(
+        max_iter;
+        desc      = "Optimizing",
+        barlen    = 31,
+        showspeed = true,
+        enabled   = show_progress
+    )
+)
+    λ        = copy(params_init)
+    opt_st   = maybe_init_optimizer(state_init, optimizer, λ)
+    obj_st   = maybe_init_objective(state_init, rng, objective, λ, restructure)
+    grad_buf = DiffResults.DiffResult(zero(eltype(λ)), similar(λ))
+    stats    = NamedTuple[]
+
+    for t = 1:max_iter
+        stat = (iteration=t,)
+
+        grad_buf, obj_st, stat′ = estimate_gradient!(
+            rng, objective, adbackend, grad_buf, problem,
+            λ, restructure,  obj_st, objargs...
+        )
+        stat = merge(stat, stat′)
+
+        g         = DiffResults.gradient(grad_buf)
+        opt_st, λ = Optimisers.update!(opt_st, λ, g)
+
+        if !isnothing(callback)
+            stat′ = callback(
+                ; stat, restructure, params=λ, gradient=g,
+                state=(optimizer=opt_st, objective=obj_st)
+            )
+            stat = !isnothing(stat′) ? merge(stat′, stat) : stat
+        end
+        
+        @debug "Iteration $t" stat...
+
+        pm_next!(prog, stat)
+        push!(stats, stat)
+    end
+    state  = (optimizer=opt_st, objective=obj_st)
+    stats  = map(identity, stats)
+    params = λ
+    params, stats, state
+end
+
+function optimize(
+    problem,
+    objective    ::AbstractVariationalObjective,
+    restructure,
+    params_init,
+    max_iter     ::Int,
+    objargs...;
+    kwargs...
+)
+    optimize(
+        Random.default_rng(),
+        problem,
+        objective,
+        restructure,
+        params_init,
+        max_iter,
+        objargs...;
+        kwargs...
+    )
+end
+
+function optimize(rng                   ::Random.AbstractRNG,
+                  problem,
+                  objective             ::AbstractVariationalObjective,
+                  variational_dist_init,
+                  n_max_iter            ::Int,
+                  objargs...;
+                  kwargs...)
+    λ, restructure = Optimisers.destructure(variational_dist_init)
+    λ, logstats, state = optimize(
+        rng, problem, objective, restructure, λ, n_max_iter, objargs...; kwargs...
+    )
+    restructure(λ), logstats, state
+end
+
+
+function optimize(
+    problem,
+    objective              ::AbstractVariationalObjective,
+    variational_dist_init,
+    max_iter               ::Int,
+    objargs...;
+    kwargs...
+)
+    optimize(
+        Random.default_rng(),
+        problem,
+        objective,
+        variational_dist_init,
+        max_iter,
+        objargs...;
+        kwargs...
+    )
+end
diff --git a/src/utils.jl b/src/utils.jl
index bb4c1f18f..8e67ff1a3 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -1,15 +1,36 @@
-using Distributions
 
-using Bijectors: Bijectors
+function pm_next!(pm, stats::NamedTuple)
+    ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)])
+end
 
+function maybe_init_optimizer(
+    state_init::NamedTuple,
+    optimizer ::Optimisers.AbstractRule,
+    λ         ::AbstractVector
+)
+    haskey(state_init, :optimizer) ? state_init.optimizer : Optimisers.setup(optimizer, λ)
+end
 
-function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution)
-    x = rand(rng, dist)
-    return x, zero(eltype(x))
+function maybe_init_objective(
+    state_init::NamedTuple,
+    rng       ::Random.AbstractRNG,
+    objective ::AbstractVariationalObjective,
+    λ         ::AbstractVector,
+    restructure
+)
+    haskey(state_init, :objective) ? state_init.objective : init(rng, objective, λ, restructure)
 end
 
-function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution)
-    x = rand(rng, dist.dist)
-    y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x)
-    return y, logjac
+eachsample(samples::AbstractMatrix) = eachcol(samples)
+
+eachsample(samples::AbstractVector) = samples
+
+function catsamples_and_acc(
+    state_curr::Tuple{<:AbstractArray,  <:Real},
+    state_new ::Tuple{<:AbstractVector, <:Real}
+)
+    x  = hcat(first(state_curr), first(state_new))
+    ∑y = last(state_curr) + last(state_new)
+    return (x, ∑y)
 end
+
diff --git a/test/Project.toml b/test/Project.toml
new file mode 100644
index 000000000..a751b89d9
--- /dev/null
+++ b/test/Project.toml
@@ -0,0 +1,45 @@
+[deps]
+ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
+Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
+Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
+DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
+Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
+ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
+Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
+LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
+Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
+PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
+Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
+Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
+SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
+StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
+Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
+Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
+Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
+
+[compat]
+ADTypes = "0.2.1"
+Bijectors = "0.13"
+Distributions = "0.25.100"
+DistributionsAD = "0.6.45"
+Enzyme = "0.11.7"
+FillArrays = "1.6.1"
+ForwardDiff = "0.10.36"
+Functors = "0.4.5"
+LinearAlgebra = "1"
+LogDensityProblems = "2.1.1"
+Optimisers = "0.2.16, 0.3"
+PDMats = "0.11.7"
+Random = "1"
+ReverseDiff = "1.15.1"
+SimpleUnPack = "1.1.0"
+StableRNGs = "1.0.0"
+Statistics = "1"
+Test = "1"
+Tracker = "0.2.20"
+Zygote = "0.6.63"
+julia = "1.6"
diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl
new file mode 100644
index 000000000..b6db22a62
--- /dev/null
+++ b/test/inference/repgradelbo_distributionsad.jl
@@ -0,0 +1,78 @@
+
+const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
+
+using Test
+
+@testset "inference RepGradELBO DistributionsAD" begin
+    @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
+        realtype ∈ [Float64, Float32],
+        (modelname, modelconstr) ∈ Dict(
+            :Normal=> normal_meanfield,
+        ),
+        (objname, objective) ∈ Dict(
+            :RepGradELBOClosedFormEntropy  => RepGradELBO(10),
+            :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
+        ),
+        (adbackname, adbackend) ∈ Dict(
+            :ForwarDiff  => AutoForwardDiff(),
+            #:ReverseDiff => AutoReverseDiff(),
+            :Zygote      => AutoZygote(), 
+            #:Enzyme      => AutoEnzyme(),
+        )
+
+        seed = (0x38bef07cf9cc549d)
+        rng  = StableRNG(seed)
+
+        modelstats = modelconstr(rng, realtype)
+        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+        T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
+
+        μ0 = Zeros(realtype, n_dims)
+        L0 = Diagonal(Ones(realtype, n_dims))
+        q0 = TuringDiagMvNormal(μ0, diag(L0))
+
+        @testset "convergence" begin
+            Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
+            q, stats, _ = optimize(
+                rng, model, objective, q0, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
+            )
+
+            μ  = mean(q)
+            L  = sqrt(cov(q))
+            Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
+
+            @test Δλ ≤ Δλ₀/T^(1/4)
+            @test eltype(μ) == eltype(μ_true)
+            @test eltype(L) == eltype(L_true)
+        end
+
+        @testset "determinism" begin
+            rng = StableRNG(seed)
+            q, stats, _ = optimize(
+                rng, model, objective, q0, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
+            )
+            μ  = mean(q)
+            L  = sqrt(cov(q))
+
+            rng_repl = StableRNG(seed)
+            q, stats, _ = optimize(
+                rng_repl, model, objective, q0, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
+            )
+            μ_repl = mean(q)
+            L_repl = sqrt(cov(q))
+            @test μ == μ_repl
+            @test L == L_repl
+        end
+    end
+end
+
diff --git a/test/inference/repgradelbo_distributionsad_bijectors.jl b/test/inference/repgradelbo_distributionsad_bijectors.jl
new file mode 100644
index 000000000..9f1e3cc4a
--- /dev/null
+++ b/test/inference/repgradelbo_distributionsad_bijectors.jl
@@ -0,0 +1,81 @@
+
+const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false
+
+using Test
+
+@testset "inference RepGradELBO DistributionsAD Bijectors" begin
+    @testset "$(modelname) $(objname) $(realtype) $(adbackname)"  for
+        realtype ∈ [Float64, Float32],
+        (modelname, modelconstr) ∈ Dict(
+            :NormalLogNormalMeanField => normallognormal_meanfield,
+        ),
+        (objname, objective) ∈ Dict(
+            :RepGradELBOClosedFormEntropy  => RepGradELBO(10),
+            :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
+        ),
+        (adbackname, adbackend) ∈ Dict(
+            :ForwarDiff  => AutoForwardDiff(),
+            #:ReverseDiff => AutoReverseDiff(),
+            #:Zygote      => AutoZygote(), 
+            #:Enzyme      => AutoEnzyme(),
+        )
+
+        seed = (0x38bef07cf9cc549d)
+        rng  = StableRNG(seed)
+
+        modelstats = modelconstr(rng, realtype)
+        @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+        T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
+
+        b    = Bijectors.bijector(model)
+        b⁻¹  = inverse(b)
+        μ₀   = Zeros(realtype, n_dims)
+        L₀   = Diagonal(Ones(realtype, n_dims))
+
+        q₀_η = TuringDiagMvNormal(μ₀, diag(L₀))
+        q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
+
+        @testset "convergence" begin
+            Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
+            q, stats, _ = optimize(
+                rng, model, objective, q₀_z, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
+            )
+
+            μ  = mean(q.dist)
+            L  = sqrt(cov(q.dist))
+            Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
+
+            @test Δλ ≤ Δλ₀/T^(1/4)
+            @test eltype(μ) == eltype(μ_true)
+            @test eltype(L) == eltype(L_true)
+        end
+
+        @testset "determinism" begin
+            rng = StableRNG(seed)
+            q, stats, _ = optimize(
+                rng, model, objective, q₀_z, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
+            )
+            μ  = mean(q.dist)
+            L  = sqrt(cov(q.dist))
+
+            rng_repl = StableRNG(seed)
+            q, stats, _ = optimize(
+                rng_repl, model, objective, q₀_z, T;
+                optimizer     = Optimisers.Adam(realtype(η)),
+                show_progress = PROGRESS,
+                adbackend     = adbackend,
+            )
+            μ_repl = mean(q.dist)
+            L_repl = sqrt(cov(q.dist))
+            @test μ == μ_repl
+            @test L == L_repl
+        end
+    end
+end
diff --git a/test/interface/ad.jl b/test/interface/ad.jl
new file mode 100644
index 000000000..b716ca2f2
--- /dev/null
+++ b/test/interface/ad.jl
@@ -0,0 +1,22 @@
+
+using Test
+
+@testset "ad" begin
+    @testset "$(adname)" for (adname, adsymbol) ∈ Dict(
+          :ForwardDiff => AutoForwardDiff(),
+          :ReverseDiff => AutoReverseDiff(),
+          :Zygote      => AutoZygote(),
+          # :Enzyme      => AutoEnzyme(), # Currently not tested against.
+        )
+        D = 10
+        A = randn(D, D)
+        λ = randn(D)
+        grad_buf = DiffResults.GradientResult(λ)
+        f(λ′) = λ′'*A*λ′ / 2
+        AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf)
+        ∇ = DiffResults.gradient(grad_buf)
+        f = DiffResults.value(grad_buf)
+        @test ∇ ≈ (A + A')*λ/2
+        @test f ≈ λ'*A*λ / 2
+    end
+end
diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl
new file mode 100644
index 000000000..6e69616bd
--- /dev/null
+++ b/test/interface/optimize.jl
@@ -0,0 +1,98 @@
+
+using Test
+
+@testset "interface optimize" begin
+    seed = (0x38bef07cf9cc549d)
+    rng  = StableRNG(seed)
+
+    T = 1000
+    modelstats = normal_meanfield(rng, Float64)
+
+    @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+    # Global Test Configurations
+    q0  = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
+    obj = RepGradELBO(10)
+
+    adbackend = AutoForwardDiff()
+    optimizer = Optimisers.Adam(1e-2)
+
+    rng  = StableRNG(seed)
+    q_ref, stats_ref, _ = optimize(
+        rng, model, obj, q0, T;
+        optimizer,
+        show_progress = false,
+        adbackend,
+    )
+    λ_ref, _ = Optimisers.destructure(q_ref)
+
+    @testset "default_rng" begin
+        optimize(
+            model, obj, q0, T;
+            optimizer,
+            show_progress = false,
+            adbackend,
+        )
+
+        λ₀, re  = Optimisers.destructure(q0)
+        optimize(
+            model, obj, re, λ₀, T;
+            optimizer,
+            show_progress = false,
+            adbackend,
+        )
+    end
+
+    @testset "restructure" begin
+        λ₀, re  = Optimisers.destructure(q0)
+
+        rng  = StableRNG(seed)
+        λ, stats, _ = optimize(
+            rng, model, obj, re, λ₀, T;
+            optimizer,
+            show_progress = false,
+            adbackend,
+        )
+        @test λ     == λ_ref
+        @test stats == stats_ref
+    end
+
+    @testset "callback" begin
+        rng  = StableRNG(seed)
+        test_values = rand(rng, T)
+
+        callback(; stat, args...) = (test_value = test_values[stat.iteration],)
+
+        rng  = StableRNG(seed)
+        _, stats, _ = optimize(
+            rng, model, obj, q0, T;
+            show_progress = false,
+            adbackend,
+            callback
+        )
+        @test [stat.test_value for stat ∈ stats] == test_values
+    end
+
+    @testset "warm start" begin
+        rng  = StableRNG(seed)
+
+        T_first = div(T,2)
+        T_last  = T - T_first
+
+        q_first, _, state = optimize(
+            rng, model, obj, q0, T_first;
+            optimizer,
+            show_progress = false,
+            adbackend
+        )
+
+        q, stats, _ = optimize(
+            rng, model, obj, q_first, T_last;
+            optimizer,
+            show_progress = false,
+            state_init    = state,
+            adbackend
+        )
+        @test q == q_ref
+    end
+end
diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl
new file mode 100644
index 000000000..61ff0111c
--- /dev/null
+++ b/test/interface/repgradelbo.jl
@@ -0,0 +1,28 @@
+
+using Test
+
+@testset "interface RepGradELBO" begin
+    seed = (0x38bef07cf9cc549d)
+    rng  = StableRNG(seed)
+
+    modelstats = normal_meanfield(rng, Float64)
+
+    @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
+
+    q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims))
+
+    obj      = RepGradELBO(10)
+    rng      = StableRNG(seed)
+    elbo_ref = estimate_objective(rng, obj, q0, model; n_samples=10^4)
+
+    @testset "determinism" begin
+        rng  = StableRNG(seed)
+        elbo = estimate_objective(rng, obj, q0, model; n_samples=10^4)
+        @test elbo == elbo_ref
+    end
+
+    @testset "default_rng" begin
+        elbo = estimate_objective(obj, q0, model; n_samples=10^4)
+        @test elbo ≈ elbo_ref rtol=0.1
+    end
+end
diff --git a/test/models/normal.jl b/test/models/normal.jl
new file mode 100644
index 000000000..3f305e1a0
--- /dev/null
+++ b/test/models/normal.jl
@@ -0,0 +1,43 @@
+
+struct TestNormal{M,S}
+    μ::M
+    Σ::S
+end
+
+function LogDensityProblems.logdensity(model::TestNormal, θ)
+    @unpack μ, Σ = model
+    logpdf(MvNormal(μ, Σ), θ)
+end
+
+function LogDensityProblems.dimension(model::TestNormal)
+    length(model.μ)
+end
+
+function LogDensityProblems.capabilities(::Type{<:TestNormal})
+    LogDensityProblems.LogDensityOrder{0}()
+end
+
+function normal_fullrank(rng::Random.AbstractRNG, realtype::Type)
+    n_dims = 5
+
+    μ = randn(rng, realtype, n_dims)
+    L = tril(I + ones(realtype, n_dims, n_dims))/2
+    Σ = L*L' |> Hermitian
+
+    model = TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0)))
+
+    TestModel(model, μ, L, n_dims, false)
+end
+
+function normal_meanfield(rng::Random.AbstractRNG, realtype::Type)
+    n_dims = 5
+
+    μ = randn(rng, realtype, n_dims)
+    σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
+
+    model = TestNormal(μ, Diagonal(σ.^2))
+
+    L = σ |> Diagonal
+
+    TestModel(model, μ, L, n_dims, true)
+end
diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl
new file mode 100644
index 000000000..6615084b9
--- /dev/null
+++ b/test/models/normallognormal.jl
@@ -0,0 +1,65 @@
+
+struct NormalLogNormal{MX,SX,MY,SY}	
+    μ_x::MX	
+    σ_x::SX	
+    μ_y::MY	
+    Σ_y::SY	
+end	
+
+function LogDensityProblems.logdensity(model::NormalLogNormal, θ)	
+    @unpack μ_x, σ_x, μ_y, Σ_y = model	
+    logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])	
+end	
+
+function LogDensityProblems.dimension(model::NormalLogNormal)	
+    length(model.μ_y) + 1	
+end	
+
+function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})	
+    LogDensityProblems.LogDensityOrder{0}()	
+end	
+
+function Bijectors.bijector(model::NormalLogNormal)	
+    @unpack μ_x, σ_x, μ_y, Σ_y = model	
+    Bijectors.Stacked(	
+        Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]),	
+        [1:1, 2:1+length(μ_y)])	
+end	
+
+function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type)	
+    n_dims = 5	
+
+    μ_x = randn(rng, realtype)	
+    σ_x = ℯ	
+    μ_y = randn(rng, realtype, n_dims)	
+    L_y = tril(I + ones(realtype, n_dims, n_dims))/2	
+    Σ_y = L_y*L_y' |> Hermitian	
+
+    model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0)))	
+
+    Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1)	
+    Σ[1,1]         = σ_x^2	
+    Σ[2:end,2:end] = Σ_y	
+    Σ = Σ |> Hermitian	
+
+    μ = vcat(μ_x, μ_y)	
+    L = cholesky(Σ).L	
+
+    TestModel(model, μ, L, n_dims+1, false)	
+end	
+
+function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type)	
+    n_dims = 5	
+
+    μ_x  = randn(rng, realtype)	
+    σ_x  = ℯ	
+    μ_y  = randn(rng, realtype, n_dims)	
+    σ_y  = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)	
+
+    model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2))	
+
+    μ = vcat(μ_x, μ_y)	
+    L = vcat(σ_x, σ_y) |> Diagonal	
+
+    TestModel(model, μ, L, n_dims+1, true)	
+end	
diff --git a/test/optimisers.jl b/test/optimisers.jl
deleted file mode 100644
index fae652ed0..000000000
--- a/test/optimisers.jl
+++ /dev/null
@@ -1,17 +0,0 @@
-using Random, Test, LinearAlgebra, ForwardDiff
-using AdvancedVI: TruncatedADAGrad, DecayedADAGrad, apply!
-
-θ = randn(10, 10)
-@testset for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)]
-    θ_fit = randn(10, 10)
-    loss(x, θ_) = mean(sum(abs2, θ*x - θ_*x; dims = 1))
-    for t = 1:10^4
-        x = rand(10)
-        Δ = ForwardDiff.gradient(θ_ -> loss(x, θ_), θ_fit)
-        Δ = apply!(opt, θ_fit, Δ)
-        @. θ_fit = θ_fit - Δ
-    end
-    @test loss(rand(10, 100), θ_fit) < 0.01
-    @test length(opt.acc) == 1
-end
-
diff --git a/test/runtests.jl b/test/runtests.jl
index a305c25e5..b14b8b2ed 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,28 +1,42 @@
-using Test
-using Distributions, DistributionsAD
-using AdvancedVI
-
-include("optimisers.jl")
-
-target = MvNormal(ones(2))
-logπ(z) = logpdf(target, z)
-advi = ADVI(10, 1000)
 
-# Using a function z ↦ q(⋅∣z)
-getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4]))
-q = vi(logπ, advi, getq, randn(4))
+using Test
+using Test: @testset, @test
+
+using Bijectors
+using Random, StableRNGs
+using Statistics
+using Distributions
+using LinearAlgebra
+using SimpleUnPack: @unpack
+using FillArrays
+using PDMats
+
+using Functors
+using DistributionsAD
+@functor TuringDiagMvNormal
+
+using LogDensityProblems
+using Optimisers
+using ADTypes
+using ForwardDiff, ReverseDiff, Zygote
 
-xs = rand(target, 10)
-@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05
+using AdvancedVI
 
-# OR: implement `update` and pass a `Distribution`
-function AdvancedVI.update(d::TuringDiagMvNormal, θ::AbstractArray{<:Real})
-    return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[length(q) + 1:end]))
+# Models for Inference Tests
+struct TestModel{M,L,S}
+    model::M
+    μ_true::L
+    L_true::S
+    n_dims::Int
+    is_meanfield::Bool
 end
+include("models/normal.jl")
+include("models/normallognormal.jl")
 
-q0 = TuringDiagMvNormal(zeros(2), ones(2))
-q = vi(logπ, advi, q0, randn(4))
-
-xs = rand(target, 10)
-@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05
+# Tests
+include("interface/ad.jl")
+include("interface/optimize.jl")
+include("interface/repgradelbo.jl")
 
+include("inference/repgradelbo_distributionsad.jl")
+include("inference/repgradelbo_distributionsad_bijectors.jl")