-
-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Description
While tracking a latency issue with JETLS.jl's analysis, I found a performance issue with JuliaLowering.expand_forms_1.
The attached is the script I'm using, which does the following:
- step 1: It takes a script, iterates through its top-level expressions, and evaluates the minimum "top-level definitions" needed for lowering, like
moduleandmacro.- This implementation is pretty rough. It doesn't handle
include(...)and has many edge cases. - In a way, you can think of this as a rough implementation of the tree-iterator provided by Compiler frontend API / lowering iterator #60529.
- This implementation is pretty rough. It doesn't handle
- step 2: Then, it runs lowering passes up to
st4on the "module context" built in step 1, and measures the time taken for each phase.- Note that step 2 doesn't perform any "execution" with side effects.
- This phase only measures the lowering performance for the module context created in step 1.
jlbench.jl
# Benchmark script for JuliaLowering steps
#
# Usage:
# julia --project=. experiments/jlbench.jl <file.jl> [module_name] [-n nruns] [--eval|--no-eval]
#
# Examples:
# julia --project=. experiments/jlbench.jl src/foo.jl
# julia --project=. experiments/jlbench.jl src/foo.jl Base
# julia --project=. experiments/jlbench.jl src/foo.jl JETLS
# julia --project=. experiments/jlbench.jl src/foo.jl Main -n 20
# julia --project=. experiments/jlbench.jl src/foo.jl --no-eval
#
# Measures cumulative time for each lowering step:
# 1. expand_forms_1 (macro expansion)
# 2. expand_forms_2 (desugaring)
# 3. resolve_scopes (scope resolution)
# 4. convert_closures (closure conversion)
using JuliaSyntax: JuliaSyntax as JS
using JuliaLowering: JuliaLowering as JL
function lookup_parent_module(name::AbstractString)
sym = Symbol(name)
for mod in Base.loaded_modules_array()
if nameof(mod) === sym
return mod
end
end
error("Module not found: $name. Available modules: $(join(nameof.(Base.loaded_modules_array()), ", "))")
end
function run_benchmark(filepath::String; parent_mod::Module=Main, nruns::Int=10, eval::Bool=true)
if !isfile(filepath)
error("File not found: $filepath")
end
module_map = Dict{UnitRange{Int},Module}(1:typemax(Int)=>parent_mod)
function get_module_context(st)
thisrange = JS.byte_range(st)
thisinfo = nothing
for (rng, mod) in module_map
if thisrange ⊆ rng
if isnothing(thisinfo)
thisinfo = (rng, mod)
else
thisstart, thisend = rng
oldstart, oldend = thisinfo[1]
if oldstart < thisstart && thisend < oldend
thisinfo = (rng, mod)
end
end
end
end
return last(@something thisinfo return parent_mod)
end
function iterate_toplevel_tree(callback, st0_top::JS.SyntaxTree; define_module::Bool=false)
sl = JS.SyntaxList(st0_top)
push!(sl, st0_top)
while !isempty(sl)
st0 = pop!(sl)
if JS.kind(st0) === JS.K"toplevel"
for i = JS.numchildren(st0):-1:1 # reversed since we use `pop!`
push!(sl, st0[i])
end
elseif JS.kind(st0) === JS.K"module"
if define_module
thismod = get_module_context(st0)
modname = try st0[1].name_val catch; "DummyName" end
mod = Core.eval(thismod, :(module $(Symbol(modname)) end))
module_map[JS.byte_range(st0)] = mod
end
stblk = st0[end]
JS.kind(stblk) === JS.K"block" || continue
for i = JS.numchildren(stblk):-1:1 # reversed since we use `pop!`
push!(sl, stblk[i])
end
elseif JS.kind(st0) === JS.K"doc"
# skip docstring expressions for now
for i = JS.numchildren(st0):-1:1 # reversed since we use `pop!`
if JS.kind(st0[i]) !== JS.K"string"
push!(sl, st0[i])
end
end
else # st0 is lowerable tree
callback(st0)
end
end
end
src = read(filepath, String)
st0_top = JS.parseall(JS.SyntaxTree, src; filename=filepath)
# Step1: Evaluate minimum top-level definitions
n = 0
iterate_toplevel_tree(st0_top; define_module=true) do st0
n += 1
k = JS.kind(st0)
if k in JS.KSet"using import export function macro struct abstract primitive"
if eval || JS.KSet"function macro struct abstract primitive"
thismod = get_module_context(st0)
try
JL.eval(thismod, st0)
catch err
@warn "Failed to evaluate $(JS.sourcetext(st0))" err
end
end
end
end
println("File: $filepath")
println("Module: $(nameof(parent_mod))")
println("Number of top-level expressions: $n")
println("Number of runs: $nruns")
println()
# Timing accumulators (sum over all runs)
t_expand1 = t_expand2 = t_scopes = t_occurrences = 0.0
# Step2: Run benchmark
for run = 1:nruns
world = Base.get_world_counter()
iterate_toplevel_tree(st0_top) do st0
thismod = get_module_context(st0)
t_expand1 += @elapsed begin
ctx1, st1 = JL.expand_forms_1(thismod, st0, true, world)
end
t_expand2 += @elapsed begin
ctx2, st2 = JL.expand_forms_2(ctx1, st1)
end
t_scopes += @elapsed begin
ctx3, st3 = JL.resolve_scopes(ctx2, st2)
end
t_closures += @elapsed begin
ctx4, st4 = JL.convert_closures(ctx3, st3)
end
end
end
t_expand1 /= nruns
t_expand2 /= nruns
t_scopes /= nruns
t_closures /= nruns
total = t_expand1 + t_expand2 + t_scopes + t_closures
println("Average cumulative times (over $nruns runs):")
println(" expand_forms_1: $(round(t_expand1 * 1000, digits=3)) ms")
println(" expand_forms_2: $(round(t_expand2 * 1000, digits=3)) ms")
println(" resolve_scopes: $(round(t_scopes * 1000, digits=3)) ms")
println(" convert_closures: $(round(t_closures * 1000, digits=3)) ms")
println(" Total: $(round(total * 1000, digits=3)) ms")
println()
println("Percentage breakdown:")
println(" expand_forms_1: $(round(t_expand1 / total * 100, digits=1))%")
println(" expand_forms_2: $(round(t_expand2 / total * 100, digits=1))%")
println(" resolve_scopes: $(round(t_scopes / total * 100, digits=1))%")
println(" convert_closures: $(round(t_closures / total * 100, digits=1))%")
return (; t_expand1, t_expand2, t_scopes, t_closures, total)
end
function (@main)(args::Vector{String})
positional = String[]
i = 1
nruns = 10
do_eval = true
while i <= length(args)
arg = args[i]
if arg == "-n"
i += 1
if i > length(args)
println("Error: -n requires a value")
exit(1)
end
nruns = parse(Int, args[i])
elseif arg == "--eval"
do_eval = true
elseif arg == "--no-eval"
do_eval = false
else
push!(positional, arg)
end
i += 1
end
if isempty(positional)
println("Usage: julia --project=. experiments/jlbench.jl <file.jl> [module_name] [-n nruns] [--eval|--no-eval]")
return Cint(1)
end
filepath = positional[1]
parent_mod = length(positional) >= 2 ? lookup_parent_module(positional[2]) : Main
run_benchmark(filepath; parent_mod, nruns, eval=do_eval)
return Cint(0)
endI ran some benchmarks using this script.
For example, here are some results:
➜ ./julia --starup-file=no ./jlbench.jl ./jlbench.jl
File: experiments/jlbench.jl
Module: Main
Number of top-level expressions: 5
Number of runs: 10
Average cumulative times (over 10 runs):
expand_forms_1: 15.072 ms
expand_forms_2: 3.448 ms
resolve_scopes: 7.206 ms
convert_closures: 1.961 ms
Total: 27.686 ms
Percentage breakdown:
expand_forms_1: 54.4%
expand_forms_2: 12.5%
resolve_scopes: 26.0%
convert_closures: 7.1%
➜ ./julia --startup-file=no jlbench.jl ./base/char.jl Base --no-eval # --no-eval is necessary to avoid invalidations from execution of the Base functions definitions
File: ./base/char.jl
Module: Base
Number of top-level expressions: 71
Number of runs: 10
Average cumulative times (over 10 runs):
expand_forms_1: 53.138 ms
expand_forms_2: 15.496 ms
resolve_scopes: 19.695 ms
convert_closures: 7.512 ms
Total: 95.841 ms
Percentage breakdown:
expand_forms_1: 55.4%
expand_forms_2: 16.2%
resolve_scopes: 20.5%
convert_closures: 7.8%
The above results are just fine, but the problem is when we benchmark test files. For example, we can measure test/char.jl with this script (though we need to add using Test), but the performance becomes very poor:
➜ ./julia --startup-file=no jlbench.jl ./test/char.jl
File: ./test/char.jl
Module: Main
Number of top-level expressions: 21
Number of runs: 10
Average cumulative times (over 10 runs):
expand_forms_1: 2390.805 ms
expand_forms_2: 28.726 ms
resolve_scopes: 87.579 ms
convert_closures: 20.907 ms
Total: 2528.017 ms
Percentage breakdown:
expand_forms_1: 94.6%
expand_forms_2: 1.1%
resolve_scopes: 3.5%
convert_closures: 0.8%
This trend is really noticeable with big test files. It gets even tougher on time, especially for files that contain a lot of @test or @testset:
➜ ./julia --startup-file=no jlbench.jl ./test/atomics.jl
File: ./test/atomics.jl
Module: Main
Number of top-level expressions: 196
Number of runs: 10
Average cumulative times (over 10 runs):
expand_forms_1: 15730.569 ms
expand_forms_2: 108.182 ms
resolve_scopes: 148.664 ms
convert_closures: 34.109 ms
Total: 16021.523 ms
Percentage breakdown:
expand_forms_1: 98.2%
expand_forms_2: 0.7%
resolve_scopes: 0.9%
convert_closures: 0.2%
The problem, as the percentage clearly shows, is clearly in expand_forms_1. It seems like there's still room for optimization.
Just this lowering taking over 1 second is a pretty critical problem for developer tools that use JuliaLowering as their core analysis, like JETLS. For example, it's obvious that this will make basic LSP features like diagnostics/completions not feeling responsible there (meaning you'd have to wait 15s to get completions when using a test file).