diff --git a/src/disjoint_set.jl b/src/disjoint_set.jl index 220f9790..6b00608a 100644 --- a/src/disjoint_set.jl +++ b/src/disjoint_set.jl @@ -242,6 +242,10 @@ end Find the root element of the subset in `s` which has the element `x` as a member. """ find_root!(s::DisjointSet{T}, x::T) where {T} = s.revmap[find_root!(s.internal, s.intmap[x])] +find_root!(s::DisjointSet{T}, x::T, ::PCIterative) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCIterative())] +find_root!(s::DisjointSet{T}, x::T, ::PCRecursive) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCRecursive())] +find_root!(s::DisjointSet{T}, x::T, ::PCHalving) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCHalving())] +find_root!(s::DisjointSet{T}, x::T, ::PCSplitting) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCSplitting())] """ in_same_set(s::DisjointSet{T}, x::T, y::T) diff --git a/test/bench_disjoint_set.jl b/test/bench_disjoint_set.jl index 8cc4bb09..f7c0796f 100644 --- a/test/bench_disjoint_set.jl +++ b/test/bench_disjoint_set.jl @@ -37,7 +37,7 @@ benchmark `find` operation function create_disjoint_set_struct(n::Int) parents = [1; collect(1:n-1)] # each element's parent is its predecessor ranks = zeros(Int, n) # ranks are all zero - DataStructures.IntDisjointSet(parents, ranks, n) + IntDisjointSet(parents, ranks, n) end # benchmarking function @@ -47,20 +47,20 @@ function benchmark_find_root(n::Int) println("Recursive may path compression may encounter stack-overflow; skipping") else s = create_disjoint_set_struct(n) - @btime DataStructures.find_root!($s, $n, DataStructures.PCRecursive()) + @btime find_root!($s, $n, PCRecursive()) end println("Benchmarking iterative path compression implementation (find_root_iterative!):") s = create_disjoint_set_struct(n) # reset parents - @btime DataStructures.find_root!($s, $n, DataStructures.PCIterative()) + @btime find_root!($s, $n, PCIterative()) println("Benchmarking path-halving implementation (find_root_halving!):") s = create_disjoint_set_struct(n) # reset parents - @btime DataStructures.find_root!($s, $n, DataStructures.PCHalving()) + @btime find_root!($s, $n, PCHalving()) println("Benchmarking path-splitting implementation (find_root_path_splitting!):") s = create_disjoint_set_struct(n) # reset parents - @btime DataStructures.find_root!($s, $n, DataStructures.PCSplitting()) + @btime find_root!($s, $n, PCSplitting()) end # run benchmark tests diff --git a/test/test_disjoint_set.jl b/test/test_disjoint_set.jl index e96b3d9c..d146f309 100644 --- a/test/test_disjoint_set.jl +++ b/test/test_disjoint_set.jl @@ -29,10 +29,16 @@ @test num_groups(s) == T(9) @test in_same_set(s, T(2), T(3)) @test find_root!(s, T(3)) == T(2) + @test find_root!(s, T(3), PCIterative()) == T(2) + @test find_root!(s, T(3), PCHalving()) == T(2) + @test find_root!(s, T(3), PCSplitting()) == T(2) union!(s, T(3), T(2)) @test num_groups(s) == T(9) @test in_same_set(s, T(2), T(3)) @test find_root!(s, T(3)) == T(2) + @test find_root!(s, T(3), PCIterative()) == T(2) + @test find_root!(s, T(3), PCHalving()) == T(2) + @test find_root!(s, T(3), PCSplitting()) == T(2) end @testset "more tests" begin @@ -48,19 +54,19 @@ @test union!(s, T(8), T(5)) == T(8) @test num_groups(s) == T(7) @test find_root!(s, T(6)) == T(8) - @test find_root!(s, T(6), :iterative) == T(8) - @test find_root!(s, T(6), :halving) == T(8) - @test find_root!(s, T(6), :splitting) == T(8) + @test find_root!(s, T(6), PCIterative()) == T(8) + @test find_root!(s, T(6), PCHalving()) == T(8) + @test find_root!(s, T(6), PCSplitting()) == T(8) union!(s, T(2), T(6)) @test find_root!(s, T(2)) == T(8) root1 = find_root!(s, T(6)) - root1 = find_root!(s, T(6), :iterative) - root1 = find_root!(s, T(6), :halving) - root1 = find_root!(s, T(6), :splitting) + root1 = find_root!(s, T(6), PCIterative()) + root1 = find_root!(s, T(6), PCHalving()) + root1 = find_root!(s, T(6), PCSplitting()) root2 = find_root!(s, T(2)) - root2 = find_root!(s, T(2), :iterative) - root2 = find_root!(s, T(2), :halving) - root2 = find_root!(s, T(2), :splitting) + root2 = find_root!(s, T(2), PCIterative()) + root2 = find_root!(s, T(2), PCHalving()) + root2 = find_root!(s, T(2), PCSplitting()) @test root_union!(s, T(root1), T(root2)) == T(8) @test union!(s, T(5), T(6)) == T(8) end @@ -107,6 +113,12 @@ r = [find_root!(s, i) for i in 1 : 10] @test isequal(r, collect(1:10)) + r = [find_root!(s, i, PCIterative()) for i in 1 : 10] + @test isequal(r, collect(1:10)) + r = [find_root!(s, i, PCHalving()) for i in 1 : 10] + @test isequal(r, collect(1:10)) + r = [find_root!(s, i, PCSplitting()) for i in 1 : 10] + @test isequal(r, collect(1:10)) end @testset "union!" begin @@ -126,6 +138,57 @@ @test num_groups(s) == 2 end + @testset "union! PCIterative" begin + for i = 1 : 5 + x = 2 * i - 1 + y = 2 * i + union!(s, x, y) + @test find_root!(s, x, PCIterative()) == find_root!(s, y, PCIterative()) + end + + + @test union!(s, 1, 4) == find_root!(s, 1, PCIterative()) + @test union!(s, 3, 5) == find_root!(s, 1, PCIterative()) + @test union!(s, 7, 9) == find_root!(s, 7, PCIterative()) + + @test length(s) == 10 + @test num_groups(s) == 2 + end + + @testset "union! PCHalving" begin + for i = 1 : 5 + x = 2 * i - 1 + y = 2 * i + union!(s, x, y) + @test find_root!(s, x, PCHalving()) == find_root!(s, y, PCHalving()) + end + + + @test union!(s, 1, 4) == find_root!(s, 1, PCHalving()) + @test union!(s, 3, 5) == find_root!(s, 1, PCHalving()) + @test union!(s, 7, 9) == find_root!(s, 7, PCHalving()) + + @test length(s) == 10 + @test num_groups(s) == 2 + end + + @testset "union! PCSplitting" begin + for i = 1 : 5 + x = 2 * i - 1 + y = 2 * i + union!(s, x, y) + @test find_root!(s, x, PCSplitting()) == find_root!(s, y, PCSplitting()) + end + + + @test union!(s, 1, 4) == find_root!(s, 1, PCSplitting()) + @test union!(s, 3, 5) == find_root!(s, 1, PCSplitting()) + @test union!(s, 7, 9) == find_root!(s, 7, PCSplitting()) + + @test length(s) == 10 + @test num_groups(s) == 2 + end + @testset "r0" begin r0 = [ find_root!(s,i) for i in 1:10 ] # Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11 @@ -139,30 +202,51 @@ @test isequal(r, r0) end + @testset "r0 Iterative" begin + r0 = [ find_root!(s,i) for i in 1:10 ] + # Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11 + push!(s, 17) + + @test length(s) == 11 + @test num_groups(s) == 3 + + r0 = [ r0 ; 17] + r = [find_root!(s, i, PCIterative()) for i in [1 : 10; 17] ] + @test isequal(r, r0) + end + + @testset "r0 Splitting" begin + r0 = [ find_root!(s,i) for i in 1:10 ] + # Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11 + push!(s, 17) + + @test length(s) == 11 + @test num_groups(s) == 3 + + r0 = [ r0 ; 17] + r = [find_root!(s, i, PCSplitting()) for i in [1 : 10; 17] ] + @test isequal(r, r0) + end + + @testset "r0 Halving" begin + r0 = [ find_root!(s,i) for i in 1:10 ] + # Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11 + push!(s, 17) + + @test length(s) == 11 + @test num_groups(s) == 3 + + r0 = [ r0 ; 17] + r = [find_root!(s, i, PCHalving()) for i in [1 : 10; 17] ] + @test isequal(r, r0) + end + @testset "root_union!" begin root1 = find_root!(s, 7) root2 = find_root!(s, 3) @test root1 != root2 root_union!(s, 7, 3) @test find_root!(s, 7) == find_root!(s, 3) - - root1 = find_root!(s, 7, :iterative) - root2 = find_root!(s, 3, :iterative) - @test root1 != root2 - root_union!(s, 7, 3) - @test find_root!(s, 7, :iterative) == find_root!(s, 3, :iterative) - - root1 = find_root!(s, 7, :halving) - root2 = find_root!(s, 3, :halving) - @test root1 != root2 - root_union!(s, 7, 3) - @test find_root!(s, 7, :halving) == find_root!(s, 3, :halving) - - root1 = find_root!(s, 7, :splitting) - root2 = find_root!(s, 3, :splitting) - @test root1 != root2 - root_union!(s, 7, 3) - @test find_root!(s, 7, :splitting) == find_root!(s, 3, :splitting) end @testset "Some tests using non-integer disjoint sets" begin