Skip to content

Commit 9daec23

Browse files
committed
Check matmul sizes
1 parent 194df93 commit 9daec23

File tree

1 file changed

+46
-15
lines changed

1 file changed

+46
-15
lines changed

src/fillalgebra.jl

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,6 @@ for type in tuple(AbstractVector, AbstractZerosVector, linearalgebra_types...)
648648
end
649649
end
650650

651-
# TODO: add dim check to all abstract ones multiplication
652651
for type in linearalgebra_types
653652
@eval begin
654653
function *(A::$type, B::DiagonalFill)
@@ -702,53 +701,85 @@ end
702701

703702
for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AbstractOnesMatrix, AbstractOnesVector)
704703
@eval begin
705-
*(A::DiagonalOnes, B::$type) = Ones{promote_type(eltype(A),eltype(B))}(size(B))
704+
function *(A::DiagonalOnes, B::$type)
705+
check_matmul_sizes(A, B)
706+
Ones{promote_type(eltype(A),eltype(B))}(size(B))
707+
end
706708
end
707709
end
708710
for type in (AdjointAbsVec{<:Any,<:AbstractOnesVector}, TransposeAbsVec{<:Any,<:AbstractOnesVector}, AbstractOnesMatrix)
709711
@eval begin
710-
*(A::$type, B::DiagonalOnes) = Ones{promote_type(eltype(A),eltype(B))}(size(A))
712+
function *(A::$type, B::DiagonalOnes)
713+
check_matmul_sizes(A, B)
714+
Ones{promote_type(eltype(A),eltype(B))}(size(A))
715+
end
711716
end
712717
end
713718

714719
for type in (DiagonalFill, DiagonalOnes)
715720
@eval begin
716721
function *(A::$type, B::RectDiagonalFill)
717722
check_matmul_sizes(A, B)
718-
len = minimum(size(B))
719-
RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), size(B))
723+
len = Base.OneTo(minimum(size(B)))
724+
RectDiagonal(view(A.diag, len) .* view(B.diag, len), size(B))
720725
end
721726

722727
function *(A::RectDiagonalFill, B::$type)
723728
check_matmul_sizes(A, B)
724-
len = minimum(size(A))
725-
RectDiagonal(view(A.diag, Base.OneTo(len)) .* view(B.diag, Base.OneTo(len)), size(A))
729+
len = Base.OneTo(minimum(size(A)))
730+
RectDiagonal(view(A.diag, len) .* view(B.diag, len), size(A))
726731
end
727732
end
728733
end
729734

730735
for type1 in (AbstractMatrix, Diagonal)
731736
for type2 in (Diagonal, DiagonalOnes, DiagonalFill)
732737
@eval begin
733-
*(Da::DiagonalZeros, A::$type1, Db::$type2) = Zeros(Da) * A
734-
*(Da::$type2, A::$type1, Db::DiagonalZeros) = A * Zeros(Db)
738+
function *(Da::DiagonalZeros, A::$type1, Db::$type2)
739+
check_matmul_sizes(A, Db)
740+
Zeros(Da) * A
741+
end
742+
function *(Da::$type2, A::$type1, Db::DiagonalZeros)
743+
check_matmul_sizes(Da, A)
744+
A * Zeros(Db)
745+
end
735746
end
736747
end
737748

738749
for type2 in (Diagonal, DiagonalFill)
739750
@eval begin
740-
*(Da::DiagonalOnes, A::$type1, Db::$type2) = A * Db
741-
*(Da::$type2, A::$type1, Db::DiagonalOnes) = Da * A
751+
function *(Da::DiagonalOnes, A::$type1, Db::$type2)
752+
check_matmul_sizes(Da, A)
753+
ones(eltype(Da)) * A * Db
754+
end
755+
function *(Da::$type2, A::$type1, Db::DiagonalOnes)
756+
check_matmul_sizes(A, Db)
757+
Da * A * ones(eltype(Db))
758+
end
742759
end
743760
end
744761

745762
@eval begin
746763
*(Da::DiagonalZeros, A::$type1, Db::DiagonalZeros) = Zeros(Da) * A * Zeros(Db)
747-
*(Da::DiagonalOnes, A::$type1, Db::DiagonalOnes) = A
764+
function *(Da::DiagonalOnes, A::$type1, Db::DiagonalOnes)
765+
check_matmul_sizes(Da, A)
766+
check_matmul_sizes(A, Db)
767+
(one(eltype(Da)) * one(eltype(Db))) * A
768+
end
748769

749-
*(Da::DiagonalFill, A::$type1, Db::Diagonal) = getindex_value(Da.diag) * A * Db
750-
*(Da::Diagonal, A::$type1, Db::DiagonalFill) = Da * A * getindex_value(Db.diag)
751-
*(Da::DiagonalFill, A::$type1, Db::DiagonalFill) = getindex_value(Da.diag) * getindex_value(Db.diag) * A
770+
function *(Da::DiagonalFill, A::$type1, Db::Diagonal)
771+
check_matmul_sizes(Da, A)
772+
getindex_value(Da.diag) * A * Db
773+
end
774+
function *(Da::Diagonal, A::$type1, Db::DiagonalFill)
775+
check_matmul_sizes(A, Db)
776+
Da * A * getindex_value(Db.diag)
777+
end
778+
function *(Da::DiagonalFill, A::$type1, Db::DiagonalFill)
779+
check_matmul_sizes(Da, A)
780+
check_matmul_sizes(A, Db)
781+
(getindex_value(Da.diag) * getindex_value(Db.diag)) * A
782+
end
752783
end
753784
end
754785

0 commit comments

Comments
 (0)