1
1
using Pkg
2
- Pkg. develop (path= joinpath (@__DIR__ , " .." ))
2
+ Pkg. develop (; path= joinpath (@__DIR__ , " .." ))
3
3
4
- using
5
- AbstractGPs,
4
+ using AbstractGPs,
6
5
Chairmarks,
7
6
CSV,
8
7
DataFrames,
@@ -28,13 +27,13 @@ using Mooncake:
28
27
29
28
using Mooncake. TestUtils: _deepcopy
30
29
31
- function to_benchmark (__rrule!!:: R , dx:: Vararg{CoDual, N} ) where {R, N}
30
+ function to_benchmark (__rrule!!:: R , dx:: Vararg{CoDual,N} ) where {R,N}
32
31
dx_f = Mooncake. tuple_map (x -> CoDual (primal (x), Mooncake. fdata (tangent (x))), dx)
33
32
out, pb!! = __rrule!! (dx_f... )
34
33
return pb!! (Mooncake. zero_rdata (primal (out)))
35
34
end
36
35
37
- function zygote_to_benchmark (ctx, x:: Vararg{Any, N} ) where {N}
36
+ function zygote_to_benchmark (ctx, x:: Vararg{Any,N} ) where {N}
38
37
out, pb = Zygote. _pullback (ctx, x... )
39
38
return pb (out)
40
39
end
107
106
@model broadcast_demo (x) = begin
108
107
μ ~ truncated (Normal (1 , 2 ), 0.1 , 10 )
109
108
σ ~ truncated (Normal (1 , 2 ), 0.1 , 10 )
110
- x .~ LogNormal (μ, σ)
109
+ x .~ LogNormal (μ, σ)
111
110
end
112
111
113
112
function build_turing_problem ()
@@ -122,17 +121,21 @@ function build_turing_problem()
122
121
return test_function, randn (rng, d)
123
122
end
124
123
125
- run_turing_problem (f:: F , x:: X ) where {F, X} = f (x)
124
+ run_turing_problem (f:: F , x:: X ) where {F,X} = f (x)
126
125
127
- should_run_benchmark (
126
+ function should_run_benchmark (
128
127
:: Val{:zygote} , :: Base.Fix1{<:typeof(DynamicPPL.LogDensityProblems.logdensity)} , x...
129
- ) = false
130
- should_run_benchmark (
128
+ )
129
+ return false
130
+ end
131
+ function should_run_benchmark (
131
132
:: Val{:enzyme} , :: Base.Fix1{<:typeof(DynamicPPL.LogDensityProblems.logdensity)} , x...
132
- ) = false
133
+ )
134
+ return false
135
+ end
133
136
should_run_benchmark (:: Val{:enzyme} , x... ) = false
134
137
135
- @inline g (x, a, :: Val{N} ) where {N} = N > 0 ? g (x * a, a, Val (N- 1 )) : x
138
+ @inline g (x, a, :: Val{N} ) where {N} = N > 0 ? g (x * a, a, Val (N - 1 )) : x
136
139
137
140
large_single_block (x:: AbstractVector{<:Real} ) = g (x[1 ], x[2 ], Val (400 ))
138
141
@@ -168,14 +171,12 @@ function generate_inter_framework_tests()
168
171
end
169
172
170
173
function benchmark_rules!! (test_case_data, default_ratios, include_other_frameworks:: Bool )
171
-
172
174
test_cases = reduce (vcat, map (first, test_case_data))
173
175
memory = map (x -> x[2 ], test_case_data)
174
176
ranges = reduce (vcat, map (x -> x[3 ], test_case_data))
175
177
tags = reduce (vcat, map (x -> x[4 ], test_case_data))
176
178
GC. @preserve memory begin
177
179
return map (enumerate (test_cases)) do (n, args)
178
-
179
180
@info " $n / $(length (test_cases)) " , _typeof (args)
180
181
suite = Dict ()
181
182
@@ -186,7 +187,7 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo
186
187
() -> primals,
187
188
primals -> (primals[1 ], _deepcopy (primals[2 : end ])),
188
189
(a -> a[1 ]((a[2 ]. .. ))),
189
- _ -> true ,
190
+ _ -> true ;
190
191
evals= 1 ,
191
192
)
192
193
@@ -199,17 +200,19 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo
199
200
() -> (rule, coduals),
200
201
identity,
201
202
a -> to_benchmark (a[1 ], a[2 ]. .. ),
202
- _ -> true ,
203
+ _ -> true ;
203
204
evals= 1 ,
204
205
)
205
206
206
207
if include_other_frameworks
207
-
208
208
if should_run_benchmark (Val (:zygote ), args... )
209
209
@info " Zygote"
210
210
suite[" zygote" ] = @be (
211
- _, _, zygote_to_benchmark ($ (Zygote. Context ()), $ primals... ), _,
212
- evals= 1 ,
211
+ _,
212
+ _,
213
+ zygote_to_benchmark ($ (Zygote. Context ()), $ primals... ),
214
+ _,
215
+ evals = 1 ,
213
216
)
214
217
end
215
218
@@ -219,21 +222,27 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo
219
222
compiled_tape = ReverseDiff. compile (tape)
220
223
result = map (x -> randn (size (x)), primals[2 : end ])
221
224
suite[" rd" ] = @be (
222
- _, _, rd_to_benchmark! ($ result, $ compiled_tape, $ primals[2 : end ]), _,
223
- evals= 1 ,
225
+ _,
226
+ _,
227
+ rd_to_benchmark! ($ result, $ compiled_tape, $ primals[2 : end ]),
228
+ _,
229
+ evals = 1 ,
224
230
)
225
231
end
226
232
227
233
if should_run_benchmark (Val (:enzyme ), args... )
228
234
@info " Enzyme"
229
235
dup_args = map (x -> Duplicated (x, randn (size (x))), primals[2 : end ])
230
236
suite[" enzyme" ] = @be (
231
- _, _, autodiff (Reverse, $ primals[1 ], Active, $ dup_args... ), _,
232
- evals= 1 ,
237
+ _,
238
+ _,
239
+ autodiff (Reverse, $ primals[1 ], Active, $ dup_args... ),
240
+ _,
241
+ evals = 1 ,
233
242
)
234
243
end
235
244
end
236
-
245
+
237
246
return combine_results ((args, suite), tags[n], ranges[n], default_ratios)
238
247
end
239
248
end
@@ -319,7 +328,7 @@ well-suited to the numbers typically found in this field.
319
328
function plot_ratio_histogram! (df:: DataFrame )
320
329
bin = 10.0 .^ (- 1.0 : 0.05 : 4.0 )
321
330
xlim = extrema (bin)
322
- histogram (df. Mooncake; xscale= :log10 , xlim, bin, title= " log" , label= " " )
331
+ return histogram (df. Mooncake; xscale= :log10 , xlim, bin, title= " log" , label= " " )
323
332
end
324
333
325
334
function create_inter_ad_benchmarks ()
@@ -328,7 +337,7 @@ function create_inter_ad_benchmarks()
328
337
df = DataFrame (results)[:, [:tag , tools... ]]
329
338
330
339
# Plot graph of results.
331
- plt = plot (yscale= :log10 , legend= :topright , title= " AD Time / Primal Time (Log Scale)" )
340
+ plt = plot (; yscale= :log10 , legend= :topright , title= " AD Time / Primal Time (Log Scale)" )
332
341
for label in string .(tools)
333
342
plot! (plt, df. tag, df[:, label]; label, marker= :circle , xrotation= 45 )
334
343
end
@@ -337,7 +346,9 @@ function create_inter_ad_benchmarks()
337
346
# Write table of results.
338
347
formatted_cols = map (t -> t => string .(round .(df[:, t]; sigdigits= 3 )), tools)
339
348
df_formatted = DataFrame (:Label => df. tag, formatted_cols... )
340
- open (io -> pretty_table (io, df_formatted), " bench/benchmark_results.txt" ; write= true )
349
+ return open (
350
+ io -> pretty_table (io, df_formatted), " bench/benchmark_results.txt" ; write= true
351
+ )
341
352
end
342
353
343
354
function main ()
0 commit comments