diff --git a/src/macro.jl b/src/macro.jl index 62f1ffe81..17ea854a6 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -157,7 +157,125 @@ end # https://github.com/tkf/ThreadsX.jl/pull/106. But this should be # done automatically in Transducers.jl. +""" + localize(body::Expr) -> body′::Expr + +Add `local`s to make non-accumulators local in the loop `body`. + +It adds `local v` to the inner most scope-creating lexical block that contain +the outer most lexical block that has an assignment to `v`. +""" +localize(body::Expr) = localize!(body) + +struct VariableEnv + current::Vector{Symbol} + outer::Union{VariableEnv,Nothing} +end + +function Base.in(v::Symbol, env::VariableEnv) + v in env.current && return true + outer = env.outer + return outer !== nothing && v in outer +end + +function Base.push!(env::VariableEnv, v::Symbol) + v in env || push!(env.current, v) + return env +end + +Base.append!(env::VariableEnv, variables) = foldl(push!, variables; init = env) + +localize!(@nospecialize(body), ::VariableEnv) = body +function localize!(body::Expr, env::VariableEnv = VariableEnv(Symbol[], nothing)) + add_new_assignments!(env, body) + body = localize_nested(body, env) + if isempty(env.current) + return body + else + return Expr(:block, Expr(:local, env.current...), body) + end +end + +# Not that this doesn't handle `local lhs = rhs` but it's OK since it doesn't +# matter exactly which scope the variable comes from. +function add_new_assignments!(env::VariableEnv, ex::Expr) + @match ex begin + Expr(:meta, _...) => nothing + Expr(:loopinfo, _...) => nothing + + Expr(:function, Expr(:call, f::Symbol, _...), _...) => push!(env, f) + Expr(:(=), lhs, rhs) => begin + @match lhs begin + Expr(:call, f::Symbol, _...) => push!(env, f) + # TODO: handle where + _ => begin + if rhs isa Expr + add_new_assignments!(env, rhs) + end + append!(env, vars_in(lhs)) + end + end + end + + # Scope-creating + Expr(:let, _...) => nothing + Expr(:function, _...) => nothing + Expr(:->, _...) => nothing + + Expr(_, args...) => begin + for x in args + if x isa Expr + add_new_assignments!(env, x) + end + end + end + end +end + +localize_nested(@nospecialize(body), ::VariableEnv) = body +function localize_nested(body::Expr, env::VariableEnv) + @match body begin + Expr(:let, let_bindings_, let_body) => begin + let_bindings = @match let_bindings_ begin + Expr(:block, args...) => collect(args) + b => [b] + end + let_vars = mapfoldl(vars_in, append!, let_bindings; init = Symbol[]) + localize!(let_body, VariableEnv(let_vars, env)) + end + Expr(:(=), lhs, rhs) => begin + if isexpr(lhs, :call) + Expr(:(=), lhs, localize!(rhs, VariableEnv(Symbol[], env))) + else + Expr(:(=), lhs, localize_nested(rhs, env)) + end + end + Expr(:function, call, def) => Expr( + :function, + localize_nested(call, env), + localize!(def, VariableEnv(Symbol[], env)), + ) + Expr(:->, lhs, rhs) => + Expr(:->, localize_nested(lhs, env), localize!(rhs, VariableEnv(Symbol[], env))) + + Expr(:meta, _...) => body + Expr(:loopinfo, _...) => body + + Expr(head, args...) => begin + args = mapfoldl(push!, args; init = []) do x + if x isa Expr + localize_nested(x, env) + else + x + end + end + Expr(head, args...) + end + end +end + function transform_loop_body(body, state_vars) + # body = localize(body) # TODO: enable this for sequential case as well external_labels::Vector{Symbol} = setdiff(gotos_in(body), labels_in(body)) # state_vars = extract_state_vars(body) pack_state = :(($(state_vars...),)) @@ -236,7 +354,7 @@ end vars_in(x::Symbol) = [x] function vars_in(ex) @match ex begin - Expr(:tuple, vars...) => vars + Expr(:tuple, vars...) => mapfoldl(vars_in, append!, vars; init = Symbol[]) _ => Symbol[] end end diff --git a/src/reduce.jl b/src/reduce.jl index 807a7efe0..62f2859fa 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -763,7 +763,7 @@ function as_parallel_loop(ctx::MacroContext, rf_arg, coll, body0::Expr, simd, ex end check_invariance() - body2, info = transform_loop_body(body1, accs_symbols) + body2, info = transform_loop_body(localize(body1), accs_symbols) @gensym oninit_function reducing_function combine_function result context_function if ctx.module_ === Main