diff --git a/base/abstractdict.jl b/base/abstractdict.jl index c71a98241b719..79b776de50607 100644 --- a/base/abstractdict.jl +++ b/base/abstractdict.jl @@ -453,7 +453,7 @@ end function isequal(l::AbstractDict, r::AbstractDict) l === r && return true - if isa(l,IdDict) != isa(r,IdDict) + if isa(l,IdDict) != isa(r,IdDict) || isa(l,WeakKeyDict) != isa(r,WeakKeyDict) || isa(l,WeakKeyIdDict) != isa(r,WeakKeyIdDict) return false end if length(l) != length(r) return false end @@ -467,7 +467,7 @@ end function ==(l::AbstractDict, r::AbstractDict) l === r && return true - if isa(l,IdDict) != isa(r,IdDict) + if isa(l,IdDict) != isa(r,IdDict) || isa(l,WeakKeyDict) != isa(r,WeakKeyDict) || isa(l,WeakKeyIdDict) != isa(r,WeakKeyIdDict) return false end length(l) != length(r) && return false diff --git a/base/exports.jl b/base/exports.jl index d87889eb047df..3c246cfd80e09 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -93,6 +93,7 @@ export Vector, VersionNumber, WeakKeyDict, + WeakKeyIdDict, # Ccall types Cchar, diff --git a/base/weakkeydict.jl b/base/weakkeydict.jl index 32470135045de..01eb06935fc5a 100644 --- a/base/weakkeydict.jl +++ b/base/weakkeydict.jl @@ -2,6 +2,59 @@ # weak key dictionaries +# Type to wrap a WeakRef to furbish it with object-id comparison and hashing. +struct WeakRefForWeakIdDict + w::WeakRef + WeakRefForWeakIdDict(wr::WeakRef) = new(wr) +end +WeakRefForWeakIdDict(val) = WeakRefForWeakIdDict(WeakRef(val)) +==(wr1::WeakRefForWeakIdDict, wr2::WeakRefForWeakIdDict) = wr1.w.value===wr2.w.value +hash(wr::WeakRefForWeakIdDict, h::UInt) = hash_uint(3h - objectid(wr.w.value)) + +# Type to wrap a WeakRef to furbish it with == comparison and "normal" hashing of its value +struct WeakRefForWeakDict + w::WeakRef + WeakRefForWeakDict(wr::WeakRef) = new(wr) +end +WeakRefForWeakDict(val) = WeakRefForWeakDict(WeakRef(val)) +==(wr1::WeakRefForWeakDict, wr2::WeakRefForWeakDict) = wr1.w.value==wr2.w.value +hash(wr::WeakRefForWeakDict, h::UInt) = hash(wr.w.value, h) + + +abstract type AbstractWeakKeyDict{K,V} <: AbstractDict{K,V} end +""" + WeakKeyIdDict([itr]) + +`WeakKeyIdDict()` constructs a hash table where the keys are weak +references to objects, and thus may be garbage collected even when +referenced in a hash table. + +The hashing and comparison are based on object-id and === of the key. + +See [`Dict`](@ref) for further help. +""" +mutable struct WeakKeyIdDict{K,V} <: AbstractWeakKeyDict{K,V} + ht::Dict{WeakRefForWeakIdDict,V} + lock::Threads.RecursiveSpinLock + finalizer::Function + + # Constructors mirror Dict's + function WeakKeyIdDict{K,V}() where V where K + t = new(Dict{WeakRefForWeakIdDict,V}(), Threads.RecursiveSpinLock(), identity) + t.finalizer = function (k) + # when a weak key is finalized, remove from dictionary if it is still there + if islocked(t) + finalizer(t.finalizer, k) + return nothing + end + delete!(t, k) + end + return t + end +end +keytype_ht(::Type{<:WeakKeyIdDict}) = WeakRefForWeakIdDict +striptp(::Type{<:WeakKeyIdDict}) = WeakKeyIdDict + """ WeakKeyDict([itr]) @@ -9,16 +62,19 @@ references to objects, and thus may be garbage collected even when referenced in a hash table. +The hashing and comparison are based on the normal hash and == of the +key. + See [`Dict`](@ref) for further help. """ -mutable struct WeakKeyDict{K,V} <: AbstractDict{K,V} - ht::Dict{WeakRef,V} +mutable struct WeakKeyDict{K,V} <: AbstractWeakKeyDict{K,V} + ht::Dict{WeakRefForWeakDict,V} lock::Threads.RecursiveSpinLock finalizer::Function # Constructors mirror Dict's function WeakKeyDict{K,V}() where V where K - t = new(Dict{Any,V}(), Threads.RecursiveSpinLock(), identity) + t = new(Dict{WeakRefForWeakDict,V}(), Threads.RecursiveSpinLock(), identity) t.finalizer = function (k) # when a weak key is finalized, remove from dictionary if it is still there if islocked(t) @@ -30,81 +86,95 @@ mutable struct WeakKeyDict{K,V} <: AbstractDict{K,V} return t end end -function WeakKeyDict{K,V}(kv) where V where K - h = WeakKeyDict{K,V}() +keytype_ht(::Type{<:WeakKeyDict}) = WeakRefForWeakDict +striptp(::Type{<:WeakKeyDict}) = WeakKeyDict + +# Constructors as for Dict +function (::Type{W})(kv) where W<:AbstractWeakKeyDict{K,V} where {K, V} + h = W() for (k,v) in kv h[k] = v end return h end -WeakKeyDict{K,V}(p::Pair) where V where K = setindex!(WeakKeyDict{K,V}(), p.second, p.first) -function WeakKeyDict{K,V}(ps::Pair...) where V where K - h = WeakKeyDict{K,V}() +(::Type{W})(p::Pair) where W<:AbstractWeakKeyDict{K,V} where {K, V} = setindex!(W(), p.second, p.first) +function (::Type{W})(ps::Pair...) where W<:AbstractWeakKeyDict{K,V} where {K, V} + h = W() sizehint!(h, length(ps)) for p in ps h[p.first] = p.second end return h end -WeakKeyDict() = WeakKeyDict{Any,Any}() +(::Type{W})() where W<:AbstractWeakKeyDict = W{Any,Any}() -WeakKeyDict(kv::Tuple{}) = WeakKeyDict() -copy(d::WeakKeyDict) = WeakKeyDict(d) +(::Type{W})(kv::Tuple{}) where W<:AbstractWeakKeyDict = W() +copy(d::W) where W<:AbstractWeakKeyDict = W(d) -WeakKeyDict(ps::Pair{K,V}...) where {K,V} = WeakKeyDict{K,V}(ps) -WeakKeyDict(ps::Pair{K}...) where {K} = WeakKeyDict{K,Any}(ps) -WeakKeyDict(ps::(Pair{K,V} where K)...) where {V} = WeakKeyDict{Any,V}(ps) -WeakKeyDict(ps::Pair...) = WeakKeyDict{Any,Any}(ps) +(::Type{W})(ps::Pair{K,V}...) where W<:AbstractWeakKeyDict where {K,V} = W{K,V}(ps) +(::Type{W})(ps::Pair{K}...) where W<:AbstractWeakKeyDict where {K} = W{K,Any}(ps) +(::Type{W})(ps::(Pair{K,V} where K)...) where W<:AbstractWeakKeyDict where {V} = W{Any,V}(ps) +(::Type{W})(ps::Pair...) where W<:AbstractWeakKeyDict = W{Any,Any}(ps) +(::Type{W})(ps::Pair) where W<:AbstractWeakKeyDict = W{Any,Any}(ps) -function WeakKeyDict(kv) +function (::Type{W})(kv) where W<:AbstractWeakKeyDict try - Base.dict_with_eltype((K, V) -> WeakKeyDict{K, V}, kv, eltype(kv)) + Base.dict_with_eltype((K, V) -> W{K, V}, kv, eltype(kv)) catch e if !isiterable(typeof(kv)) || !all(x->isa(x,Union{Tuple,Pair}),kv) - throw(ArgumentError("WeakKeyDict(kv): kv needs to be an iterator of tuples or pairs")) + throw(ArgumentError("$W(kv): kv needs to be an iterator of tuples or pairs")) else rethrow(e) end end end -empty(d::WeakKeyDict, ::Type{K}, ::Type{V}) where {K, V} = WeakKeyDict{K, V}() +empty(d::W, ::Type{K}, ::Type{V}) where W<:AbstractWeakKeyDict{K,V} where {K, V} = striptp(W){K, V}() -islocked(wkh::WeakKeyDict) = islocked(wkh.lock) -lock(f, wkh::WeakKeyDict) = lock(f, wkh.lock) -trylock(f, wkh::WeakKeyDict) = trylock(f, wkh.lock) +islocked(wkh::AbstractWeakKeyDict) = islocked(wkh.lock) +lock(f, wkh::AbstractWeakKeyDict) = lock(f, wkh.lock) +trylock(f, wkh::AbstractWeakKeyDict) = trylock(f, wkh.lock) -function setindex!(wkh::WeakKeyDict{K}, v, key) where K - k = convert(K, key) - finalizer(wkh.finalizer, k) +function setindex!(wkh::W, v, key) where W<:AbstractWeakKeyDict{K} where K + !isa(key, K) && throw(ArgumentError("$key is not a valid key for type $K")) + finalizer(wkh.finalizer, key) lock(wkh) do - wkh.ht[WeakRef(k)] = v + wkh.ht[keytype_ht(W)(key)] = v end return wkh end -function getkey(wkh::WeakKeyDict{K}, kk, default) where K +function getkey(wkh::W, kk, default) where W<:AbstractWeakKeyDict{K} where K return lock(wkh) do - k = getkey(wkh.ht, kk, secret_table_token) + k = getkey(wkh.ht, keytype_ht(W)(kk), secret_table_token) k === secret_table_token && return default - return k.value::K + return k.w.value::K end end -get(wkh::WeakKeyDict{K}, key, default) where {K} = lock(() -> get(wkh.ht, key, default), wkh) -get(default::Callable, wkh::WeakKeyDict{K}, key) where {K} = lock(() -> get(default, wkh.ht, key), wkh) -get!(wkh::WeakKeyDict{K}, key, default) where {K} = lock(() -> get!(wkh.ht, key, default), wkh) -get!(default::Callable, wkh::WeakKeyDict{K}, key) where {K} = lock(() -> get!(default, wkh.ht, key), wkh) -pop!(wkh::WeakKeyDict{K}, key) where {K} = lock(() -> pop!(wkh.ht, key), wkh) -pop!(wkh::WeakKeyDict{K}, key, default) where {K} = lock(() -> pop!(wkh.ht, key, default), wkh) -delete!(wkh::WeakKeyDict, key) = lock(() -> delete!(wkh.ht, key), wkh) -empty!(wkh::WeakKeyDict) = (lock(() -> empty!(wkh.ht), wkh); wkh) -haskey(wkh::WeakKeyDict{K}, key) where {K} = lock(() -> haskey(wkh.ht, key), wkh) -getindex(wkh::WeakKeyDict{K}, key) where {K} = lock(() -> getindex(wkh.ht, key), wkh) -isempty(wkh::WeakKeyDict) = isempty(wkh.ht) -length(t::WeakKeyDict) = length(t.ht) - -function iterate(t::WeakKeyDict{K,V}) where V where K +get(wkh::W, key, default) where W<:AbstractWeakKeyDict{K} where {K} = + lock(() -> get(wkh.ht, keytype_ht(W)(key), default), wkh) +get(default::Callable, wkh::W, key) where W<:AbstractWeakKeyDict{K} where {K} = + lock(() -> get(default, wkh.ht, keytype_ht(W)(key)), wkh) +get!(wkh::W, key, default) where W<:AbstractWeakKeyDict{K} where {K} = + lock(() -> get!(wkh.ht, keytype_ht(W)(key), default), wkh) +get!(default::Callable, wkh::W, key) where W<:AbstractWeakKeyDict{K} where {K} = + lock(() -> get!(default, wkh.ht, keytype_ht(W)(key)), wkh) +pop!(wkh::W, key) where W<:AbstractWeakKeyDict{K} where {K} = + lock(() -> pop!(wkh.ht, keytype_ht(W)(key)), wkh) +pop!(wkh::W, key, default) where W<:AbstractWeakKeyDict{K} where {K} = + lock(() -> pop!(wkh.ht, keytype_ht(W)(key), default), wkh) +delete!(wkh::W, key) where W<:AbstractWeakKeyDict = + lock(() -> delete!(wkh.ht, keytype_ht(W)(key)), wkh) +empty!(wkh::AbstractWeakKeyDict) = (lock(() -> empty!(wkh.ht), wkh); wkh) +haskey(wkh::W, key) where W<:AbstractWeakKeyDict{K} where {K} = + lock(() -> haskey(wkh.ht, keytype_ht(W)(key)), wkh) +getindex(wkh::W, key) where W<:AbstractWeakKeyDict{K} where {K} = + lock(() -> getindex(wkh.ht, keytype_ht(W)(key)), wkh) +isempty(wkh::AbstractWeakKeyDict) = isempty(wkh.ht) +length(t::AbstractWeakKeyDict) = length(t.ht) + +function iterate(t::AbstractWeakKeyDict) gc_token = Ref{Bool}(false) # no keys will be deleted via finalizers until this token is gc'd finalizer(gc_token) do r if r[] @@ -115,13 +185,13 @@ function iterate(t::WeakKeyDict{K,V}) where V where K s = lock(t.lock) iterate(t, (gc_token,)) end -function iterate(t::WeakKeyDict{K,V}, state) where V where K +function iterate(t::W, state) where W<:AbstractWeakKeyDict{K,V} where {K,V} gc_token = first(state) y = iterate(t.ht, tail(state)...) y === nothing && return nothing wkv, i = y - kv = Pair{K,V}(wkv[1].value::K, wkv[2]) + kv = Pair{K,V}(wkv[1].w.value::K, wkv[2]) return (kv, (gc_token, i)) end -filter!(f, d::WeakKeyDict) = filter_in_one_pass!(f, d) +filter!(f, d::AbstractWeakKeyDict) = filter_in_one_pass!(f, d) diff --git a/doc/src/base/collections.md b/doc/src/base/collections.md index bc7c6a41dd40a..4a9b6ff007de0 100644 --- a/doc/src/base/collections.md +++ b/doc/src/base/collections.md @@ -43,6 +43,7 @@ Fully implemented by: * [`IdDict`](@ref) * [`Dict`](@ref) * [`WeakKeyDict`](@ref) + * [`WeakKeyIdDict`](@ref) * `EachLine` * `AbstractString` * [`Set`](@ref) @@ -79,6 +80,7 @@ Fully implemented by: * [`IdDict`](@ref) * [`Dict`](@ref) * [`WeakKeyDict`](@ref) + * [`WeakKeyIdDict`](@ref) * `AbstractString` * [`Set`](@ref) * [`NamedTuple`](@ref) diff --git a/test/dict.jl b/test/dict.jl index ec5dabab136e4..8e1e85c50f210 100644 --- a/test/dict.jl +++ b/test/dict.jl @@ -780,60 +780,90 @@ end Dict(1 => rand(2,3), 'c' => "asdf") # just make sure this does not trigger a deprecation @testset "WeakKeyDict" begin + for WD in [WeakKeyDict, WeakKeyIdDict] + A = [1] + B = [2] + C = [3] + local x = 0 + local y = 0 + local z = 0 + finalizer(a->(x+=1), A) + finalizer(b->(y+=1), B) + finalizer(c->(z+=1), C) + + # construction + wkd = WD() + wkd[A] = 2 + wkd[B] = 3 + wkd[C] = 4 + dd = convert(Dict{Any,Any},wkd) + @test WD(dd) == wkd + @test convert(WD{Any, Any}, dd) == wkd + @test isa(WD(dd), WD{Any,Any}) + @test WD(A=>2, B=>3, C=>4) == wkd + @test isa(WD(A=>2, B=>3, C=>4), WD{Array{Int,1},Int}) + @test WD(a=>i+1 for (i,a) in enumerate([A,B,C]) ) == wkd + @test WD([(A,2), (B,3), (C,4)]) == wkd + @test WD(Pair(A,2), Pair(B,3), Pair(C,4)) == wkd + @test copy(wkd) == wkd + + @test length(wkd) == 3 + @test !isempty(wkd) + res = pop!(wkd, C) + @test res == 4 + @test C ∉ keys(wkd) + @test 4 ∉ values(wkd) + @test length(wkd) == 2 + @test !isempty(wkd) + wkd = filter!( p -> p.first != B, wkd) + @test B ∉ keys(wkd) + @test 3 ∉ values(wkd) + @test length(wkd) == 1 + @test WD(Pair(A, 2)) == wkd + @test !isempty(wkd) + + wkd = empty!(wkd) + @test wkd == empty(wkd) + @test typeof(wkd) == typeof(empty(wkd)) + @test length(wkd) == 0 + @test isempty(wkd) + @test isa(wkd, WD) + + @test_throws ArgumentError WD([1, 2, 3]) + + # WeakKeyDict does not convert keys + @test_throws ArgumentError WD{Int,Any}(5.0=>1) + + # WeakKeyDict compares false to non-WeakKeyDict + @test IdDict(A=>1)!=WD(A=>1) + @test Dict(A=>1)!=WD(A=>1) + + # issue #26939 + d26939 = WD() + d26939[big"1.0" + 1.1] = 1 + GC.gc() # make sure this doesn't segfault + end + + # WeakKeyIdDict hashes with object-id A = [1] - B = [2] - C = [3] - local x = 0 - local y = 0 - local z = 0 - finalizer(a->(x+=1), A) - finalizer(b->(y+=1), B) - finalizer(c->(z+=1), C) - - # construction - wkd = WeakKeyDict() - wkd[A] = 2 - wkd[B] = 3 - wkd[C] = 4 - dd = convert(Dict{Any,Any},wkd) - @test WeakKeyDict(dd) == wkd - @test convert(WeakKeyDict{Any, Any}, dd) == wkd - @test isa(WeakKeyDict(dd), WeakKeyDict{Any,Any}) - @test WeakKeyDict(A=>2, B=>3, C=>4) == wkd - @test isa(WeakKeyDict(A=>2, B=>3, C=>4), WeakKeyDict{Array{Int,1},Int}) - @test WeakKeyDict(a=>i+1 for (i,a) in enumerate([A,B,C]) ) == wkd - @test WeakKeyDict([(A,2), (B,3), (C,4)]) == wkd - @test WeakKeyDict(Pair(A,2), Pair(B,3), Pair(C,4)) == wkd - @test copy(wkd) == wkd - - @test length(wkd) == 3 - @test !isempty(wkd) - res = pop!(wkd, C) - @test res == 4 - @test C ∉ keys(wkd) - @test 4 ∉ values(wkd) - @test length(wkd) == 2 - @test !isempty(wkd) - wkd = filter!( p -> p.first != B, wkd) - @test B ∉ keys(wkd) - @test 3 ∉ values(wkd) - @test length(wkd) == 1 - @test WeakKeyDict(Pair(A, 2)) == wkd - @test !isempty(wkd) - - wkd = empty!(wkd) - @test wkd == empty(wkd) - @test typeof(wkd) == typeof(empty(wkd)) - @test length(wkd) == 0 - @test isempty(wkd) - @test isa(wkd, WeakKeyDict) - - @test_throws ArgumentError WeakKeyDict([1, 2, 3]) - - # issue #26939 - d26939 = WeakKeyDict() - d26939[big"1.0" + 1.1] = 1 - GC.gc() # make sure this doesn't segfault + AA = copy(A) + GC.@preserve A AA begin + wkd = WeakKeyIdDict(A=>1, AA=>2) + @test length(wkd)==2 + kk = collect(keys(wkd)) + @test kk[1]==kk[2] + @test kk[1]!==kk[2] + end + + # WeakKeyDict uses normal hashes + GC.@preserve A AA begin + wkd = WeakKeyDict(A=>1, AA=>2) + @test length(wkd)==1 + kk = collect(keys(wkd)) + @test kk[1]==A + @test kk[1]==AA + end + end @testset "issue #19995, hash of dicts" begin