Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented WeakKeyIdDict #402

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ This package implements a variety of data structures, including
- RobinDict (implemented with [Robin Hood Hashing](https://cs.uwaterloo.ca/research/tr/1986/CS-86-14.pdf))
- SwissDict (inspired from [SwissTables](https://abseil.io/blog/20180927-swisstables))
- Dictionaries with Defaults
- Weak-Key dicts using object-id as hash
- Trie
- Linked List and Mutable Linked List
- Sorted Dict, Sorted Multi-Dict and Sorted Set
Expand Down
2 changes: 2 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ This package implements a variety of data structures, including
- RobinDict and OrderedRobinDict (implemented with [Robin Hood Hashing](https://cs.uwaterloo.ca/research/tr/1986/CS-86-14.pdf))
- SwissDict (inspired from [SwissTables](https://abseil.io/blog/20180927-swisstables))
- Dictionaries with Defaults
- Weak-Key dicts using object-id as hash
- Trie
- Linked List and Mutable Linked List
- Sorted Dict, Sorted Multi-Dict and Sorted Set
Expand All @@ -43,6 +44,7 @@ Pages = [
"heaps.md",
"ordered_containers.md",
"default_dict.md",
"weakkeyid_dict.md",
"robin_dict.md",
"trie.md",
"linked_list.md",
Expand Down
21 changes: 21 additions & 0 deletions docs/src/weakkeyid_dict.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# WeakKeyIdDict

`WeakKeyIdDict()` constructs a hash table where the keys are weak
references to objects, and thus may be garbage collected even when
referenced in a hash table. Unlike the Julia-Base `WeakKeyDict`, it uses
object-id for hashing and `===` for comparison, which is often more
appropriate.

```julia
A = [1]
wkid = WeakKeyIdDict(A => 1)
wk = WeakKeyDict(A => 1)

haskey(wkid, copy(A)) # false
haskey(wk, copy(A)) # true

A = 1
GC.gc() # make sure the [1] is garbage collected
haskey(wkid, A) # false
haskey(wk, A) # false
```
2 changes: 2 additions & 0 deletions src/DataStructures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module DataStructures
using OrderedCollections: isordered
export OrderedDict, OrderedSet, LittleDict
export DefaultDict, DefaultOrderedDict
export WeakKeyIdDict

export complement, complement!

Expand Down Expand Up @@ -78,6 +79,7 @@ module DataStructures
include("heaps.jl")

include("default_dict.jl")
include("weakkeyid_dict.jl")
include("dict_support.jl")
include("trie.jl")

Expand Down
132 changes: 132 additions & 0 deletions src/weakkeyid_dict.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# weak key dict using object-id hashing/equality

# Type to wrap a WeakRef to furbish it with objectid comparison and hashing.
struct WeakRefForWeakDict
w::WeakRef
WeakRefForWeakDict(wr::WeakRef) = new(wr)
end
WeakRefForWeakDict(val) = WeakRefForWeakDict(WeakRef(val))
Base.:(==)(wr1::WeakRefForWeakDict, wr2::WeakRefForWeakDict) = wr1.w.value===wr2.w.value
Base.hash(wr::WeakRefForWeakDict, h::UInt) = Base.hash_uint(3h - objectid(wr.w.value))

"""
WeakKeyIdDict([itr])
`WeakKeyIdDict()` constructs a hash table where the keys are weak
references to objects, and thus may be garbage collected even when
referenced in a hash table.
See [`Dict`](@ref) for further help.
"""
mutable struct WeakKeyIdDict{K,V} <: AbstractDict{K,V}
ht::Dict{WeakRefForWeakDict,V}
lock::ReentrantLock
finalizer::Function

# Constructors mirror Dict's
function WeakKeyIdDict{K,V}() where V where K
t = new(Dict{WeakRefForWeakDict,V}(), ReentrantLock(), identity)
t.finalizer = function (k)
# when a weak key is finalized, remove from dictionary if it is still there
if islocked(t)
finalizer(t.finalizer, k)
return nothing
end
delete!(t, k)
end
return t
end
end
function WeakKeyIdDict{K,V}(kv) where V where K
h = WeakKeyIdDict{K,V}()
for (k,v) in kv
h[k] = v
end
return h
end
WeakKeyIdDict{K,V}(p::Pair) where V where K = setindex!(WeakKeyIdDict{K,V}(), p.second, p.first)
function WeakKeyIdDict{K,V}(ps::Pair...) where V where K
h = WeakKeyIdDict{K,V}()
sizehint!(h, length(ps))
for p in ps
h[p.first] = p.second
end
return h
end
WeakKeyIdDict() = WeakKeyIdDict{Any,Any}()

WeakKeyIdDict(kv::Tuple{}) = WeakKeyIdDict()
Base.copy(d::WeakKeyIdDict) = WeakKeyIdDict(d)

WeakKeyIdDict(ps::Pair{K,V}...) where {K,V} = WeakKeyIdDict{K,V}(ps)
WeakKeyIdDict(ps::Pair{K}...) where {K} = WeakKeyIdDict{K,Any}(ps)
WeakKeyIdDict(ps::(Pair{K,V} where K)...) where {V} = WeakKeyIdDict{Any,V}(ps)
WeakKeyIdDict(ps::Pair...) = WeakKeyIdDict{Any,Any}(ps)

function WeakKeyIdDict(kv)
try
Base.dict_with_eltype((K, V) -> WeakKeyIdDict{K, V}, kv, eltype(kv))
catch e
if !Base.isiterable(typeof(kv)) || !all(x->isa(x,Union{Tuple,Pair}),kv)
throw(ArgumentError("WeakKeyIdDict(kv): kv needs to be an iterator of tuples or pairs"))
else
rethrow(e)
end
end
end

Base.empty(d::WeakKeyIdDict, ::Type{K}, ::Type{V}) where {K, V} = WeakKeyIdDict{K, V}()

Base.islocked(wkh::WeakKeyIdDict) = islocked(wkh.lock)
Base.lock(f, wkh::WeakKeyIdDict) = lock(f, wkh.lock)
Base.trylock(f, wkh::WeakKeyIdDict) = trylock(f, wkh.lock)

function Base.setindex!(wkh::WeakKeyIdDict{K}, v, key) where K
!isa(key, K) && throw(ArgumentError("$key is not a valid key for type $K"))
finalizer(wkh.finalizer, key)
lock(wkh) do
wkh.ht[WeakRefForWeakDict(key)] = v
end
return wkh
end

function Base.getkey(wkh::WeakKeyIdDict{K}, kk, default) where K
return lock(wkh) do
k = getkey(wkh.ht, WeakRefForWeakDict(kk), secret_table_token)
k === secret_table_token && return default
return k.w.value::K
end
end

Base.get(wkh::WeakKeyIdDict{K}, key, default) where {K} = lock(() -> get(wkh.ht, WeakRefForWeakDict(key), default), wkh)
Base.get(default::Base.Callable, wkh::WeakKeyIdDict{K}, key) where {K} = lock(() -> get(default, wkh.ht, WeakRefForWeakDict(key)), wkh)
Base.get!(wkh::WeakKeyIdDict{K}, key, default) where {K} = lock(() -> get!(wkh.ht, WeakRefForWeakDict(key), default), wkh)
Base.get!(default::Base.Callable, wkh::WeakKeyIdDict{K}, key) where {K} = lock(() -> get!(default, wkh.ht, WeakRefForWeakDict(key)), wkh)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to use Base.@lock here instead of the closure form. It can be more performant and keeps stacktraces shorter; I have a recursive call where this almost doubles the number of stack frames. It is exported on Julia 1.9 and present in many versions of Julia, but not present in 1.0 unfortunately, which DataStructures still supports.

Base.pop!(wkh::WeakKeyIdDict{K}, key) where {K} = lock(() -> pop!(wkh.ht, WeakRefForWeakDict(key)), wkh)
Base.pop!(wkh::WeakKeyIdDict{K}, key, default) where {K} = lock(() -> pop!(wkh.ht, WeakRefForWeakDict(key), default), wkh)
Base.delete!(wkh::WeakKeyIdDict, key) = lock(() -> delete!(wkh.ht, WeakRefForWeakDict(key)), wkh)
Base.empty!(wkh::WeakKeyIdDict) = (lock(() -> empty!(wkh.ht), wkh); wkh)
Base.haskey(wkh::WeakKeyIdDict{K}, key) where {K} = lock(() -> haskey(wkh.ht, WeakRefForWeakDict(key)), wkh)
Base.getindex(wkh::WeakKeyIdDict{K}, key) where {K} = lock(() -> getindex(wkh.ht, WeakRefForWeakDict(key)), wkh)
Base.isempty(wkh::WeakKeyIdDict) = isempty(wkh.ht)
Base.length(t::WeakKeyIdDict) = length(t.ht)

function Base.iterate(t::WeakKeyIdDict{K,V}) where V where K
gc_token = Ref{Bool}(false) # no keys will be deleted via finalizers until this token is gc'd
finalizer(gc_token) do r
if r[]
r[] = false
unlock(t.lock)
end
end
s = lock(t.lock)
iterate(t, (gc_token,))
end
function Base.iterate(t::WeakKeyIdDict{K,V}, state) where V where K
gc_token = first(state)
y = iterate(t.ht, Base.tail(state)...)
y === nothing && return nothing
wkv, i = y
kv = Pair{K,V}(wkv[1].w.value::K, wkv[2])
return (kv, (gc_token, i))
end

Base.filter!(f, d::WeakKeyIdDict) = Base.filter_in_one_pass!(f, d)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ tests = ["deprecations",
"mutable_binheap",
"minmax_heap",
"default_dict",
"weakkeyid_dict",
"trie",
"list",
"mutable_list",
Expand Down
68 changes: 68 additions & 0 deletions test/test_weakkeyid_dict.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
@testset "WeakKeyIdDict" begin
A = [1]
B = [2]
C = [3]

# construction
wkd = WeakKeyIdDict()
wkd[A] = 2
wkd[B] = 3
wkd[C] = 4
dd = convert(Dict{Any,Any},wkd)
@test WeakKeyIdDict(dd) == wkd
@test convert(WeakKeyIdDict{Any, Any}, dd) == wkd
@test isa(WeakKeyIdDict(dd), WeakKeyIdDict{Any,Any})
@test WeakKeyIdDict(A=>2, B=>3, C=>4) == wkd
@test isa(WeakKeyIdDict(A=>2, B=>3, C=>4), WeakKeyIdDict{Array{Int,1},Int})
@test WeakKeyIdDict(a=>i+1 for (i,a) in enumerate([A,B,C]) ) == wkd
@test WeakKeyIdDict([(A,2), (B,3), (C,4)]) == wkd
@test WeakKeyIdDict(Pair(A,2), Pair(B,3), Pair(C,4)) == wkd
@test copy(wkd) == wkd

@test length(wkd) == 3
@test !isempty(wkd)
res = pop!(wkd, C)
@test res == 4
@test C ∉ keys(wkd)
@test 4 ∉ values(wkd)
@test length(wkd) == 2
@test !isempty(wkd)
wkd = filter!( p -> p.first != B, wkd)
@test B ∉ keys(wkd)
@test 3 ∉ values(wkd)
@test length(wkd) == 1
@test WeakKeyIdDict(Pair(A, 2)) == wkd
@test !isempty(wkd)

wkd = empty!(wkd)
@test wkd == empty(wkd)
@test typeof(wkd) == typeof(empty(wkd))
@test length(wkd) == 0
@test isempty(wkd)
@test isa(wkd, WeakKeyIdDict)

@test_throws ArgumentError WeakKeyIdDict([1, 2, 3])

# WeakKeyIdDict does not convert keys
@test_throws ArgumentError WeakKeyIdDict{Int,Any}(5.0=>1)

# WeakKeyIdDict hashes with object-id
AA = copy(A)
GC.@preserve A AA begin
wkd = WeakKeyIdDict(A=>1, AA=>2)
@test length(wkd)==2
kk = collect(keys(wkd))
@test kk[1]==kk[2]
@test kk[1]!==kk[2]
end

# WeakKeyIdDict compares to other dicts:
@test IdDict(A=>1)!=WeakKeyIdDict(A=>1)
@test Dict(A=>1)==WeakKeyIdDict(A=>1)
@test Dict(copy(A)=>1)!=WeakKeyIdDict(A=>1)

# issue #26939
d26939 = WeakKeyIdDict()
d26939[big"1.0" + 1.1] = 1
GC.gc() # make sure this doesn't segfault
end