From c589ecba2e83f23a34a8dc5d14b49f20d8fd591e Mon Sep 17 00:00:00 2001 From: odow Date: Mon, 21 Nov 2022 07:55:12 +1300 Subject: [PATCH 1/2] dispatch.jl: simplify overloads for SparseMatrix * AbstractArray --- src/dispatch.jl | 222 +++++++----------------------------------------- 1 file changed, 29 insertions(+), 193 deletions(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index 725c7b3f..61d7f832 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -343,199 +343,35 @@ end # `LinearAlgebra.mul!` which prevents us from using mutability of the # arithmetic. For this reason we intercept the calls and redirect them to `mul`. -# A few are overwritten below but many more need to be redirected to `mul` in -# `linalg.jl`. - -Base.:*(A::_SparseMat{<:AbstractMutable}, x::StridedVector) = mul(A, x) - -Base.:*(A::_SparseMat, x::StridedVector{<:AbstractMutable}) = mul(A, x) - -function Base.:*( - A::_SparseMat{<:AbstractMutable}, - x::StridedVector{<:AbstractMutable}, -) - return mul(A, x) -end - -# These six methods are needed on Julia v1.2 and earlier -function Base.:*( - A::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, - x::StridedVector, -) - return mul(A, x) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:Any,<:_SparseMat}, - x::StridedVector{<:AbstractMutable}, -) - return mul(A, x) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, - x::StridedVector{<:AbstractMutable}, -) - return mul(A, x) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:AbstractMutable,<:_SparseMat}, - x::StridedVector, -) - return mul(A, x) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:Any,<:_SparseMat}, - x::StridedVector{<:AbstractMutable}, -) - return mul(A, x) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:AbstractMutable,<:_SparseMat}, - x::StridedVector{<:AbstractMutable}, -) - return mul(A, x) -end - -function Base.:*( - A::_SparseMat{<:AbstractMutable}, - B::_SparseMat{<:AbstractMutable}, -) - return mul(A, B) -end - -Base.:*(A::_SparseMat{<:Any}, B::_SparseMat{<:AbstractMutable}) = mul(A, B) - -Base.:*(A::_SparseMat{<:AbstractMutable}, B::_SparseMat{<:Any}) = mul(A, B) - -function Base.:*( - A::_SparseMat{<:AbstractMutable}, - B::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, -) - return mul(A, B) -end - -function Base.:*( - A::_SparseMat{<:Any}, - B::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, -) - return mul(A, B) -end - -function Base.:*( - A::_SparseMat{<:AbstractMutable}, - B::LinearAlgebra.Adjoint{<:Any,<:_SparseMat}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, - B::_SparseMat{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:Any,<:_SparseMat}, - B::_SparseMat{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, - B::_SparseMat{<:Any}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:AbstractMutable,<:_SparseMat}, - B::_SparseMat{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:Any,<:_SparseMat}, - B::_SparseMat{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:AbstractMutable,<:_SparseMat}, - B::_SparseMat{<:Any}, -) - return mul(A, B) -end - -function Base.:*( - A::StridedMatrix{<:AbstractMutable}, - B::_SparseMat{<:AbstractMutable}, -) - return mul(A, B) -end - -Base.:*(A::StridedMatrix{<:Any}, B::_SparseMat{<:AbstractMutable}) = mul(A, B) - -Base.:*(A::StridedMatrix{<:AbstractMutable}, B::_SparseMat{<:Any}) = mul(A, B) - -function Base.:*( - A::_SparseMat{<:AbstractMutable}, - B::StridedMatrix{<:AbstractMutable}, -) - return mul(A, B) -end - -Base.:*(A::_SparseMat{<:Any}, B::StridedMatrix{<:AbstractMutable}) = mul(A, B) - -Base.:*(A::_SparseMat{<:AbstractMutable}, B::StridedMatrix{<:Any}) = mul(A, B) - -function Base.:*( - A::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, - B::StridedMatrix{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:Any,<:_SparseMat}, - B::StridedMatrix{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, - B::StridedMatrix{<:Any}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:AbstractMutable,<:_SparseMat}, - B::StridedMatrix{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:Any,<:_SparseMat}, - B::StridedMatrix{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:AbstractMutable,<:_SparseMat}, - B::StridedMatrix{<:Any}, -) - return mul(A, B) +const _LinearAlgebraWrappers = ( + LinearAlgebra.Adjoint, + LinearAlgebra.Transpose, + # TODO(odow): we could expand these overloads to other LinearAlgebra types. + # LinearAlgebra.Symmetric, + # LinearAlgebra.Hermitian, + # LinearAlgebra.Diagonal, + # LinearAlgebra.LowerTriangular, + # LinearAlgebra.UpperTriangular, + # LinearAlgebra.UnitLowerTriangular, + # LinearAlgebra.UnitUpperTriangular, +) + +const _MatrixLike = vcat( + Any[T -> LA{<:T,<:_SparseMat} for LA in _LinearAlgebraWrappers], + Any[T -> _SparseMat{<:T}, T -> StridedMatrix{<:T}], +) + +for f_A in _MatrixLike, f_B in vcat(_MatrixLike, T -> StridedVector{<:T}) + A, mut_A = f_A(Any), f_A(AbstractMutable) + B, mut_B = f_B(Any), f_B(AbstractMutable) + if A <: StridedMatrix && B <: StridedMatrix + continue + end + @eval begin + Base.:*(a::$(mut_A), b::$(B)) = mul(a, b) + Base.:*(a::$(A), b::$(mut_B)) = mul(a, b) + Base.:*(a::$(mut_A), b::$(mut_B)) = mul(a, b) + end end const StridedMaybeAdjOrTransMat{T} = Union{ From fe5e4a1be653c1e897cdd8167d35664c9d6f2e8f Mon Sep 17 00:00:00 2001 From: odow Date: Mon, 21 Nov 2022 08:35:12 +1300 Subject: [PATCH 2/2] Fix formatting --- src/dispatch.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index 61d7f832..3ffa4e36 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -358,7 +358,7 @@ const _LinearAlgebraWrappers = ( const _MatrixLike = vcat( Any[T -> LA{<:T,<:_SparseMat} for LA in _LinearAlgebraWrappers], - Any[T -> _SparseMat{<:T}, T -> StridedMatrix{<:T}], + Any[T->_SparseMat{<:T}, T->StridedMatrix{<:T}], ) for f_A in _MatrixLike, f_B in vcat(_MatrixLike, T -> StridedVector{<:T})