diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 18a8612b..babfd6e1 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -17,6 +17,10 @@ function __init__() Requires.@require Tables="bd369af6-aec1-5ad0-b16a-f7cc5008161c" include("tables.jl") Requires.@require WeakRefStrings="ea10d353-3f73-51f8-a26c-33c1cb351aa5" begin fastpermute!(v::WeakRefStrings.StringArray, p::AbstractVector) = permute!(v, p) + @inline function roweq(a::WeakRefStrings.StringArray{String}, i, j) + weaksa = convert(WeakRefStrings.StringArray{WeakRefStrings.WeakRefString{UInt8}}, a) + @inbounds isequal(weaksa[i], weaksa[j]) + end end end diff --git a/src/sort.jl b/src/sort.jl index d9c443e2..0fe2ffce 100644 --- a/src/sort.jl +++ b/src/sort.jl @@ -4,99 +4,61 @@ fastpermute!(v::AbstractArray, p::AbstractVector) = copyto!(v, v[p]) fastpermute!(v::StructArray, p::AbstractVector) = permute!(v, p) fastpermute!(v::PooledArray, p::AbstractVector) = permute!(v, p) -optimize_isequal(v::AbstractArray) = v -optimize_isequal(v::PooledArray) = v.refs -optimize_isequal(v::StructArray{<:Union{Tuple, NamedTuple}}) = StructArray(map(optimize_isequal, fieldarrays(v))) - -recover_original(v::AbstractArray, el) = el -recover_original(v::PooledArray, el) = v.pool[el] -recover_original(v::StructArray{T}, el) where {T<:Union{Tuple, NamedTuple}} = T(map(recover_original, fieldarrays(v), el)) - -pool(v::AbstractArray, condition = !isbitstype∘eltype) = condition(v) ? convert(PooledArray, v) : v -pool(v::StructArray, condition = !isbitstype∘eltype) = replace_storage(t -> pool(t, condition), v) - function Base.permute!(c::StructArray, p::AbstractVector) foreachfield(v -> fastpermute!(v, p), c) return c end -struct TiedIndices{T<:AbstractVector, V<:AbstractVector{<:Integer}, U<:AbstractUnitRange} - vec::T - perm::V +pool(v::AbstractArray, condition = !isbitstype∘eltype) = condition(v) ? convert(PooledArray, v) : v +pool(v::StructArray, condition = !isbitstype∘eltype) = replace_storage(t -> pool(t, condition), v) + +struct GroupPerm{V<:AbstractVector, P<:AbstractVector{<:Integer}, U<:AbstractUnitRange} + vec::V + perm::P within::U end -TiedIndices(vec::AbstractVector, perm=sortperm(vec)) = - TiedIndices(vec, perm, axes(vec, 1)) - -Base.IteratorSize(::Type{<:TiedIndices}) = Base.SizeUnknown() +GroupPerm(vec, perm=sortperm(vec)) = GroupPerm(vec, perm, axes(vec, 1)) -Base.eltype(::Type{<:TiedIndices{T}}) where {T} = - Pair{eltype(T), UnitRange{Int}} +Base.sortperm(g::GroupPerm) = g.perm -Base.sortperm(t::TiedIndices) = t.perm - -function Base.iterate(n::TiedIndices, i = first(n.within)) - vec, perm = n.vec, n.perm - l = last(n.within) +function Base.iterate(g::GroupPerm, i = first(g.within)) + vec, perm = g.vec, g.perm + l = last(g.within) i > l && return nothing - @inbounds row = vec[perm[i]] + @inbounds pi = perm[i] i1 = i+1 - @inbounds while i1 <= l && isequal(row, vec[perm[i1]]) + @inbounds while i1 <= l && roweq(vec, pi, perm[i1]) i1 += 1 end - return (row => i:(i1-1), i1) + return (i:(i1-1), i1) end -""" -`tiedindices(v, perm=sortperm(v))` - -Given an abstract vector `v` and a permutation vector `perm`, return an iterator -of pairs `val => range` where `range` is a maximal interval such as `v[perm[range]]` -is constant: `val` is the unique value of `v[perm[range]]`. -""" -tiedindices(v, perm=sortperm(v)) = TiedIndices(v, perm) - -""" -`maptiedindices(f, v, perm)` - -Given a function `f`, compute the iterator `tiedindices(v, perm)` and return -in iterable object which yields `f(val, idxs)` where `val => idxs` are the pairs -iterated by `tiedindices(v, perm)`. - -## Examples - -`maptiedindices` is a low level building block that can be used to define grouping -operators. For example: - -```jldoctest -julia> function mygroupby(f, keys, data) - perm = sortperm(keys) - StructArrays.maptiedindices(keys, perm) do key, idxs - key => f(data[perm[idxs]]) - end - end -mygroupby (generic function with 1 method) - -julia> StructArray(mygroupby(sum, [1, 2, 1, 3], [1, 4, 10, 11])) -3-element StructArray{Pair{Int64,Int64},1,NamedTuple{(:first, :second),Tuple{Array{Int64,1},Array{Int64,1}}}}: - 1 => 11 - 2 => 4 - 3 => 11 -``` -""" -function maptiedindices(f, v, perm) - fast_v = optimize_isequal(v) - itr = TiedIndices(fast_v, perm) - (f(recover_original(v, val), idxs) for (val, idxs) in itr) +Base.IteratorSize(::Type{<:GroupPerm}) = Base.SizeUnknown() + +Base.eltype(::Type{<:GroupPerm}) = UnitRange{Int} + +@inline roweq(x::AbstractVector, i, j) = (@inbounds eq=isequal(x[i], x[j]); eq) +@inline roweq(a::PooledArray, i, j) = (@inbounds x=a.refs[i] == a.refs[j]; x) +@generated function roweq(c::StructVector{D,C}, i, j) where {D,C} + N = fieldcount(C) + ex = :(roweq(getfield(fieldarrays(c),1), i, j)) + for n in 2:N + ex = :(($ex) && (roweq(getfield(fieldarrays(c),$n), i, j))) + end + ex end function uniquesorted(keys, perm=sortperm(keys)) - maptiedindices((key, _) -> key, keys, perm) + (keys[perm[idxs[1]]] for idxs in GroupPerm(keys, perm)) end function finduniquesorted(keys, perm=sortperm(keys)) - maptiedindices((key, idxs) -> (key => perm[idxs]), keys, perm) + func = function (idxs) + p_idxs = perm[idxs] + return keys[p_idxs[1]] => p_idxs + end + (func(idxs) for idxs in GroupPerm(keys, perm)) end function Base.sortperm(c::StructVector{T}) where {T<:Union{Tuple, NamedTuple}} @@ -126,7 +88,7 @@ function refine_perm!(p, cols, c, x, y′, lo, hi) order = Perm(Forward, y′) y = something(forward_vec(order), y′) nc = length(cols) - for (_, idxs) in TiedIndices(optimize_isequal(x), p, lo:hi) + for idxs in GroupPerm(x, p, lo:hi) i, i1 = extrema(idxs) if i1 > i sort_sub_by!(p, i, i1, y, order, temp) diff --git a/test/runtests.jl b/test/runtests.jl index cb8d0fa9..bc91d0eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -30,19 +30,19 @@ end @test v_pooled == StructArrays.pool(v) end -@testset "optimize_isequal" begin +@testset "roweq" begin a = ["a", "b", "a", "a"] b = PooledArrays.PooledArray(["x", "y", "z", "x"]) s = StructArray((a, b)) - t = StructArrays.optimize_isequal(s) - @test t[1] != t[2] - @test t[1] != t[3] - @test t[1] == t[4] - @test t[1][2] isa Integer - @test StructArrays.recover_original(s, t[1]) == s[1] - @test StructArrays.recover_original(s, t[2]) == s[2] - @test StructArrays.recover_original(s, t[3]) == s[3] - @test StructArrays.recover_original(s, t[4]) == s[4] + @test StructArrays.roweq(s, 1, 1) + @test !StructArrays.roweq(s, 1, 2) + @test !StructArrays.roweq(s, 1, 3) + @test StructArrays.roweq(s, 1, 4) + strs = WeakRefStrings.StringArray(["a", "a", "b"]) + @test StructArrays.roweq(strs, 1, 1) + @test StructArrays.roweq(strs, 1, 2) + @test !StructArrays.roweq(strs, 1, 3) + @test !StructArrays.roweq(strs, 2, 3) end @testset "namedtuple" begin @@ -95,11 +95,12 @@ end @testset "iterators" begin c = [1, 2, 3, 1, 1] - d = StructArrays.tiedindices(c) - @test eltype(d) == Pair{Int, UnitRange{Int}} + d = StructArrays.GroupPerm(c) + @test eltype(d) == UnitRange{Int} + @test Base.IteratorEltype(d) == Base.HasEltype() + @test sortperm(d) == sortperm(c) s = collect(d) - @test first.(s) == [1, 2, 3] - @test last.(s) == [1:3, 4:4, 5:5] + @test s == [1:3, 4:4, 5:5] t = collect(StructArrays.finduniquesorted(c)) @test first.(t) == [1, 2, 3] @test last.(t) == [[1, 4, 5], [2], [3]]