Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include custom less-than in AVL Tree #729

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 27 additions & 26 deletions src/avl_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ AVLTreeNode_or_null{T} = Union{AVLTreeNode{T}, Nothing}
mutable struct AVLTree{T}
root::AVLTreeNode_or_null{T}
count::Int
lt::Function
Copy link
Member

@eulerkochy eulerkochy Mar 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing I am deliberating is the member name, lt. I suggest comp for better readability of the code. Otherwise this looks good to me.

Also, try squashing the commits, into single one. A lot of Update src/avl_tree.jl doesn't make sense in the long scheme of things. A meaningful commit message such as Added a custom compare to AVLTree will be good.

The codes looks fine to be merged after these two changes.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In isolation I agree, but I went with lt for consistency with sort and all the other sorting functions, which were the only other Julia functions could think of which take a comparison as an optional parameter. If you still prefer comp I can switch it over though.

Copy link
Member

@oxinabox oxinabox Mar 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an abstractly typed field.
That is going to make things slow as it will cause a dynamic dispatch.

Instead do:

mutable struct AVLTree{T, F}
    root::AVLTreeNode_or_null{T}
    count::Int
    lt::F
end


AVLTree{T}() where T = new{T}(nothing, 0)
AVLTree{T}(; lt = isless) where T = new{T}(nothing, 0, lt)
end

AVLTree() = AVLTree{Any}()
Expand Down Expand Up @@ -58,7 +59,7 @@ end
"""
left_rotate(node_x::AVLTreeNode)

Performs a left-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
Performs a left-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
"""
function left_rotate(z::AVLTreeNode)
y = z.rightChild
Expand All @@ -75,7 +76,7 @@ end
"""
right_rotate(node_x::AVLTreeNode)

Performs a right-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
Performs a right-rotation on `node_x`, updates height of the nodes, and returns the rotated node.
"""
function right_rotate(z::AVLTreeNode)
y = z.leftChild
Expand All @@ -90,9 +91,9 @@ function right_rotate(z::AVLTreeNode)
end

"""
minimum_node(tree::AVLTree, node::AVLTreeNode)
minimum_node(tree::AVLTree, node::AVLTreeNode)

Returns the AVLTreeNode with minimum value in subtree of `node`.
Returns the AVLTreeNode with minimum value in subtree of `node`.
"""
function minimum_node(node::Union{AVLTreeNode, Nothing})
while node != nothing && node.leftChild != nothing
Expand All @@ -107,17 +108,17 @@ function search_node(tree::AVLTree{K}, d::K) where K
while node != nothing && node.data != nothing && node.data != d

prev = node
if d < node.data
if tree.lt(d, node.data)
node = node.leftChild
else
node = node.rightChild
end
end

return (node == nothing) ? prev : node
end

function Base.haskey(tree::AVLTree{K}, d::K) where K
function Base.haskey(tree::AVLTree{K}, d::K) where K
(tree.root == nothing) && return false
node = search_node(tree, d)
return (node.data == d)
Expand All @@ -130,18 +131,18 @@ function Base.insert!(tree::AVLTree{K}, d::K) where K
function insert_node(node::Union{AVLTreeNode, Nothing}, key)
if node == nothing
return AVLTreeNode{K}(key)
elseif key < node.data
elseif tree.lt(key, node.data)
node.leftChild = insert_node(node.leftChild, key)
else
node.rightChild = insert_node(node.rightChild, key)
end

node.subsize = compute_subtree_size(node)
node.height = compute_height(node)
balance = get_balance(node)

if balance > 1
if key < node.leftChild.data
if tree.lt(key, node.leftChild.data)
return right_rotate(node)
else
node.leftChild = left_rotate(node.leftChild)
Expand All @@ -150,7 +151,7 @@ function Base.insert!(tree::AVLTree{K}, d::K) where K
end

if balance < -1
if key > node.rightChild.data
if tree.lt(node.rightChild.data, key)
return left_rotate(node)
else
node.rightChild = right_rotate(node.rightChild)
Expand All @@ -176,9 +177,9 @@ end
function Base.delete!(tree::AVLTree{K}, d::K) where K

function delete_node!(node::Union{AVLTreeNode, Nothing}, key)
if key < node.data
if tree.lt(key, node.data)
node.leftChild = delete_node!(node.leftChild, key)
elseif key > node.data
elseif tree.lt(node.data, key)
node.rightChild = delete_node!(node.rightChild, key)
else
if node.leftChild == nothing
Expand All @@ -187,13 +188,13 @@ function Base.delete!(tree::AVLTree{K}, d::K) where K
elseif node.rightChild == nothing
result = node.leftChild
return result
else
else
result = minimum_node(node.rightChild)
node.data = result.data
node.rightChild = delete_node!(node.rightChild, result.data)
end
end
end

node.subsize = compute_subtree_size(node)
node.height = compute_height(node)
balance = get_balance(node)
Expand All @@ -214,14 +215,14 @@ function Base.delete!(tree::AVLTree{K}, d::K) where K
node.rightChild = right_rotate(node.rightChild)
return left_rotate(node)
end
end
end

return node
end

# if the key is not in the tree, do nothing and return the tree
!haskey(tree, d) && return tree

# if the key is present, delete it from the tree
tree.root = delete_node!(tree.root, d)
tree.count -= 1
Expand All @@ -238,18 +239,18 @@ function sorted_rank(tree::AVLTree{K}, key::K) where K
node = tree.root
rank = 0
while node.data != key
if (node.data < key)
if tree.lt(node.data, key)
rank += (1 + get_subsize(node.leftChild))
node = node.rightChild
else
node = node.leftChild
end
end
end
rank += (1 + get_subsize(node.leftChild))
return rank
end

function Base.getindex(tree::AVLTree{K}, ind::Integer) where K
function Base.getindex(tree::AVLTree{K}, ind::Integer) where K
@boundscheck (1 <= ind <= tree.count) || throw(BoundsError("$ind should be in between 1 and $(tree.count)"))
function traverse_tree(node::AVLTreeNode_or_null, idx)
if (node != nothing)
Expand All @@ -263,6 +264,6 @@ function Base.getindex(tree::AVLTree{K}, ind::Integer) where K
end
end
end
value = traverse_tree(tree.root, ind)
value = traverse_tree(tree.root, ind)
return value
end
end