Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Range operator fails with TracedRNumbers #837

Open
jumerckx opened this issue Mar 3, 2025 · 1 comment · May be fixed by #839
Open

Range operator fails with TracedRNumbers #837

jumerckx opened this issue Mar 3, 2025 · 1 comment · May be fixed by #839

Comments

@jumerckx
Copy link
Collaborator

jumerckx commented Mar 3, 2025

I believe this has something to do with the fact that there is no specialized implementation for TracedRNumbers and the generic implementation does comparisons between a and b.
(cc @mofeing)

using Reactant

function f(a, b)
   it = collect(a:b)
   return it
end

a, b = ConcreteRNumber(1), ConcreteRNumber(3)

@code_hlo f(a, b)
ERROR: TypeError: non-boolean (Reactant.TracedRNumber{Bool}) used in boolean context
Stacktrace:
  [1] Colon
    @ ./range.jl:7 [inlined]
  [2] (::Colon)(none::Reactant.TracedRNumber{Int64}, none::Reactant.TracedRNumber{Int64})
    @ Reactant ./<missing>:0
  [3] NamedTuple
    @ ./boot.jl:727 [inlined]
  [4] >=
    @ ~/.julia/packages/Reactant/Memqo/src/TracedRNumber.jl:166 [inlined]
  [5] Colon
    @ ./range.jl:7 [inlined]
  [6] call_with_reactant(::Reactant.MustThrowError, ::Colon, ::Reactant.TracedRNumber{Int64}, ::Reactant.TracedRNumber{Int64})
    @ Reactant ~/.julia/packages/Reactant/Memqo/src/utils.jl:0
  [7] f
    @ ./REPL[18]:2 [inlined]
  [8] f(none::Reactant.TracedRNumber{Int64}, none::Reactant.TracedRNumber{Int64})
    @ Reactant ./<missing>:0
  [9] f
    @ ./REPL[18]:2 [inlined]
 [10] call_with_reactant(::typeof(f), ::Reactant.TracedRNumber{Int64}, ::Reactant.TracedRNumber{Int64})
    @ Reactant ~/.julia/packages/Reactant/Memqo/src/utils.jl:0
 [11] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/Memqo/src/TracedUtils.jl:256
 [12] make_mlir_fn
    @ ~/.julia/packages/Reactant/Memqo/src/TracedUtils.jl:153 [inlined]
 [13] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, callcache::Dict{…}, sdycache::IdDict{…}; optimize::Bool, no_nan::Bool, backend::String, fn_kwargs::@NamedTuple{}, raise::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/Memqo/src/Compiler.jl:655
 [14] compile_mlir!
    @ ~/.julia/packages/Reactant/Memqo/src/Compiler.jl:613 [inlined]
 [15] (::Reactant.Compiler.var"#7#8"{@Kwargs{fn_kwargs::@NamedTuple{}, no_nan::Bool, raise::Bool, optimize::Bool}, typeof(f), Tuple{ConcretePJRTNumber{…}, ConcretePJRTNumber{…}}})()
    @ Reactant.Compiler ~/.julia/packages/Reactant/Memqo/src/Compiler.jl:509
 [16] context!(f::Reactant.Compiler.var"#7#8"{@Kwargs{fn_kwargs::@NamedTuple{}, no_nan::Bool, raise::Bool, optimize::Bool}, typeof(f), Tuple{ConcretePJRTNumber{…}, ConcretePJRTNumber{…}}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/Memqo/src/mlir/IR/Context.jl:76
 [17] compile_mlir(f::Function, args::Tuple{ConcretePJRTNumber{…}, ConcretePJRTNumber{…}}; client::Nothing, kwargs::@Kwargs{fn_kwargs::@NamedTuple{}, no_nan::Bool, raise::Bool, optimize::Bool})
    @ Reactant.Compiler ~/.julia/packages/Reactant/Memqo/src/Compiler.jl:506
 [18] top-level scope
    @ ~/.julia/packages/Reactant/Memqo/src/Compiler.jl:1098
Some type information was truncated. Use `show(err)` to see complete types.
@mofeing
Copy link
Collaborator

mofeing commented Mar 3, 2025

It seems like the faulting problem is this line

https://github.com/JuliaLang/julia/blob/c4eeabf2f64ddfc133b02be04f8b6557d5f722e9/base/range.jl#L5-L7

Adding a new method for Base.:(:) on TracedRNumber seems to fix it:

julia> Base.:(:)(start::TracedRNumber{T}, stop::TracedRNumber{T}) where {T} = UnitRange{TracedRNumber{T}}(start, stop)

julia> function f(a,b)
           x = zero(a)
           @trace for i in a:b
               x += i
           end
           return x
       end
f (generic function with 1 method)

julia> a = ConcreteRNumber(1)
ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1)

julia> b = ConcreteRNumber(4)
ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(4)

julia> fre = @compile f(a,b)
Reactant.Compiler.Thunk{typeof(f), Symbol("##f_reactant#244"), Tuple{ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, false}(f)

julia> fre(a,b)
ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(10)

julia> 1 + 2 + 3 + 4
10

@mofeing mofeing linked a pull request Mar 3, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants