diff --git a/src/dispatch.jl b/src/dispatch.jl index 725c7b3f..1b617a13 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -92,251 +92,6 @@ function LinearAlgebra.diagm(x::AbstractVector{<:AbstractMutable}) return LinearAlgebra.diagm(0 => copyto!(similar(x, ZeroType), x)) end -################################################################################ -# Interception of Base's matrix/vector arithmetic machinery -# -# Redirect calls with `eltype(ret) <: AbstractMutable` to `_mul!` to replace it -# with an implementation more efficient than `generic_matmatmul!` and -# `generic_matvecmul!` since it takes into account the mutability of the -# arithmetic. We need `args...` because SparseArrays` also gives `α` and `β` -# arguments. - -function _mul!(output, A, B, α, β) - # See SparseArrays/src/linalg.jl - if !isone(β) - if iszero(β) - operate!(zero, output) - else - rmul!(output, scaling(β)) - end - end - return operate!(add_mul, output, A, B, scaling(α)) -end - -function _mul!(output, A, B, α) - operate!(zero, output) - return operate!(add_mul, output, A, B, scaling(α)) -end - -# LinearAlgebra uses `Base.promote_op(LinearAlgebra.matprod, ...)` to try to -# infere the return type. If the operation is not supported, it returns -# `Union{}`. -function _mul!(output::AbstractArray{Union{}}, A, B) - # Normally, if the product is not supported, this should redirect to - # `MA.promote_operation(*, ...)` which redirects to - # `zero(...) * zero(...)` which should throw an appropriate error. - # For example, in JuMP, it would say that you cannot multiply quadratic - # expressions with an affine expression for instance. - ProdType = promote_array_mul(typeof(A), typeof(B)) - # If we arrived here, it means that we have found a type for `output`, even - # if LinearAlgebra couldn't. This is most probably a but so let's provide - # extensive information to help debugging. - return error( - "Cannot multiply a `$(typeof(A))` with a `$(typeof(B))` because the " * - "sum of the product of a `$(eltype(A))` and a `$(eltype(B))` could " * - "not be inferred so a `$(typeof(output))` allocated to store the " * - "output of the multiplication instead of a `$ProdType`.", - ) -end - -_mul!(output, A, B) = operate_to!(output, *, A, B) - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::AbstractVecOrMat, - B::AbstractVecOrMat, -) - return _mul!(ret, A, B) -end - -function LinearAlgebra.mul!( - ret::AbstractVector{<:AbstractMutable}, - A::AbstractVecOrMat, - B::AbstractVector, -) - return _mul!(ret, A, B) -end - -function LinearAlgebra.mul!( - ret::AbstractVector{<:AbstractMutable}, - A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat}, - B::AbstractVector, -) - return _mul!(ret, A, B) -end - -function LinearAlgebra.mul!( - ret::AbstractVector{<:AbstractMutable}, - A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat}, - B::AbstractVector, -) - return _mul!(ret, A, B) -end - -function LinearAlgebra.mul!( - ret::AbstractVector{<:AbstractMutable}, - A::LinearAlgebra.AbstractTriangular, - B::AbstractVector, -) - return _mul!(ret, A, B) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat}, - B::AbstractMatrix, -) - return _mul!(ret, A, B) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat}, - B::AbstractMatrix, -) - return _mul!(ret, A, B) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::LinearAlgebra.AbstractTriangular, - B::AbstractMatrix, -) - return _mul!(ret, A, B) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::AbstractMatrix, - B::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat}, -) - return _mul!(ret, A, B) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::AbstractMatrix, - B::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat}, -) - return _mul!(ret, A, B) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat}, - B::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat}, -) - return _mul!(ret, A, B) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat}, - B::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat}, -) - return _mul!(ret, A, B) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::AbstractVecOrMat, - B::AbstractVecOrMat, - α::Number, - β::Number, -) - return _mul!(ret, A, B, α, β) -end - -function LinearAlgebra.mul!( - ret::AbstractVector{<:AbstractMutable}, - A::AbstractVecOrMat, - B::AbstractVector, - α::Number, - β::Number, -) - return _mul!(ret, A, B, α, β) -end - -function LinearAlgebra.mul!( - ret::AbstractVector{<:AbstractMutable}, - A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat}, - B::AbstractVector, - α::Number, - β::Number, -) - return _mul!(ret, A, B, α, β) -end - -function LinearAlgebra.mul!( - ret::AbstractVector{<:AbstractMutable}, - A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat}, - B::AbstractVector, - α::Number, - β::Number, -) - return _mul!(ret, A, B, α, β) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat}, - B::AbstractMatrix, - α::Number, - β::Number, -) - return _mul!(ret, A, B, α, β) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat}, - B::AbstractMatrix, - α::Number, - β::Number, -) - return _mul!(ret, A, B, α, β) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::AbstractMatrix, - B::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat}, - α::Number, - β::Number, -) - return _mul!(ret, A, B, α, β) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::AbstractMatrix, - B::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat}, - α::Number, - β::Number, -) - return _mul!(ret, A, B, α, β) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat}, - B::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat}, - α::Number, - β::Number, -) - return _mul!(ret, A, B, α, β) -end - -function LinearAlgebra.mul!( - ret::AbstractMatrix{<:AbstractMutable}, - A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat}, - B::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat}, - α::Number, - β::Number, -) - return _mul!(ret, A, B, α, β) -end - # SparseArrays promotes the element types of `A` and `B` to the same type which, # always produce quadratic expressions for JuMP even if only one of them was # affine and the other one constant. Moreover, it does not always go through diff --git a/test/matmul.jl b/test/matmul.jl index bacbbef1..3debee31 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -64,29 +64,6 @@ end @test MA.operate(convert, Int, 1) === 1 end -const EXPECTED_ERROR = string( - "Cannot multiply a `Matrix{NoProdMutable}` with a ", - "`Matrix{NoProdMutable}` because the sum of the product of a ", - "`NoProdMutable` and a `NoProdMutable` could not be inferred so a ", - "`Matrix{Union{}}` allocated to store the output of the ", - "multiplication instead of a `Matrix{$(Int)}`.", -) - -struct NoProdMutable <: MA.AbstractMutable end -function MA.promote_operation( - ::typeof(*), - ::Type{NoProdMutable}, - ::Type{NoProdMutable}, -) - return Int # Dummy result just to test error message -end - -function unsupported_product() - A = [NoProdMutable() for i in 1:2, j in 1:2] - err = ErrorException(EXPECTED_ERROR) - @test_throws err A * A -end - @testset "Errors" begin @testset "`promote_op` error" begin AT = CustomArray{Int,3}