From 3e2ee9d30e53f09746e4858a0aeb9fc20d4702b1 Mon Sep 17 00:00:00 2001 From: Pietro Vertechi Date: Sun, 17 Mar 2019 12:50:39 +0000 Subject: [PATCH] allow comparison optimizations in TiedIndices (#56) --- REQUIRE | 1 + src/StructArrays.jl | 5 ++--- src/sort.jl | 10 +++++++++- test/REQUIRE | 1 - test/runtests.jl | 12 ++++++++++++ 5 files changed, 24 insertions(+), 5 deletions(-) diff --git a/REQUIRE b/REQUIRE index 87c49a18..9fe3e7e8 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,2 +1,3 @@ julia 0.7 +PooledArrays 0.5 Requires diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 85c95d93..18a8612b 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -1,6 +1,8 @@ module StructArrays import Requires +using PooledArrays: PooledArray + export StructArray, StructVector, LazyRow, LazyRows export collect_structarray, fieldarrays @@ -13,9 +15,6 @@ include("lazy.jl") function __init__() Requires.@require Tables="bd369af6-aec1-5ad0-b16a-f7cc5008161c" include("tables.jl") - Requires.@require PooledArrays="2dfb63ee-cc39-5dd5-95bd-886bf059d720" begin - fastpermute!(v::PooledArrays.PooledArray, p::AbstractVector) = permute!(v, p) - end Requires.@require WeakRefStrings="ea10d353-3f73-51f8-a26c-33c1cb351aa5" begin fastpermute!(v::WeakRefStrings.StringArray, p::AbstractVector) = permute!(v, p) end diff --git a/src/sort.jl b/src/sort.jl index b39007ac..4def9362 100644 --- a/src/sort.jl +++ b/src/sort.jl @@ -2,6 +2,14 @@ using Base.Sort, Base.Order fastpermute!(v::AbstractArray, p::AbstractVector) = copyto!(v, v[p]) fastpermute!(v::StructArray, p::AbstractVector) = permute!(v, p) +fastpermute!(v::PooledArray, p::AbstractVector) = permute!(v, p) + +optimize_isequal(v::AbstractArray) = v +optimize_isequal(v::PooledArray) = v.refs +optimize_isequal(v::StructArray{<:Union{Tuple, NamedTuple}}) = StructArray(map(optimize_isequal, fieldarrays(v))) + +pool(v::AbstractArray, condition = !isbitstype∘eltype) = condition(v) ? convert(PooledArray, v) : v +pool(v::StructArray, condition = !isbitstype∘eltype) = replace_storage(t -> pool(t, condition), v) function Base.permute!(c::StructArray, p::AbstractVector) foreachfield(v -> fastpermute!(v, p), c) @@ -76,7 +84,7 @@ function refine_perm!(p, cols, c, x, y′, lo, hi) order = Perm(Forward, y′) y = something(forward_vec(order), y′) nc = length(cols) - for (_, idxs) in TiedIndices(x, p, lo:hi) + for (_, idxs) in TiedIndices(optimize_isequal(x), p, lo:hi) i, i1 = extrema(idxs) if i1 > i sort_sub_by!(p, i, i1, y, order, temp) diff --git a/test/REQUIRE b/test/REQUIRE index 86384542..49277b0c 100644 --- a/test/REQUIRE +++ b/test/REQUIRE @@ -1,3 +1,2 @@ Tables -PooledArrays WeakRefStrings diff --git a/test/runtests.jl b/test/runtests.jl index ef7c3d03..9791d968 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,18 @@ end @test all(v.b .== v_pooled.b) @test !isa(v_pooled.a, PooledArrays.PooledArray) @test isa(v_pooled.b, PooledArrays.PooledArray) + @test v_pooled == StructArrays.pool(v) +end + +@testset "optimize_isequal" begin + a = ["a", "b", "a", "a"] + b = PooledArrays.PooledArray(["x", "y", "z", "x"]) + s = StructArray((a, b)) + t = StructArrays.optimize_isequal(s) + @test t[1] != t[2] + @test t[1] != t[3] + @test t[1] == t[4] + @test t[1][2] isa Integer end @testset "namedtuple" begin