diff --git a/src/avl_tree.jl b/src/avl_tree.jl index bacf1cbb2..e8aad8359 100644 --- a/src/avl_tree.jl +++ b/src/avl_tree.jl @@ -16,14 +16,19 @@ AVLTreeNode(d) = AVLTreeNode{Any}(d) AVLTreeNode_or_null{T} = Union{AVLTreeNode{T}, Nothing} -mutable struct AVLTree{T} +mutable struct AVLTree{T, Ord<:Ordering} root::AVLTreeNode_or_null{T} count::Int + ord::Ord - AVLTree{T}() where T = new{T}(nothing, 0) + AVLTree{T, Ord}(o::Ord=Forward) where {T, Ord<:Ordering} = new{T, Ord}(nothing, 0, o) end -AVLTree() = AVLTree{Any}() +AVLTree() = AVLTree{Any, ForwardOrdering}(Forward) + +AVLTree{T}() where T = AVLTree{T, ForwardOrdering}(Forward) + +AVLTree{T}(o::Ord) where {T, Ord<:Ordering} = AVLTree{T, Ord}(o) Base.length(tree::AVLTree) = tree.count @@ -58,7 +63,7 @@ end """ left_rotate(node_x::AVLTreeNode) -Performs a left-rotation on `node_x`, updates height of the nodes, and returns the rotated node. +Performs a left-rotation on `node_x`, updates height of the nodes, and returns the rotated node. """ function left_rotate(z::AVLTreeNode) y = z.rightChild @@ -75,7 +80,7 @@ end """ right_rotate(node_x::AVLTreeNode) -Performs a right-rotation on `node_x`, updates height of the nodes, and returns the rotated node. +Performs a right-rotation on `node_x`, updates height of the nodes, and returns the rotated node. """ function right_rotate(z::AVLTreeNode) y = z.leftChild @@ -90,9 +95,9 @@ function right_rotate(z::AVLTreeNode) end """ - minimum_node(tree::AVLTree, node::AVLTreeNode) + minimum_node(tree::AVLTree, node::AVLTreeNode) -Returns the AVLTreeNode with minimum value in subtree of `node`. +Returns the AVLTreeNode with minimum value in subtree of `node`. """ function minimum_node(node::Union{AVLTreeNode, Nothing}) while node != nothing && node.leftChild != nothing @@ -104,23 +109,23 @@ end function search_node(tree::AVLTree{K}, d::K) where K prev = nothing node = tree.root - while node != nothing && node.data != nothing && node.data != d + while node != nothing && node.data != nothing && !eq(tree.ord, node.data, d) prev = node - if d < node.data + if lt(tree.ord, d, node.data) node = node.leftChild else node = node.rightChild end end - + return (node == nothing) ? prev : node end -function Base.haskey(tree::AVLTree{K}, d::K) where K +function Base.haskey(tree::AVLTree{K}, d::K) where K (tree.root == nothing) && return false node = search_node(tree, d) - return (node.data == d) + return eq(tree.ord, node.data, d) end Base.in(key, tree::AVLTree) = haskey(tree, key) @@ -130,18 +135,18 @@ function Base.insert!(tree::AVLTree{K}, d::K) where K function insert_node(node::Union{AVLTreeNode, Nothing}, key) if node == nothing return AVLTreeNode{K}(key) - elseif key < node.data + elseif lt(tree.ord, key, node.data) node.leftChild = insert_node(node.leftChild, key) else node.rightChild = insert_node(node.rightChild, key) end - + node.subsize = compute_subtree_size(node) node.height = compute_height(node) balance = get_balance(node) - + if balance > 1 - if key < node.leftChild.data + if lt(tree.ord, key, node.leftChild.data) return right_rotate(node) else node.leftChild = left_rotate(node.leftChild) @@ -150,7 +155,7 @@ function Base.insert!(tree::AVLTree{K}, d::K) where K end if balance < -1 - if key > node.rightChild.data + if lt(tree.ord, node.rightChild.data, key) return left_rotate(node) else node.rightChild = right_rotate(node.rightChild) @@ -176,9 +181,9 @@ end function Base.delete!(tree::AVLTree{K}, d::K) where K function delete_node!(node::Union{AVLTreeNode, Nothing}, key) - if key < node.data + if lt(tree.ord, key, node.data) node.leftChild = delete_node!(node.leftChild, key) - elseif key > node.data + elseif lt(tree.ord, node.data, key) node.rightChild = delete_node!(node.rightChild, key) else if node.leftChild == nothing @@ -187,13 +192,13 @@ function Base.delete!(tree::AVLTree{K}, d::K) where K elseif node.rightChild == nothing result = node.leftChild return result - else + else result = minimum_node(node.rightChild) node.data = result.data node.rightChild = delete_node!(node.rightChild, result.data) - end + end end - + node.subsize = compute_subtree_size(node) node.height = compute_height(node) balance = get_balance(node) @@ -214,14 +219,14 @@ function Base.delete!(tree::AVLTree{K}, d::K) where K node.rightChild = right_rotate(node.rightChild) return left_rotate(node) end - end - + end + return node end # if the key is not in the tree, do nothing and return the tree !haskey(tree, d) && return tree - + # if the key is present, delete it from the tree tree.root = delete_node!(tree.root, d) tree.count -= 1 @@ -237,19 +242,19 @@ function sorted_rank(tree::AVLTree{K}, key::K) where K !haskey(tree, key) && throw(KeyError(key)) node = tree.root rank = 0 - while node.data != key - if (node.data < key) + while !eq(tree.ord, node.data, key) + if lt(tree.ord, node.data, key) rank += (1 + get_subsize(node.leftChild)) node = node.rightChild else node = node.leftChild end - end + end rank += (1 + get_subsize(node.leftChild)) return rank end -function Base.getindex(tree::AVLTree{K}, ind::Integer) where K +function Base.getindex(tree::AVLTree{K}, ind::Integer) where K @boundscheck (1 <= ind <= tree.count) || throw(BoundsError("$ind should be in between 1 and $(tree.count)")) function traverse_tree(node::AVLTreeNode_or_null, idx) if (node != nothing) @@ -263,6 +268,6 @@ function Base.getindex(tree::AVLTree{K}, ind::Integer) where K end end end - value = traverse_tree(tree.root, ind) + value = traverse_tree(tree.root, ind) return value -end \ No newline at end of file +end