Skip to content

[BlockSparseArrays] Fix adjoint and transpose #1470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.10"
version = "0.3.11"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using ArrayLayouts: LayoutArray
using BlockArrays: blockisequal
using LinearAlgebra: Adjoint, Transpose
using ..SparseArrayInterface:
SparseArrayInterface,
SparseArrayStyle,
Expand Down Expand Up @@ -73,6 +74,20 @@ function Base.copyto!(a_dest::LayoutArray, a_src::BlockSparseArrayLike)
return a_dest
end

function Base.copyto!(
a_dest::AbstractMatrix, a_src::Transpose{T,<:AbstractBlockSparseMatrix{T}}
) where {T}
sparse_copyto!(a_dest, a_src)
return a_dest
end

function Base.copyto!(
a_dest::AbstractMatrix, a_src::Adjoint{T,<:AbstractBlockSparseMatrix{T}}
) where {T}
sparse_copyto!(a_dest, a_src)
return a_dest
end

function Base.permutedims!(a_dest, a_src::BlockSparseArrayLike, perm)
sparse_permutedims!(a_dest, a_src, perm)
return a_dest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ using SplitApplyCombine: groupcount

using Adapt: Adapt, WrappedArray

const WrappedAbstractBlockSparseArray{T,N,A} = WrappedArray{
T,N,<:AbstractBlockSparseArray,<:AbstractBlockSparseArray{T,N}
const WrappedAbstractBlockSparseArray{T,N} = WrappedArray{
T,N,AbstractBlockSparseArray,AbstractBlockSparseArray{T,N}
}

# TODO: Rename `AnyBlockSparseArray`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using BlockArrays:
blocks,
blocklengths,
findblockindex
using LinearAlgebra: Adjoint, Transpose
using ..SparseArrayInterface: perm, iperm, nstored
## using MappedArrays: mappedarray

Expand Down Expand Up @@ -86,35 +87,96 @@ end

# BlockArrays

using ..SparseArrayInterface: SparseArrayInterface, AbstractSparseArray
using ..SparseArrayInterface:
SparseArrayInterface, AbstractSparseArray, AbstractSparseMatrix

# Represents the array of arrays of a `SubArray`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`.
_perm(::PermutedDimsArray{<:Any,<:Any,P}) where {P} = P
_getindices(t::Tuple, indices) = map(i -> t[i], indices)
_getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), indices))

# Represents the array of arrays of a `PermutedDimsArray`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `PermutedDimsArray`.
struct SparsePermutedDimsArrayBlocks{T,N,Array<:PermutedDimsArray{T,N}} <:
AbstractSparseArray{T,N}
array::Array
end
function blocksparse_blocks(a::PermutedDimsArray)
return SparsePermutedDimsArrayBlocks(a)
end
_perm(::PermutedDimsArray{<:Any,<:Any,P}) where {P} = P
_getindices(t::Tuple, indices) = map(i -> t[i], indices)
_getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), indices))
function SparseArrayInterface.stored_indices(a::SparsePermutedDimsArrayBlocks)
return map(I -> _getindices(I, _perm(a.array)), stored_indices(blocks(parent(a.array))))
end
function Base.size(a::SparsePermutedDimsArrayBlocks)
return _getindices(size(blocks(parent(a.array))), _perm(a.array))
end
function Base.getindex(a::SparsePermutedDimsArrayBlocks, index::Vararg{Int})
function Base.getindex(
a::SparsePermutedDimsArrayBlocks{<:Any,N}, index::Vararg{Int,N}
) where {N}
return PermutedDimsArray(
blocks(parent(a.array))[_getindices(index, _perm(a.array))...], _perm(a.array)
)
end
function SparseArrayInterface.stored_indices(a::SparsePermutedDimsArrayBlocks)
return map(I -> _getindices(I, _perm(a.array)), stored_indices(blocks(parent(a.array))))
end
# TODO: Either make this the generic interface or define
# `SparseArrayInterface.sparse_storage`, which is used
# to defined this.
SparseArrayInterface.nstored(a::SparsePermutedDimsArrayBlocks) = length(stored_indices(a))
function SparseArrayInterface.sparse_storage(a::SparsePermutedDimsArrayBlocks)
return error("Not implemented")
end

reverse_index(index) = reverse(index)
reverse_index(index::CartesianIndex) = CartesianIndex(reverse(Tuple(index)))

