@@ -8,7 +8,7 @@ using AbstractPlotting.MakieLayout # Layouting tool
8
8
using Colors, ColorSchemes # Colors tools
9
9
using KernelDensity # To be able to give a KDE
10
10
using OnlineStats # Estimators
11
- using DynamicPPL : VarInfo, Model
11
+ using Turing : DynamicPPL . VarInfo, DynamicPPL . Model, Inference . _params_to_array
12
12
13
13
export TurkieCallback
14
14
@@ -23,19 +23,31 @@ name(::Val{T}) where {T} = string(T)
23
23
name (s:: OnlineStat ) = string (nameof (typeof (s)))
24
24
25
25
"""
26
- TurkieCallback(model::DynamicPPL.Model, plots::Series/AbstractVector; window=1000, kwargs...)
27
-
26
+ TurkieCallback(args...; kwargs....)
27
+
28
+ ## Arguments
29
+ - Option 1 :
30
+ `model::DynamicPPL.Model, plots::Series/AbstractVector=[:histkde, Mean(Float32), Variance(Float32), AutoCov(20, Float32)]`
31
+
32
+ For each of the variables of the given model each `plot` from `plots` will be plotted
33
+ Multidimensional variable will be automatically have indices added to them
34
+ - Option 2 :
35
+ `vars::NamedTuple/Dict`
36
+ Will plot each pair of symbol and series of plots.
37
+ Note that for multidimensional variable you should pass a Symbol as `Symbol("m[1]")` for example.
38
+ See the docs for some examples.
28
39
## Keyword arguments
29
40
- `window=1000` : Use a window for plotting the trace
41
+ - `refresh=false` : Restart the plots from scratch everytime `sample` is called again (still WIP)
30
42
"""
31
43
TurkieCallback
32
44
33
- struct TurkieCallback
45
+ struct TurkieCallback{TN <: NamedTuple ,TD <: AbstractDict }
34
46
scene:: Scene
35
47
data:: Dict{Symbol, MovingWindow}
36
48
axis_dict:: Dict
37
- vars:: Dict{Symbol, Any}
38
- params:: Dict{Any, Any}
49
+ vars:: TN
50
+ params:: TD
39
51
iter:: Observable{Int}
40
52
end
41
53
@@ -44,38 +56,39 @@ function TurkieCallback(model::Model, plots::Series; kwargs...)
44
56
end
45
57
46
58
function TurkieCallback (model:: Model , plots:: AbstractVector = [:histkde , Mean (Float32), Variance (Float32), AutoCov (20 , Float32)]; kwargs... )
47
- variables = VarInfo (model). metadata
59
+ vars, vals = _params_to_array ([ VarInfo (model)])
48
60
return TurkieCallback (
49
- Dict ( Pair .(keys (variables) , Ref (plots)));
61
+ (; Pair .(vars , Ref (plots))... ); # Return a named Tuple
50
62
kwargs...
51
- )
63
+ )
52
64
end
53
65
54
- function TurkieCallback (varsdict :: Dict ; kwargs... )
55
- return TurkieCallback (varsdict , Dict {Symbol,Any} (kwargs... ))
66
+ function TurkieCallback (vars :: Union{ Dict, NamedTuple} ; kwargs... )
67
+ return TurkieCallback ((;vars ... ) , Dict {Symbol,Any} (kwargs... ))
56
68
end
57
69
58
- function TurkieCallback (vars:: Dict , params:: Dict )
70
+ function TurkieCallback (vars:: NamedTuple , params:: Dict )
59
71
# Create a scene and a layout
60
72
outer_padding = 5
61
73
scene, layout = layoutscene (outer_padding, resolution = (1200 , 700 ))
62
74
window = get! (params, :window , 1000 )
63
75
refresh = get! (params, :refresh , false )
64
76
params[:t0 ] = 0
65
- iter = Node (0 )
77
+ iter = Observable (0 )
66
78
data = Dict {Symbol, MovingWindow} (:iter => MovingWindow (window, Int))
67
79
obs = Dict {Symbol, Any} ()
68
80
axis_dict = Dict ()
69
- for (i, (variable, plots)) in enumerate (vars)
81
+ for (i, variable) in enumerate (keys (vars))
82
+ plots = vars[variable]
70
83
data[variable] = MovingWindow (window, Float32)
71
84
axis_dict[(variable, :varname )] = layout[i, 1 , Left ()] = Label (scene, string (variable), textsize = 30 )
72
- axis_dict[(variable, :varname )]. padding = (0 , 50 , 0 , 0 )
85
+ axis_dict[(variable, :varname )]. padding = (0 , 60 , 0 , 0 )
73
86
onlineplot! (scene, layout, axis_dict, plots, iter, data, variable, i)
74
87
end
75
88
on (iter) do i
76
89
if i > 1 # To deal with autolimits a certain number of samples are needed
77
- for ( variable, plots) in vars
78
- for p in plots
90
+ for variable in keys ( vars)
91
+ for p in vars[variable]
79
92
autolimits! (axis_dict[(variable, p)])
80
93
end
81
94
end
@@ -97,12 +110,10 @@ function (cb::TurkieCallback)(rng, model, sampler, transition, iteration)
97
110
end
98
111
cb. params[:t0 ] = cb. iter[]
99
112
end
100
- fit! (cb. data[:iter ], iteration + cb. params[:t0 ])
101
- for (vals, ks) in values (transition. θ)
102
- for (k, val) in zip (ks, vals)
103
- if haskey (cb. data, Symbol (k))
104
- fit! (cb. data[Symbol (k)], Float32 (val))
105
- end
113
+ fit! (cb. data[:iter ], iteration + cb. params[:t0 ]) # Update the iteration value
114
+ for (variable, val) in zip (_params_to_array ([transition])... )
115
+ if haskey (cb. data, variable) # Check if symbol should be plotted
116
+ fit! (cb. data[variable], Float32 (val)) # Update its value
106
117
end
107
118
end
108
119
cb. iter[] += 1
0 commit comments