Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dispatch.jl: remove LinearAlgebra.mul! overloads #182

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 0 additions & 245 deletions src/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,251 +92,6 @@ function LinearAlgebra.diagm(x::AbstractVector{<:AbstractMutable})
return LinearAlgebra.diagm(0 => copyto!(similar(x, ZeroType), x))
end

################################################################################
# Interception of Base's matrix/vector arithmetic machinery
#
# Redirect calls with `eltype(ret) <: AbstractMutable` to `_mul!` to replace it
# with an implementation more efficient than `generic_matmatmul!` and
# `generic_matvecmul!` since it takes into account the mutability of the
# arithmetic. We need `args...` because SparseArrays` also gives `α` and `β`
# arguments.

function _mul!(output, A, B, α, β)
# See SparseArrays/src/linalg.jl
if !isone(β)
if iszero(β)
operate!(zero, output)
else
rmul!(output, scaling(β))
end
end
return operate!(add_mul, output, A, B, scaling(α))
end

function _mul!(output, A, B, α)
operate!(zero, output)
return operate!(add_mul, output, A, B, scaling(α))
end

# LinearAlgebra uses `Base.promote_op(LinearAlgebra.matprod, ...)` to try to
# infere the return type. If the operation is not supported, it returns
# `Union{}`.
function _mul!(output::AbstractArray{Union{}}, A, B)
# Normally, if the product is not supported, this should redirect to
# `MA.promote_operation(*, ...)` which redirects to
# `zero(...) * zero(...)` which should throw an appropriate error.
# For example, in JuMP, it would say that you cannot multiply quadratic
# expressions with an affine expression for instance.
ProdType = promote_array_mul(typeof(A), typeof(B))
# If we arrived here, it means that we have found a type for `output`, even
# if LinearAlgebra couldn't. This is most probably a but so let's provide
# extensive information to help debugging.
return error(
"Cannot multiply a `$(typeof(A))` with a `$(typeof(B))` because the " *
"sum of the product of a `$(eltype(A))` and a `$(eltype(B))` could " *
"not be inferred so a `$(typeof(output))` allocated to store the " *
"output of the multiplication instead of a `$ProdType`.",
)
end

_mul!(output, A, B) = operate_to!(output, *, A, B)

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::AbstractVecOrMat,
B::AbstractVecOrMat,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::AbstractVecOrMat,
B::AbstractVector,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::AbstractVector,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::AbstractVector,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::LinearAlgebra.AbstractTriangular,
B::AbstractVector,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::AbstractMatrix,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::AbstractMatrix,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.AbstractTriangular,
B::AbstractMatrix,
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::AbstractMatrix,
B::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::AbstractMatrix,
B::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
)
return _mul!(ret, A, B)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::AbstractVecOrMat,
B::AbstractVecOrMat,
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::AbstractVecOrMat,
B::AbstractVector,
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::AbstractVector,
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractVector{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::AbstractVector,
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::AbstractMatrix,
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::AbstractMatrix,
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::AbstractMatrix,
B::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::AbstractMatrix,
B::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

function LinearAlgebra.mul!(
ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
α::Number,
β::Number,
)
return _mul!(ret, A, B, α, β)
end

# SparseArrays promotes the element types of `A` and `B` to the same type which,
# always produce quadratic expressions for JuMP even if only one of them was
# affine and the other one constant. Moreover, it does not always go through
Expand Down
26 changes: 0 additions & 26 deletions test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,29 +64,6 @@ end
@test MA.operate(convert, Int, 1) === 1
end

const EXPECTED_ERROR = string(
"Cannot multiply a `Matrix{NoProdMutable}` with a ",
"`Matrix{NoProdMutable}` because the sum of the product of a ",
"`NoProdMutable` and a `NoProdMutable` could not be inferred so a ",
"`Matrix{Union{}}` allocated to store the output of the ",
"multiplication instead of a `Matrix{$(Int)}`.",
)

struct NoProdMutable <: MA.AbstractMutable end
function MA.promote_operation(
::typeof(*),
::Type{NoProdMutable},
::Type{NoProdMutable},
)
return Int # Dummy result just to test error message
end

function unsupported_product()
A = [NoProdMutable() for i in 1:2, j in 1:2]
err = ErrorException(EXPECTED_ERROR)
@test_throws err A * A
end

@testset "Errors" begin
@testset "`promote_op` error" begin
AT = CustomArray{Int,3}
Expand Down Expand Up @@ -126,9 +103,6 @@ end
)
@test_throws err MA.operate!(+, A, B)
end
@testset "unsupported_product" begin
unsupported_product()
end
end

@testset "Matrix multiplication" begin
Expand Down