Skip to content

Conversation

@yebai
Copy link
Member

@yebai yebai commented Jun 26, 2025

Fix #597 (again).

@github-actions
Copy link
Contributor

Mooncake.jl documentation for PR #620 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR620/

@yebai
Copy link
Member Author

yebai commented Jun 26, 2025

I don't understand the Aqua complaint about unbound type parameters. @MilesCranmer @penelopeysm any ideas?

Unbound type parameters detected:
[1] tangent_type(::Type{NoFData}, ::Type{Union{NoRData, T}}) where T<:Union{Float16, Float32, Float64} @ Mooncake ~/work/Mooncake.jl/Mooncake.jl/src/fwds_rvs_data.jl:864
[2] tangent_type(::Type{Union{NoFData, T}}, ::Type{NoRData}) where T<:(Array{<:Union{Float16, Float32, Float64}}) @ Mooncake ~/work/Mooncake.jl/Mooncake.jl/src/fwds_rvs_data.jl:867
Unbound type parameters: Test Failed at /home/runner/.julia/packages/Aqua/1UuaV/src/unbound_args.jl:37
  Expression: isempty(unbounds)
   Evaluated: isempty(Method[tangent_type(::Type{NoFData}, ::Type{Union{NoRData, T}}) where T<:Union{Float16, Float32, Float64} @ Mooncake ~/work/Mooncake.jl/Mooncake.jl/src/fwds_rvs_data.jl:864, tangent_type(::Type{Union{NoFData, T}}, ::Type{NoRData}) where T<:(Array{<:Union{Float16, Float32, Float64}}) @ Mooncake ~/work/Mooncake.jl/Mooncake.jl/src/fwds_rvs_data.jl:867])

@yebai yebai requested a review from MilesCranmer June 26, 2025 12:02
@MilesCranmer
Copy link
Collaborator

I wonder if that is an Aqua.jl bug? Usually that error would be for things like

f(x::Union{Float16,T}) where {T} = T

which is a real issue. But when inside Type{Union{Float16,T}} it's different, I suppose if T could be Float16, that would be an issue, but here T can never be NoRData so it seems fine to me.

@github-actions
Copy link
Contributor

github-actions bot commented Jun 26, 2025

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬──────────┬─────────┬─────────────┬─────────┐
│                      Label │   Primal │ Mooncake │  Zygote │ ReverseDiff │  Enzyme │
│                     String │   String │   String │  String │      String │  String │
├────────────────────────────┼──────────┼──────────┼─────────┼─────────────┼─────────┤
│                   sum_1000 │ 100.0 ns │      1.8 │     1.1 │        5.61 │    8.21 │
│                  _sum_1000 │ 941.0 ns │      6.7 │  1460.0 │        33.8 │    1.09 │
│               sum_sin_1000 │  6.55 μs │     2.24 │    1.69 │        10.5 │    2.17 │
│              _sum_sin_1000 │  5.45 μs │     2.56 │   265.0 │        12.7 │    2.41 │
│                   kron_sum │ 288.0 μs │     45.7 │    5.33 │       236.0 │    13.3 │
│              kron_view_sum │ 336.0 μs │     40.9 │    10.8 │       210.0 │     6.4 │
│      naive_map_sin_cos_exp │  2.16 μs │     2.13 │ missing │        7.13 │    2.31 │
│            map_sin_cos_exp │  2.16 μs │     2.39 │    1.45 │        5.98 │    2.86 │
│      broadcast_sin_cos_exp │  2.27 μs │     2.28 │    2.37 │        1.46 │    2.27 │
│                 simple_mlp │ 418.0 μs │     4.68 │    1.61 │        6.84 │    3.23 │
│                     gp_lml │ 243.0 μs │     8.91 │    4.01 │     missing │    5.75 │
│ turing_broadcast_benchmark │  1.97 ms │     3.49 │ missing │        24.5 │ missing │
│         large_single_block │ 380.0 ns │     4.51 │  4050.0 │        30.9 │    2.24 │
└────────────────────────────┴──────────┴──────────┴─────────┴─────────────┴─────────┘

@penelopeysm
Copy link
Collaborator

penelopeysm commented Jun 26, 2025

module Foo
struct Wat end
function f(::Type{Union{Wat,T}}) where {T <: Base.IEEEFloat}
    return T
end
end

using Aqua
Aqua.test_unbound_args(Foo)
# ERROR: ...

Foo.f(Foo.Wat)
# Union {}