# Represents the array of arrays of a `Transpose`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Transpose`.
struct SparseTransposeBlocks{T,Array<:Transpose{T}} <: AbstractSparseMatrix{T}
array::Array
end
function blocksparse_blocks(a::Transpose)
return SparseTransposeBlocks(a)
end
function Base.size(a::SparseTransposeBlocks)
return reverse(size(blocks(parent(a.array))))
end
function Base.getindex(a::SparseTransposeBlocks, index::Vararg{Int,2})
return transpose(blocks(parent(a.array))[reverse(index)...])
end
function SparseArrayInterface.stored_indices(a::SparseTransposeBlocks)
return map(reverse_index, stored_indices(blocks(parent(a.array))))
end
# TODO: Either make this the generic interface or define
# `SparseArrayInterface.sparse_storage`, which is used
# to defined this.
SparseArrayInterface.nstored(a::SparseTransposeBlocks) = length(stored_indices(a))
function SparseArrayInterface.sparse_storage(a::SparseTransposeBlocks)
return error("Not implemented")
end

# Represents the array of arrays of a `Adjoint`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Adjoint`.
struct SparseAdjointBlocks{T,Array<:Adjoint{T}} <: AbstractSparseMatrix{T}
array::Array
end
function blocksparse_blocks(a::Adjoint)
return SparseAdjointBlocks(a)
end
function Base.size(a::SparseAdjointBlocks)
return reverse(size(blocks(parent(a.array))))
end
function Base.getindex(a::SparseAdjointBlocks, index::Vararg{Int,2})
return blocks(parent(a.array))[reverse(index)...]'
end
function SparseArrayInterface.stored_indices(a::SparseAdjointBlocks)
return map(reverse_index, stored_indices(blocks(parent(a.array))))
end
# TODO: Either make this the generic interface or define
# `SparseArrayInterface.sparse_storage`, which is used
# to defined this.
SparseArrayInterface.nstored(a::SparseAdjointBlocks) = length(stored_indices(a))
function SparseArrayInterface.sparse_storage(a::SparseAdjointBlocks)
return error("Not implemented")
end

