diff --git a/Project.toml b/Project.toml index 6e651f8f..2e923ad2 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ julia = "1.6" [extras] OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [targets] -test = ["OffsetArrays"] +test = ["OffsetArrays", "Random"] diff --git a/src/implementations/SparseArrays.jl b/src/implementations/SparseArrays.jl index b00bf9d5..19a7d5df 100644 --- a/src/implementations/SparseArrays.jl +++ b/src/implementations/SparseArrays.jl @@ -213,6 +213,7 @@ function operate!( B::_SparseMat, α::Vararg{Union{T,Scaling},N}, ) where {T,N} + rhs_constants = prod(α) _dim_check(ret, A, B) rowvalA = SparseArrays.rowvals(A) nzvalA = SparseArrays.nonzeros(A) @@ -234,7 +235,7 @@ function operate!( ret.colptr[i] = ip0 = ip k0 = ip - 1 for jp in SparseArrays.nzrange(B, i) - nzB = nzvalB[jp] + nzB = nzvalB[jp] * rhs_constants j = rowvalB[jp] for kp in SparseArrays.nzrange(A, j) k = rowvalA[kp] diff --git a/test/SparseArrays.jl b/test/SparseArrays.jl new file mode 100644 index 00000000..17743c74 --- /dev/null +++ b/test/SparseArrays.jl @@ -0,0 +1,71 @@ +# Copyright (c) 2019 MutableArithmetics.jl contributors +# +# This Source Code Form is subject to the terms of the Mozilla Public License, +# v.2.0. If a copy of the MPL was not distributed with this file, You can obtain +# one at http://mozilla.org/MPL/2.0/. + +module TestInterfaceSparseArrays + +using Test + +import MutableArithmetics +import Random +import SparseArrays + +const MA = MutableArithmetics + +function runtests() + for name in names(@__MODULE__; all = true) + if startswith("$(name)", "test_") + @testset "$(name)" begin + getfield(@__MODULE__, name)() + end + end + end + return +end + +function test_spmatmul() + Random.seed!(1234) + for m in [1, 2, 3, 5, 11] + for n in [1, 2, 3, 5, 11] + A = SparseArrays.sprand(Float64, m, n, 0.5) + B = SparseArrays.sprand(Float64, n, m, 0.5) + ret = SparseArrays.spzeros(Float64, m, m) + MA.operate!(MA.add_mul, ret, A, B) + @test ret ≈ A * B + ret = SparseArrays.spzeros(Float64, m, m) + MA.operate!(MA.add_mul, ret, A, A') + @test ret ≈ A * A' + ret = SparseArrays.spzeros(Float64, m, m) + MA.operate!(MA.add_mul, ret, A, B, 2.0) + @test ret ≈ A * B * 2.0 + ret = SparseArrays.spzeros(Float64, m, m) + MA.operate!(MA.add_mul, ret, A, B, 2.0, 1.5) + @test ret ≈ A * B * 2.0 * 1.5 + end + end + return +end + +function test_spmatmul_prefer_sort() + Random.seed!(1234) + m = n = 100 + p = 0.01 + A = SparseArrays.sprand(Float64, m, n, p) + B = SparseArrays.sprand(Float64, n, m, p) + ret = SparseArrays.spzeros(Float64, m, m) + MA.operate!(MA.add_mul, ret, A, B) + @test ret ≈ A * B + ret = SparseArrays.spzeros(Float64, m, m) + MA.operate!(MA.add_mul, ret, A, B, 2.0) + @test ret ≈ A * B * 2.0 + ret = SparseArrays.spzeros(Float64, m, m) + MA.operate!(MA.add_mul, ret, A, B, 2.0, 1.5) + @test ret ≈ A * B * 2.0 * 1.5 + return +end + +end # module + +TestInterfaceSparseArrays.runtests() diff --git a/test/runtests.jl b/test/runtests.jl index ae1c093c..92d41edc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,8 @@ include("matmul.jl") include("dispatch.jl") include("rewrite.jl") +include("SparseArrays.jl") + # It is easy to introduce macro scoping issues into MutableArithmetics, # particularly ones that rely on `MA` or `MutableArithmetics` being present in # the current scope. To work around that, include the "hygiene" script in a