Foo.Wat isa Type{Union{Foo.Wat,T}} where {T <: Base.IEEEFloat}
# true

I have to confess that the last line is entirely unintuitive to me (but I suppose given this behaviour, Aqua is technically correct...?). (Edit: And then:)

Float64 isa Type{Union{Foo.Wat,T}} where {T <: Base.IEEEFloat}
# false

@MilesCranmer
Copy link
Collaborator

Ugh. Maybe it's special handling of Type{Union{A,B}} gets converted into Type{<:Union{A,B}} which is similar to Vararg...

@MilesCranmer
Copy link
Collaborator

julia> Float64 isa Type{Union{Foo.Wat,T}} where {T <: Base.IEEEFloat}
false

julia> Foo.Wat isa Type{Union{Foo.Wat,T}} where {T <: Base.IEEEFloat}
true

Very weird indeed

@MilesCranmer
Copy link
Collaborator

Even weirder!

julia> Union{Foo.Wat,Float32} isa Type{Union{Foo.Wat,T}} where {T <: Base.IEEEFloat}
true

Ok @yebai I think the best solution is to just manually declare for each element of Base.IEEEFloat rather than use the union.

@yebai
Copy link
Member Author

yebai commented Jun 26, 2025

It appears to be an edge behaviour of Julia that we don't understand...

julia> Float64 isa Type{Union{Foo.Wat,T}} where {T <: Base.IEEEFloat}
false

julia> Foo.Wat isa Type{Union{Foo.Wat,T}} where {T <: Base.IEEEFloat}
true

julia> Union{Float16} isa Type{Union{Foo.Wat,T}} where {T <: Base.IEEEFloat}
false

julia> Union{Float16, Float32} isa Type{Union{Foo.Wat,T}} where {T <: Base.IEEEFloat}
false

julia> Union{Float16, Float32, Float64} isa Type{Union{Foo.Wat,T}} where {T <: Base.IEEEFloat}
false

julia> Union{Float16, Float32, Foo.Wat} isa Type{Union{Foo.Wat,T}} where {T <: Base.IEEEFloat}
true

@ChrisRackauckas, @ViralBShah, maybe you have additional insights?

@yebai
Copy link
Member Author

yebai commented Jun 26, 2025

@MilesCranmer, this is ready for another look.

@MilesCranmer
Copy link
Collaborator

Thanks. Got some explanation here: https://discourse.julialang.org/t/weirdness-of-type-union-a-b-where-b-superb/130234/2?u=milescranmer.

But at the moment I'm not sure any efficient way of writing this so let's just go with the current

yebai and others added 2 commits June 26, 2025 15:29
Co-authored-by: Miles Cranmer <[email protected]>
Signed-off-by: Hong Ge <[email protected]>
@MilesCranmer
Copy link
Collaborator

MilesCranmer commented Jun 26, 2025

Actually, I think it's fine as-is. Because that unbounded form is never actually matched!

julia> struct Foo end

julia> f(::Type{Union{Foo,T}}) where {T<:Base.IEEEFloat} = T
f (generic function with 1 method)

julia> f(Foo)  # If we hadn't defined another method - we WOULD see the unbounded form
Union{}

julia> f(::Type{Foo}) = Foo
f (generic function with 2 methods)

julia> f(Foo)
Foo

julia> f(Union{Foo,Float32})
Float32

julia> f(Union{Foo})
Foo

So just need to skip that single Aqua error.

I think it is preferable to match all Array{T,N} than avoid the appearance of an unbounded (but not real, due to method dispatch) method.

@yebai
Copy link
Member Author

yebai commented Jun 26, 2025

Should be okay now -- Aqua has been annoying (though useful)!

@yebai
Copy link
Member Author

yebai commented Jun 26, 2025

@MilesCranmer, I have invited you to this organisation so you can make new releases using @JuliaRegistrator.

@ViralBShah
Copy link

@oscardssmith

Comment on lines +864 to +873
for T in [Float16, Float32, Float64]
@eval @foldable tangent_type(::Type{NoFData}, ::Type{Union{NoRData,$T}}) = Union{
NoTangent,tangent_type($T)
}
for N in 0:5 # Just go up to N=5 until general solution to https://github.com/chalk-lab/Mooncake.jl/pull/620 available
@eval @foldable tangent_type(::Type{Union{NoFData,Array{$T,$N}}}, ::Type{NoRData}) = Union{
NoTangent,tangent_type(Array{$T,$N})
}
end
end
Copy link

