From a548208666b36f3b769704685730823ba742c26c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 10 Oct 2024 13:38:57 +0200 Subject: [PATCH 1/2] Allow init argument for sum --- src/dispatch.jl | 4 ++-- src/reduce.jl | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index 2960906b..20aee65b 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -13,8 +13,8 @@ abstract type AbstractMutable end -function Base.sum(a::AbstractArray{<:AbstractMutable}) - return operate(sum, a) +function Base.sum(a::AbstractArray{<:AbstractMutable}; kws...) + return operate(sum, a; kws...) end # When doing `x'y` where the elements of `x` and/or `y` are arrays, redirecting diff --git a/src/reduce.jl b/src/reduce.jl index c1644470..9a909eee 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -50,11 +50,10 @@ function fused_map_reduce(op::F, args::Vararg{Any,N}) where {F<:Function,N} return accumulator end -function operate(::typeof(sum), a::AbstractArray) - return mapreduce( - identity, - add!!, - a; - init = zero(promote_operation(+, eltype(a), eltype(a))), - ) +function operate( + ::typeof(sum), + a::AbstractArray; + init = zero(promote_operation(+, eltype(a), eltype(a))), +) + return mapreduce(identity, add!!, a; init) end From 978e294e97cf00d926c0f11b960562d23c57404a Mon Sep 17 00:00:00 2001 From: odow Date: Fri, 11 Oct 2024 09:04:45 +1300 Subject: [PATCH 2/2] Update --- src/dispatch.jl | 4 ++-- test/dispatch.jl | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index 20aee65b..a5a1c28f 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -13,8 +13,8 @@ abstract type AbstractMutable end -function Base.sum(a::AbstractArray{<:AbstractMutable}; kws...) - return operate(sum, a; kws...) +function Base.sum(a::AbstractArray{<:AbstractMutable}; kwargs...) + return operate(sum, a; kwargs...) end # When doing `x'y` where the elements of `x` and/or `y` are arrays, redirecting diff --git a/test/dispatch.jl b/test/dispatch.jl index 84ba729d..722f0a0a 100644 --- a/test/dispatch.jl +++ b/test/dispatch.jl @@ -106,3 +106,28 @@ end end end end + +function non_mutable_sum_pr306(x) + y = zero(eltype(x)) + for xi in x + y += xi + end + return y +end + +@testset "sum_with_init" begin + x = convert(Vector{DummyBigInt}, 1:100) + # compilation + @allocated sum(x) + @allocated sum(x; init = DummyBigInt(0)) + @allocated non_mutable_sum_pr306(x) + # now test actual allocations + no_init = @allocated sum(x) + with_init = @allocated sum(x; init = DummyBigInt(0)) + no_ma = @allocated non_mutable_sum_pr306(x) + # There's an additional 16 bytes for kwarg version. Upper bound by 40 to be + # safe between Julia versions + @test with_init <= no_init + 40 + # MA is at least 10-times better than no MA for this example + @test 10 * with_init < no_ma +end