Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 30a0295

Browse files
committedOct 31, 2024·
Add SVD support for BlockDiagonal
1 parent f382da6 commit 30a0295

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed
 

‎src/BlockArrays.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,14 @@ include("blocks.jl")
6666

6767
include("blockbroadcast.jl")
6868
include("blockcholesky.jl")
69-
include("blocksvd.jl")
7069
include("blocklinalg.jl")
7170
include("blockproduct.jl")
7271
include("show.jl")
7372
include("blockreduce.jl")
7473
include("blockdeque.jl")
7574
include("blockarrayinterface.jl")
7675
include("blockbanded.jl")
76+
include("blocksvd.jl")
7777

7878
@deprecate getblock(A::AbstractBlockArray{T,N}, I::Vararg{Integer, N}) where {T,N} view(A, Block(I))
7979
@deprecate getblock!(X, A::AbstractBlockArray{T,N}, I::Vararg{Integer, N}) where {T,N} copyto!(X, view(A, Block(I)))

‎src/blocksvd.jl

+14
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@ SVD on blockmatrices:
33
Interpret the matrix as a linear map acting on vector spaces with a direct sum structure, which is reflected in the structure of U and V.
44
In the generic case, the SVD does not preserve this structure, and can mix up the blocks, so S becomes a single block.
55
(Implementation-wise, also most efficiently done by first mapping to a `BlockedArray`)
6+
In the case of `BlockDiagonal` however, the structure is preserved and carried over to the structure of `S`.
67
=#
78

89
LinearAlgebra.eigencopy_oftype(A::AbstractBlockMatrix, S) = BlockedArray(Array{S}(A), blocksizes(A, 1), blocksizes(A, 2))
910

11+
function LinearAlgebra.eigencopy_oftype(A::BlockDiagonal, S)
12+
diag = map(Base.Fix2(LinearAlgebra.eigencopy_oftype, S), A.blocks.diag)
13+
return BlockDiagonal(diag)
14+
end
15+
1016
function LinearAlgebra.svd!(A::BlockedMatrix; full::Bool=false, alg::LinearAlgebra.Algorithm=default_svd_alg(A))
1117
USV = svd!(parent(A); full, alg)
1218

@@ -19,3 +25,11 @@ function LinearAlgebra.svd!(A::BlockedMatrix; full::Bool=false, alg::LinearAlgeb
1925
vt = BlockedArray(USV.Vt, bsz2, bsz3)
2026
return SVD(u, s, vt)
2127
end
28+
29+
function LinearAlgebra.svd!(A::BlockDiagonal; full::Bool=false, alg::LinearAlgebra.Algorithm=default_svd_alg(A))
30+
USVs = map(a -> svd!(a; full, alg), A.blocks.diag)
31+
Us = map(Base.Fix2(getproperty, :U), USVs)
32+
Ss = map(Base.Fix2(getproperty, :S), USVs)
33+
Vts = map(Base.Fix2(getproperty, :Vt), USVs)
34+
return SVD(BlockDiagonal(Us), mortar(Ss), BlockDiagonal(Vts))
35+
end

‎test/test_blocksvd.jl

+34
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module TestBlockSVD
22

33
using BlockArrays, Test, LinearAlgebra, Random
4+
using BlockArrays: BlockDiagonal
45

56
Random.seed!(0)
67

@@ -70,4 +71,37 @@ end
7071
@test U_blocked * Diagonal(S_blocked) * Vt_blocked y
7172
end
7273

74+
@testset "BlockDiagonal SVD ($T)" for T in eltypes
75+
blocksz = (2, 3, 1)
76+
y = BlockDiagonal([rand(T, d, d) for d in blocksz])
77+
x = Array(y)
78+
79+
USV = svd(x)
80+
U, S, Vt = USV.U, USV.S, USV.Vt
81+
82+
# https://github.com/JuliaArrays/BlockArrays.jl/issues/425
83+
# USV_blocked = @inferred svd(y)
84+
USV_block = svd(y)
85+
U_block, S_block, Vt_block = USV_block.U, USV_block.S, USV_block.Vt
86+
87+
# test types
88+
@test U_block isa BlockDiagonal
89+
@test eltype(U_block) == float(T)
90+
@test S_block isa BlockVector
91+
@test eltype(S_block) == real(float(T))
92+
@test Vt_block isa BlockDiagonal
93+
@test eltype(Vt_block) == float(T)
94+
95+
# test structure
96+
@test blocksizes(U_block, 1) == blocksizes(y, 1)
97+
@test length(blocksizes(U_block, 2)) == length(blocksz)
98+
@test blocksizes(Vt_block, 2) == blocksizes(y, 2)
99+
@test length(blocksizes(Vt_block, 1)) == length(blocksz)
100+
101+
# test correctness: SVD is not unique, so cannot compare to dense
102+
@test U_block * BlockDiagonal(Diagonal.(S_block.blocks)) * Vt_block y
103+
@test U_block' * U_block LinearAlgebra.I
104+
@test Vt_block * Vt_block' LinearAlgebra.I
105+
end
106+
73107
end # module

0 commit comments

Comments
 (0)