Skip to content

Commit

Permalink
Fix implementation of mul! for AbstractMatrix and AbstractVector (#252)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Jan 9, 2024
1 parent c096dd7 commit 85f5ff2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 65 deletions.
93 changes: 28 additions & 65 deletions src/implementations/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,41 +223,6 @@ function promote_array_mul(
return Vector{promote_sum_mul(S, T)}
end

################################################################################
# We roll our own matmul here (instead of using Julia's generic fallbacks)
# because doing so allows us to accumulate the expressions for the inner loops
# in-place.
# Additionally, Julia's generic fallbacks can be finnicky when your array
# elements aren't `<:Number`.

# This method of `mul!` is adapted from upstream Julia. Note that we
# confuse transpose with adjoint.
#=
> Copyright (c) 2009-2018: Jeff Bezanson, Stefan Karpinski, Viral B. Shah,
> and other contributors:
>
> https://github.com/JuliaLang/julia/contributors
>
> Permission is hereby granted, free of charge, to any person obtaining
> a copy of this software and associated documentation files (the
> "Software"), to deal in the Software without restriction, including
> without limitation the rights to use, copy, modify, merge, publish,
> distribute, sublicense, and/or sell copies of the Software, and to
> permit persons to whom the Software is furnished to do so, subject to
> the following conditions:
>
> The above copyright notice and this permission notice shall be
> included in all copies or substantial portions of the Software.
>
> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
> EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
> MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
> NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
> LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
> OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
> WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
=#

function _dim_check(C::AbstractVector, A::AbstractMatrix, B::AbstractVector)
mB = length(B)
mA, nA = size(A)
Expand Down Expand Up @@ -298,46 +263,44 @@ function _dim_check(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix)
return
end

function _add_mul_array(buffer, C::Vector, A::AbstractMatrix, B::AbstractVector)
Astride = size(A, 1)
# We need a buffer to hold the intermediate multiplication.
@inbounds begin
for k in eachindex(B)
aoffs = (k - 1) * Astride
b = B[k]
for i in Base.OneTo(size(A, 1))
C[i] = buffered_operate!!(buffer, add_mul, C[i], A[aoffs+i], b)
end
end
end # @inbounds
return C
end

# This is incorrect if `C` is `LinearAlgebra.Symmetric` as we modify twice the
# same diagonal element.
function _add_mul_array(buffer, C::Matrix, A::AbstractMatrix, B::AbstractMatrix)
@inbounds begin
for i in 1:size(A, 1), j in 1:size(B, 2)
Ctmp = C[i, j]
for k in 1:size(A, 2)
Ctmp =
buffered_operate!!(buffer, add_mul, Ctmp, A[i, k], B[k, j])
end
C[i, j] = Ctmp
function buffered_operate!(
buffer,
::typeof(add_mul),
C::Vector,
A::AbstractMatrix,
B::AbstractVector,
)
_dim_check(C, A, B)
for (ci, ai) in zip(axes(C, 1), axes(A, 1))
for (aj, bj) in zip(axes(A, 2), axes(B, 1))
C[ci] = buffered_operate!!(buffer, add_mul, C[ci], A[ai, aj], B[bj])
end
end # @inbounds
end
return C
end

function buffered_operate!(
buffer,
::typeof(add_mul),
C::VecOrMat,
C::Matrix,
A::AbstractMatrix,
B::AbstractVecOrMat,
B::AbstractMatrix,
)
_dim_check(C, A, B)
return _add_mul_array(buffer, C, A, B)
for (ci, ai) in zip(axes(C, 1), axes(A, 1))
for (cj, bj) in zip(axes(C, 2), axes(B, 2))
for (aj, bi) in zip(axes(A, 2), axes(B, 1))
C[ci, cj] = buffered_operate!!(
buffer,
add_mul,
C[ci, cj],
A[ai, aj],
B[bi, bj],
)
end
end
end
return C
end

function buffer_for(
Expand Down
2 changes: 2 additions & 0 deletions test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ Base.size(x::Issue65Matrix) = size(x.x)
Base.getindex(x::Issue65Matrix, args...) = getindex(x.x, args...)
Base.axes(x::Issue65Matrix, n) = Issue65OneTo(size(x.x, n))
Base.convert(::Type{Base.OneTo}, x::Issue65OneTo) = Base.OneTo(x.N)
Base.iterate(x::Issue65OneTo) = iterate(Base.OneTo(x.N))
Base.iterate(x::Issue65OneTo, arg) = iterate(Base.OneTo(x.N), arg)

@testset "Issue #65" begin
x = [1.0 2.0; 3.0 4.0]
Expand Down

0 comments on commit 85f5ff2

Please sign in to comment.