Skip to content

Commit

Permalink
support threadpools, and multiple array arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter committed Jan 30, 2024
1 parent 54c0947 commit a09725d
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 60 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ StableTasks = "91464d47-22a1-43fe-8b7f-2d57ee82463f"
[compat]
BangBang = "0.4"
ChunkSplitters = "2.1"
StableTasks = "0.1.2"
StableTasks = "0.1.4"
julia = "1.6"

[extras]
Expand Down
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ Unlike most JuliaFolds2 packages, it is not built off of
Rather, OhMyThreads is meant to be a simpler, more maintainable, and more accessible alternative to packages
like [ThreadsX.jl](https://github.com/tkf/ThreadsX.jl) or [Folds.jl](https://github.com/JuliaFolds2/Folds.jl).

OhMyThreads.jl re-exports the very useful function `chunks` from
OhMyThreads.jl re-exports the function `chunks` from
[ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl), and provides the following functions:

<details><summary> tmapreduce </summary>
<p>

```
tmapreduce(f, op, A::AbstractArray;
tmapreduce(f, op, A::AbstractArray...;
[init],
nchunks::Int = nthreads(),
split::Symbol = :batch,
Expand Down Expand Up @@ -63,7 +63,7 @@ ____________________________
<p>

```
treducemap(op, f, A::AbstractArray;
treducemap(op, f, A::AbstractArray...;
[init],
nchunks::Int = nthreads(),
split::Symbol = :batch,
Expand Down Expand Up @@ -107,7 +107,8 @@ ____________________________
<p>

```
treduce(op, A::AbstractArray; [init],
treduce(op, A::AbstractArray...;
[init],
nchunks::Int = nthreads(),
split::Symbol = :batch,
schedule::Symbol =:dynamic,
Expand Down Expand Up @@ -150,7 +151,7 @@ ____________________________
<p>

```
tmap(f, [OutputElementType], A::AbstractArray;
tmap(f, [OutputElementType], A::AbstractArray...;
nchunks::Int = nthreads(),
split::Symbol = :batch,
schedule::Symbol =:dynamic)
Expand All @@ -174,7 +175,7 @@ ____________________________
<p>

```
tmap!(f, out, A::AbstractArray;
tmap!(f, out, A::AbstractArray...;
nchunks::Int = nthreads(),
split::Symbol = :batch,
schedule::Symbol =:dynamic)
Expand All @@ -198,7 +199,7 @@ ____________________________
<p>

```
tforeach(f, A::AbstractArray;
tforeach(f, A::AbstractArray...;
nchunks::Int = nthreads(),
split::Symbol = :batch,
schedule::Symbol =:dynamic) :: Nothing
Expand Down
13 changes: 7 additions & 6 deletions src/OhMyThreads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using ChunkSplitters: chunks
export chunks, treduce, tmapreduce, treducemap, tmap, tmap!, tforeach, tcollect

"""
tmapreduce(f, op, A::AbstractArray;
tmapreduce(f, op, A::AbstractArray...;
[init],
nchunks::Int = nthreads(),
split::Symbol = :batch,
Expand Down Expand Up @@ -42,7 +42,7 @@ needed if you are using a `:static` schedule, since the `:dynamic` schedule is u
function tmapreduce end

"""
treducemap(op, f, A::AbstractArray;
treducemap(op, f, A::AbstractArray...;
[init],
nchunks::Int = nthreads(),
split::Symbol = :batch,
Expand Down Expand Up @@ -79,7 +79,8 @@ function treducemap end


"""
treduce(op, A::AbstractArray; [init],
treduce(op, A::AbstractArray...;
[init],
nchunks::Int = nthreads(),
split::Symbol = :batch,
schedule::Symbol =:dynamic,
Expand Down Expand Up @@ -114,7 +115,7 @@ needed if you are using a `:static` schedule, since the `:dynamic` schedule is u
function treduce end

"""
tforeach(f, A::AbstractArray;
tforeach(f, A::AbstractArray...;
nchunks::Int = nthreads(),
split::Symbol = :batch,
schedule::Symbol =:dynamic) :: Nothing
Expand All @@ -134,7 +135,7 @@ A multithreaded function like `Base.foreach`. Apply `f` to each element of `A` o
function tforeach end

"""
tmap(f, [OutputElementType], A::AbstractArray;
tmap(f, [OutputElementType], A::AbstractArray...;
nchunks::Int = nthreads(),
split::Symbol = :batch,
schedule::Symbol =:dynamic)
Expand All @@ -154,7 +155,7 @@ fewer allocations than the version where `OutputElementType` is not specified.
function tmap end

"""
tmap!(f, out, A::AbstractArray;
tmap!(f, out, A::AbstractArray...;
nchunks::Int = nthreads(),
split::Symbol = :batch,
schedule::Symbol =:dynamic)
Expand Down
67 changes: 43 additions & 24 deletions src/implementation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,77 +9,96 @@ using Base.Threads: nthreads, @threads

using BangBang: BangBang, append!!

function tmapreduce(f, op, A;
function tmapreduce(f, op, Arrs...;
nchunks::Int=nthreads(),
split::Symbol=:batch,
schedule::Symbol=:dynamic,
outputtype::Type=Any,
kwargs...)
if schedule === :dynamic
_tmapreduce(f, op, A, outputtype, nchunks, split; kwargs...)
if schedule === :dynamic
_tmapreduce(f, op, Arrs, outputtype, nchunks, split, :default; kwargs...)
elseif schedule === :interactive
_tmapreduce(f, op, Arrs, outputtype, nchunks, split, :interactive; kwargs...)
elseif schedule === :static
_tmapreduce_static(f, op, A, outputtype, nchunks, split; kwargs...)
_tmapreduce_static(f, op, Arrs, outputtype, nchunks, split; kwargs...)
else
schedule_err(schedule)
end
end
@noinline schedule_err(s) = error(ArgumentError("Invalid schedule option: $s, expected :dynamic or :static."))

treducemap(op, f, A; kwargs...) = tmapreduce(f, op, A; kwargs...)
treducemap(op, f, A...; kwargs...) = tmapreduce(f, op, A...; kwargs...)

function _tmapreduce(f, op, A, ::Type{OutputType}, nchunks, split=:batch; kwargs...)::OutputType where {OutputType}
tasks = map(chunks(A; n=nchunks, split)) do inds
@spawn mapreduce(f, op, @view(A[inds]); kwargs...)
function _tmapreduce(f, op, Arrs, ::Type{OutputType}, nchunks, split, schedule; kwargs...)::OutputType where {OutputType}
check_all_have_same_indices(Arrs)
tasks = map(chunks(first(Arrs); n=nchunks, split)) do inds
args = map(A -> A[inds], Arrs)
@spawn schedule mapreduce(f, op, args...; kwargs...)
end
mapreduce(fetch, op, tasks)
end

function _tmapreduce_static(f, op, A, ::Type{OutputType}, nchunks, split; kwargs...) where {OutputType}
function _tmapreduce_static(f, op, Arrs, ::Type{OutputType}, nchunks, split; kwargs...) where {OutputType}
nt = nthreads()
check_all_have_same_indices(Arrs)
if nchunks > nt
# We could implement strategies, like round-robin, in the future
throw(ArgumentError("We currently only support `nchunks <= nthreads()` for static scheduling."))
end
tasks = map(enumerate(chunks(A; n=nchunks, split))) do (c, inds)
tasks = map(enumerate(chunks(first(Arrs); n=nchunks, split))) do (c, inds)
tid = @inbounds nthtid(c)
@spawnat tid mapreduce(f, op, @view(A[inds]); kwargs...)
args = map(A -> A[inds], Arrs)
@spawnat tid mapreduce(f, op, args...; kwargs...)
end
mapreduce(fetch, op, tasks)
end


check_all_have_same_indices(Arrs) = let A = first(Arrs), Arrs = Arrs[2:end]
if !all(B -> eachindex(A) == eachindex(B), Arrs)
error("The indices of the input arrays must match the indices of the output array.")
end
end

#-------------------------------------------------------------

function treduce(op, A; kwargs...)
tmapreduce(identity, op, A; kwargs...)
function treduce(op, A...; kwargs...)
tmapreduce(identity, op, A...; kwargs...)
end

#-------------------------------------------------------------

function tforeach(f, A::AbstractArray; kwargs...)::Nothing
tmapreduce(f, (l, r) -> l, A; kwargs..., init=nothing, outputtype=Nothing)
function tforeach(f, A...; kwargs...)::Nothing
tmapreduce(f, (l, r) -> l, A...; kwargs..., init=nothing, outputtype=Nothing)
end

#-------------------------------------------------------------

function tmap(f, ::Type{T}, A::AbstractArray; kwargs...) where {T}
tmap!(f, similar(A, T), A; kwargs...)
function tmap(f, ::Type{T}, A::AbstractArray, _Arrs::AbstractArray...; kwargs...) where {T}
Arrs = (A, _Arrs...)
tmap!(f, similar(A, T), Arrs...; kwargs...)
end

function tmap(f, A; nchunks::Int= 2*nthreads(), kwargs...)
function tmap(f, A::AbstractArray, _Arrs::AbstractArray...; nchunks::Int=nthreads(), kwargs...)
Arrs = (A, _Arrs...)
check_all_have_same_indices(Arrs)
the_chunks = collect(chunks(A; n=nchunks))
# It's vital that we force split=:batch here because we're not doing a commutative operation!
v = tmapreduce(append!!, the_chunks; kwargs..., nchunks, split=:batch) do inds
map(f, @view A[inds])
args = map(A -> @view(A[inds]), Arrs)
map(f, args...)
end
reshape(v, size(A)...)
end

@propagate_inbounds function tmap!(f, out, A::AbstractArray; kwargs...)
@boundscheck eachindex(out) == eachindex(A) || error("The indices of the input array must match the indices of the output array.")
@propagate_inbounds function tmap!(f, out, A::AbstractArray, _Arrs::AbstractArray...; kwargs...)
Arrs = (A, _Arrs...)
@boundscheck check_all_have_same_indices((out, Arrs...))
# It's vital that we force split=:batch here because we're not doing a commutative operation!
tforeach(eachindex(A); kwargs..., split=:batch) do i
fAi = f(@inbounds A[i])
out[i] = fAi
tforeach(eachindex(out); kwargs..., split=:batch) do i
args = map(A -> @inbounds(A[i]), Arrs)
res = f(args...)
out[i] = res
end
out
end
Expand Down
44 changes: 22 additions & 22 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,43 +1,43 @@
using Test, OhMyThreads

@testset "Basics" begin
for (~, f, op, itr) [
(isapprox, sin, +, rand(ComplexF64, 10, 10)),
(isapprox, cos, max, 1:100000),
(==, round, vcat, randn(1000)),
(==, last, *, [1=>"a", 2=>"b", 3=>"c", 4=>"d", 5=>"e"])
for (~, f, op, itrs) [
(isapprox, sin∘*, +, (rand(ComplexF64, 10, 10), rand(-10:10, 10, 10))),
(isapprox, cos, max, (1:100000,)),
(==, round, vcat, (randn(1000),)),
(==, last, *, ([1=>"a", 2=>"b", 3=>"c", 4=>"d", 5=>"e"],))
]
@testset for schedule (:static, :dynamic,)
@testset for schedule (:static, :dynamic, :interactive)
@testset for split (:batch, :scatter)
if split == :scatter # scatter only works for commutative operators
if op (vcat, *)
continue
end
end
for nchunks (1, 2, 6, 10, 100)
for nchunks (1, 2, 6, 10)
if schedule == :static && nchunks > Threads.nthreads()
continue
end
kwargs = (; schedule, split, nchunks)
mapreduce_f_op_itr = mapreduce(f, op, itr)
@test tmapreduce(f, op, itr; kwargs...) ~ mapreduce_f_op_itr
@test treducemap(op, f, itr; kwargs...) ~ mapreduce_f_op_itr
@test treduce(op, f.(itr); kwargs...) ~ mapreduce_f_op_itr
mapreduce_f_op_itr = mapreduce(f, op, itrs...)
@test tmapreduce(f, op, itrs...; kwargs...) ~ mapreduce_f_op_itr
@test treducemap(op, f, itrs...; kwargs...) ~ mapreduce_f_op_itr
@test treduce(op, f.(itrs...); kwargs...) ~ mapreduce_f_op_itr

map_f_itr = map(f, itr)
@test all(tmap(f, Any, itr; kwargs...) .~ map_f_itr)
@test all(tcollect(Any, (f(x) for x in itr); kwargs...) .~ map_f_itr)
@test all(tcollect(Any, f.(itr); kwargs...) .~ map_f_itr)
map_f_itr = map(f, itrs...)
@test all(tmap(f, Any, itrs...; kwargs...) .~ map_f_itr)
@test all(tcollect(Any, (f(x...) for x in collect(zip(itrs...))); kwargs...) .~ map_f_itr)
@test all(tcollect(Any, f.(itrs...); kwargs...) .~ map_f_itr)

@test tmap(f, itr; kwargs...) ~ map_f_itr
@test tcollect((f(x) for x in itr); kwargs...) ~ map_f_itr
@test tcollect(f.(itr); kwargs...) ~ map_f_itr
@test tmap(f, itrs...; kwargs...) ~ map_f_itr
@test tcollect((f(x...) for x in collect(zip(itrs...))); kwargs...) ~ map_f_itr
@test tcollect(f.(itrs...); kwargs...) ~ map_f_itr

RT = Core.Compiler.return_type(f, Tuple{eltype(itr)})
RT = Core.Compiler.return_type(f, Tuple{eltype.(itrs)...})

@test tmap(f, RT, itr; kwargs...) ~ map_f_itr
@test tcollect(RT, (f(x) for x in itr); kwargs...) ~ map_f_itr
@test tcollect(RT, f.(itr); kwargs...) ~ map_f_itr
@test tmap(f, RT, itrs...; kwargs...) ~ map_f_itr
@test tcollect(RT, (f(x...) for x in collect(zip(itrs...))); kwargs...) ~ map_f_itr
@test tcollect(RT, f.(itrs...); kwargs...) ~ map_f_itr
end
end
end
Expand Down

0 comments on commit a09725d

Please sign in to comment.