Skip to content

Commit d9a6952

Browse files
authored
Use JuliaFormatter (#390)
* Add JuliaFormatter style and test * Fix path * Apply formatter * Bump patch, rename Aqua test set
1 parent 5397d28 commit d9a6952

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+3109
-2241
lines changed

.JuliaFormatter.toml

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
style = "blue"

.github/workflows/CI.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
fail-fast: false
2222
matrix:
2323
test_group: [
24-
'aqua',
24+
'quality',
2525
'basic',
2626
'rrules/avoiding_non_differentiable_code',
2727
'rrules/blas',

Project.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Mooncake"
22
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
33
authors = ["Will Tebbutt, Hong Ge, and contributors"]
4-
version = "0.4.49"
4+
version = "0.4.50"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -55,6 +55,7 @@ FunctionWrappers = "1.1.3"
5555
Graphs = "1"
5656
InteractiveUtils = "1"
5757
JET = "0.9"
58+
JuliaFormatter = "1.0"
5859
LinearAlgebra = "1"
5960
LuxLib = "1"
6061
MistyClosures = "2"
@@ -74,9 +75,10 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
7475
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
7576
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
7677
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
78+
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
7779
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
7880
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
7981
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8082

8183
[targets]
82-
test = ["AllocCheck", "Aqua", "BenchmarkTools", "DiffTests", "JET", "Pkg", "StableRNGs", "Test"]
84+
test = ["AllocCheck", "Aqua", "BenchmarkTools", "DiffTests", "JET", "JuliaFormatter", "Pkg", "StableRNGs", "Test"]

bench/run_benchmarks.jl

+38-27
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
using Pkg
2-
Pkg.develop(path=joinpath(@__DIR__, ".."))
2+
Pkg.develop(; path=joinpath(@__DIR__, ".."))
33

4-
using
5-
AbstractGPs,
4+
using AbstractGPs,
65
Chairmarks,
76
CSV,
87
DataFrames,
@@ -28,13 +27,13 @@ using Mooncake:
2827

2928
using Mooncake.TestUtils: _deepcopy
3029

31-
function to_benchmark(__rrule!!::R, dx::Vararg{CoDual, N}) where {R, N}
30+
function to_benchmark(__rrule!!::R, dx::Vararg{CoDual,N}) where {R,N}
3231
dx_f = Mooncake.tuple_map(x -> CoDual(primal(x), Mooncake.fdata(tangent(x))), dx)
3332
out, pb!! = __rrule!!(dx_f...)
3433
return pb!!(Mooncake.zero_rdata(primal(out)))
3534
end
3635

37-
function zygote_to_benchmark(ctx, x::Vararg{Any, N}) where {N}
36+
function zygote_to_benchmark(ctx, x::Vararg{Any,N}) where {N}
3837
out, pb = Zygote._pullback(ctx, x...)
3938
return pb(out)
4039
end
@@ -107,7 +106,7 @@ end
107106
@model broadcast_demo(x) = begin
108107
μ ~ truncated(Normal(1, 2), 0.1, 10)
109108
σ ~ truncated(Normal(1, 2), 0.1, 10)
110-
x .~ LogNormal(μ, σ)
109+
x .~ LogNormal(μ, σ)
111110
end
112111

113112
function build_turing_problem()
@@ -122,17 +121,21 @@ function build_turing_problem()
122121
return test_function, randn(rng, d)
123122
end
124123

125-
run_turing_problem(f::F, x::X) where {F, X} = f(x)
124+
run_turing_problem(f::F, x::X) where {F,X} = f(x)
126125

127-
should_run_benchmark(
126+
function should_run_benchmark(
128127
::Val{:zygote}, ::Base.Fix1{<:typeof(DynamicPPL.LogDensityProblems.logdensity)}, x...
129-
) = false
130-
should_run_benchmark(
128+
)
129+
return false
130+
end
131+
function should_run_benchmark(
131132
::Val{:enzyme}, ::Base.Fix1{<:typeof(DynamicPPL.LogDensityProblems.logdensity)}, x...
132-
) = false
133+
)
134+
return false
135+
end
133136
should_run_benchmark(::Val{:enzyme}, x...) = false
134137

135-
@inline g(x, a, ::Val{N}) where {N} = N > 0 ? g(x * a, a, Val(N-1)) : x
138+
@inline g(x, a, ::Val{N}) where {N} = N > 0 ? g(x * a, a, Val(N - 1)) : x
136139

137140
large_single_block(x::AbstractVector{<:Real}) = g(x[1], x[2], Val(400))
138141

@@ -168,14 +171,12 @@ function generate_inter_framework_tests()
168171
end
169172

170173
function benchmark_rules!!(test_case_data, default_ratios, include_other_frameworks::Bool)
171-
172174
test_cases = reduce(vcat, map(first, test_case_data))
173175
memory = map(x -> x[2], test_case_data)
174176
ranges = reduce(vcat, map(x -> x[3], test_case_data))
175177
tags = reduce(vcat, map(x -> x[4], test_case_data))
176178
GC.@preserve memory begin
177179
return map(enumerate(test_cases)) do (n, args)
178-
179180
@info "$n / $(length(test_cases))", _typeof(args)
180181
suite = Dict()
181182

@@ -186,7 +187,7 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo
186187
() -> primals,
187188
primals -> (primals[1], _deepcopy(primals[2:end])),
188189
(a -> a[1]((a[2]...))),
189-
_ -> true,
190+
_ -> true;
190191
evals=1,
191192
)
192193

@@ -199,17 +200,19 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo
199200
() -> (rule, coduals),
200201
identity,
201202
a -> to_benchmark(a[1], a[2]...),
202-
_ -> true,
203+
_ -> true;
203204
evals=1,
204205
)
205206

