Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a isa_wrapped_array function #460

Open
avik-pal opened this issue Dec 15, 2024 · 8 comments
Open

Introduce a isa_wrapped_array function #460

avik-pal opened this issue Dec 15, 2024 · 8 comments

Comments

@avik-pal
Copy link
Contributor

Most details are in EnzymeAD/Reactant.jl#369 (comment). I will copy over the important parts.

We introduce a function isa_wrapped_array that downstream packages can use to mark that their array type wraps another array. Using a union type from Adapt doesn't solve this problem, because that fundamentally doesn't extend to new array types.

With this function, we can override functions inside our custom interpreter. Consider this simple example of extending LinearAlgebra.diag

Base.Experimental.@overlay REACTANT_METHOD_TABLE function LinearAlgebra.diag(
    x::AbstractArray{T,2}, k::Integer=0
) where {T}
    if isa_wrapped_array(x) && ancestor(x) isa TracedRArray
        y = materialize_traced_array(x) # convert it to a known type
        return diag(y, k)
    else
        # invoke diag(x) on NativeInterpreter
    end
end
@rafaqz
Copy link
Contributor

rafaqz commented Dec 15, 2024

How does materialize_traced_array work? is the idea to get the wrapped TracedArray out of x?

Can that work through multiple layers of wrappers?

Ok think I get it. ancestor will recurse down through wrapped arrays to find the first non-wrapper array?

(and is_wrapped_array must mean wraps without changing values or size?)

@avik-pal
Copy link
Contributor Author

Ok think I get it. ancestor will recurse down through wrapped arrays to find the first non-wrapper array?

Correct. I am still debating if it should be ancestor or ancestors. For example, Tridiagonal stores 3 arrays underneath to it returns a 3-tuple. For a Diagonal or SubArray we return a 1-tuple.

and is_wrapped_array must mean wraps without changing values or size?

Broader than that. Any array type that wraps over "primitive" array types (examples being Array, JLArray, CuArray, etc.) would count. For example:

julia> dl = ones(3); d = ones(4); du = ones(3);

julia> X = Tridiagonal(dl, d, du)
4×4 Tridiagonal{Float64, Vector{Float64}}:
 1.0  1.0        
 1.0  1.0  1.0    
     1.0  1.0  1.0
         1.0  1.0

julia> X .*= 2
4×4 Tridiagonal{Float64, Vector{Float64}}:
 2.0  2.0        
 2.0  2.0  2.0    
     2.0  2.0  2.0
         2.0  2.0

julia> dl
3-element Vector{Float64}:
 2.0
 2.0
 2.0

Tridiagonal would count as a wrapper even though it's size is not same as the size of the underlying buffers

@rafaqz
Copy link
Contributor

rafaqz commented Dec 15, 2024

Doesn't that risk losing information from the wrapper? E.g. would SubArray be a wrapper?

@avik-pal
Copy link
Contributor Author

Doesn't that risk losing information from the wrapper? E.g. would SubArray be a wrapper?

Do you mean in the pseudo-code implementation of diag I shared?

@rafaqz
Copy link
Contributor

rafaqz commented Dec 15, 2024

I meant this kind of method generally, but yes, I assume getting the ancestor through a SubArray would be broken in your example.

@ChrisRackauckas
Copy link
Member

Makes sense to me.

@avik-pal
Copy link
Contributor Author

I meant this kind of method generally, but yes, I assume getting the ancestor through a SubArray would be broken in your example.

Yes it is lossy and that is intentional. Let me summarize the functions first and then I can clarify why it being lossy is fine.

Functions Introduced

  1. isa_wrapper_type is being used to extend what a wrapper type is by external packages. For most cases this is essentially parent(x) !== x, but this doesn't hold for Tridiagonal, Diagonal, etc. (as demonstrated above)
  2. ancestors is a function giving us the leaves of a wrapper type. An example:
struct MyVectorWrapper{T} <: AbstractVector{T}
	v::Vector{T}
end

Base.parent(v::MyVectorWrapper) = v.v

z = Tridiagonal(MyVectorWrapper(...), ....)

Calling ancestors(z) should return (z.dl.v, z.d.v, z.du.v). Why is this useful?

  • We can check if any of them are TracedRArray and then use that information for a specialized dispatch.
  • This is also useful for writing specialized CUDA dispatches (and in-general GPU backend determination). Example: https://github.com/LuxDL/Lux.jl/blob/main/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl#L35 uses AnyCuArray but if the user has a Hermitian(CuArray(rand(4, 4)) then this returns an incorrect result (CPUDevice()) since AnyCuArray is a Union type.

Why is lossy information still useful?

Consider any array type that doesn't support fast_scalar_indexing. If we don't detect the ancestors correctly it will fallback to Base implementations which will often simply use a loop over the values.

If we detect that this is a special ancestor type (say TracedRArray), we do a fallback implementation where we materialize it into a dense array. While this is not going to be optimal for performance (in general), it is almost certainly not going to error. Note that in cases where a loop doesn't error, materializing a dense array will still be more performant than falling back to a loop.

Makes sense to me.

I will open a draft PR on this. It might help us give correct results for https://github.com/JuliaArrays/ArrayInterface.jl/blob/master/ext/ArrayInterfaceGPUArraysCoreExt.jl#L8 for cases not covered in the union type

@rafaqz
Copy link
Contributor

rafaqz commented Dec 16, 2024

It could also do ancestors(TracedArray, A) to be more general. That would allow getting closer ancestors like Diagonal with ancestors(Diagonal, A) rather than the vector it wraps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants