Skip to content

Commit

Permalink
switch from Ref to AtomicRef
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter committed Jan 30, 2024
1 parent 697e1fb commit 3705f60
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
14 changes: 10 additions & 4 deletions src/StableTasks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@ module StableTasks
macro spawn end
macro spawnat end

using Base: RefValue
struct StableTask{T}
t::Task
ret::RefValue{T}
mutable struct AtomicRef{T}
@atomic x::T
AtomicRef{T}() where {T} = new{T}()
AtomicRef(x::T) where {T} = new{T}(x)
AtomicRef{T}(x) where {T} = new{T}(convert(T, x))
end

mutable struct StableTask{T}
const t::Task
ret::AtomicRef{T}
end

include("internals.jl")
Expand Down
38 changes: 27 additions & 11 deletions src/internals.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
module Internals

import StableTasks: @spawn, @spawnat, StableTask
import StableTasks: @spawn, @spawnat, StableTask, AtomicRef

Base.getindex(r::AtomicRef) = @atomic r.x
Base.setindex!(r::AtomicRef{T}, x) where {T} = @atomic r.x = convert(T, x)

function Base.fetch(t::StableTask{T}) where {T}
fetch(t.t)
Expand All @@ -25,32 +28,45 @@ Base.schedule(t::StableTask) = (schedule(t.t); t)
Base.schedule(t, val; error=false) = (schedule(t.t, val; error); t)


macro spawn(ex)
macro spawn(args...)
tp = QuoteNode(:default)
na = length(args)
if na == 2
ttype, ex = args
if ttype isa QuoteNode
ttype = ttype.value
if ttype !== :interactive && ttype !== :default
throw(ArgumentError("unsupported threadpool in StableTasks.@spawn: $ttype"))
end
tp = QuoteNode(ttype)
else
tp = ttype
end
elseif na == 1
ex = args[1]
else
throw(ArgumentError("wrong number of arguments in @spawn"))
end
letargs = _lift_one_interp!(ex)

thunk = replace_linenums!(:(() -> ($(esc(ex)))), __source__)
var = esc(Base.sync_varname) # This is for the @sync macro which sets a local variable whose name is
# the symbol bound to Base.sync_varname
# I asked on slack and this is apparently safe to consider a public API
set_pool = if VERSION < v"1.9"
nothing
else
:(Threads._spawn_set_thrpool(task, :default))
end
quote
let $(letargs...)
f = $thunk
T = Core.Compiler.return_type(f, Tuple{})
ref = Ref{T}()
ref = AtomicRef{T}()
f_wrap = () -> (ref[] = f(); nothing)
task = Task(f_wrap)
task.sticky = false
$set_pool
Threads._spawn_set_thrpool(task, $(esc(tp)))
if $(Expr(:islocal, var))
put!($var, task) # Sync will set up a Channel, and we want our task to be in there.
end
schedule(task)
StableTask(task, ref)
StableTask{T}(task, ref)
end
end
end
Expand All @@ -75,7 +91,7 @@ macro spawnat(thrdid, ex)
let $(letargs...)
thunk = $thunk
RT = Core.Compiler.return_type(thunk, Tuple{})
ret = Ref{RT}()
ret = AtomicRef{RT}()
thunk_wrap = () -> (ret[] = thunk(); nothing)
local task = Task(thunk_wrap)
task.sticky = true
Expand Down

0 comments on commit 3705f60

Please sign in to comment.