Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow non-string indices for Trie #759

Merged
merged 1 commit into from
Oct 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 53 additions & 35 deletions src/trie.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,56 @@
mutable struct Trie{T}
value::T
children::Dict{Char,Trie{T}}
mutable struct Trie{K,V}
value::V
children::Dict{K,Trie{K,V}}
is_key::Bool

function Trie{T}() where T
self = new{T}()
self.children = Dict{Char,Trie{T}}()
function Trie{K,V}() where {K,V}
self = new{K,V}()
self.children = Dict{K,Trie{K,V}}()
self.is_key = false
return self
end

function Trie{T}(ks, vs) where T
t = Trie{T}()
for (k, v) in zip(ks, vs)
t[k] = v
end
return t
function Trie{K,V}(ks, vs) where {K,V}
return Trie{K,V}(zip(ks, vs))
end

function Trie{T}(kv) where T
t = Trie{T}()
function Trie{K,V}(kv) where {K,V}
t = Trie{K,V}()
for (k,v) in kv
t[k] = v
end
return t
end
end

Trie() = Trie{Any}()
Trie(ks::AbstractVector{K}, vs::AbstractVector{V}) where {K<:AbstractString,V} = Trie{V}(ks, vs)
Trie(kv::AbstractVector{Tuple{K,V}}) where {K<:AbstractString,V} = Trie{V}(kv)
Trie(kv::AbstractDict{K,V}) where {K<:AbstractString,V} = Trie{V}(kv)
Trie(ks::AbstractVector{K}) where {K<:AbstractString} = Trie{Nothing}(ks, similar(ks, Nothing))
Trie() = Trie{Any,Any}()
Trie(ks::AbstractVector{K}, vs::AbstractVector{V}) where {K,V} = Trie{eltype(K),V}(ks, vs)
Trie(kv::AbstractVector{Tuple{K,V}}) where {K,V} = Trie{eltype(K),V}(kv)
Trie(kv::AbstractDict{K,V}) where {K,V} = Trie{eltype(K),V}(kv)
Trie(ks::AbstractVector{K}) where {K} = Trie{eltype(K),Nothing}(ks, similar(ks, Nothing))

function Base.setindex!(t::Trie{T}, val, key::AbstractString) where T
value = convert(T, val) # we don't want to iterate before finding out it fails
function Base.setindex!(t::Trie{K,V}, val, key) where {K,V}
value = convert(V, val) # we don't want to iterate before finding out it fails
node = t
for char in key
if !haskey(node.children, char)
node.children[char] = Trie{T}()
node.children[char] = Trie{K,V}()
end
node = node.children[char]
end
node.is_key = true
node.value = value
end

function Base.getindex(t::Trie, key::AbstractString)
function Base.getindex(t::Trie, key)
node = subtrie(t, key)
if node != nothing && node.is_key
return node.value
end
throw(KeyError("key not found: $key"))
end

function subtrie(t::Trie, prefix::AbstractString)
function subtrie(t::Trie, prefix)
node = t
for char in prefix
if !haskey(node.children, char)
Expand All @@ -66,30 +62,38 @@ function subtrie(t::Trie, prefix::AbstractString)
return node
end

function Base.haskey(t::Trie, key::AbstractString)
function Base.haskey(t::Trie, key)
node = subtrie(t, key)
node != nothing && node.is_key
end

function Base.get(t::Trie, key::AbstractString, notfound)
function Base.get(t::Trie, key, notfound)
node = subtrie(t, key)
if node != nothing && node.is_key
return node.value
end
return notfound
end

function Base.keys(t::Trie, prefix::AbstractString="", found=AbstractString[])
_concat(prefix::String, char::Char) = string(prefix, char)
_concat(prefix::Vector{T}, char::T) where {T} = vcat(prefix, char)
Copy link
Member

@oxinabox oxinabox Oct 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is all we need to add in order to support arbitrary iterators just:

Suggested change
_concat(prefix::Vector{T}, char::T) where {T} = vcat(prefix, char)
_concat(prefix, char) = Iterators.flatten(prefix, (char,))

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean Iterators.flatten((prefix, (char,)))?
That would cause the types to explode:

julia> p1 = [1]
1-element Vector{Int64}:
 1

julia> p2 = Iterators.flatten((p1, (2,)))
Base.Iterators.Flatten{Tuple{Vector{Int64}, Tuple{Int64}}}(([1], (2,)))

julia> p3 = Iterators.flatten((p2, (3,)))
Base.Iterators.Flatten{Tuple{Base.Iterators.Flatten{Tuple{Vector{Int64}, Tuple{Int64}}}, Tuple{Int64}}}((Base.Iterators.Flatten{Tuple{Vector{Int64}, Tuple{Int64}}}(([1], (2,))), (3,)))

As _concat only receives objects derived from _empty_prefix, which are either Strings or Vectors, I don't think this needs to be generalized further. Instead, it might be good to wrap the 3-argument keys (Base.keys(t) = _keys(t)), such that arguments 2 and 3 are not accessible to the average user.


