Skip to content

[JuliaLowering] expand_forms_1 performance issue #60756

@aviatesk

Description

@aviatesk

While tracking a latency issue with JETLS.jl's analysis, I found a performance issue with JuliaLowering.expand_forms_1.
The attached is the script I'm using, which does the following:

  • step 1: It takes a script, iterates through its top-level expressions, and evaluates the minimum "top-level definitions" needed for lowering, like module and macro.
  • step 2: Then, it runs lowering passes up to st4 on the "module context" built in step 1, and measures the time taken for each phase.
    • Note that step 2 doesn't perform any "execution" with side effects.
    • This phase only measures the lowering performance for the module context created in step 1.
jlbench.jl
# Benchmark script for JuliaLowering steps
#
# Usage:
#   julia --project=. experiments/jlbench.jl <file.jl> [module_name] [-n nruns] [--eval|--no-eval]
#
# Examples:
#   julia --project=. experiments/jlbench.jl src/foo.jl
#   julia --project=. experiments/jlbench.jl src/foo.jl Base
#   julia --project=. experiments/jlbench.jl src/foo.jl JETLS
#   julia --project=. experiments/jlbench.jl src/foo.jl Main -n 20
#   julia --project=. experiments/jlbench.jl src/foo.jl --no-eval
#
# Measures cumulative time for each lowering step:
#   1. expand_forms_1 (macro expansion)
#   2. expand_forms_2 (desugaring)
#   3. resolve_scopes (scope resolution)
#   4. convert_closures (closure conversion)

using JuliaSyntax: JuliaSyntax as JS
using JuliaLowering: JuliaLowering as JL

function lookup_parent_module(name::AbstractString)
    sym = Symbol(name)
    for mod in Base.loaded_modules_array()
        if nameof(mod) === sym
            return mod
        end
    end
    error("Module not found: $name. Available modules: $(join(nameof.(Base.loaded_modules_array()), ", "))")
end

function run_benchmark(filepath::String; parent_mod::Module=Main, nruns::Int=10, eval::Bool=true)
    if !isfile(filepath)
        error("File not found: $filepath")
    end

    module_map = Dict{UnitRange{Int},Module}(1:typemax(Int)=>parent_mod)
    function get_module_context(st)
        thisrange = JS.byte_range(st)
        thisinfo = nothing
        for (rng, mod) in module_map
            if thisrange  rng
                if isnothing(thisinfo)
                    thisinfo = (rng, mod)
                else
                    thisstart, thisend = rng
                    oldstart, oldend = thisinfo[1]
                    if oldstart < thisstart && thisend < oldend
                        thisinfo = (rng, mod)
                    end
                end
            end
        end
        return last(@something thisinfo return parent_mod)
    end

    function iterate_toplevel_tree(callback, st0_top::JS.SyntaxTree; define_module::Bool=false)
        sl = JS.SyntaxList(st0_top)
        push!(sl, st0_top)
        while !isempty(sl)
            st0 = pop!(sl)
            if JS.kind(st0) === JS.K"toplevel"
                for i = JS.numchildren(st0):-1:1 # reversed since we use `pop!`
                    push!(sl, st0[i])
                end
            elseif JS.kind(st0) === JS.K"module"
                if define_module
                    thismod = get_module_context(st0)
                    modname = try st0[1].name_val catch; "DummyName" end
                    mod = Core.eval(thismod, :(module $(Symbol(modname)) end))
                    module_map[JS.byte_range(st0)] = mod
                end
                stblk = st0[end]
                JS.kind(stblk) === JS.K"block" || continue
                for i = JS.numchildren(stblk):-1:1 # reversed since we use `pop!`
                    push!(sl, stblk[i])
                end
            elseif JS.kind(st0) === JS.K"doc"
                # skip docstring expressions for now
                for i = JS.numchildren(st0):-1:1 # reversed since we use `pop!`
                    if JS.kind(st0[i]) !== JS.K"string"
                        push!(sl, st0[i])
                    end
                end
            else # st0 is lowerable tree
                callback(st0)
            end
        end
    end

    src = read(filepath, String)
    st0_top = JS.parseall(JS.SyntaxTree, src; filename=filepath)

    # Step1: Evaluate minimum top-level definitions
    n = 0
    iterate_toplevel_tree(st0_top; define_module=true) do st0
        n += 1
        k = JS.kind(st0)
        if k in JS.KSet"using import export function macro struct abstract primitive"
            if eval || JS.KSet"function macro struct abstract primitive"
                thismod = get_module_context(st0)
                try
                    JL.eval(thismod, st0)
                catch err
                    @warn "Failed to evaluate $(JS.sourcetext(st0))" err
                end
            end
        end
    end

    println("File: $filepath")
    println("Module: $(nameof(parent_mod))")
    println("Number of top-level expressions: $n")
    println("Number of runs: $nruns")
    println()

    # Timing accumulators (sum over all runs)
    t_expand1 = t_expand2 = t_scopes = t_occurrences = 0.0
    
    # Step2: Run benchmark
    for run = 1:nruns
        world = Base.get_world_counter()
        iterate_toplevel_tree(st0_top) do st0
            thismod = get_module_context(st0)
            t_expand1 += @elapsed begin
                ctx1, st1 = JL.expand_forms_1(thismod, st0, true, world)
            end

            t_expand2 += @elapsed begin
                ctx2, st2 = JL.expand_forms_2(ctx1, st1)
            end

            t_scopes += @elapsed begin
                ctx3, st3 = JL.resolve_scopes(ctx2, st2)
            end

            t_closures += @elapsed begin
                ctx4, st4 = JL.convert_closures(ctx3, st3)
            end
        end
    end

    t_expand1 /= nruns
    t_expand2 /= nruns
    t_scopes /= nruns
    t_closures /= nruns
    total = t_expand1 + t_expand2 + t_scopes + t_closures

    println("Average cumulative times (over $nruns runs):")
    println("  expand_forms_1:              $(round(t_expand1 * 1000, digits=3)) ms")
    println("  expand_forms_2:              $(round(t_expand2 * 1000, digits=3)) ms")
    println("  resolve_scopes:              $(round(t_scopes * 1000, digits=3)) ms")
    println("  convert_closures:            $(round(t_closures * 1000, digits=3)) ms")
    println("  Total:                       $(round(total * 1000, digits=3)) ms")
    println()
    println("Percentage breakdown:")
    println("  expand_forms_1:              $(round(t_expand1 / total * 100, digits=1))%")
    println("  expand_forms_2:              $(round(t_expand2 / total * 100, digits=1))%")
    println("  resolve_scopes:              $(round(t_scopes / total * 100, digits=1))%")
    println("  convert_closures:            $(round(t_closures / total * 100, digits=1))%")

    return (; t_expand1, t_expand2, t_scopes, t_closures, total)
