Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WeakKeyIdDict implementation (keeping WeakKeyDict) #28182

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions base/abstractdict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ end

function isequal(l::AbstractDict, r::AbstractDict)
l === r && return true
if isa(l,IdDict) != isa(r,IdDict)
if isa(l,IdDict) != isa(r,IdDict) || isa(l,WeakKeyDict) != isa(r,WeakKeyDict) || isa(l,WeakKeyIdDict) != isa(r,WeakKeyIdDict)
return false
end
if length(l) != length(r) return false end
Expand All @@ -467,7 +467,7 @@ end

function ==(l::AbstractDict, r::AbstractDict)
l === r && return true
if isa(l,IdDict) != isa(r,IdDict)
if isa(l,IdDict) != isa(r,IdDict) || isa(l,WeakKeyDict) != isa(r,WeakKeyDict) || isa(l,WeakKeyIdDict) != isa(r,WeakKeyIdDict)
return false
end
length(l) != length(r) && return false
Expand Down
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ export
Vector,
VersionNumber,
WeakKeyDict,
WeakKeyIdDict,

# Ccall types
Cchar,
Expand Down
162 changes: 116 additions & 46 deletions base/weakkeydict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,79 @@

# weak key dictionaries

# Type to wrap a WeakRef to furbish it with object-id comparison and hashing.
struct WeakRefForWeakIdDict
w::WeakRef
WeakRefForWeakIdDict(wr::WeakRef) = new(wr)
end
WeakRefForWeakIdDict(val) = WeakRefForWeakIdDict(WeakRef(val))
==(wr1::WeakRefForWeakIdDict, wr2::WeakRefForWeakIdDict) = wr1.w.value===wr2.w.value
hash(wr::WeakRefForWeakIdDict, h::UInt) = hash_uint(3h - objectid(wr.w.value))

# Type to wrap a WeakRef to furbish it with == comparison and "normal" hashing of its value
struct WeakRefForWeakDict
w::WeakRef
WeakRefForWeakDict(wr::WeakRef) = new(wr)
end
WeakRefForWeakDict(val) = WeakRefForWeakDict(WeakRef(val))
==(wr1::WeakRefForWeakDict, wr2::WeakRefForWeakDict) = wr1.w.value==wr2.w.value
hash(wr::WeakRefForWeakDict, h::UInt) = hash(wr.w.value, h)


abstract type AbstractWeakKeyDict{K,V} <: AbstractDict{K,V} end
"""
WeakKeyIdDict([itr])

`WeakKeyIdDict()` constructs a hash table where the keys are weak
references to objects, and thus may be garbage collected even when
referenced in a hash table.

The hashing and comparison are based on object-id and === of the key.

See [`Dict`](@ref) for further help.
"""
mutable struct WeakKeyIdDict{K,V} <: AbstractWeakKeyDict{K,V}
ht::Dict{WeakRefForWeakIdDict,V}
lock::Threads.RecursiveSpinLock
finalizer::Function

# Constructors mirror Dict's
function WeakKeyIdDict{K,V}() where V where K
t = new(Dict{WeakRefForWeakIdDict,V}(), Threads.RecursiveSpinLock(), identity)
t.finalizer = function (k)
# when a weak key is finalized, remove from dictionary if it is still there
if islocked(t)
finalizer(t.finalizer, k)
return nothing
end
delete!(t, k)
end
return t
end
end
keytype_ht(::Type{<:WeakKeyIdDict}) = WeakRefForWeakIdDict
striptp(::Type{<:WeakKeyIdDict}) = WeakKeyIdDict

"""
WeakKeyDict([itr])

`WeakKeyDict()` constructs a hash table where the keys are weak
references to objects, and thus may be garbage collected even when
referenced in a hash table.

The hashing and comparison are based on the normal hash and == of the
key.

See [`Dict`](@ref) for further help.
"""
mutable struct WeakKeyDict{K,V} <: AbstractDict{K,V}
ht::Dict{WeakRef,V}
mutable struct WeakKeyDict{K,V} <: AbstractWeakKeyDict{K,V}
ht::Dict{WeakRefForWeakDict,V}
lock::Threads.RecursiveSpinLock
finalizer::Function

# Constructors mirror Dict's
function WeakKeyDict{K,V}() where V where K
t = new(Dict{Any,V}(), Threads.RecursiveSpinLock(), identity)
t = new(Dict{WeakRefForWeakDict,V}(), Threads.RecursiveSpinLock(), identity)
t.finalizer = function (k)
# when a weak key is finalized, remove from dictionary if it is still there
if islocked(t)
Expand All @@ -30,81 +86,95 @@ mutable struct WeakKeyDict{K,V} <: AbstractDict{K,V}
return t
end
end
function WeakKeyDict{K,V}(kv) where V where K
h = WeakKeyDict{K,V}()
keytype_ht(::Type{<:WeakKeyDict}) = WeakRefForWeakDict
striptp(::Type{<:WeakKeyDict}) = WeakKeyDict

