diff --git a/src/structarray.jl b/src/structarray.jl index 121a5ab7..230bec1c 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -21,28 +21,23 @@ struct StructArray{T, N, C<:Tup, I} <: AbstractArray{T, N} axes(c[i]) == ax || error("all field arrays must have same shape") end end - new{T, N, C, index_type(C)}(c) + new{T, N, C, index_type(c)}(c) end end -# compute optimal type to use for indexing as a function of component types -index_type(::Type{NamedTuple{names, types}}) where {names, types} = index_type(types) -index_type(::Type{Tuple{}}) = Int -function index_type(::Type{T}) where {T<:Tuple} - S, U = tuple_type_head(T), tuple_type_tail(T) - return _index_type(S, U) +# compute optimal type to use for indexing as a function of components +index_type(components::NamedTuple) = index_type(values(components)) +index_type(::Tuple{}) = Int +function index_type(components::Tuple) + f, ls = first(components), tail(components) + return IndexStyle(f) isa IndexCartesian ? CartesianIndex{ndims(f)} : index_type(ls) end -# Julia v1.7.0-beta3 doesn't seem to specialize `index_type` as defined above -# for tuple types with "many" elements (three or four, depending on the concrete -# types). However, we can help the compiler for homogeneous types by defining -# the specialization below. -index_type(::Type{NTuple{N, S}}) where {N, S} = _index_type(S) +# Only check first component if the all the component types match +index_type(components::NTuple) = invoke(index_type, Tuple{Tuple}, (first(components),)) +# Return the index type parameter as a function of the StructArray type or instance +index_type(s::StructArray) = index_type(typeof(s)) index_type(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = I -function _index_type(::Type{S}, ::Type{U}=Tuple{}) where {S, U} - return IndexStyle(S) isa IndexCartesian ? CartesianIndex{ndims(S)} : index_type(U) -end - array_types(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = array_types(C) array_types(::Type{NamedTuple{names, types}}) where {names, types} = types array_types(::Type{TT}) where {TT<:Tuple} = TT @@ -168,7 +163,6 @@ function buildfromslices(::Type{T}, unwrap::F, slices) where {T,F} end end - function Base.IndexStyle(::Type{S}) where {S<:StructArray} index_type(S) === Int ? IndexLinear() : IndexCartesian() end diff --git a/test/runtests.jl b/test/runtests.jl index f8137ad5..70ac927f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,11 +50,10 @@ end end @testset "indexstyle" begin - @inferred IndexStyle(StructArray(a=rand(10,10), b=view(rand(100,100), 1:10, 1:10))) s = StructArray(a=rand(10,10), b=view(rand(100,100), 1:10, 1:10)) T = typeof(s) @test IndexStyle(T) === IndexCartesian() - @test StructArrays.index_type(T) == CartesianIndex{2} + @test StructArrays.index_type(s) == CartesianIndex{2} @test s[100] == s[10, 10] == (a=s.a[10,10], b=s.b[10,10]) s[100] = (a=1, b=1) @test s[100] == s[10, 10] == (a=1, b=1) @@ -63,8 +62,8 @@ end @inferred IndexStyle(StructArray(a=rand(10,10), b=rand(10,10))) s = StructArray(a=rand(10,10), b=rand(10,10)) T = typeof(s) - @test IndexStyle(T) === IndexLinear() - @test StructArrays.index_type(T) == Int + @test StructArrays.index_type(s) == Int + @inferred IndexStyle(s) @test s[100] == s[10, 10] == (a=s.a[10,10], b=s.b[10,10]) s[100] = (a=1, b=1) @test s[100] == s[10, 10] == (a=1, b=1) @@ -72,13 +71,16 @@ end @test s[100] == s[10, 10] == (a=0, b=0) # inference for "many" types, both for linear ad Cartesian indexing - @inferred StructArrays.index_type(NTuple{2, Vector{Float64}}) - @inferred StructArrays.index_type(NTuple{3, Matrix{Float64}}) - @inferred StructArrays.index_type(NTuple{4, Array{Float64, 3}}) + @inferred StructArrays.index_type(ntuple(_ -> rand(5), 2)) + @inferred StructArrays.index_type(ntuple(_ -> rand(5, 5), 3)) + @inferred StructArrays.index_type(ntuple(_ -> rand(5, 5, 5), 4)) + + @inferred StructArrays.index_type(ntuple(_ -> view(rand(5), 1:3), 2)) + @inferred StructArrays.index_type(ntuple(_ -> view(rand(5, 5), 1:3, 1:2), 3)) + @inferred StructArrays.index_type(ntuple(_ -> view(rand(5, 5, 5), 1:3, 1:2, 1:4), 4)) - @inferred StructArrays.index_type(NTuple{2, SubArray{Float64, 1, Array{Float64, 2}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}) - @inferred StructArrays.index_type(NTuple{3, SubArray{Float64, 1, Array{Float64, 2}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}) - @inferred StructArrays.index_type(NTuple{4, SubArray{Float64, 1, Array{Float64, 2}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}) + @inferred StructArrays.index_type(ntuple(n -> n == 1 ? rand(5, 5) : view(rand(5, 5), 1:2, 1:3), 5)) + @inferred IndexStyle(StructArray(a=rand(10,10), b=view(rand(100,100), 1:10, 1:10))) end @testset "replace_storage" begin