Skip to content

Commit

Permalink
Fix sum(::AbstractArray{<:AbstractMutable}; dims)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Oct 15, 2024
1 parent 3f811de commit 9851148
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@

abstract type AbstractMutable end

function Base.sum(a::AbstractArray{<:AbstractMutable}; kwargs...)
return operate(sum, a; kwargs...)
function Base.sum(
a::AbstractArray{T};
dims = missing,
init = zero(promote_operation(+, T, T)),
) where {T<:AbstractMutable}
if !ismissing(dims)
return mapreduce(identity, Base.add_sum, a; dims, init)
end
return operate(sum, a; init)
end

# When doing `x'y` where the elements of `x` and/or `y` are arrays, redirecting
Expand Down
12 changes: 12 additions & 0 deletions test/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,15 @@ end
# MA is at least 10-times better than no MA for this example
@test 10 * with_init < no_ma
end

@testset "sum_with_init_and_dims" begin
x = reshape(convert(Vector{DummyBigInt}, 1:12), 3, 4)
X = reshape(1:12, 3, 4)
for dims in (1, 2, :, 1:2, (1, 2))
# Without (; init)
@test MA.isequal_canonical(sum(x; dims), DummyBigInt.(sum(X; dims)))
# With (; init)
y = sum(x; init = DummyBigInt(0), dims)
@test MA.isequal_canonical(y, DummyBigInt.(sum(X; dims)))
end
end

0 comments on commit 9851148

Please sign in to comment.