_empty_prefix(::Trie{Char,V}) where {V} = ""
_empty_prefix(::Trie{K,V}) where {K,V} = K[]

function Base.keys(t::Trie{K,V},
prefix=_empty_prefix(t),
found=Vector{typeof(prefix)}()) where {K,V}
if t.is_key
push!(found, prefix)
end
for (char,child) in t.children
keys(child, string(prefix,char), found)
keys(child, _concat(prefix, char), found)
end
return found
end

function keys_with_prefix(t::Trie, prefix::AbstractString)
function keys_with_prefix(t::Trie, prefix)
st = subtrie(t, prefix)
st != nothing ? keys(st,prefix) : []
end
Expand All @@ -101,7 +105,7 @@ end
# see the comments and implementation below for details.
struct TrieIterator
t::Trie
str::AbstractString
str
end

# At the start, there is no previous iteration,
Expand All @@ -120,11 +124,11 @@ function Base.iterate(it::TrieIterator, (t, i) = (it.t, 0))
end
end

partial_path(t::Trie, str::AbstractString) = TrieIterator(t, str)
partial_path(t::Trie, str) = TrieIterator(t, str)
Base.IteratorSize(::Type{TrieIterator}) = Base.SizeUnknown()

"""
find_prefixes(t::Trie, str::AbstractString)
find_prefixes(t::Trie, str)

Find all keys from the `Trie` that are prefix of the given string

Expand All @@ -137,10 +141,24 @@ julia> find_prefixes(t, "ABCDE")
"A"
"ABC"
"ABCD"

julia> t′ = Trie([1:1, 1:3, 1:4, 2:4]);

julia> find_prefixes(t′, 1:5)
3-element Vector{UnitRange{Int64}}:
1:1
1:3
1:4

julia> find_prefixes(t′, [1,2,3,4,5])
3-element Vector{Vector{Int64}}:
[1]
[1, 2, 3]
[1, 2, 3, 4]
```
"""
function find_prefixes(t::Trie, str::AbstractString)
prefixes = AbstractString[]
function find_prefixes(t::Trie, str::T) where {T}
prefixes = T[]
it = partial_path(t, str)
idx = 0
for t in it
Expand All @@ -150,4 +168,4 @@ function find_prefixes(t::Trie, str::AbstractString)
idx = nextind(str, idx)
end
return prefixes
end
end
2 changes: 1 addition & 1 deletion test/test_deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# These are the tests for deprecated features, they should be deleted along with them

@testset "Trie: path iterator" begin
t = Trie{Int}()
t = Trie{Char,Int}()
t["rob"] = 27
t["roger"] = 52
t["kevin"] = Int8(11)
Expand Down
26 changes: 19 additions & 7 deletions test/test_trie.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@testset "Trie" begin
@testset "Core Functionality" begin
t = Trie{Int}()
t = Trie{Char,Int}()
t["amy"] = 56
t["ann"] = 15
t["emma"] = 30
Expand All @@ -19,14 +19,14 @@
ks = ["amy", "ann", "emma", "rob", "roger"]
vs = [56, 15, 30, 27, 52]
kvs = collect(zip(ks, vs))
@test isa(Trie(ks, vs), Trie{Int})
@test isa(Trie(kvs), Trie{Int})
@test isa(Trie(Dict(kvs)), Trie{Int})
@test isa(Trie(ks), Trie{Nothing})
@test isa(Trie(ks, vs), Trie{Char,Int})
@test isa(Trie(kvs), Trie{Char,Int})
@test isa(Trie(Dict(kvs)), Trie{Char,Int})
@test isa(Trie(ks), Trie{Char,Nothing})
end

@testset "partial_path iterator" begin
t = Trie{Int}()
t = Trie{Char,Int}()
t["rob"] = 27
t["roger"] = 52
t["kevin"] = Int8(11)
Expand All @@ -53,7 +53,7 @@
@test collect(partial_path(t, "東京")) == [t0, t1, t2]
@test collect(partial_path(t, "東京スカイツリー")) == [t0, t1, t2]
end

@testset "find_prefixes" begin
t = Trie(["A", "ABC", "ABD", "BCD"])
prefixes = find_prefixes(t, "ABCDE")
Expand All @@ -66,4 +66,16 @@
@test prefixes == ["東京都", "東京都渋谷区"]
end

@testset "non-string indexing" begin
t = Trie{Int,Int}()
t[[1,2,3,4]] = 1
t[[1,2]] = 2
@test haskey(t, [1,2])
@test get(t, [1,2], nothing) == 2
st = subtrie(t, [1,2,3])
@test keys(st) == [[4]]
@test st[[4]] == 1
@test find_prefixes(t, [1,2,3,5]) == [[1,2]]
@test find_prefixes(t, 1:3) == [1:2]
end
end # @testset Trie