Skip to content

Commit

Permalink
Add tests and fix for SparseArray matrix multiplication (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Nov 17, 2022
1 parent 201aafb commit b46f175
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 2 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ julia = "1.6"

[extras]
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[targets]
test = ["OffsetArrays"]
test = ["OffsetArrays", "Random"]
3 changes: 2 additions & 1 deletion src/implementations/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ function operate!(
B::_SparseMat,
α::Vararg{Union{T,Scaling},N},
) where {T,N}
rhs_constants = prod(α)
_dim_check(ret, A, B)
rowvalA = SparseArrays.rowvals(A)
nzvalA = SparseArrays.nonzeros(A)
Expand All @@ -234,7 +235,7 @@ function operate!(
ret.colptr[i] = ip0 = ip
k0 = ip - 1
for jp in SparseArrays.nzrange(B, i)
nzB = nzvalB[jp]
nzB = nzvalB[jp] * rhs_constants
j = rowvalB[jp]
for kp in SparseArrays.nzrange(A, j)
k = rowvalA[kp]
Expand Down
71 changes: 71 additions & 0 deletions test/SparseArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2019 MutableArithmetics.jl contributors
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v.2.0. If a copy of the MPL was not distributed with this file, You can obtain
# one at http://mozilla.org/MPL/2.0/.

module TestInterfaceSparseArrays

using Test

import MutableArithmetics
import Random
import SparseArrays

const MA = MutableArithmetics

function runtests()
for name in names(@__MODULE__; all = true)
if startswith("$(name)", "test_")
@testset "$(name)" begin
getfield(@__MODULE__, name)()
end
end
end
return
end

function test_spmatmul()
Random.seed!(1234)
for m in [1, 2, 3, 5, 11]
for n in [1, 2, 3, 5, 11]
A = SparseArrays.sprand(Float64, m, n, 0.5)
B = SparseArrays.sprand(Float64, n, m, 0.5)
ret = SparseArrays.spzeros(Float64, m, m)
MA.operate!(MA.add_mul, ret, A, B)
@test ret A * B
ret = SparseArrays.spzeros(Float64, m, m)
MA.operate!(MA.add_mul, ret, A, A')
@test ret A * A'
ret = SparseArrays.spzeros(Float64, m, m)
MA.operate!(MA.add_mul, ret, A, B, 2.0)
@test ret A * B * 2.0
ret = SparseArrays.spzeros(Float64, m, m)
MA.operate!(MA.add_mul, ret, A, B, 2.0, 1.5)
@test ret A * B * 2.0 * 1.5
end
end
return
end

function test_spmatmul_prefer_sort()
Random.seed!(1234)
m = n = 100
p = 0.01
A = SparseArrays.sprand(Float64, m, n, p)
B = SparseArrays.sprand(Float64, n, m, p)
ret = SparseArrays.spzeros(Float64, m, m)
MA.operate!(MA.add_mul, ret, A, B)
@test ret A * B
ret = SparseArrays.spzeros(Float64, m, m)
MA.operate!(MA.add_mul, ret, A, B, 2.0)
@test ret A * B * 2.0
ret = SparseArrays.spzeros(Float64, m, m)
MA.operate!(MA.add_mul, ret, A, B, 2.0, 1.5)
@test ret A * B * 2.0 * 1.5
return
end

end # module

TestInterfaceSparseArrays.runtests()
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ include("matmul.jl")
include("dispatch.jl")
include("rewrite.jl")

include("SparseArrays.jl")

# It is easy to introduce macro scoping issues into MutableArithmetics,
# particularly ones that rely on `MA` or `MutableArithmetics` being present in
# the current scope. To work around that, include the "hygiene" script in a
Expand Down

0 comments on commit b46f175

Please sign in to comment.