# Constructors as for Dict
function (::Type{W})(kv) where W<:AbstractWeakKeyDict{K,V} where {K, V}
h = W()
for (k,v) in kv
h[k] = v
end
return h
end
WeakKeyDict{K,V}(p::Pair) where V where K = setindex!(WeakKeyDict{K,V}(), p.second, p.first)
function WeakKeyDict{K,V}(ps::Pair...) where V where K
h = WeakKeyDict{K,V}()
(::Type{W})(p::Pair) where W<:AbstractWeakKeyDict{K,V} where {K, V} = setindex!(W(), p.second, p.first)
function (::Type{W})(ps::Pair...) where W<:AbstractWeakKeyDict{K,V} where {K, V}
h = W()
sizehint!(h, length(ps))
for p in ps
h[p.first] = p.second
end
return h
end
WeakKeyDict() = WeakKeyDict{Any,Any}()
(::Type{W})() where W<:AbstractWeakKeyDict = W{Any,Any}()

WeakKeyDict(kv::Tuple{}) = WeakKeyDict()
copy(d::WeakKeyDict) = WeakKeyDict(d)
(::Type{W})(kv::Tuple{}) where W<:AbstractWeakKeyDict = W()
copy(d::W) where W<:AbstractWeakKeyDict = W(d)

WeakKeyDict(ps::Pair{K,V}...) where {K,V} = WeakKeyDict{K,V}(ps)
WeakKeyDict(ps::Pair{K}...) where {K} = WeakKeyDict{K,Any}(ps)
WeakKeyDict(ps::(Pair{K,V} where K)...) where {V} = WeakKeyDict{Any,V}(ps)
WeakKeyDict(ps::Pair...) = WeakKeyDict{Any,Any}(ps)
(::Type{W})(ps::Pair{K,V}...) where W<:AbstractWeakKeyDict where {K,V} = W{K,V}(ps)
(::Type{W})(ps::Pair{K}...) where W<:AbstractWeakKeyDict where {K} = W{K,Any}(ps)
(::Type{W})(ps::(Pair{K,V} where K)...) where W<:AbstractWeakKeyDict where {V} = W{Any,V}(ps)
(::Type{W})(ps::Pair...) where W<:AbstractWeakKeyDict = W{Any,Any}(ps)
(::Type{W})(ps::Pair) where W<:AbstractWeakKeyDict = W{Any,Any}(ps)

function WeakKeyDict(kv)
function (::Type{W})(kv) where W<:AbstractWeakKeyDict
try
Base.dict_with_eltype((K, V) -> WeakKeyDict{K, V}, kv, eltype(kv))
Base.dict_with_eltype((K, V) -> W{K, V}, kv, eltype(kv))
catch e
if !isiterable(typeof(kv)) || !all(x->isa(x,Union{Tuple,Pair}),kv)
throw(ArgumentError("WeakKeyDict(kv): kv needs to be an iterator of tuples or pairs"))
throw(ArgumentError("$W(kv): kv needs to be an iterator of tuples or pairs"))
else
rethrow(e)
end
end
end

empty(d::WeakKeyDict, ::Type{K}, ::Type{V}) where {K, V} = WeakKeyDict{K, V}()
empty(d::W, ::Type{K}, ::Type{V}) where W<:AbstractWeakKeyDict{K,V} where {K, V} = striptp(W){K, V}()

islocked(wkh::WeakKeyDict) = islocked(wkh.lock)
lock(f, wkh::WeakKeyDict) = lock(f, wkh.lock)
trylock(f, wkh::WeakKeyDict) = trylock(f, wkh.lock)
islocked(wkh::AbstractWeakKeyDict) = islocked(wkh.lock)
lock(f, wkh::AbstractWeakKeyDict) = lock(f, wkh.lock)
trylock(f, wkh::AbstractWeakKeyDict) = trylock(f, wkh.lock)

function setindex!(wkh::WeakKeyDict{K}, v, key) where K
k = convert(K, key)
finalizer(wkh.finalizer, k)
function setindex!(wkh::W, v, key) where W<:AbstractWeakKeyDict{K} where K
!isa(key, K) && throw(ArgumentError("$key is not a valid key for type $K"))
finalizer(wkh.finalizer, key)
lock(wkh) do
wkh.ht[WeakRef(k)] = v
wkh.ht[keytype_ht(W)(key)] = v
end
return wkh
end

function getkey(wkh::WeakKeyDict{K}, kk, default) where K
function getkey(wkh::W, kk, default) where W<:AbstractWeakKeyDict{K} where K
return lock(wkh) do
k = getkey(wkh.ht, kk, secret_table_token)
k = getkey(wkh.ht, keytype_ht(W)(kk), secret_table_token)
k === secret_table_token && return default
return k.value::K
return k.w.value::K
end
end

