Skip to content

Commit

Permalink
dispatch.jl: remove LinearAlgebra.mul! overloads
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Nov 21, 2022
1 parent 352aa8f commit f2a728d
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 268 deletions.
245 changes: 0 additions & 245 deletions src/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,251 +92,6 @@ function LinearAlgebra.diagm(x::AbstractVector{<:AbstractMutable})
return LinearAlgebra.diagm(0 => copyto!(similar(x, ZeroType), x))
end

################################################################################
# Interception of Base's matrix/vector arithmetic machinery
#
# Redirect calls with `eltype(ret) <: AbstractMutable` to `_mul!` to replace it
# with an implementation more efficient than `generic_matmatmul!` and
# `generic_matvecmul!` since it takes into account the mutability of the
# arithmetic. We need `args...` because SparseArrays` also gives `α` and `β`
# arguments.

function _mul!(output, A, B, α, β)
# See SparseArrays/src/linalg.jl
if !isone(β)
if iszero(β)
operate!(zero, output)
else
rmul!(output, scaling(β))
end
end
return operate!(add_mul, output, A, B, scaling(α))
end

function _mul!(output, A, B, α)
operate!(zero, output)
return operate!(add_mul, output, A, B, scaling(α))
end

# LinearAlgebra uses `Base.promote_op(LinearAlgebra.matprod, ...)` to try to
# infere the return type. If the operation is not supported, it returns
# `Union{}`.
function _mul!(output::AbstractArray{Union{}}, A, B)
# Normally, if the product is not supported, this should redirect to
# `MA.promote_operation(*, ...)` which redirects to
# `zero(...) * zero(...)` which should throw an appropriate error.
# For example, in JuMP, it would say that you cannot multiply quadratic
# expressions with an affine expression for instance.
ProdType = promote_array_mul(typeof(A), typeof(B))
# If we arrived here, it means that we have found a type for `output`, even
# if LinearAlgebra couldn't. This is most probably a but so let's provide
# extensive information to help debugging.
return error(
"Cannot multiply a `$(typeof(A))` with a `$(typeof(B))` because the " *
"sum of the product of a `$(eltype(A))` and a `$(eltype(B))` could " *
"not be inferred so a `$(typeof(output))` allocated to store the " *
"output of the multiplication instead of a `$ProdType`.",
)
end

_mul!(output, A, B) = operate_to!(output, *, A, B)

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::AbstractVecOrMat,
B::AbstractVecOrMat,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::AbstractVecOrMat,
B::AbstractVector,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::AbstractVector,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::AbstractVector,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::LinearAlgebra.AbstractTriangular,
B::AbstractVector,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::AbstractMatrix,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::AbstractMatrix,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.AbstractTriangular,
B::AbstractMatrix,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::AbstractMatrix,
B::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::AbstractMatrix,
B::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::AbstractVecOrMat,
B::AbstractVecOrMat,
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::AbstractVecOrMat,
B::AbstractVector,
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::AbstractVector,
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::AbstractVector,
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::AbstractMatrix,
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::AbstractMatrix,
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::AbstractMatrix,
B::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::AbstractMatrix,
B::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

# SparseArrays promotes the element types of `A` and `B` to the same type which,
# always produce quadratic expressions for JuMP even if only one of them was
# affine and the other one constant. Moreover, it does not always go through
Expand Down
23 changes: 0 additions & 23 deletions test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,29 +64,6 @@ end
@test MA.operate(convert, Int, 1) === 1
end

const EXPECTED_ERROR = string(
"Cannot multiply a `Matrix{NoProdMutable}` with a ",
"`Matrix{NoProdMutable}` because the sum of the product of a ",
"`NoProdMutable` and a `NoProdMutable` could not be inferred so a ",
"`Matrix{Union{}}` allocated to store the output of the ",
"multiplication instead of a `Matrix{$(Int)}`.",
)

struct NoProdMutable <: MA.AbstractMutable end
function MA.promote_operation(
::typeof(*),
::Type{NoProdMutable},
::Type{NoProdMutable},
)
return Int # Dummy result just to test error message
end

function unsupported_product()
A = [NoProdMutable() for i in 1:2, j in 1:2]
err = ErrorException(EXPECTED_ERROR)
@test_throws err A * A
end

@testset "Errors" begin
@testset "`promote_op` error" begin
AT = CustomArray{Int,3}
Expand Down

0 comments on commit f2a728d

Please sign in to comment.