Skip to content

Commit 06a18c7

Browse files
committed
optimize code & improve test coverage
1 parent 9daec23 commit 06a18c7

File tree

3 files changed

+239
-149
lines changed

3 files changed

+239
-149
lines changed

src/fillalgebra.jl

Lines changed: 36 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,18 @@ mult_ones(a, b) = mult_ones(a, b, mult_axes(a, b))
8181

8282
*(a::AbstractFillMatrix, b::AbstractFillMatrix) = mult_fill(a,b)
8383
*(a::AbstractFillMatrix, b::AbstractFillVector) = mult_fill(a,b)
84-
for type in (AdjointAbsVec{<:Any,<:AbstractFillVector}, TransposeAbsVec{<:Any,<:AbstractFillVector})
84+
for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AdjointAbsVec{<:Any,<:AbstractFillVector}, TransposeAbsVec{<:Any,<:AbstractFillVector})
8585
@eval begin
8686
function *(A::AbstractFillVector, B::$type)
8787
size(A,2) == size(B,1) ||
8888
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
8989
Fill(getindex_value(A) * getindex_value(B), size(A, 1), size(B, 2))
9090
end
91+
function *(A::AbstractFillMatrix, B::$type)
92+
size(A,2) == size(B,1) ||
93+
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
94+
Fill(getindex_value(A) * getindex_value(B), size(A, 1), size(B, 2))
95+
end
9196
end
9297
end
9398

@@ -97,11 +102,12 @@ end
97102
for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector})
98103
@eval begin
99104
*(A::AbstractOnesVector, B::$type) = mult_ones(A, B)
105+
*(A::AbstractOnesMatrix, B::$type) = mult_ones(A, B)
100106
end
101107
end
102108

103109
for type2 in (AdjointAbsVec{<:Any,<:AbstractZerosVector}, TransposeAbsVec{<:Any,<:AbstractZerosVector})
104-
for type1 in (AbstractFillVector, AbstractZerosVector, AbstractOnesVector)
110+
for type1 in (AbstractFillVector, AbstractZerosVector, AbstractOnesVector, AbstractFillMatrix, AbstractZerosMatrix, AbstractOnesMatrix)
105111
@eval begin
106112
function *(A::$type1, B::$type2)
107113
size(A,2) == size(B,1) ||
@@ -119,6 +125,11 @@ for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:
119125
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
120126
Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2))
121127
end
128+
function *(A::AbstractZerosMatrix, B::$type)
129+
size(A,2) == size(B,1) ||
130+
throw(DimensionMismatch("second dimension of A, $(size(A,2)) does not match first dimension of B, $(size(B,1))"))
131+
Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1), size(B, 2))
132+
end
122133
end
123134
end
124135

