Skip to content

Commit

Permalink
add maptiedindices (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
piever authored Mar 18, 2019
1 parent 3e2ee9d commit dc2a29b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 11 deletions.
63 changes: 52 additions & 11 deletions src/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ optimize_isequal(v::AbstractArray) = v
optimize_isequal(v::PooledArray) = v.refs
optimize_isequal(v::StructArray{<:Union{Tuple, NamedTuple}}) = StructArray(map(optimize_isequal, fieldarrays(v)))

recover_original(v::AbstractArray, el) = el
recover_original(v::PooledArray, el) = v.pool[el]
recover_original(v::StructArray{T}, el) where {T<:Union{Tuple, NamedTuple}} = T(map(recover_original, fieldarrays(v), el))

pool(v::AbstractArray, condition = !isbitstypeeltype) = condition(v) ? convert(PooledArray, v) : v
pool(v::StructArray, condition = !isbitstypeeltype) = replace_storage(t -> pool(t, condition), v)

Expand All @@ -16,9 +20,9 @@ function Base.permute!(c::StructArray, p::AbstractVector)
return c
end

struct TiedIndices{T<:AbstractVector, I<:Integer, U<:AbstractUnitRange}
struct TiedIndices{T<:AbstractVector, V<:AbstractVector{<:Integer}, U<:AbstractUnitRange}
vec::T
perm::Vector{I}
perm::V
within::U
end

Expand All @@ -44,17 +48,55 @@ function Base.iterate(n::TiedIndices, i = first(n.within))
return (row => i:(i1-1), i1)
end

tiedindices(args...) = TiedIndices(args...)
"""
`tiedindices(v, perm=sortperm(v))`
Given an abstract vector `v` and a permutation vector `perm`, return an iterator
of pairs `val => range` where `range` is a maximal interval such as `v[perm[range]]`
is constant: `val` is the unique value of `v[perm[range]]`.
"""
tiedindices(v, perm=sortperm(v)) = TiedIndices(v, perm)

"""
`maptiedindices(f, v, perm)`
Given a function `f`, compute the iterator `tiedindices(v, perm)` and return
in iterable object which yields `f(val, idxs)` where `val => idxs` are the pairs
iterated by `tiedindices(v, perm)`.
## Examples
`maptiedindices` is a low level building block that can be used to define grouping
operators. For example:
```jldoctest
julia> function mygroupby(f, keys, data)
perm = sortperm(keys)
StructArrays.maptiedindices(keys, perm) do key, idxs
key => f(data[perm[idxs]])
end
end
mygroupby (generic function with 1 method)
julia> StructArray(mygroupby(sum, [1, 2, 1, 3], [1, 4, 10, 11]))
3-element StructArray{Pair{Int64,Int64},1,NamedTuple{(:first, :second),Tuple{Array{Int64,1},Array{Int64,1}}}}:
1 => 11
2 => 4
3 => 11
```
"""
function maptiedindices(f, v, perm)
fast_v = optimize_isequal(v)
itr = TiedIndices(fast_v, perm)
(f(recover_original(v, val), idxs) for (val, idxs) in itr)
end

function uniquesorted(args...)
t = tiedindices(args...)
(row for (row, _) in t)
function uniquesorted(keys, perm=sortperm(keys))
maptiedindices((key, _) -> key, keys, perm)
end

function finduniquesorted(args...)
t = tiedindices(args...)
p = sortperm(t)
(row => p[idxs] for (row, idxs) in t)
function finduniquesorted(keys, perm=sortperm(keys))
maptiedindices((key, idxs) -> (key => perm[idxs]), keys, perm)
end

function Base.sortperm(c::StructVector{T}) where {T<:Union{Tuple, NamedTuple}}
Expand Down Expand Up @@ -148,4 +190,3 @@ function sort_int_range_sub_by!(x, ioffs, n, by, rangelen, minval, temp)
end
x
end

4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ end
@test t[1] != t[3]
@test t[1] == t[4]
@test t[1][2] isa Integer
@test StructArrays.recover_original(s, t[1]) == s[1]
@test StructArrays.recover_original(s, t[2]) == s[2]
@test StructArrays.recover_original(s, t[3]) == s[3]
@test StructArrays.recover_original(s, t[4]) == s[4]
end

@testset "namedtuple" begin
Expand Down

0 comments on commit dc2a29b

Please sign in to comment.