# TODO: Move to `BlockArraysExtensions`.
# This takes a range of indices `indices` of array `a`
# and maps it to the range of indices within block `block`.
Expand Down Expand Up @@ -167,9 +229,6 @@ end
function Base.size(a::SparseSubArrayBlocks)
return length.(axes(a))
end
function SparseArrayInterface.stored_indices(a::SparseSubArrayBlocks)
return stored_indices(view(blocks(parent(a.array)), axes(a)...))
end
function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) where {N}
return a[Tuple(I)...]
end
Expand All @@ -192,6 +251,13 @@ function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) whe
# TODO: Implement this properly.
return true
end
function SparseArrayInterface.stored_indices(a::SparseSubArrayBlocks)
return stored_indices(view(blocks(parent(a.array)), axes(a)...))
end
# TODO: Either make this the generic interface or define
# `SparseArrayInterface.sparse_storage`, which is used
# to defined this.
SparseArrayInterface.nstored(a::SparseSubArrayBlocks) = length(stored_indices(a))
function SparseArrayInterface.sparse_storage(a::SparseSubArrayBlocks)
return error("Not implemented")
end
Expand Down
93 changes: 87 additions & 6 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ include("TestBlockSparseArraysUtils.jl")
@test block_nstored(a) == 2
@test nstored(a) == 2 * 4 + 3 * 3

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = similar(a, complex(elt))
@test eltype(b) == complex(eltype(a))
@test iszero(b)
Expand All @@ -56,37 +59,58 @@ include("TestBlockSparseArraysUtils.jl")
@test size(b) == size(a)
@test blocksize(b) == blocksize(a)

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = copy(a)
b[1, 1] = 11
@test b[1, 1] == 11
@test a[1, 1] ≠ 11

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = copy(a)
b .*= 2
@test b ≈ 2a

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = copy(a)
b ./= 2
@test b ≈ a / 2

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = 2 * a
@test Array(b) ≈ 2 * Array(a)
@test eltype(b) == elt
@test block_nstored(b) == 2
@test nstored(b) == 2 * 4 + 3 * 3

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = (2 + 3im) * a
@test Array(b) ≈ (2 + 3im) * Array(a)
@test eltype(b) == complex(elt)
@test block_nstored(b) == 2
@test nstored(b) == 2 * 4 + 3 * 3

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = a + a
@test Array(b) ≈ 2 * Array(a)
@test eltype(b) == elt
@test block_nstored(b) == 2
@test nstored(b) == 2 * 4 + 3 * 3

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
x = BlockSparseArray{elt}(undef, ([3, 4], [2, 3]))
x[Block(1, 2)] = randn(elt, size(@view(x[Block(1, 2)])))
x[Block(2, 1)] = randn(elt, size(@view(x[Block(2, 1)])))
Expand All @@ -96,12 +120,18 @@ include("TestBlockSparseArraysUtils.jl")
@test block_nstored(b) == 2
@test nstored(b) == 2 * 4 + 3 * 3

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = permutedims(a, (2, 1))
@test Array(b) ≈ permutedims(Array(a), (2, 1))
@test eltype(b) == elt
@test block_nstored(b) == 2
@test nstored(b) == 2 * 4 + 3 * 3

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = map(x -> 2x, a)
@test Array(b) ≈ 2 * Array(a)
@test eltype(b) == elt
Expand All @@ -110,6 +140,9 @@ include("TestBlockSparseArraysUtils.jl")
@test block_nstored(b) == 2
@test nstored(b) == 2 * 4 + 3 * 3

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = a[[Block(2), Block(1)], [Block(2), Block(1)]]
@test b[Block(1, 1)] == a[Block(2, 2)]
@test b[Block(1, 2)] == a[Block(2, 1)]
Expand All @@ -120,13 +153,19 @@ include("TestBlockSparseArraysUtils.jl")
@test nstored(b) == nstored(a)
@test block_nstored(b) == 2

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = a[Block(1):Block(2), Block(1):Block(2)]
@test b == a
@test size(b) == size(a)
@test blocksize(b) == (2, 2)
@test nstored(b) == nstored(a)
@test block_nstored(b) == 2

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = a[Block(1):Block(1), Block(1):Block(2)]
@test b == Array(a)[1:2, 1:end]
@test b[Block(1, 1)] == a[Block(1, 1)]
Expand All @@ -136,41 +175,83 @@ include("TestBlockSparseArraysUtils.jl")
@test nstored(b) == nstored(a[Block(1, 2)])
@test block_nstored(b) == 1

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = a[2:4, 2:4]
@test b == Array(a)[2:4, 2:4]
@test size(b) == (3, 3)
@test blocksize(b) == (2, 2)
@test nstored(b) == 1 * 1 + 2 * 2
@test block_nstored(b) == 2

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = a[Block(2, 1)[1:2, 2:3]]
@test b == Array(a)[3:4, 2:3]
@test size(b) == (2, 2)
@test blocksize(b) == (1, 1)
@test nstored(b) == 2 * 2
@test block_nstored(b) == 1

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = PermutedDimsArray(a, (2, 1))
@test block_nstored(b) == 2
@test Array(b) == permutedims(Array(a), (2, 1))
c = 2 * b
@test block_nstored(c) == 2
@test Array(c) == 2 * permutedims(Array(a), (2, 1))

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = a'
@test block_nstored(b) == 2
@test Array(b) == Array(a)'
c = 2 * b
@test block_nstored(c) == 2
@test Array(c) == 2 * Array(a)'

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = transpose(a)
@test block_nstored(b) == 2
@test Array(b) == transpose(Array(a))
c = 2 * b
@test block_nstored(c) == 2
@test Array(c) == 2 * transpose(Array(a))

## Broken, need to fix.

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
@test_broken a[Block(1), Block(1):Block(2)]

# This is outputting only zero blocks.
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = a[Block(2):Block(2), Block(1):Block(2)]
@test_broken block_nstored(b) == 1
@test_broken b == Array(a)[3:5, 1:end]

b = a'
@test_broken block_nstored(b) == 2

b = transpose(a)
@test_broken block_nstored(b) == 2

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = copy(a)
x = randn(size(@view(a[Block(2, 2)])))
b[Block(2), Block(2)] = x
@test_broken b[Block(2, 2)] == x

# Doesnt' set the block
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
b = copy(a)
b[Block(1, 1)] .= 1
@test_broken b[1, 1] == trues(size(@view(b[1, 1])))
Expand Down
Loading