@@ -630,11 +641,11 @@ end
630641
# DiagonalFill Multiplication
631642
const DiagonalZeros{T,V<:AbstractZerosVector{T}} = Diagonal{T,V}
632643
const DiagonalOnes{T,V<:AbstractOnesVector{T}} = Diagonal{T,V}
633-
linearalgebra_types = (AbstractMatrix, Diagonal, RectDiagonal, AbstractZerosMatrix,
644+
mat_types = (AbstractMatrix, RectDiagonal, AbstractZerosMatrix,
634645
AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec, UnitUpperTriangular, UnitLowerTriangular,
635646
LowerTriangular, UpperTriangular, LinearAlgebra.AbstractTriangular, Symmetric, Hermitian,
636647
SymTridiagonal, UpperHessenberg, AdjOrTransAbsVec{<:Any,<:AbstractZerosVector})#, OneElement)
637-
for type in tuple(AbstractVector, AbstractZerosVector, linearalgebra_types...)
648+
for type in tuple(AbstractVector, AbstractZerosVector, mat_types...)
638649
@eval begin
639650
function *(A::DiagonalFill, B::$type)
640651
check_matmul_sizes(A, B)
@@ -643,12 +654,13 @@ for type in tuple(AbstractVector, AbstractZerosVector, linearalgebra_types...)
643654
*(A::DiagonalZeros, B::$type) = Zeros(A) * B
644655
function *(A::DiagonalOnes, B::$type)
645656
check_matmul_sizes(A, B)
646-
one(eltype(A)) * B
657+
convert(AbstractArray{promote_type(eltype(A), eltype(B))}, deepcopy(B))
647658
end
648659
end
649660
end
661+
*(A::DiagonalOnes, B::AbstractRange) = one(eltype(A)) * B
650662

651-
for type in linearalgebra_types
663+
for type in mat_types
652664
@eval begin
653665
function *(A::$type, B::DiagonalFill)
654666
check_matmul_sizes(A, B)
@@ -657,7 +669,7 @@ for type in linearalgebra_types
657669
*(A::$type, B::DiagonalZeros) = A * Zeros(B)
658670
function *(A::$type, B::DiagonalOnes)
659671
check_matmul_sizes(A, B)
660-
one(eltype(B)) * A
672+
convert(AbstractMatrix{promote_type(eltype(A), eltype(B))}, deepcopy(A))
661673
end
662674
end
663675
end
@@ -669,53 +681,22 @@ for type1 in (DiagonalFill, DiagonalOnes, DiagonalZeros)
669681
*(A::$type1, B::$type2) = A * Zeros(B)
670682
end
671683
end
672-
end
673-
674-
for type in (DiagonalFill, DiagonalOnes, RectDiagonalFill)
675684
@eval begin
676-
*(A::$type, B::DiagonalZeros) = A * Zeros(B)
677-
*(A::DiagonalZeros, B::$type) = Zeros(A) * B
685+
*(A::Diagonal, B::$type1) = Diagonal(A.diag .* B.diag)
686+
*(A::$type1, B::Diagonal) = Diagonal(A.diag .* B.diag)
678687
end
679688
end
680-
function *(A::DiagonalZeros, B::DiagonalZeros)
681-
check_matmul_sizes(A, B)
682-
Zeros{promote_type(eltype(A),eltype(B))}(A)
683-
end
684689

685-
for type1 in (DiagonalFill, DiagonalOnes)
686-
for type2 in (DiagonalFill, DiagonalOnes)
687-
if type1 !== DiagonalOnes || type2 !== DiagonalOnes
688-
@eval begin
689-
function *(A::$type1, B::$type2)
690-
check_matmul_sizes(A, B)
691-
getindex_value(A.diag) * B
692-
end
693-
end
694-
end
695-
end
696-
end
697-
function *(A::DiagonalOnes, B::DiagonalOnes)
698-
check_matmul_sizes(A, B)
699-
Diagonal(Ones{promote_type(eltype(A), eltype(B))}(size(A, 1)))
700-
end
701-
702-
for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AbstractOnesMatrix, AbstractOnesVector)
703-
@eval begin
704-
function *(A::DiagonalOnes, B::$type)
705-
check_matmul_sizes(A, B)
706-
Ones{promote_type(eltype(A),eltype(B))}(size(B))
707-
end
708-
end
709-
end
710-
for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AbstractOnesMatrix)
711-
@eval begin
712-
function *(A::$type, B::DiagonalOnes)
713-
check_matmul_sizes(A, B)
714-
Ones{promote_type(eltype(A),eltype(B))}(size(A))
690+
for type1 in (DiagonalFill, DiagonalOnes, DiagonalZeros)
691+
for type2 in (DiagonalFill, DiagonalOnes, DiagonalZeros)
692+
@eval begin
693+
*(A::$type1, B::$type2) = Diagonal(A.diag .* B.diag)
715694
end
716695
end
717696
end
718697

698+
*(A::RectDiagonalFill, B::DiagonalZeros) = A * Zeros(B)
699+
*(A::DiagonalZeros, B::RectDiagonalFill) = Zeros(A) * B
719700
for type in (DiagonalFill, DiagonalOnes)
720701
@eval begin
721702
function *(A::$type, B::RectDiagonalFill)
@@ -732,54 +713,15 @@ for type in (DiagonalFill, DiagonalOnes)
732713
end
733714
end
734715