get(wkh::WeakKeyDict{K}, key, default) where {K} = lock(() -> get(wkh.ht, key, default), wkh)
get(default::Callable, wkh::WeakKeyDict{K}, key) where {K} = lock(() -> get(default, wkh.ht, key), wkh)
get!(wkh::WeakKeyDict{K}, key, default) where {K} = lock(() -> get!(wkh.ht, key, default), wkh)
get!(default::Callable, wkh::WeakKeyDict{K}, key) where {K} = lock(() -> get!(default, wkh.ht, key), wkh)
pop!(wkh::WeakKeyDict{K}, key) where {K} = lock(() -> pop!(wkh.ht, key), wkh)
pop!(wkh::WeakKeyDict{K}, key, default) where {K} = lock(() -> pop!(wkh.ht, key, default), wkh)
delete!(wkh::WeakKeyDict, key) = lock(() -> delete!(wkh.ht, key), wkh)
empty!(wkh::WeakKeyDict) = (lock(() -> empty!(wkh.ht), wkh); wkh)
haskey(wkh::WeakKeyDict{K}, key) where {K} = lock(() -> haskey(wkh.ht, key), wkh)
getindex(wkh::WeakKeyDict{K}, key) where {K} = lock(() -> getindex(wkh.ht, key), wkh)
isempty(wkh::WeakKeyDict) = isempty(wkh.ht)
length(t::WeakKeyDict) = length(t.ht)

function iterate(t::WeakKeyDict{K,V}) where V where K
get(wkh::W, key, default) where W<:AbstractWeakKeyDict{K} where {K} =
lock(() -> get(wkh.ht, keytype_ht(W)(key), default), wkh)
get(default::Callable, wkh::W, key) where W<:AbstractWeakKeyDict{K} where {K} =
lock(() -> get(default, wkh.ht, keytype_ht(W)(key)), wkh)
get!(wkh::W, key, default) where W<:AbstractWeakKeyDict{K} where {K} =
lock(() -> get!(wkh.ht, keytype_ht(W)(key), default), wkh)
get!(default::Callable, wkh::W, key) where W<:AbstractWeakKeyDict{K} where {K} =
lock(() -> get!(default, wkh.ht, keytype_ht(W)(key)), wkh)
pop!(wkh::W, key) where W<:AbstractWeakKeyDict{K} where {K} =
lock(() -> pop!(wkh.ht, keytype_ht(W)(key)), wkh)
pop!(wkh::W, key, default) where W<:AbstractWeakKeyDict{K} where {K} =
lock(() -> pop!(wkh.ht, keytype_ht(W)(key), default), wkh)
delete!(wkh::W, key) where W<:AbstractWeakKeyDict =
lock(() -> delete!(wkh.ht, keytype_ht(W)(key)), wkh)
empty!(wkh::AbstractWeakKeyDict) = (lock(() -> empty!(wkh.ht), wkh); wkh)
haskey(wkh::W, key) where W<:AbstractWeakKeyDict{K} where {K} =
lock(() -> haskey(wkh.ht, keytype_ht(W)(key)), wkh)
getindex(wkh::W, key) where W<:AbstractWeakKeyDict{K} where {K} =
lock(() -> getindex(wkh.ht, keytype_ht(W)(key)), wkh)
isempty(wkh::AbstractWeakKeyDict) = isempty(wkh.ht)
length(t::AbstractWeakKeyDict) = length(t.ht)

function iterate(t::AbstractWeakKeyDict)
gc_token = Ref{Bool}(false) # no keys will be deleted via finalizers until this token is gc'd
finalizer(gc_token) do r
if r[]
Expand All @@ -115,13 +185,13 @@ function iterate(t::WeakKeyDict{K,V}) where V where K
s = lock(t.lock)
iterate(t, (gc_token,))
end
function iterate(t::WeakKeyDict{K,V}, state) where V where K
function iterate(t::W, state) where W<:AbstractWeakKeyDict{K,V} where {K,V}
gc_token = first(state)
y = iterate(t.ht, tail(state)...)
y === nothing && return nothing
wkv, i = y
kv = Pair{K,V}(wkv[1].value::K, wkv[2])
kv = Pair{K,V}(wkv[1].w.value::K, wkv[2])
return (kv, (gc_token, i))
end

filter!(f, d::WeakKeyDict) = filter_in_one_pass!(f, d)
filter!(f, d::AbstractWeakKeyDict) = filter_in_one_pass!(f, d)
2 changes: 2 additions & 0 deletions doc/src/base/collections.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Fully implemented by:
* [`IdDict`](@ref)
* [`Dict`](@ref)
* [`WeakKeyDict`](@ref)
* [`WeakKeyIdDict`](@ref)
* `EachLine`
* `AbstractString`
* [`Set`](@ref)
Expand Down Expand Up @@ -79,6 +80,7 @@ Fully implemented by:
* [`IdDict`](@ref)
* [`Dict`](@ref)
* [`WeakKeyDict`](@ref)
* [`WeakKeyIdDict`](@ref)
* `AbstractString`
* [`Set`](@ref)
* [`NamedTuple`](@ref)
Expand Down
Loading