Skip to content

Commit

Permalink
Add multimodal HDI (#40)
Browse files Browse the repository at this point in the history
* Add function for computing the KDE with boundary reflection

* Fix issues with KDE grid padding

* Add utility function for KDE bounds checking

* Refactor kde_reflected into internal func

* Document kde_reflected

* Add improved Sheather-Jones bandwidth rule

* Add missing imports

* Add tests for reflected KDE and bandwidth

* Fix variable name

* Avoid internal KDE.jl functions

* Fix silverman bandwidth

* Define internal density estimation interface

* Add missing include statement

* Add multimodal HDI

* Use range signatures present in Julia v1.6

* Bump FFTW version

At least this version is required for downgrade CI

* Fix doctest

* Skip test on Julia v1.6

* Add tests for density estimation interface

* Bump StatsBase lower bound

So histrange works

* Fix version check

* Update other doctests

* Update HDI doctests

* Actually repair HDI docstrings

* Restore type-inferrability for Symbol methods

* Throw ArgumentrError for invalid prob

To match behavior of quantile/eti

* Make sure interval eltype matches sample eltype

* Consistently handle NaNs

* Update HDI tests

* Use `Compat.@constprop` for old Julia versions

* Define missing variable

* Document KDE kwargs

* Use ISJ bandwidth by default for KDE

* Add kde_reflected to utility docs

* Skip type inference tests for old Julia versions

* Fix doctest

* Give more details about HDI methods

* Bump FFTW compat

This version is required for multimodal HDI type inference
  • Loading branch information
sethaxen authored Nov 1, 2024
1 parent 6fbbf5c commit c52f54a
Show file tree
Hide file tree
Showing 13 changed files with 870 additions and 112 deletions.
10 changes: 9 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
Expand All @@ -20,6 +23,7 @@ PSIS = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -33,10 +37,13 @@ Compat = "4.2.0"
DataInterpolations = "4, 5, 6"
Distributions = "0.25.57"
DocStringExtensions = "0.8, 0.9"
FFTW = "1.6.0"
FiniteDifferences = "0.12.17"
GLM = "1.8.0"
IntervalSets = "0.5, 0.6, 0.7"
IrrationalConstants = "0.1, 0.2"
IteratorInterfaceExtensions = "1"
KernelDensity = "0.6.3"
LinearAlgebra = "1.6"
LogExpFunctions = "0.3.3"
MCMCDiagnosticTools = "0.3.4"
Expand All @@ -49,9 +56,10 @@ PrettyTables = "2.1"
Printf = "1.6"
RCall = "0.13.11"
Random = "1.6"
Roots = "1, 2"
Setfield = "1"
Statistics = "1.6"
StatsBase = "0.33.7, 0.34"
StatsBase = "0.33.17, 0.34"
TableOperations = "1"
TableTraits = "1"
Tables = "1.9"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,6 @@ r2_score
### Utilities

```@docs
PosteriorStats.kde_reflected
PosteriorStats.smooth_data
```
8 changes: 7 additions & 1 deletion src/PosteriorStats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ using Compat: @constprop
using DataInterpolations: DataInterpolations
using Distributions: Distributions
using DocStringExtensions: FIELDS, FUNCTIONNAME, TYPEDEF, TYPEDFIELDS, SIGNATURES
using FFTW: FFTW
using IrrationalConstants: sqrthalfπ, sqrtπ, sqrt2
using IteratorInterfaceExtensions: IteratorInterfaceExtensions
using LinearAlgebra: mul!, norm
using KernelDensity: KernelDensity
using LinearAlgebra: mul!, norm, normalize
using LogExpFunctions: LogExpFunctions
using Markdown: @doc_str
using MCMCDiagnosticTools: MCMCDiagnosticTools
Expand All @@ -15,6 +18,7 @@ using PrettyTables: PrettyTables
using Printf: Printf
using PSIS: PSIS, PSISResult, psis, psis!
using Random: Random
using Roots: Roots
using Setfield: Setfield
using Statistics: Statistics
using StatsBase: StatsBase
Expand Down Expand Up @@ -48,6 +52,8 @@ const DEFAULT_INTERVAL_PROB = 0.94
const INFORMATION_CRITERION_SCALES = (deviance=-2, log=1, negative_log=-1)

include("utils.jl")
include("density_estimation.jl")
include("kde.jl")
include("eti.jl")
include("hdi.jl")
include("elpdresult.jl")
Expand Down
103 changes: 103 additions & 0 deletions src/density_estimation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Density estimation methods and utilities

"""
DensityEstimationMethod
Abstract type for density estimation methods.
Densities are defined wrt the Lebesgue measure on subintervals of the real line for
continuous data or the counting measure on unique values for discrete data.
Each method should implement:
- `bins_and_probs(::DensityEstimationMethod, x::AbstractVector{<:Real}) -> (bins, probs)`
Returns the density evaluated at regularly spaced points, normalized to sum to 1.
Empty bins may be omitted.
- `density_at(::DensityEstimationMethod, x::AbstractVector{<:Real}) -> densities`
Returns the density evaluated at the points in `x`
"""
abstract type DensityEstimationMethod end

"""
DefaultDensityEstimation <: DensityEstimationMethod
Select density estimation method based on data type.
"""
struct DefaultDensityEstimation <: DensityEstimationMethod end

"""
DiscreteDensityEstimation <: DensityEstimationMethod
Estimate density for integer-valued data using the counting measure.
"""
struct DiscreteDensityEstimation <: DensityEstimationMethod end

function bins_and_probs(::DiscreteDensityEstimation, x::AbstractVector{<:Real})
prop_map = OrderedCollections.OrderedDict(StatsBase.proportionmap(x))
sort!(prop_map)
bins = collect(keys(prop_map))
probs = collect(values(prop_map))
return bins, probs
end

"""
HistogramEstimation{K} <: DensityEstimationMethod
Estimate piecewise constant density using a histogram.
"""
struct HistogramEstimation{K} <: DensityEstimationMethod
hist_kwargs::K
end
HistogramEstimation(; hist_kwargs...) = HistogramEstimation(hist_kwargs)

function bins_and_probs(est::HistogramEstimation, x::AbstractVector{<:Real})
hist = StatsBase.fit(StatsBase.Histogram, x; est.hist_kwargs...)
return StatsBase.midpoints(hist.edges[1]), normalize(hist; mode=:probability).weights
end

function density_at(est::HistogramEstimation, x::AbstractVector{<:Real})
hist = normalize(
StatsBase.fit(StatsBase.Histogram, x; est.hist_kwargs...); mode=:density
)
return _histogram_density.(Ref(hist), x)
end

function _histogram_density(hist::StatsBase.Histogram, x::Real)
edges = only(hist.edges)
bin_index = _binindex(edges, hist.closed, x)
weights = hist.weights
return get(weights, bin_index, zero(eltype(weights)))
end

function _binindex(edges::AbstractVector, closed::Symbol, x::Real)
if closed === :right
return searchsortedfirst(edges, x) - 1
else
return searchsortedlast(edges, x)
end
end

"""
KDEstimation{F,K} <: DensityEstimationMethod
Estimate density as uniform mixture of identical data-centered kernels.
"""
struct KDEstimation{F,K} <: DensityEstimationMethod
"""Function to use to compute the KDE, with signature
`kde_func(x; kde_kwargs...) -> KernelDensity.UnivariateKDE`."""
kde_func::F
"""Keyword arguments for `kde_func`."""
kde_kwargs::K
end
KDEstimation(kde_func=kde_reflected; kde_kwargs...) = KDEstimation(kde_func, kde_kwargs)

function density_at(est::KDEstimation, x::AbstractVector{<:Real})
kde = est.kde_func(x; est.kde_kwargs...)
return KernelDensity.pdf(kde, x)
end

function bins_and_probs(est::KDEstimation, x::AbstractVector{<:Real})
kde = est.kde_func(x; est.kde_kwargs...)
bins = kde.x
probs = kde.density * step(kde.x)
return bins, probs
end
Loading

0 comments on commit c52f54a

Please sign in to comment.