735-
for type1 in (AbstractMatrix, Diagonal)
736-
for type2 in (Diagonal, DiagonalOnes, DiagonalFill)
737-
@eval begin
738-
function *(Da::DiagonalZeros, A::$type1, Db::$type2)
739-
check_matmul_sizes(A, Db)
740-
Zeros(Da) * A
741-
end
742-
function *(Da::$type2, A::$type1, Db::DiagonalZeros)
743-
check_matmul_sizes(Da, A)
744-
A * Zeros(Db)
745-
end
746-
end
747-
end
748-
749-
for type2 in (Diagonal, DiagonalFill)
750-
@eval begin
751-
function *(Da::DiagonalOnes, A::$type1, Db::$type2)
752-
check_matmul_sizes(Da, A)
753-
ones(eltype(Da)) * A * Db
754-
end
755-
function *(Da::$type2, A::$type1, Db::DiagonalOnes)
756-
check_matmul_sizes(A, Db)
757-
Da * A * ones(eltype(Db))
758-
end
759-
end
760-
end
761-
762-
@eval begin
763-
*(Da::DiagonalZeros, A::$type1, Db::DiagonalZeros) = Zeros(Da) * A * Zeros(Db)
764-
function *(Da::DiagonalOnes, A::$type1, Db::DiagonalOnes)
765-
check_matmul_sizes(Da, A)
766-
check_matmul_sizes(A, Db)
767-
(one(eltype(Da)) * one(eltype(Db))) * A
768-
end
769-
770-
function *(Da::DiagonalFill, A::$type1, Db::Diagonal)
771-
check_matmul_sizes(Da, A)
772-
getindex_value(Da.diag) * A * Db
773-
end
774-
function *(Da::Diagonal, A::$type1, Db::DiagonalFill)
775-
check_matmul_sizes(A, Db)
776-
Da * A * getindex_value(Db.diag)
777-
end
778-
function *(Da::DiagonalFill, A::$type1, Db::DiagonalFill)
779-
check_matmul_sizes(Da, A)
780-
check_matmul_sizes(A, Db)
781-
(getindex_value(Da.diag) * getindex_value(Db.diag)) * A
782-
end
716+
function *(Da::Diagonal, A::RectDiagonal, Db::Diagonal)
717+
check_matmul_sizes(Da, A)
718+
check_matmul_sizes(A, Db)
719+
len = Base.OneTo(minimum(size(A)))
720+
diag = view(Da.diag, len) .* view(A.diag, len) .* view(Db.diag, len)
721+
if diag isa Zeros
722+
Zeros{eltype(diag)}(axes(A))
723+
else
724+
RectDiagonal(diag, axes(A))
783725
end
784726
end
785727

src/fillbroadcast.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,21 @@ broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {
259259
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = broadcasted_fill(op, r, op(getindex_value(r),x[]), axes(r))
260260
broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = broadcasted_fill(op, r, op(x[], getindex_value(r)), axes(r))
261261

262+
# ternary broadcasting
263+
for type1 in (AbstractArray, AbstractFill, AbstractZeros)
264+
for type2 in (AbstractArray, AbstractFill, AbstractZeros)
265+
for type3 in (AbstractArray, AbstractFill, AbstractZeros)
266+
if type1 === AbstractZeros || type2 === AbstractZeros || type3 === AbstractZeros
267+
@eval begin
268+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::$type1, b::$type2, c::$type3) = Zeros{promote_type(eltype(a),eltype(b),eltype(c))}(broadcast_shape(axes(a), axes(b), axes(c)))
269+
end
270+
end
271+
end
272+
end
273+
end
274+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractOnes, b::AbstractOnes, c::AbstractOnes) = Ones{promote_type(eltype(a),eltype(b),eltype(c))}(broadcast_shape(axes(a), axes(b), axes(c)))
275+
broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractFill, b::AbstractFill, c::AbstractFill) = Fill(getindex_value(a)*getindex_value(b)*getindex_value(c), broadcast_shape(axes(a), axes(b), axes(c)))
276+
262277
# support AbstractFill .^ k
263278
broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractFill{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_fill(op, r, getindex_value(r)^k, axes(r))
264279
broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractOnes{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_ones(op, r, T, axes(r))

0 commit comments

Comments
 (0)