@@ -648,7 +648,6 @@ for type in tuple(AbstractVector, AbstractZerosVector, linearalgebra_types...)
648648 end
649649end
650650
651- # TODO : add dim check to all abstract ones multiplication
652651for type in linearalgebra_types
653652 @eval begin
654653 function * (A:: $type , B:: DiagonalFill )
@@ -702,53 +701,85 @@ end
702701
703702for type in (AdjointAbsVec{<: Any ,<: AbstractOnesVector }, TransposeAbsVec{<: Any ,<: AbstractOnesVector }, AbstractOnesMatrix, AbstractOnesVector)
704703 @eval begin
705- * (A:: DiagonalOnes , B:: $type ) = Ones{promote_type(eltype(A),eltype(B))}(size(B))
704+ function * (A:: DiagonalOnes , B:: $type )
705+ check_matmul_sizes(A, B)
706+ Ones{promote_type(eltype(A),eltype(B))}(size(B))
707+ end
706708 end
707709end
708710for type in (AdjointAbsVec{<: Any ,<: AbstractOnesVector }, TransposeAbsVec{<: Any ,<: AbstractOnesVector }, AbstractOnesMatrix)
709711 @eval begin
710- * (A:: $type , B:: DiagonalOnes ) = Ones{promote_type(eltype(A),eltype(B))}(size(A))
712+ function * (A:: $type , B:: DiagonalOnes )
713+ check_matmul_sizes(A, B)
714+ Ones{promote_type(eltype(A),eltype(B))}(size(A))
715+ end
711716 end
712717end
713718
714719for type in (DiagonalFill, DiagonalOnes)
715720 @eval begin
716721 function * (A:: $type , B:: RectDiagonalFill )
717722 check_matmul_sizes(A, B)
718- len = minimum(size(B))
719- RectDiagonal(view(A. diag, Base . OneTo( len)) .* view(B. diag, Base . OneTo( len) ), size(B))
723+ len = Base . OneTo( minimum(size(B) ))
724+ RectDiagonal(view(A. diag, len) .* view(B. diag, len), size(B))
720725 end
721726
722727 function * (A:: RectDiagonalFill , B:: $type )
723728 check_matmul_sizes(A, B)
724- len = minimum(size(A))
725- RectDiagonal(view(A. diag, Base . OneTo( len)) .* view(B. diag, Base . OneTo( len) ), size(A))
729+ len = Base . OneTo( minimum(size(A) ))
730+ RectDiagonal(view(A. diag, len) .* view(B. diag, len), size(A))
726731 end
727732 end
728733end
729734
730735for type1 in (AbstractMatrix, Diagonal)
731736 for type2 in (Diagonal, DiagonalOnes, DiagonalFill)
732737 @eval begin
733- * (Da:: DiagonalZeros , A:: $type1 , Db:: $type2 ) = Zeros(Da) * A
734- * (Da:: $type2 , A:: $type1 , Db:: DiagonalZeros ) = A * Zeros(Db)
738+ function * (Da:: DiagonalZeros , A:: $type1 , Db:: $type2 )
739+ check_matmul_sizes(A, Db)
740+ Zeros(Da) * A
741+ end
742+ function * (Da:: $type2 , A:: $type1 , Db:: DiagonalZeros )
743+ check_matmul_sizes(Da, A)
744+ A * Zeros(Db)
745+ end
735746 end
736747 end
737748
738749 for type2 in (Diagonal, DiagonalFill)
739750 @eval begin
740- * (Da:: DiagonalOnes , A:: $type1 , Db:: $type2 ) = A * Db
741- * (Da:: $type2 , A:: $type1 , Db:: DiagonalOnes ) = Da * A
751+ function * (Da:: DiagonalOnes , A:: $type1 , Db:: $type2 )
752+ check_matmul_sizes(Da, A)
753+ ones(eltype(Da)) * A * Db
754+ end
755+ function * (Da:: $type2 , A:: $type1 , Db:: DiagonalOnes )
756+ check_matmul_sizes(A, Db)
757+ Da * A * ones(eltype(Db))
758+ end
742759 end
743760 end
744761
745762 @eval begin
746763 * (Da:: DiagonalZeros , A:: $type1 , Db:: DiagonalZeros ) = Zeros(Da) * A * Zeros(Db)
747- * (Da:: DiagonalOnes , A:: $type1 , Db:: DiagonalOnes ) = A
764+ function * (Da:: DiagonalOnes , A:: $type1 , Db:: DiagonalOnes )
765+ check_matmul_sizes(Da, A)
766+ check_matmul_sizes(A, Db)
767+ (one(eltype(Da)) * one(eltype(Db))) * A
768+ end
748769
749- * (Da:: DiagonalFill , A:: $type1 , Db:: Diagonal ) = getindex_value(Da. diag) * A * Db
750- * (Da:: Diagonal , A:: $type1 , Db:: DiagonalFill ) = Da * A * getindex_value(Db. diag)
751- * (Da:: DiagonalFill , A:: $type1 , Db:: DiagonalFill ) = getindex_value(Da. diag) * getindex_value(Db. diag) * A
770+ function * (Da:: DiagonalFill , A:: $type1 , Db:: Diagonal )
771+ check_matmul_sizes(Da, A)
772+ getindex_value(Da. diag) * A * Db
773+ end
774+ function * (Da:: Diagonal , A:: $type1 , Db:: DiagonalFill )
775+ check_matmul_sizes(A, Db)
776+ Da * A * getindex_value(Db. diag)
777+ end
778+ function * (Da:: DiagonalFill , A:: $type1 , Db:: DiagonalFill )
779+ check_matmul_sizes(Da, A)
780+ check_matmul_sizes(A, Db)
781+ (getindex_value(Da. diag) * getindex_value(Db. diag)) * A
782+ end
752783 end
753784end
754785
0 commit comments