end

function (@main)(args::Vector{String})
    positional = String[]
    i = 1
    nruns = 10
    do_eval = true
    while i <= length(args)
        arg = args[i]
        if arg == "-n"
            i += 1
            if i > length(args)
                println("Error: -n requires a value")
                exit(1)
            end
            nruns = parse(Int, args[i])
        elseif arg == "--eval"
            do_eval = true
        elseif arg == "--no-eval"
            do_eval = false
        else
            push!(positional, arg)
        end
        i += 1
    end
    if isempty(positional)
        println("Usage: julia --project=. experiments/jlbench.jl <file.jl> [module_name] [-n nruns] [--eval|--no-eval]")
        return Cint(1)
    end
    filepath = positional[1]
    parent_mod = length(positional) >= 2 ? lookup_parent_module(positional[2]) : Main
    run_benchmark(filepath; parent_mod, nruns, eval=do_eval)
    return Cint(0)
end

I ran some benchmarks using this script.
For example, here are some results:

➜ ./julia --starup-file=no ./jlbench.jl ./jlbench.jl
File: experiments/jlbench.jl
Module: Main
Number of top-level expressions: 5
Number of runs: 10

Average cumulative times (over 10 runs):
  expand_forms_1:              15.072 ms
  expand_forms_2:              3.448 ms
  resolve_scopes:              7.206 ms
  convert_closures:            1.961 ms
  Total:                       27.686 ms

Percentage breakdown:
  expand_forms_1:              54.4%
  expand_forms_2:              12.5%
  resolve_scopes:              26.0%
  convert_closures:            7.1%
➜ ./julia --startup-file=no jlbench.jl ./base/char.jl Base --no-eval # --no-eval is necessary to avoid invalidations from execution of the Base functions definitions
File: ./base/char.jl
Module: Base
Number of top-level expressions: 71
Number of runs: 10

Average cumulative times (over 10 runs):
  expand_forms_1:              53.138 ms
  expand_forms_2:              15.496 ms
  resolve_scopes:              19.695 ms
  convert_closures:            7.512 ms
  Total:                       95.841 ms

Percentage breakdown:
  expand_forms_1:              55.4%
  expand_forms_2:              16.2%
  resolve_scopes:              20.5%
  convert_closures:            7.8%

The above results are just fine, but the problem is when we benchmark test files. For example, we can measure test/char.jl with this script (though we need to add using Test), but the performance becomes very poor:

➜ ./julia --startup-file=no jlbench.jl ./test/char.jl
File: ./test/char.jl
Module: Main
Number of top-level expressions: 21
Number of runs: 10

Average cumulative times (over 10 runs):
  expand_forms_1:              2390.805 ms
  expand_forms_2:              28.726 ms
  resolve_scopes:              87.579 ms
  convert_closures:            20.907 ms
  Total:                       2528.017 ms

Percentage breakdown:
  expand_forms_1:              94.6%
  expand_forms_2:              1.1%
  resolve_scopes:              3.5%
  convert_closures:            0.8%

This trend is really noticeable with big test files. It gets even tougher on time, especially for files that contain a lot of @test or @testset:

➜ ./julia --startup-file=no jlbench.jl ./test/atomics.jl
File: ./test/atomics.jl
Module: Main
Number of top-level expressions: 196
Number of runs: 10

Average cumulative times (over 10 runs):
  expand_forms_1:              15730.569 ms
  expand_forms_2:              108.182 ms
  resolve_scopes:              148.664 ms
  convert_closures:            34.109 ms
  Total:                       16021.523 ms

Percentage breakdown:
  expand_forms_1:              98.2%
  expand_forms_2:              0.7%
  resolve_scopes:              0.9%
  convert_closures:            0.2%

The problem, as the percentage clearly shows, is clearly in expand_forms_1. It seems like there's still room for optimization.
Just this lowering taking over 1 second is a pretty critical problem for developer tools that use JuliaLowering as their core analysis, like JETLS. For example, it's obvious that this will make basic LSP features like diagnostics/completions not feeling responsible there (meaning you'd have to wait 15s to get completions when using a test file).

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions