Skip to content

Commit a05cc0a

Browse files
authored
Update Turkie to new versions of Turing and Makie (#26)
* General updates, update AbstractPlotting to Makie * Working version * Patch bump
1 parent 238a31c commit a05cc0a

7 files changed

+46
-40
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ test/Manifest.toml
44
test/video.gif
55
test/video.webm
66
docs/build/
7-
docs/Manifest.toml
7+
docs/Manifest.toml
8+
.vscode/settings.json

Project.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
name = "Turkie"
22
uuid = "8156cc02-0533-41cd-9345-13411ebe105f"
33
authors = ["Theo Galy-Fajou <[email protected]> and contributors"]
4-
version = "0.1.3"
4+
version = "0.1.4"
55

66
[deps]
7-
AbstractPlotting = "537997a7-5e4e-5d89-9595-2241ea00577e"
87
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
98
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
109
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
10+
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
1111
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
1212
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
1313

1414
[compat]
15-
AbstractPlotting = "0.15, 0.16"
1615
ColorSchemes = "3.10"
1716
Colors = "0.12"
1817
KernelDensity = "0.5, 0.6"
18+
Makie = "0.13"
1919
OnlineStats = "1.5"
2020
Turing = "0.15"
2121
julia = "1.4"

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Small example:
3939
```julia
4040
using Turing
4141
using Turkie
42-
using Makie # You could also use CairoMakie or another backend
42+
using GLMakie # You could also use CairoMakie or another backend
4343
@model function demo(x) # Some random Turing model
4444
m0 ~ Normal(0, 2)
4545
s ~ InverseGamma(2, 3)
@@ -76,7 +76,7 @@ If you want to record the video do
7676

7777
```julia
7878
using Makie
79-
record(cb.scene, joinpath(@__DIR__, "video.webm")) do io
79+
record(cb, joinpath(@__DIR__, "video.webm")) do io
8080
addIO!(cb, io)
8181
sample(m, NUTS(0.65), 300; callback = cb)
8282
end

src/Turkie.jl

+15-14
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
module Turkie
22

3-
using AbstractPlotting: Scene, Point2f0
4-
using AbstractPlotting: barplot!, lines!, scatter! # Plotting tools
5-
using AbstractPlotting: Observable, Node, lift, on # Observable tools
6-
using AbstractPlotting: recordframe! # Recording tools
7-
using AbstractPlotting.MakieLayout # Layouting tool
3+
using Makie: Figure, Scene, Point2f0
4+
using Makie: barplot!, lines!, scatter! # Plotting tools
5+
using Makie: Observable, Node, lift, on # Observable tools
6+
using Makie: recordframe! # Recording tools
7+
using Makie.MakieLayout # Layouting tool
88
using Colors, ColorSchemes # Colors tools
99
using KernelDensity # To be able to give a KDE
1010
using OnlineStats # Estimators
@@ -16,6 +16,7 @@ export addIO!, record
1616

1717
include("online_stats_plots.jl")
1818

19+
# Uses the colorblind scheme of seaborn by default
1920
const std_colors = ColorSchemes.seaborn_colorblind
2021

2122
name(s::Symbol) = name(Val(s))
@@ -43,7 +44,7 @@ See the docs for some examples.
4344
TurkieCallback
4445

4546
struct TurkieCallback{TN<:NamedTuple,TD<:AbstractDict}
46-
scene::Scene
47+
figure::Figure
4748
data::Dict{Symbol, MovingWindow}
4849
axis_dict::Dict
4950
vars::TN
@@ -70,7 +71,7 @@ end
7071
function TurkieCallback(vars::NamedTuple, params::Dict)
7172
# Create a scene and a layout
7273
outer_padding = 5
73-
scene, layout = layoutscene(outer_padding, resolution = (1200, 700))
74+
fig = Figure(;resolution = (1200, 700), figure_padding=outer_padding)
7475
window = get!(params, :window, 1000)
7576
refresh = get!(params, :refresh, false)
7677
params[:t0] = 0
@@ -81,9 +82,9 @@ function TurkieCallback(vars::NamedTuple, params::Dict)
8182
for (i, variable) in enumerate(keys(vars))
8283
plots = vars[variable]
8384
data[variable] = MovingWindow(window, Float32)
84-
axis_dict[(variable, :varname)] = layout[i, 1, Left()] = Label(scene, string(variable), textsize = 30)
85+
axis_dict[(variable, :varname)] = fig[i, 1, Left()] = Label(fig, string(variable), textsize = 30)
8586
axis_dict[(variable, :varname)].padding = (0, 60, 0, 0)
86-
onlineplot!(scene, layout, axis_dict, plots, iter, data, variable, i)
87+
onlineplot!(fig, axis_dict, plots, iter, data, variable, i)
8788
end
8889
on(iter) do i
8990
if i > 1 # To deal with autolimits a certain number of samples are needed
@@ -94,16 +95,16 @@ function TurkieCallback(vars::NamedTuple, params::Dict)
9495
end
9596
end
9697
end
97-
MakieLayout.trim!(layout)
98-
display(scene)
99-
TurkieCallback(scene, data, axis_dict, vars, params, iter)
98+
MakieLayout.trim!(fig.layout)
99+
display(fig)
100+
return TurkieCallback(fig, data, axis_dict, vars, params, iter)
100101
end
101102

102103
function addIO!(cb::TurkieCallback, io)
103104
cb.params[:io] = io
104105
end
105106

106-
function (cb::TurkieCallback)(rng, model, sampler, transition, iteration)
107+
function (cb::TurkieCallback)(rng, model, sampler, transition, state, iteration; kwargs...)
107108
if iteration == 1
108109
if cb.params[:refresh]
109110
refresh_plots!(cb)
@@ -116,7 +117,7 @@ function (cb::TurkieCallback)(rng, model, sampler, transition, iteration)
116117
fit!(cb.data[variable], Float32(val)) # Update its value
117118
end
118119
end
119-
cb.iter[] += 1
120+
cb.iter[] = cb.iter[] + 1
120121
if haskey(cb.params, :io)
121122
recordframe!(cb.params[:io])
122123
end

src/online_stats_plots.jl

+18-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
function onlineplot!(scene, layout, axis_dict, stats::AbstractVector, iter, data, variable, i)
1+
function onlineplot!(fig, axis_dict, stats::AbstractVector, iter, data, variable, i)
22
for (j, stat) in enumerate(stats)
3-
axis_dict[(variable, stat)] = layout[i, j] = Axis(scene, title = "$(name(stat))")
3+
axis_dict[(variable, stat)] = fig[i, j] = Axis(fig, title="$(name(stat))")
44
limits!(axis_dict[(variable, stat)], 0.0, 10.0, -1.0, 1.0)
55
onlineplot!(axis_dict[(variable, stat)], stat, iter, data[variable], data[:iter], i, j)
66
tight_ticklabel_spacing!(axis_dict[(variable, stat)])
@@ -19,27 +19,26 @@ onlineplot!(axis, ::Val{:autocov}, args...) = onlineplot!(axis, AutoCov(20), arg
1919

2020
onlineplot!(axis, ::Val{:hist}, args...) = onlineplot!(axis, KHist(50, Float32), args...)
2121

22-
2322
# Generic fallback for OnlineStat objects
2423
function onlineplot!(axis, stat::T, iter, data, iterations, i, j) where {T<:OnlineStat}
2524
window = data.b
2625
@eval TStat = $(nameof(T))
2726
stat = Observable(TStat(Float32))
28-
on(iter) do i
27+
on(iter) do _
2928
stat[] = fit!(stat[], last(value(data)))
3029
end
3130
statvals = Observable(MovingWindow(window, Float32))
3231
on(stat) do s
3332
statvals[] = fit!(statvals[], Float32(value(s)))
3433
end
35-
statpoints = lift(statvals; init = Point2f0.([0], [0])) do v
34+
statpoints = map!(Observable(Point2f0.([0], [0])), statvals) do v
3635
Point2f0.(value(iterations), value(v))
3736
end
3837
lines!(axis, statpoints, color = std_colors[i], linewidth = 3.0)
3938
end
4039

4140
function onlineplot!(axis, ::Val{:trace}, iter, data, iterations, i, j)
42-
trace = lift(iter; init = [Point2f0(0f0, 0f0)]) do i
41+
trace = map!(Observable([Point2f0(0, 0)]), iter) do _
4342
Point2f0.(value(iterations), value(data))
4443
end
4544
lines!(axis, trace, color = std_colors[i]; linewidth = 3.0)
@@ -48,15 +47,17 @@ end
4847
function onlineplot!(axis, stat::KHist, iter, data, iterations, i, j)
4948
nbins = stat.k
5049
stat = Observable(KHist(nbins, Float32))
51-
on(iter) do i
50+
on(iter) do _
5251
stat[] = fit!(stat[], last(value(data)))
5352
end
54-
hist_vals = lift(stat; init = Point2f0.(range(0, 1, length = nbins), zeros(Float32, nbins))) do h
53+
hist_vals = Node(Point2f0.(collect(range(0f0, 1f0, length=nbins)), zeros(Float32, nbins)))
54+
on(stat) do h
5555
edges, weights = OnlineStats.xy(h)
5656
weights = nobs(h) > 1 ? weights / OnlineStats.area(h) : weights
57-
return Point2f0.(edges, weights)
57+
hist_vals[] = Point2f0.(edges, weights)
5858
end
59-
barplot!(axis, hist_vals, color = std_colors[i])
59+
barplot!(axis, hist_vals; color=std_colors[i])
60+
# barplot!(axis, rand(4), rand(4))
6061
end
6162

6263
function expand_extrema(xs)
@@ -69,19 +70,20 @@ end
6970

7071
function onlineplot!(axis, ::Val{:kde}, iter, data, iterations, i, j)
7172
interpkde = Observable(InterpKDE(kde([1f0])))
72-
on(iter) do i
73+
on(iter) do _
7374
interpkde[] = InterpKDE(kde(value(data)))
7475
end
75-
xs = lift(iter; init = range(0.0, 2.0, length = 200)) do i
76-
range(expand_extrema(extrema(value(data)))..., length = 200)
76+
xs = Observable(range(0, 2, length=10))
77+
on(iter) do _
78+
xs[] = range(expand_extrema(extrema(value(data)))..., length = 200)
7779
end
7880
kde_pdf = lift(xs) do xs
7981
pdf.(Ref(interpkde[]), xs)
8082
end
8183
lines!(axis, xs, kde_pdf, color = std_colors[i], linewidth = 3.0)
8284
end
8385

84-
name(s::Val{:histkde}) = "Hist + KDE"
86+
name(s::Val{:histkde}) = "Hist. + KDE"
8587

8688
function onlineplot!(axis, ::Val{:histkde}, iter, data, iterations, i, j)
8789
onlineplot!(axis, KHist(50), iter, data, iterations, i, j)
@@ -91,10 +93,10 @@ end
9193
function onlineplot!(axis, stat::AutoCov, iter, data, iterations, i, j)
9294
b = length(stat.cross)
9395
stat = Observable(AutoCov(b, Float32))
94-
on(iter) do i
96+
on(iter) do _
9597
stat[] = fit!(stat[], last(value(data)))
9698
end
97-
statvals = lift(stat; init = zeros(Float32, b + 1)) do s
99+
statvals = map!(Observable(zeros(Float32, b + 1)), stat) do s
98100
value(s)
99101
end
100102
scatter!(axis, Point2f0.([0.0, b], [-0.1, 1.0]), markersize = 0.0, color = RGBA(0.0, 0.0, 0.0, 0.0)) # Invisible points to keep limits fixed

test/dev_test.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using Pkg; Pkg.activate("..")
12
using Turing
23
using Turkie
34
using GLMakie # You could also use CairoMakie or another backend
@@ -19,7 +20,7 @@ cb = TurkieCallback(m) # Create a callback function to be given to the sample fu
1920
chain = sample(m, NUTS(0.65), 30; callback = cb)
2021

2122

22-
record(cb.scene, joinpath(@__DIR__, "video.gif")) do io
23+
record(cb.figure, joinpath(@__DIR__, "video.gif")) do io
2324
addIO!(cb, io)
2425
sample(m, NUTS(0.65), 50; callback = cb)
2526
end

test/runtests.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Turkie
22
using Test
33
using CairoMakie
4+
CairoMakie.activate!()
45
using OnlineStats
56
using Turing
67

@@ -23,7 +24,7 @@ using Turing
2324
@test Turkie.name(OnlineStats.Mean(Float32)) == "Mean"
2425

2526
cb = TurkieCallback(model; blah=2.0)
26-
@test cb.scene isa Scene
27+
@test cb.figure isa Figure
2728
@test sort(collect(keys(cb.data))) == sort(vcat(vars, :iter))
2829
@test cb.data[:m] isa MovingWindow{Float32}
2930
@test sort(collect(keys(cb.vars))) == sort(vars)
@@ -41,13 +42,13 @@ using Turing
4142
@testset "Vector of symbols" begin
4243
for stat in [:histkde, :kde, :hist, :mean, :var, :trace, :autocov]
4344
cb = TurkieCallback(Dict(:m => [stat]))
44-
sample(model, MH(), 50; callback = cb)
45+
sample(model, MH(), 50; callback=cb)
4546
end
4647
end
4748
@testset "Series" begin
4849
for stat in [Mean(Float32), Variance(Float32)]
4950
cb = TurkieCallback(model, OnlineStats.Series(stat))
50-
sample(model, MH(), 50; callback = cb)
51+
sample(model, MH(), 50; callback=cb)
5152
end
5253
end
5354
end

0 commit comments

Comments
 (0)