@@ -81,13 +81,18 @@ mult_ones(a, b) = mult_ones(a, b, mult_axes(a, b))
8181
8282* (a:: AbstractFillMatrix , b:: AbstractFillMatrix ) = mult_fill(a,b)
8383* (a:: AbstractFillMatrix , b:: AbstractFillVector ) = mult_fill(a,b)
84- for type in (AdjointAbsVec{<: Any ,<: AbstractFillVector }, TransposeAbsVec{<: Any ,<: AbstractFillVector })
84+ for type in (AdjointAbsVec{<: Any ,<: AbstractOnesVector }, TransposeAbsVec{ <: Any , <: AbstractOnesVector }, AdjointAbsVec{ <: Any , <: AbstractFillVector }, TransposeAbsVec{<: Any ,<: AbstractFillVector })
8585 @eval begin
8686 function * (A:: AbstractFillVector , B:: $type )
8787 size(A,2 ) == size(B,1 ) ||
8888 throw(DimensionMismatch(" second dimension of A, $(size(A,2 )) does not match first dimension of B, $(size(B,1 )) " ))
8989 Fill(getindex_value(A) * getindex_value(B), size(A, 1 ), size(B, 2 ))
9090 end
91+ function * (A:: AbstractFillMatrix , B:: $type )
92+ size(A,2 ) == size(B,1 ) ||
93+ throw(DimensionMismatch(" second dimension of A, $(size(A,2 )) does not match first dimension of B, $(size(B,1 )) " ))
94+ Fill(getindex_value(A) * getindex_value(B), size(A, 1 ), size(B, 2 ))
95+ end
9196 end
9297end
9398
97102for type in (AdjointAbsVec{<: Any ,<: AbstractOnesVector }, TransposeAbsVec{<: Any ,<: AbstractOnesVector })
98103 @eval begin
99104 * (A:: AbstractOnesVector , B:: $type ) = mult_ones(A, B)
105+ * (A:: AbstractOnesMatrix , B:: $type ) = mult_ones(A, B)
100106 end
101107end
102108
103109for type2 in (AdjointAbsVec{<: Any ,<: AbstractZerosVector }, TransposeAbsVec{<: Any ,<: AbstractZerosVector })
104- for type1 in (AbstractFillVector, AbstractZerosVector, AbstractOnesVector)
110+ for type1 in (AbstractFillVector, AbstractZerosVector, AbstractOnesVector, AbstractFillMatrix, AbstractZerosMatrix, AbstractOnesMatrix )
105111 @eval begin
106112 function * (A:: $type1 , B:: $type2 )
107113 size(A,2 ) == size(B,1 ) ||
@@ -119,6 +125,11 @@ for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:
119125 throw(DimensionMismatch(" second dimension of A, $(size(A,2 )) does not match first dimension of B, $(size(B,1 )) " ))
120126 Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1 ), size(B, 2 ))
121127 end
128+ function * (A:: AbstractZerosMatrix , B:: $type )
129+ size(A,2 ) == size(B,1 ) ||
130+ throw(DimensionMismatch(" second dimension of A, $(size(A,2 )) does not match first dimension of B, $(size(B,1 )) " ))
131+ Zeros{promote_type(eltype(A), eltype(B))}(size(A, 1 ), size(B, 2 ))
132+ end
122133 end
123134end
124135
@@ -630,11 +641,11 @@ end
630641# DiagonalFill Multiplication
631642const DiagonalZeros{T,V<: AbstractZerosVector{T} } = Diagonal{T,V}
632643const DiagonalOnes{T,V<: AbstractOnesVector{T} } = Diagonal{T,V}
633- linearalgebra_types = (AbstractMatrix, Diagonal , RectDiagonal, AbstractZerosMatrix,
644+ mat_types = (AbstractMatrix, RectDiagonal, AbstractZerosMatrix,
634645 AbstractFillMatrix, AdjointAbsVec, TransposeAbsVec, UnitUpperTriangular, UnitLowerTriangular,
635646 LowerTriangular, UpperTriangular, LinearAlgebra. AbstractTriangular, Symmetric, Hermitian,
636647 SymTridiagonal, UpperHessenberg, AdjOrTransAbsVec{<: Any ,<: AbstractZerosVector })# , OneElement)
637- for type in tuple(AbstractVector, AbstractZerosVector, linearalgebra_types ... )
648+ for type in tuple(AbstractVector, AbstractZerosVector, mat_types ... )
638649 @eval begin
639650 function * (A:: DiagonalFill , B:: $type )
640651 check_matmul_sizes(A, B)
@@ -643,12 +654,13 @@ for type in tuple(AbstractVector, AbstractZerosVector, linearalgebra_types...)
643654 * (A:: DiagonalZeros , B:: $type ) = Zeros(A) * B
644655 function * (A:: DiagonalOnes , B:: $type )
645656 check_matmul_sizes(A, B)
646- one( eltype(A)) * B
657+ convert(AbstractArray{promote_type( eltype(A), eltype(B))}, deepcopy(B))
647658 end
648659 end
649660end
661+ * (A:: DiagonalOnes , B:: AbstractRange ) = one(eltype(A)) * B
650662
651- for type in linearalgebra_types
663+ for type in mat_types
652664 @eval begin
653665 function * (A:: $type , B:: DiagonalFill )
654666 check_matmul_sizes(A, B)
@@ -657,7 +669,7 @@ for type in linearalgebra_types
657669 * (A:: $type , B:: DiagonalZeros ) = A * Zeros(B)
658670 function * (A:: $type , B:: DiagonalOnes )
659671 check_matmul_sizes(A, B)
660- one( eltype(B)) * A
672+ convert(AbstractMatrix{promote_type( eltype(A), eltype( B))}, deepcopy(A))
661673 end
662674 end
663675end
@@ -669,53 +681,22 @@ for type1 in (DiagonalFill, DiagonalOnes, DiagonalZeros)
669681 * (A:: $type1 , B:: $type2 ) = A * Zeros(B)
670682 end
671683 end
672- end
673-
674- for type in (DiagonalFill, DiagonalOnes, RectDiagonalFill)
675684 @eval begin
676- * (A:: $type , B:: DiagonalZeros ) = A * Zeros(B )
677- * (A:: DiagonalZeros , B:: $type ) = Zeros(A) * B
685+ * (A:: Diagonal , B:: $type1 ) = Diagonal(A . diag .* B . diag )
686+ * (A:: $type1 , B:: Diagonal ) = Diagonal(A . diag . * B. diag)
678687 end
679688end
680- function * (A:: DiagonalZeros , B:: DiagonalZeros )
681- check_matmul_sizes(A, B)
682- Zeros{promote_type(eltype(A),eltype(B))}(A)
683- end
684689
685- for type1 in (DiagonalFill, DiagonalOnes)
686- for type2 in (DiagonalFill, DiagonalOnes)
687- if type1 != = DiagonalOnes || type2 != = DiagonalOnes
688- @eval begin
689- function * (A:: $type1 , B:: $type2 )
690- check_matmul_sizes(A, B)
691- getindex_value(A. diag) * B
692- end
693- end
694- end
695- end
696- end
697- function * (A:: DiagonalOnes , B:: DiagonalOnes )
698- check_matmul_sizes(A, B)
699- Diagonal(Ones{promote_type(eltype(A), eltype(B))}(size(A, 1 )))
700- end
701-
702- for type in (AdjointAbsVec{<: Any ,<: AbstractOnesVector }, TransposeAbsVec{<: Any ,<: AbstractOnesVector }, AbstractOnesMatrix, AbstractOnesVector)
703- @eval begin
704- function * (A:: DiagonalOnes , B:: $type )
705- check_matmul_sizes(A, B)
706- Ones{promote_type(eltype(A),eltype(B))}(size(B))
707- end
708- end
709- end
710- for type in (AdjointAbsVec{<: Any ,<: AbstractOnesVector }, TransposeAbsVec{<: Any ,<: AbstractOnesVector }, AbstractOnesMatrix)
711- @eval begin
712- function * (A:: $type , B:: DiagonalOnes )
713- check_matmul_sizes(A, B)
714- Ones{promote_type(eltype(A),eltype(B))}(size(A))
690+ for type1 in (DiagonalFill, DiagonalOnes, DiagonalZeros)
691+ for type2 in (DiagonalFill, DiagonalOnes, DiagonalZeros)
692+ @eval begin
693+ * (A:: $type1 , B:: $type2 ) = Diagonal(A. diag .* B. diag)
715694 end
716695 end
717696end
718697
698+ * (A:: RectDiagonalFill , B:: DiagonalZeros ) = A * Zeros(B)
699+ * (A:: DiagonalZeros , B:: RectDiagonalFill ) = Zeros(A) * B
719700for type in (DiagonalFill, DiagonalOnes)
720701 @eval begin
721702 function * (A:: $type , B:: RectDiagonalFill )
@@ -732,54 +713,15 @@ for type in (DiagonalFill, DiagonalOnes)
732713 end
733714end
734715
735- for type1 in (AbstractMatrix, Diagonal)
736- for type2 in (Diagonal, DiagonalOnes, DiagonalFill)
737- @eval begin
738- function * (Da:: DiagonalZeros , A:: $type1 , Db:: $type2 )
739- check_matmul_sizes(A, Db)
740- Zeros(Da) * A
741- end
742- function * (Da:: $type2 , A:: $type1 , Db:: DiagonalZeros )
743- check_matmul_sizes(Da, A)
744- A * Zeros(Db)
745- end
746- end
747- end
748-
749- for type2 in (Diagonal, DiagonalFill)
750- @eval begin
751- function * (Da:: DiagonalOnes , A:: $type1 , Db:: $type2 )
752- check_matmul_sizes(Da, A)
753- ones(eltype(Da)) * A * Db
754- end
755- function * (Da:: $type2 , A:: $type1 , Db:: DiagonalOnes )
756- check_matmul_sizes(A, Db)
757- Da * A * ones(eltype(Db))
758- end
759- end
760- end
761-
762- @eval begin
763- * (Da:: DiagonalZeros , A:: $type1 , Db:: DiagonalZeros ) = Zeros(Da) * A * Zeros(Db)
764- function * (Da:: DiagonalOnes , A:: $type1 , Db:: DiagonalOnes )
765- check_matmul_sizes(Da, A)
766- check_matmul_sizes(A, Db)
767- (one(eltype(Da)) * one(eltype(Db))) * A
768- end
769-
770- function * (Da:: DiagonalFill , A:: $type1 , Db:: Diagonal )
771- check_matmul_sizes(Da, A)
772- getindex_value(Da. diag) * A * Db
773- end
774- function * (Da:: Diagonal , A:: $type1 , Db:: DiagonalFill )
775- check_matmul_sizes(A, Db)
776- Da * A * getindex_value(Db. diag)
777- end
778- function * (Da:: DiagonalFill , A:: $type1 , Db:: DiagonalFill )
779- check_matmul_sizes(Da, A)
780- check_matmul_sizes(A, Db)
781- (getindex_value(Da. diag) * getindex_value(Db. diag)) * A
782- end
716+ function * (Da:: Diagonal , A:: RectDiagonal , Db:: Diagonal )
717+ check_matmul_sizes(Da, A)
718+ check_matmul_sizes(A, Db)
719+ len = Base. OneTo(minimum(size(A)))
720+ diag = view(Da. diag, len) .* view(A. diag, len) .* view(Db. diag, len)
721+ if diag isa Zeros
722+ Zeros{eltype(diag)}(axes(A))
723+ else
724+ RectDiagonal(diag, axes(A))
783725 end
784726end
785727
0 commit comments