Skip to content

Commit 94964bf

Browse files
authored
Allow for Vector of parameters and use Turing as dependency (#22)
1 parent 81fa7cc commit 94964bf

File tree

6 files changed

+80
-39
lines changed

6 files changed

+80
-39
lines changed

Project.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
name = "Turkie"
22
uuid = "8156cc02-0533-41cd-9345-13411ebe105f"
33
authors = ["Theo Galy-Fajou <[email protected]> and contributors"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
AbstractPlotting = "537997a7-5e4e-5d89-9595-2241ea00577e"
88
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
99
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
10-
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
1110
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
1211
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
12+
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
1313

1414
[compat]
1515
AbstractPlotting = "0.15"
1616
ColorSchemes = "3.10"
1717
Colors = "0.12"
18-
DynamicPPL = "0.9, 0.10"
1918
KernelDensity = "0.5, 0.6"
2019
OnlineStats = "1.5"
20+
Turing = "0.15"
2121
julia = "1.4"
2222

2323
[extras]

README.md

+6-3
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,18 @@ chain = sample(m, NUTS(0.65), 300; callback = cb) # Sample and plot at the same
5858
If you want to show only some variables you can give a `Dict` to `TurkieCallback` :
5959

6060
```julia
61-
cb = TurkieCallback(Dict(:m0 => [:trace, :mean],
62-
:s => [:autocov, :var]))
61+
cb = TurkieCallback(
62+
(m0 = [:trace, :mean], s = [:autocov, :var])
63+
)
6364

6465
```
6566

6667
You can also directly pass `OnlineStats` object :
6768
```julia
6869
using OnlineStats
69-
cb = TurkieCallback(Dict(:v => [Mean(), AutoCov(20)]))
70+
cb = TurkieCallback(
71+
(v = [Mean(), AutoCov(20)],)
72+
)
7073
```
7174

7275
If you want to record the video do

docs/src/index.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,13 @@ While sampling the callback object `cb` will be called and the statistics will b
6666
## Tuning the quantities
6767

6868
Of course the default is not always desirable.
69-
You can chose what variables and what quantities are shown by giving a `Dict` to `TurkieCallback` instead of a model.
69+
You can chose what variables and what quantities are shown by giving a `NamedTuple` to `TurkieCallback` instead of a model.
7070
For example,
7171
```julia
72-
cb = TurkieCallback(Dict(:v => [:trace, :mean],
73-
:s => [:autocov, :var]))
72+
cb = TurkieCallback(
73+
(v = [:trace, :mean],
74+
s = [:autocov, :var])
75+
)
7476
```
7577
will only show the trace and the sample mean of `v` and the auto-covariance and variance of `s`.
7678
Pairs should be of the type `{Symbol,AbstractVector}`.

src/Turkie.jl

+34-23
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using AbstractPlotting.MakieLayout # Layouting tool
88
using Colors, ColorSchemes # Colors tools
99
using KernelDensity # To be able to give a KDE
1010
using OnlineStats # Estimators
11-
using DynamicPPL: VarInfo, Model
11+
using Turing: DynamicPPL.VarInfo, DynamicPPL.Model, Inference._params_to_array
1212

1313
export TurkieCallback
1414

@@ -23,19 +23,31 @@ name(::Val{T}) where {T} = string(T)
2323
name(s::OnlineStat) = string(nameof(typeof(s)))
2424

2525
"""
26-
TurkieCallback(model::DynamicPPL.Model, plots::Series/AbstractVector; window=1000, kwargs...)
27-
26+
TurkieCallback(args...; kwargs....)
27+
28+
## Arguments
29+
- Option 1 :
30+
`model::DynamicPPL.Model, plots::Series/AbstractVector=[:histkde, Mean(Float32), Variance(Float32), AutoCov(20, Float32)]`
31+
32+
For each of the variables of the given model each `plot` from `plots` will be plotted
33+
Multidimensional variable will be automatically have indices added to them
34+
- Option 2 :
35+
`vars::NamedTuple/Dict`
36+
Will plot each pair of symbol and series of plots.
37+
Note that for multidimensional variable you should pass a Symbol as `Symbol("m[1]")` for example.
38+
See the docs for some examples.
2839
## Keyword arguments
2940
- `window=1000` : Use a window for plotting the trace
41+
- `refresh=false` : Restart the plots from scratch everytime `sample` is called again (still WIP)
3042
"""
3143
TurkieCallback
3244

33-
struct TurkieCallback
45+
struct TurkieCallback{TN<:NamedTuple,TD<:AbstractDict}
3446
scene::Scene
3547
data::Dict{Symbol, MovingWindow}
3648
axis_dict::Dict
37-
vars::Dict{Symbol, Any}
38-
params::Dict{Any, Any}
49+
vars::TN
50+
params::TD
3951
iter::Observable{Int}
4052
end
4153

@@ -44,38 +56,39 @@ function TurkieCallback(model::Model, plots::Series; kwargs...)
4456
end
4557

4658
function TurkieCallback(model::Model, plots::AbstractVector = [:histkde, Mean(Float32), Variance(Float32), AutoCov(20, Float32)]; kwargs...)
47-
variables = VarInfo(model).metadata
59+
vars, vals = _params_to_array([VarInfo(model)])
4860
return TurkieCallback(
49-
Dict(Pair.(keys(variables), Ref(plots)));
61+
(;Pair.(vars, Ref(plots))...); # Return a named Tuple
5062
kwargs...
51-
)
63+
)
5264
end
5365

54-
function TurkieCallback(varsdict::Dict; kwargs...)
55-
return TurkieCallback(varsdict, Dict{Symbol,Any}(kwargs...))
66+
function TurkieCallback(vars::Union{Dict, NamedTuple}; kwargs...)
67+
return TurkieCallback((;vars...), Dict{Symbol,Any}(kwargs...))
5668
end
5769

58-
function TurkieCallback(vars::Dict, params::Dict)
70+
function TurkieCallback(vars::NamedTuple, params::Dict)
5971
# Create a scene and a layout
6072
outer_padding = 5
6173
scene, layout = layoutscene(outer_padding, resolution = (1200, 700))
6274
window = get!(params, :window, 1000)
6375
refresh = get!(params, :refresh, false)
6476
params[:t0] = 0
65-
iter = Node(0)
77+
iter = Observable(0)
6678
data = Dict{Symbol, MovingWindow}(:iter => MovingWindow(window, Int))
6779
obs = Dict{Symbol, Any}()
6880
axis_dict = Dict()
69-
for (i, (variable, plots)) in enumerate(vars)
81+
for (i, variable) in enumerate(keys(vars))
82+
plots = vars[variable]
7083
data[variable] = MovingWindow(window, Float32)
7184
axis_dict[(variable, :varname)] = layout[i, 1, Left()] = Label(scene, string(variable), textsize = 30)
72-
axis_dict[(variable, :varname)].padding = (0, 50, 0, 0)
85+
axis_dict[(variable, :varname)].padding = (0, 60, 0, 0)
7386
onlineplot!(scene, layout, axis_dict, plots, iter, data, variable, i)
7487
end
7588
on(iter) do i
7689
if i > 1 # To deal with autolimits a certain number of samples are needed
77-
for (variable, plots) in vars
78-
for p in plots
90+
for variable in keys(vars)
91+
for p in vars[variable]
7992
autolimits!(axis_dict[(variable, p)])
8093
end
8194
end
@@ -97,12 +110,10 @@ function (cb::TurkieCallback)(rng, model, sampler, transition, iteration)
97110
end
98111
cb.params[:t0] = cb.iter[]
99112
end
100-
fit!(cb.data[:iter], iteration + cb.params[:t0])
101-
for (vals, ks) in values(transition.θ)
102-
for (k, val) in zip(ks, vals)
103-
if haskey(cb.data, Symbol(k))
104-
fit!(cb.data[Symbol(k)], Float32(val))
105-
end
113+
fit!(cb.data[:iter], iteration + cb.params[:t0]) # Update the iteration value
114+
for (variable, val) in zip(_params_to_array([transition])...)
115+
if haskey(cb.data, variable) # Check if symbol should be plotted
116+
fit!(cb.data[variable], Float32(val)) # Update its value
106117
end
107118
end
108119
cb.iter[] += 1

src/online_stats_plots.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ onlineplot!(axis, ::Val{:hist}, args...) = onlineplot!(axis, KHist(50, Float32),
2424
function onlineplot!(axis, stat::T, iter, data, iterations, i, j) where {T<:OnlineStat}
2525
window = data.b
2626
@eval TStat = $(nameof(T))
27-
stat = Node(TStat(Float32))
27+
stat = Observable(TStat(Float32))
2828
on(iter) do i
2929
stat[] = fit!(stat[], last(value(data)))
3030
end
31-
statvals = Node(MovingWindow(window, Float32))
31+
statvals = Observable(MovingWindow(window, Float32))
3232
on(stat) do s
3333
statvals[] = fit!(statvals[], Float32(value(s)))
3434
end
@@ -39,15 +39,15 @@ function onlineplot!(axis, stat::T, iter, data, iterations, i, j) where {T<:Onli
3939
end
4040

4141
function onlineplot!(axis, ::Val{:trace}, iter, data, iterations, i, j)
42-
trace = lift(iter; init = [Point2f0(0, 0f0)]) do i
42+
trace = lift(iter; init = [Point2f0(0f0, 0f0)]) do i
4343
Point2f0.(value(iterations), value(data))
4444
end
4545
lines!(axis, trace, color = std_colors[i]; linewidth = 3.0)
4646
end
4747

4848
function onlineplot!(axis, stat::KHist, iter, data, iterations, i, j)
4949
nbins = stat.k
50-
stat = Node(KHist(nbins, Float32))
50+
stat = Observable(KHist(nbins, Float32))
5151
on(iter) do i
5252
stat[] = fit!(stat[], last(value(data)))
5353
end
@@ -68,7 +68,7 @@ function expand_extrema(xs)
6868
end
6969

7070
function onlineplot!(axis, ::Val{:kde}, iter, data, iterations, i, j)
71-
interpkde = Node(InterpKDE(kde([1f0])))
71+
interpkde = Observable(InterpKDE(kde([1f0])))
7272
on(iter) do i
7373
interpkde[] = InterpKDE(kde(value(data)))
7474
end
@@ -90,7 +90,7 @@ end
9090

9191
function onlineplot!(axis, stat::AutoCov, iter, data, iterations, i, j)
9292
b = length(stat.cross)
93-
stat = Node(AutoCov(b, Float32))
93+
stat = Observable(AutoCov(b, Float32))
9494
on(iter) do i
9595
stat[] = fit!(stat[], last(value(data)))
9696
end

test/dev_test.jl

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using Turing
22
using Turkie
33
using GLMakie # You could also use CairoMakie or another backend
4+
using CairoMakie
45
GLMakie.activate!()
6+
CairoMakie.activate!()
57
Turing.@model function demo(x) # Some random Turing model
68
m0 ~ Normal(0, 2)
79
s ~ InverseGamma(2, 3)
@@ -38,4 +40,27 @@ end
3840
cb = TurkieCallback(Dict(:m0 => [:trace, :mean],
3941
:s => [:autocov, :var]))
4042

41-
advancedHMC(sossdemo(), (x=xs,), 100; callback = cb)
43+
advancedHMC(sossdemo(), (x=xs,), 100; callback = cb)
44+
45+
## Test for array of parameters
46+
using LinearAlgebra
47+
D = 1
48+
N = 20
49+
Turing.@model function vectordemo(x, y, σ)
50+
m ~ Normal(0, 10)
51+
β ~ MvNormal(m * ones(D + 1), ones(D + 1))
52+
for i in eachindex(y)
53+
y[i] ~ Normal(dot(β, vcat(1, x[i])), σ)
54+
end
55+
end
56+
57+
x = [rand(D) for _ in 1:N]
58+
β = randn(D + 1) * 2
59+
σ = 0.1
60+
y = dot.(Ref(β), vcat.(1, x)) .+ σ * randn(N)
61+
62+
m = vectordemo(x, y, σ)
63+
64+
cb = TurkieCallback(m) # Create a callback function to be given to the sample function
65+
66+
chain = sample(m, NUTS(0.65), 200; callback = cb)

0 commit comments

Comments
 (0)