From c51bad0bc855dc281b6425d0d94ccf3947b0c14f Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 4 Feb 2025 12:38:41 +0100 Subject: [PATCH 1/6] Use `BLAS.trsm!` instead of `LAPACK.trtrs!` in left-triangular solves Co-authored-by: Alexis Montoison --- src/triangular.jl | 8 +++++--- test/triangular.jl | 7 ++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/triangular.jl b/src/triangular.jl index 5c6b518..dee8f90 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -1223,11 +1223,13 @@ function generic_mattrimul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, end end # division -function generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVecOrMat{T}) where {T<:BlasFloat} +generic_trimatdiv!(C::StridedVector{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVector{T}) where {T<:BlasFloat} = + BLAS.trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B)) +function generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractMatrix{T}) where {T<:BlasFloat} if stride(C,1) == stride(A,1) == 1 - LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B)) + BLAS.trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B)) else # incompatible with LAPACK - @invoke generic_trimatdiv!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractVecOrMat) + @invoke generic_trimatdiv!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix) end end function generic_mattridiv!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat} diff --git a/test/triangular.jl b/test/triangular.jl index a5fa45d..8eeeb66 100644 --- a/test/triangular.jl +++ b/test/triangular.jl @@ -886,8 +886,13 @@ end end end -@testset "(l/r)mul! and (l/r)div! for non-contiguous matrices" begin +@testset "(l/r)mul! and (l/r)div! for non-contiguous arrays" begin U = UpperTriangular(reshape(collect(3:27.0),5,5)) + b = float.(1:10) + b2 = copy(b); b2v = view(b2, 1:2:9); b2vc = copy(b2v) + @test lmul!(U, b2v) == lmul!(U, b2vc) + b2 = copy(b); b2v = view(b2, 1:2:9); b2vc = copy(b2v) + @test ldiv!(U, b2v) ≈ ldiv!(U, b2vc) B = float.(collect(reshape(1:100, 10,10))) B2 = copy(B); B2v = view(B2, 1:2:9, 1:5); B2vc = copy(B2v) @test lmul!(U, B2v) == lmul!(U, B2vc) From 054c8c6fe21bdca1c41936c3961fec815b98b3b5 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 4 Feb 2025 20:45:13 +0100 Subject: [PATCH 2/6] add singularity checks --- src/blas.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/blas.jl b/src/blas.jl index 3c15630..d99482f 100644 --- a/src/blas.jl +++ b/src/blas.jl @@ -1369,6 +1369,11 @@ for (fname, elty) in ((:dtrsv_,:Float64), throw(DimensionMismatch(lazy"size of A is $n != length(x) = $(length(x))")) end chkstride1(A) + if diag == 'N' + for i in 1:n + iszero(A[i,i]) && throw(SingularException(i)) + end + end px, stx = vec_pointer_stride(x, ArgumentError("input vector with 0 stride is not allowed")) GC.@preserve x ccall((@blasfunc($fname), libblastrampoline), Cvoid, (Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, @@ -2217,6 +2222,12 @@ for (mmname, smname, elty) in end chkstride1(A) chkstride1(B) + if diag == 'N' + M = side == 'L' ? A : B + for i in 1:n + iszero(M[i,i]) && throw(SingularException(i)) + end + end ccall((@blasfunc($smname), libblastrampoline), Cvoid, (Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty}, From c563b4b5e3dee44c8a9b68d92d35e17878dcef89 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 4 Feb 2025 20:47:23 +0100 Subject: [PATCH 3/6] add tests --- test/testtriag.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/testtriag.jl b/test/testtriag.jl index 7542d84..6ed7a0d 100644 --- a/test/testtriag.jl +++ b/test/testtriag.jl @@ -493,6 +493,8 @@ function test_triangular(elty1_types) @test_throws DimensionMismatch transpose(Ann) \ bm if t1 == UpperTriangular || t1 == LowerTriangular @test_throws SingularException ldiv!(t1(zeros(elty1, n, n)), fill(eltyB(1), n)) + @test_throws SingularException ldiv!(t1(zeros(elty1, n, n)), fill(eltyB(1), n, 2)) + @test_throws SingularException rdiv!(fill(eltyB(1), n, n), t1(zeros(elty1, n, n))) end @test B / A1 ≈ B / M1 @test B / transpose(A1) ≈ B / transpose(M1) From 711aceb8831833d620d1cad53a20262ac6ea47f0 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 4 Feb 2025 21:48:05 +0100 Subject: [PATCH 4/6] fix mistake --- src/blas.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/blas.jl b/src/blas.jl index d99482f..67d1425 100644 --- a/src/blas.jl +++ b/src/blas.jl @@ -2223,9 +2223,8 @@ for (mmname, smname, elty) in chkstride1(A) chkstride1(B) if diag == 'N' - M = side == 'L' ? A : B - for i in 1:n - iszero(M[i,i]) && throw(SingularException(i)) + for i in 1:k + iszero(A[i,i]) && throw(SingularException(i)) end end ccall((@blasfunc($smname), libblastrampoline), Cvoid, From f8850f04f8a89b8c80d7b962233c9198ed0bf0e8 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 4 Feb 2025 21:49:39 +0100 Subject: [PATCH 5/6] use SingularException --- src/blas.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/blas.jl b/src/blas.jl index 67d1425..04e590c 100644 --- a/src/blas.jl +++ b/src/blas.jl @@ -84,7 +84,8 @@ export trsm!, trsm -using ..LinearAlgebra: libblastrampoline, BlasReal, BlasComplex, BlasFloat, BlasInt, DimensionMismatch, checksquare, chkstride1 +using ..LinearAlgebra: libblastrampoline, BlasReal, BlasComplex, BlasFloat, BlasInt, + DimensionMismatch, checksquare, chkstride1, SingularException include("lbt.jl") From dfc4c721006a026be9db1469559e75bd80b6d6bd Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 5 Feb 2025 12:03:28 +0100 Subject: [PATCH 6/6] fix signature --- src/triangular.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/triangular.jl b/src/triangular.jl index dee8f90..3faa12d 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -1225,7 +1225,7 @@ end # division generic_trimatdiv!(C::StridedVector{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVector{T}) where {T<:BlasFloat} = BLAS.trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B)) -function generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractMatrix{T}) where {T<:BlasFloat} +function generic_trimatdiv!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractMatrix{T}) where {T<:BlasFloat} if stride(C,1) == stride(A,1) == 1 BLAS.trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B)) else # incompatible with LAPACK