diff --git a/src/sort.jl b/src/sort.jl index 4def9362..9ff99a65 100644 --- a/src/sort.jl +++ b/src/sort.jl @@ -8,6 +8,10 @@ optimize_isequal(v::AbstractArray) = v optimize_isequal(v::PooledArray) = v.refs optimize_isequal(v::StructArray{<:Union{Tuple, NamedTuple}}) = StructArray(map(optimize_isequal, fieldarrays(v))) +recover_original(v::AbstractArray, el) = el +recover_original(v::PooledArray, el) = v.pool[el] +recover_original(v::StructArray{T}, el) where {T<:Union{Tuple, NamedTuple}} = T(map(recover_original, fieldarrays(v), el)) + pool(v::AbstractArray, condition = !isbitstype∘eltype) = condition(v) ? convert(PooledArray, v) : v pool(v::StructArray, condition = !isbitstype∘eltype) = replace_storage(t -> pool(t, condition), v) @@ -16,9 +20,9 @@ function Base.permute!(c::StructArray, p::AbstractVector) return c end -struct TiedIndices{T<:AbstractVector, I<:Integer, U<:AbstractUnitRange} +struct TiedIndices{T<:AbstractVector, V<:AbstractVector{<:Integer}, U<:AbstractUnitRange} vec::T - perm::Vector{I} + perm::V within::U end @@ -44,17 +48,55 @@ function Base.iterate(n::TiedIndices, i = first(n.within)) return (row => i:(i1-1), i1) end -tiedindices(args...) = TiedIndices(args...) +""" +`tiedindices(v, perm=sortperm(v))` + +Given an abstract vector `v` and a permutation vector `perm`, return an iterator +of pairs `val => range` where `range` is a maximal interval such as `v[perm[range]]` +is constant: `val` is the unique value of `v[perm[range]]`. +""" +tiedindices(v, perm=sortperm(v)) = TiedIndices(v, perm) + +""" +`maptiedindices(f, v, perm)` + +Given a function `f`, compute the iterator `tiedindices(v, perm)` and return +in iterable object which yields `f(val, idxs)` where `val => idxs` are the pairs +iterated by `tiedindices(v, perm)`. + +## Examples + +`maptiedindices` is a low level building block that can be used to define grouping +operators. For example: + +```jldoctest +julia> function mygroupby(f, keys, data) + perm = sortperm(keys) + StructArrays.maptiedindices(keys, perm) do key, idxs + key => f(data[perm[idxs]]) + end + end +mygroupby (generic function with 1 method) + +julia> StructArray(mygroupby(sum, [1, 2, 1, 3], [1, 4, 10, 11])) +3-element StructArray{Pair{Int64,Int64},1,NamedTuple{(:first, :second),Tuple{Array{Int64,1},Array{Int64,1}}}}: + 1 => 11 + 2 => 4 + 3 => 11 +``` +""" +function maptiedindices(f, v, perm) + fast_v = optimize_isequal(v) + itr = TiedIndices(fast_v, perm) + (f(recover_original(v, val), idxs) for (val, idxs) in itr) +end -function uniquesorted(args...) - t = tiedindices(args...) - (row for (row, _) in t) +function uniquesorted(keys, perm=sortperm(keys)) + maptiedindices((key, _) -> key, keys, perm) end -function finduniquesorted(args...) - t = tiedindices(args...) - p = sortperm(t) - (row => p[idxs] for (row, idxs) in t) +function finduniquesorted(keys, perm=sortperm(keys)) + maptiedindices((key, idxs) -> (key => perm[idxs]), keys, perm) end function Base.sortperm(c::StructVector{T}) where {T<:Union{Tuple, NamedTuple}} @@ -148,4 +190,3 @@ function sort_int_range_sub_by!(x, ioffs, n, by, rangelen, minval, temp) end x end - diff --git a/test/runtests.jl b/test/runtests.jl index 9791d968..cb8d0fa9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,6 +39,10 @@ end @test t[1] != t[3] @test t[1] == t[4] @test t[1][2] isa Integer + @test StructArrays.recover_original(s, t[1]) == s[1] + @test StructArrays.recover_original(s, t[2]) == s[2] + @test StructArrays.recover_original(s, t[3]) == s[3] + @test StructArrays.recover_original(s, t[4]) == s[4] end @testset "namedtuple" begin