@nsajko nsajko Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure it would be nicer to the compiler not to eval three duplicate methods. Something like this, for example:

Suggested change
for T in [Float16, Float32, Float64]
@eval @foldable tangent_type(::Type{NoFData}, ::Type{Union{NoRData,$T}}) = Union{
NoTangent,tangent_type($T)
}
for N in 0:5 # Just go up to N=5 until general solution to https://github.com/chalk-lab/Mooncake.jl/pull/620 available
@eval @foldable tangent_type(::Type{Union{NoFData,Array{$T,$N}}}, ::Type{NoRData}) = Union{
NoTangent,tangent_type(Array{$T,$N})
}
end
end
let f(@nospecialize x::Type) = Type{Union{NoRData, x}},
f16 = f(Float16),
f32 = f(Float32),
f64 = f(Float64)
global tangent_type
@foldable function tangent_type(::Type{NoFData}, u::Union{f16, f32, f64})
if u === f16
t = Float16
elseif u === f32
t = Float32
elseif u === f64
t = Float64
end
Union{NoTangent, tangent_type(t)}
end
end
for T in [Float16, Float32, Float64]
for N in 0:5 # Just go up to N=5 until general solution to https://github.com/chalk-lab/Mooncake.jl/pull/620 available
@eval @foldable tangent_type(::Type{Union{NoFData,Array{$T,$N}}}, ::Type{NoRData}) = Union{
NoTangent,tangent_type(Array{$T,$N})
}
end
end

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will apply the same deduplication to the for N in 0:5 loop, if you wish.

Copy link
Collaborator

@MilesCranmer MilesCranmer Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could just use the unbounded form right? It's only an issue if there is ambiguity. But since Type{NoFData}, Type{NoRData} are defined, it never gets matched.

julia> struct Foo end

julia> foo(::Type{Union{Foo,B}}) where {B<:AbstractFloat} = B
foo (generic function with 1 method)

julia> foo(Foo)  # Only issue here
Union{}

julia> foo(::Type{Foo}) = Foo
foo (generic function with 2 methods)

julia> foo(Foo)
Foo

julia> foo(Union{Foo,Float32})
Float32

which would let us match to any N for the arrays rather than only up to 5.

Comment on lines +864 to +873
for T in [Float16, Float32, Float64]
@eval @foldable tangent_type(::Type{NoFData}, ::Type{Union{NoRData,$T}}) = Union{
NoTangent,tangent_type($T)
}
for N in 0:5 # Just go up to N=5 until general solution to https://github.com/chalk-lab/Mooncake.jl/pull/620 available
@eval @foldable tangent_type(::Type{Union{NoFData,Array{$T,$N}}}, ::Type{NoRData}) = Union{
NoTangent,tangent_type(Array{$T,$N})
}
end
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for T in [Float16, Float32, Float64]
@eval @foldable tangent_type(::Type{NoFData}, ::Type{Union{NoRData,$T}}) = Union{
NoTangent,tangent_type($T)
}
for N in 0:5 # Just go up to N=5 until general solution to https://github.com/chalk-lab/Mooncake.jl/pull/620 available
@eval @foldable tangent_type(::Type{Union{NoFData,Array{$T,$N}}}, ::Type{NoRData}) = Union{
NoTangent,tangent_type(Array{$T,$N})
}
end
end
tangent_type(::Type{NoFData}, ::Type{Union{NoRData,T}}) where {T<:Base.IEEEFloat} = Union{
NoTangent,tangent_type(T)
}
tangent_type(::Type{Union{NoFData,Array{T,N}}}, ::Type{NoRData}) where {T<:Base.IEEEFloat,N} = Union{
NoTangent,tangent_type(Array{T,N})
}

since NoFData and NoRData already have a method, the unbounded method issue never actually occurs. So just need to somehow inform Aqua that this particular instance is not an issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When T = Union{}, Array{T,N} would be problematic. But since this case is never called, as you noticed, it might be okay.

@yebai
Copy link
Member Author

yebai commented Jun 26, 2025

Please feel free to take over this PR or replace it with a new one!

@yebai yebai merged commit 3b2f444 into main Jun 26, 2025
78 of 79 checks passed
@yebai yebai deleted the hg/union branch June 26, 2025 18:17
@codecov
Copy link

codecov bot commented Jun 26, 2025

Codecov Report

Attention: Patch coverage is 60.00000% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/fwds_rvs_data.jl 0.00% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Missing tangent_type for Union{Mooncake.NoRData, Float32}

6 participants