206207
if include_other_frameworks
207-
208208
if should_run_benchmark(Val(:zygote), args...)
209209
@info "Zygote"
210210
suite["zygote"] = @be(
211-
_, _, zygote_to_benchmark($(Zygote.Context()), $primals...), _,
212-
evals=1,
211+
_,
212+
_,
213+
zygote_to_benchmark($(Zygote.Context()), $primals...),
214+
_,
215+
evals = 1,
213216
)
214217
end
215218

@@ -219,21 +222,27 @@ function benchmark_rules!!(test_case_data, default_ratios, include_other_framewo
219222
compiled_tape = ReverseDiff.compile(tape)
220223
result = map(x -> randn(size(x)), primals[2:end])
221224
suite["rd"] = @be(
222-
_, _, rd_to_benchmark!($result, $compiled_tape, $primals[2:end]), _,
223-
evals=1,
225+
_,
226+
_,
227+
rd_to_benchmark!($result, $compiled_tape, $primals[2:end]),
228+
_,
229+
evals = 1,
224230
)
225231
end
226232

227233
if should_run_benchmark(Val(:enzyme), args...)
228234
@info "Enzyme"
229235
dup_args = map(x -> Duplicated(x, randn(size(x))), primals[2:end])
230236
suite["enzyme"] = @be(
231-
_, _, autodiff(Reverse, $primals[1], Active, $dup_args...), _,
232-
evals=1,
237+
_,
238+
_,
239+
autodiff(Reverse, $primals[1], Active, $dup_args...),
240+
_,
241+
evals = 1,
233242
)
234243
end
235244
end
236-
245+
237246
return combine_results((args, suite), tags[n], ranges[n], default_ratios)
238247
end
239248
end
@@ -319,7 +328,7 @@ well-suited to the numbers typically found in this field.
319328
function plot_ratio_histogram!(df::DataFrame)
320329
bin = 10.0 .^ (-1.0:0.05:4.0)
321330
xlim = extrema(bin)
322-
histogram(df.Mooncake; xscale=:log10, xlim, bin, title="log", label="")
331+
return histogram(df.Mooncake; xscale=:log10, xlim, bin, title="log", label="")
323332
end
324333

325334
function create_inter_ad_benchmarks()
@@ -328,7 +337,7 @@ function create_inter_ad_benchmarks()
328337
df = DataFrame(results)[:, [:tag, tools...]]
329338

330339
# Plot graph of results.
331-
plt = plot(yscale=:log10, legend=:topright, title="AD Time / Primal Time (Log Scale)")
340+
plt = plot(; yscale=:log10, legend=:topright, title="AD Time / Primal Time (Log Scale)")
332341
for label in string.(tools)
333342
plot!(plt, df.tag, df[:, label]; label, marker=:circle, xrotation=45)
334343
end
@@ -337,7 +346,9 @@ function create_inter_ad_benchmarks()
337346
# Write table of results.
338347
formatted_cols = map(t -> t => string.(round.(df[:, t]; sigdigits=3)), tools)
339348
df_formatted = DataFrame(:Label => df.tag, formatted_cols...)
340-
open(io -> pretty_table(io, df_formatted), "bench/benchmark_results.txt"; write=true)
349+
return open(
350+
io -> pretty_table(io, df_formatted), "bench/benchmark_results.txt"; write=true
351+
)
341352
end
342353

343354
function main()

docs/make.jl

+7-15
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,18 @@ DocMeta.setdocmeta!(
99
recursive=true,
1010
)
1111

12-
makedocs(
12+
makedocs(;
1313
sitename="Mooncake.jl",
1414
format=Documenter.HTML(;
15-
mathengine = Documenter.KaTeX(
16-
Dict(
17-
:macros => Dict(
18-
"\\RR" => "\\mathbb{R}",
19-
),
20-
)
21-
),
15+
mathengine=Documenter.KaTeX(Dict(:macros => Dict("\\RR" => "\\mathbb{R}"))),
2216
size_threshold_ignore=[
23-
joinpath("developer_documentation", "internal_docstrings.md"),
17+
joinpath("developer_documentation", "internal_docstrings.md")
2418
],
2519
),
2620
modules=[Mooncake],
2721
checkdocs=:none,
28-
plugins=[
29-
CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"); style=:numeric),
30-
],
31-
pages = [
22+
plugins=[CitationBibliography(joinpath(@__DIR__, "src", "refs.bib"); style=:numeric)],
23+
pages=[
3224
"Mooncake.jl" => "index.md",
3325
"Understanding Mooncake.jl" => [
3426
joinpath("understanding_mooncake", "introduction.md"),
@@ -46,7 +38,7 @@ makedocs(
4638
joinpath("developer_documentation", "internal_docstrings.md"),
4739
],
4840
"known_limitations.md",
49-
]
41+
],
5042
)
5143

52-
deploydocs(repo="github.com/compintell/Mooncake.jl.git", push_preview=true)
44+
deploydocs(; repo="github.com/compintell/Mooncake.jl.git", push_preview=true)

ext/MooncakeAllocCheckExt.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ module MooncakeAllocCheckExt
33
using AllocCheck, Mooncake
44
import Mooncake.TestUtils: check_allocs, Shim
55

6-
@check_allocs check_allocs(::Shim, f::F, x::Tuple{Vararg{Any, N}}) where {F, N} = f(x...)
6+
@check_allocs check_allocs(::Shim, f::F, x::Tuple{Vararg{Any,N}}) where {F,N} = f(x...)
77

88
end

ext/MooncakeCUDAExt.jl

+2-4
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ set_to_zero!!(x::CuArray{<:IEEEFloat}) = x .= 0
3838
_add_to_primal(x::P, y::P, ::Bool) where {P<:CuArray{<:IEEEFloat}} = x + y
3939
_diff(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x - y
4040
_dot(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = Float64(dot(x, y))
41-
_scale(x::Float64, y::P) where {T<:IEEEFloat, P<:CuArray{T}} = T(x) * y
41+
_scale(x::Float64, y::P) where {T<:IEEEFloat,P<:CuArray{T}} = T(x) * y
4242
function populate_address_map!(m::AddressMap, p::CuArray, t::CuArray)
4343
k = pointer_from_objref(p)
4444
v = pointer_from_objref(t)
@@ -55,9 +55,7 @@ end
5555

5656
# Basic rules for operating on CuArrays.
5757

58-
@is_primitive(
59-
MinimalCtx, Tuple{Type{<:CuArray}, UndefInitializer, Vararg{Int, N}} where {N},
60-
)
58+
@is_primitive(MinimalCtx, Tuple{Type{<:CuArray},UndefInitializer,Vararg{Int,N}} where {N},)
6159
function rrule!!(
6260
p::CoDual{Type{P}}, init::CoDual{UndefInitializer}, dims::CoDual{Int}...
6361
) where {P<:CuArray{<:Base.IEEEFloat}}

ext/MooncakeDynamicPPLExt.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ using DynamicPPL: DynamicPPL, istrans
44
using Mooncake: Mooncake
55

66
# This is purely an optimisation.
7-
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans), Vararg}
7+
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}
88

99
end # module

0 commit comments

Comments
 (0)