diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 148db495..f83dcf9a 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -138,6 +138,28 @@ function *(a::Transpose{T, <:AbstractVector{T}}, b::Zeros{T, 1}) where T<:Real end *(a::Transpose{T, <:AbstractMatrix{T}}, b::Zeros{T, 1}) where T<:Real = mult_zeros(a, b) +function dot(a::AbstractArray, b::AbstractFill) + length(a) == length(b) || throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) + adjoint(sum(a)) * getindex_value(b) +end +function dot(a::AbstractFill, b::AbstractArray) + length(a) == length(b) || throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) + adjoint(getindex_value(a)) * sum(b) +end +function dot(a::AbstractFill, b::AbstractFill) + length(a) == length(b) || throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b)).")) + dot(getindex_value(a), getindex_value(b)) * length(a) +end + +function dot(a::Diagonal, b::AbstractFill{<:Any,2}) + size(a) == size(b) || throw(DimensionMismatch("Matrix sizes $(size(a)) and $(size(a)) differ")) + adjoint(sum(a.diag)) * getindex_value(b) +end +function dot(a::AbstractFill{<:Any,2}, b::Diagonal) + size(a) == size(b) || throw(DimensionMismatch("Matrix sizes $(size(a)) and $(size(a)) differ")) + adjoint(getindex_value(a)) * sum(b.diag) +end + function dot(u::AbstractVector, E::Eye, v::AbstractVector) length(u) == size(E,1) && length(v) == size(E,2) || throw(DimensionMismatch("dot product arguments have dimensions $(length(u))×$(size(E))×$(length(v))")) diff --git a/test/runtests.jl b/test/runtests.jl index e3c69d24..590617ef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1111,13 +1111,35 @@ end n = 15 o = Ones(1:n) z = Zeros(1:n) - D = Diagonal(o) + D = Diagonal(o) # Eye Z = Diagonal(z) Random.seed!(5) - u = rand(n) - v = rand(n) + u = rand(ComplexF64, n) + v = rand(ComplexF64, n) + X = Fill(rand(ComplexF64), n) + Y = Fill(rand(ComplexF64), n) + # 2-arg dot + @test dot(u, X) ≈ dot(u, Array(X)) + @test dot(X, v) ≈ dot(Array(X), v) + @test dot(X, Y) ≈ dot(Array(X), Array(Y)) + + @test dot(u, o) ≈ dot(u, Array(o)) + @test dot(X, o) ≈ dot(Array(X), Array(o)) + + M = Fill(rand(ComplexF64), n, n) + D2 = Diagonal(u) + + @test dot(M, D) ≈ dot(Array(M), Array(D)) + @test dot(D, M) ≈ dot(Array(D), Array(M)) + + @test dot(M, D2) ≈ dot(Array(M), Array(D2)) + @test dot(D2, M) ≈ dot(Array(D2), Array(M)) + + @test_throws DimensionMismatch dot(u, X[1:end-1]) + + # 3-arg dot @test dot(u, D, v) == dot(u, v) @test dot(u, 2D, v) == 2dot(u, v) @test dot(u, Z, v) == 0