diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d241221..50a15b5 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -10,7 +10,6 @@ jobs: fail-fast: false matrix: version: - - '1.6' - '1.9' - '1.10.0' - 'nightly' diff --git a/Project.toml b/Project.toml index 89ea512..e8b41a8 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,7 @@ authors = ["Mason Protter "] version = "0.1.3" [compat] -julia = "1.6" +julia = "1.9" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/README.md b/README.md index 36b763f..8f744d5 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,7 @@ StableTasks is a simple package with one main API `StableTasks.@spawn` (not exported by default). -It works like `Threads.@spawn`, except it is *type stable* to `fetch` from (and it does not yet support threadpools -other than the default threadpool). +It works like `Threads.@spawn`, except it is *type stable* to `fetch` from. ``` julia julia> Core.Compiler.return_type(() -> fetch(StableTasks.@spawn 1 + 1), Tuple{}) diff --git a/src/StableTasks.jl b/src/StableTasks.jl index c2b2dc9..7db8f72 100644 --- a/src/StableTasks.jl +++ b/src/StableTasks.jl @@ -3,10 +3,16 @@ module StableTasks macro spawn end macro spawnat end -using Base: RefValue +mutable struct AtomicRef{T} + @atomic x::T + AtomicRef{T}() where {T} = new{T}() + AtomicRef(x::T) where {T} = new{T}(x) + AtomicRef{T}(x) where {T} = new{T}(convert(T, x)) +end + struct StableTask{T} t::Task - ret::RefValue{T} + ret::AtomicRef{T} end include("internals.jl") diff --git a/src/internals.jl b/src/internals.jl index 469d958..34bb325 100644 --- a/src/internals.jl +++ b/src/internals.jl @@ -1,6 +1,9 @@ module Internals -import StableTasks: @spawn, @spawnat, StableTask +import StableTasks: @spawn, @spawnat, StableTask, AtomicRef + +Base.getindex(r::AtomicRef) = @atomic r.x +Base.setindex!(r::AtomicRef{T}, x) where {T} = @atomic r.x = convert(T, x) function Base.fetch(t::StableTask{T}) where {T} fetch(t.t) @@ -25,41 +28,58 @@ Base.schedule(t::StableTask) = (schedule(t.t); t) Base.schedule(t, val; error=false) = (schedule(t.t, val; error); t) """ -Similar to `Threads.@spawn` but type-stable. Creates a `Task` and schedules it to run on any available thread in the `:default` threadpool. + @spawn [:default|:interactive] expr + +Similar to `Threads.@spawn` but type-stable. Creates a `Task` and schedules it to run on any available +thread in the specified threadpool (defaults to the `:default` threadpool). """ -macro spawn(ex) +macro spawn(args...) + tp = QuoteNode(:default) + na = length(args) + if na == 2 + ttype, ex = args + if ttype isa QuoteNode + ttype = ttype.value + if ttype !== :interactive && ttype !== :default + throw(ArgumentError("unsupported threadpool in StableTasks.@spawn: $ttype")) + end + tp = QuoteNode(ttype) + else + tp = ttype + end + elseif na == 1 + ex = args[1] + else + throw(ArgumentError("wrong number of arguments in @spawn")) + end + letargs = _lift_one_interp!(ex) thunk = replace_linenums!(:(() -> ($(esc(ex)))), __source__) var = esc(Base.sync_varname) # This is for the @sync macro which sets a local variable whose name is # the symbol bound to Base.sync_varname # I asked on slack and this is apparently safe to consider a public API - set_pool = if VERSION < v"1.9" - nothing - else - :(Threads._spawn_set_thrpool(task, :default)) - end quote let $(letargs...) f = $thunk T = Core.Compiler.return_type(f, Tuple{}) - ref = Ref{T}() + ref = AtomicRef{T}() f_wrap = () -> (ref[] = f(); nothing) task = Task(f_wrap) task.sticky = false - $set_pool + Threads._spawn_set_thrpool(task, $(esc(tp))) if $(Expr(:islocal, var)) put!($var, task) # Sync will set up a Channel, and we want our task to be in there. end schedule(task) - StableTask(task, ref) + StableTask{T}(task, ref) end end end """ Similar to `StableTasks.@spawn` but creates a **sticky** `Task` and schedules it to run on the thread with the given id (`thrdid`). -The task is guaranteed to stay on this thread (it won't migrate to another thread). +The task is guaranteed to stay on this thread (it won't migrate to another thread). """ macro spawnat(thrdid, ex) letargs = _lift_one_interp!(ex) @@ -81,7 +101,7 @@ macro spawnat(thrdid, ex) let $(letargs...) thunk = $thunk RT = Core.Compiler.return_type(thunk, Tuple{}) - ret = Ref{RT}() + ret = AtomicRef{RT}() thunk_wrap = () -> (ret[] = thunk(); nothing) local task = Task(thunk_wrap) task.sticky = true diff --git a/test/runtests.jl b/test/runtests.jl index beb7d39..ef32a6f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,15 @@ using StableTasks: @spawn, @spawnat t = @eval @spawn inv([1 2 ; 3 4]) @test inv([1 2 ; 3 4]) == @inferred fetch(t) + @test 2 == @inferred fetch(@spawn :interactive 1 + 1) + t = @eval @spawn :interactive inv([1 2 ; 3 4]) + @test inv([1 2 ; 3 4]) == @inferred fetch(t) + + s = :default + @test 2 == @inferred fetch(@spawn s 1 + 1) + t = @eval @spawn $(QuoteNode(s)) inv([1 2 ; 3 4]) + @test inv([1 2 ; 3 4]) == @inferred fetch(t) + @test 2 == @inferred fetch(@spawnat 1 1 + 1) t = @eval @spawnat 1 inv([1 2 ; 3 4]) @test inv([1 2 ; 3 4]) == @inferred fetch(t)