Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include custom less-than in AVL Tree #729

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 36 additions & 31 deletions src/avl_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@ AVLTreeNode(d) = AVLTreeNode{Any}(d)

AVLTreeNode_or_null{T} = Union{AVLTreeNode{T}, Nothing}

mutable struct AVLTree{T}
mutable struct AVLTree{T, Ord<:Ordering}
root::AVLTreeNode_or_null{T}
count::Int
ord::Ord

AVLTree{T}() where T = new{T}(nothing, 0)
AVLTree{T, Ord}(o::Ord=Forward) where {T, Ord<:Ordering} = new{T, Ord}(nothing, 0, o)
end

AVLTree() = AVLTree{Any}()
AVLTree() = AVLTree{Any, ForwardOrdering}(Forward)

AVLTree{T}() where T = AVLTree{T, ForwardOrdering}(Forward)

AVLTree{T}(o::Ord) where {T, Ord<:Ordering} = AVLTree{T, Ord}(o)

Base.length(tree::AVLTree) = tree.count

Expand Down Expand Up @@ -58,7 +63,7 @@ end
"""
left_rotate(node_x::AVLTreeNode)

Performs a left-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
Performs a left-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
"""
function left_rotate(z::AVLTreeNode)
y = z.rightChild
Expand All @@ -75,7 +80,7 @@ end
"""
right_rotate(node_x::AVLTreeNode)

Performs a right-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
Performs a right-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
"""
function right_rotate(z::AVLTreeNode)
y = z.leftChild
Expand All @@ -90,9 +95,9 @@ function right_rotate(z::AVLTreeNode)
end

"""
minimum_node(tree::AVLTree, node::AVLTreeNode)
minimum_node(tree::AVLTree, node::AVLTreeNode)

Returns the AVLTreeNode with minimum value in subtree of `node`.
Returns the AVLTreeNode with minimum value in subtree of `node`.
"""
function minimum_node(node::Union{AVLTreeNode, Nothing})
while node != nothing && node.leftChild != nothing
Expand All @@ -104,23 +109,23 @@ end
function search_node(tree::AVLTree{K}, d::K) where K
prev = nothing
node = tree.root
while node != nothing && node.data != nothing && node.data != d
while node != nothing && node.data != nothing && !eq(tree.ord, node.data, d)

prev = node
if d < node.data
if lt(tree.ord, d, node.data)
node = node.leftChild
else
node = node.rightChild
end
end

return (node == nothing) ? prev : node
end

function Base.haskey(tree::AVLTree{K}, d::K) where K
function Base.haskey(tree::AVLTree{K}, d::K) where K
(tree.root == nothing) && return false
node = search_node(tree, d)
return (node.data == d)
return eq(tree.ord, node.data, d)
end

Base.in(key, tree::AVLTree) = haskey(tree, key)
Expand All @@ -130,18 +135,18 @@ function Base.insert!(tree::AVLTree{K}, d::K) where K
function insert_node(node::Union{AVLTreeNode, Nothing}, key)
if node == nothing
return AVLTreeNode{K}(key)
elseif key < node.data
elseif lt(tree.ord, key, node.data)
node.leftChild = insert_node(node.leftChild, key)
else
node.rightChild = insert_node(node.rightChild, key)
end

node.subsize = compute_subtree_size(node)
node.height = compute_height(node)
balance = get_balance(node)

if balance > 1
if key < node.leftChild.data
if lt(tree.ord, key, node.leftChild.data)
return right_rotate(node)
else
node.leftChild = left_rotate(node.leftChild)
Expand All @@ -150,7 +155,7 @@ function Base.insert!(tree::AVLTree{K}, d::K) where K
end

if balance < -1
if key > node.rightChild.data
if lt(tree.ord, node.rightChild.data, key)
return left_rotate(node)
else
node.rightChild = right_rotate(node.rightChild)
Expand All @@ -176,9 +181,9 @@ end
function Base.delete!(tree::AVLTree{K}, d::K) where K

function delete_node!(node::Union{AVLTreeNode, Nothing}, key)
if key < node.data
if lt(tree.ord, key, node.data)
node.leftChild = delete_node!(node.leftChild, key)
elseif key > node.data
elseif lt(tree.ord, node.data, key)
node.rightChild = delete_node!(node.rightChild, key)
else
if node.leftChild == nothing
Expand All @@ -187,13 +192,13 @@ function Base.delete!(tree::AVLTree{K}, d::K) where K
elseif node.rightChild == nothing
result = node.leftChild
return result
else
else
result = minimum_node(node.rightChild)
node.data = result.data
node.rightChild = delete_node!(node.rightChild, result.data)
end
end
end

node.subsize = compute_subtree_size(node)
node.height = compute_height(node)
balance = get_balance(node)
Expand All @@ -214,14 +219,14 @@ function Base.delete!(tree::AVLTree{K}, d::K) where K
node.rightChild = right_rotate(node.rightChild)
return left_rotate(node)
end
end
end

return node
end

# if the key is not in the tree, do nothing and return the tree
!haskey(tree, d) && return tree

# if the key is present, delete it from the tree
tree.root = delete_node!(tree.root, d)
tree.count -= 1
Expand All @@ -237,19 +242,19 @@ function sorted_rank(tree::AVLTree{K}, key::K) where K
!haskey(tree, key) && throw(KeyError(key))
node = tree.root
rank = 0
while node.data != key
if (node.data < key)
while !eq(tree.ord, node.data, key)
if lt(tree.ord, node.data, key)
rank += (1 + get_subsize(node.leftChild))
node = node.rightChild
else
node = node.leftChild
end
end
end
rank += (1 + get_subsize(node.leftChild))
return rank
end

function Base.getindex(tree::AVLTree{K}, ind::Integer) where K
function Base.getindex(tree::AVLTree{K}, ind::Integer) where K
@boundscheck (1 <= ind <= tree.count) || throw(BoundsError("$ind should be in between 1 and $(tree.count)"))
function traverse_tree(node::AVLTreeNode_or_null, idx)
if (node != nothing)
Expand All @@ -263,6 +268,6 @@ function Base.getindex(tree::AVLTree{K}, ind::Integer) where K
end
end
end
value = traverse_tree(tree.root, ind)
value = traverse_tree(tree.root, ind)
return value
end
end