From dfcbd9d21461837dc1caa493f1f4e26a314f1c3b Mon Sep 17 00:00:00 2001 From: piever Date: Thu, 30 Nov 2023 16:21:02 +0100 Subject: [PATCH] index_type cleanup --- src/structarray.jl | 16 ++++------------ test/runtests.jl | 28 ++++++++++++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/structarray.jl b/src/structarray.jl index c8231bc..f901fb9 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -18,21 +18,13 @@ struct StructArray{T, N, C<:Tup, I} <: AbstractArray{T, N} ax = findconsistentvalue(axes, c) (ax === nothing) && throw(ArgumentError("all component arrays must have the same shape")) length(ax) == N || throw(ArgumentError("wrong number of dimensions")) - new{T, N, C, index_type(c)}(c) + # Compute optimal type to use for indexing as a function of components + I = IndexStyle(c...) isa IndexLinear ? Int : CartesianIndex{N} + return new{T, N, C, I}(c) end end -# compute optimal type to use for indexing as a function of components -index_type(components::NamedTuple) = index_type(values(components)) -index_type(::Tuple{}) = Int -function index_type(components::Tuple) - f, ls = first(components), tail(components) - return IndexStyle(f) isa IndexCartesian ? CartesianIndex{ndims(f)} : index_type(ls) -end -# Only check first component if the all the component types match -index_type(components::NTuple) = invoke(index_type, Tuple{Tuple}, (first(components),)) -# Return the index type parameter as a function of the StructArray type or instance -index_type(s::StructArray) = index_type(typeof(s)) +# Return the index type parameter as a function of the StructArray type index_type(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = I array_types(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = array_types(C) diff --git a/test/runtests.jl b/test/runtests.jl index 611dfc6..8d74747 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -122,7 +122,7 @@ end s = StructArray(a=rand(10,10), b=view(rand(100,100), 1:10, 1:10)) T = typeof(s) @test IndexStyle(T) === IndexCartesian() - @test StructArrays.index_type(s) == CartesianIndex{2} + @test StructArrays.index_type(T) == CartesianIndex{2} @test s[100] == s[10, 10] == (a=s.a[10,10], b=s.b[10,10]) s[100] = (a=1, b=1) @test s[100] == s[10, 10] == (a=1, b=1) @@ -131,7 +131,7 @@ end @inferred IndexStyle(StructArray(a=rand(10,10), b=rand(10,10))) s = StructArray(a=rand(10,10), b=rand(10,10)) T = typeof(s) - @test StructArrays.index_type(s) == Int + @test StructArrays.index_type(T) == Int @inferred IndexStyle(s) @test s[100] == s[10, 10] == (a=s.a[10,10], b=s.b[10,10]) s[100] = (a=1, b=1) @@ -139,16 +139,24 @@ end s[10, 10] = (a=0, b=0) @test s[100] == s[10, 10] == (a=0, b=0) - # inference for "many" types, both for linear ad Cartesian indexing - @inferred StructArrays.index_type(ntuple(_ -> rand(5), 2)) - @inferred StructArrays.index_type(ntuple(_ -> rand(5, 5), 3)) - @inferred StructArrays.index_type(ntuple(_ -> rand(5, 5, 5), 4)) + # inference for "many" types, both for linear and Cartesian indexing + s = @inferred StructArray(ntuple(_ -> rand(5), 2)) + @test StructArrays.index_type(typeof(s)) === Int + s = @inferred StructArray(ntuple(_ -> rand(5, 5), 3)) + @test StructArrays.index_type(typeof(s)) === Int + s = @inferred StructArray(ntuple(_ -> rand(5, 5, 5), 4)) + @test StructArrays.index_type(typeof(s)) === Int - @inferred StructArrays.index_type(ntuple(_ -> view(rand(5), 1:3), 2)) - @inferred StructArrays.index_type(ntuple(_ -> view(rand(5, 5), 1:3, 1:2), 3)) - @inferred StructArrays.index_type(ntuple(_ -> view(rand(5, 5, 5), 1:3, 1:2, 1:4), 4)) + s = @inferred StructArray(ntuple(_ -> view(rand(5), 1:3), 2)) + @test StructArrays.index_type(typeof(s)) === Int + s = @inferred StructArray(ntuple(_ -> view(rand(5, 5), 1:3, 1:2), 3)) + @test StructArrays.index_type(typeof(s)) === CartesianIndex{2} + s = @inferred StructArray(ntuple(_ -> view(rand(5, 5, 5), 1:3, 1:2, 1:4), 4)) + @test StructArrays.index_type(typeof(s)) === CartesianIndex{3} + + s = @inferred StructArray(ntuple(n -> n == 1 ? rand(2, 3) : view(rand(5, 5), 1:2, 1:3), 5)) + @test StructArrays.index_type(typeof(s)) === CartesianIndex{2} - @inferred StructArrays.index_type(ntuple(n -> n == 1 ? rand(5, 5) : view(rand(5, 5), 1:2, 1:3), 5)) @inferred IndexStyle(StructArray(a=rand(10,10), b=view(rand(100,100), 1:10, 1:10))) end