Skip to content

Commit

Permalink
index_type cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
piever committed Dec 6, 2023
1 parent 15b044b commit dfcbd9d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
16 changes: 4 additions & 12 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,13 @@ struct StructArray{T, N, C<:Tup, I} <: AbstractArray{T, N}
ax = findconsistentvalue(axes, c)
(ax === nothing) && throw(ArgumentError("all component arrays must have the same shape"))
length(ax) == N || throw(ArgumentError("wrong number of dimensions"))
new{T, N, C, index_type(c)}(c)
# Compute optimal type to use for indexing as a function of components
I = IndexStyle(c...) isa IndexLinear ? Int : CartesianIndex{N}
return new{T, N, C, I}(c)
end
end

# compute optimal type to use for indexing as a function of components
index_type(components::NamedTuple) = index_type(values(components))
index_type(::Tuple{}) = Int
function index_type(components::Tuple)
f, ls = first(components), tail(components)
return IndexStyle(f) isa IndexCartesian ? CartesianIndex{ndims(f)} : index_type(ls)
end
# Only check first component if the all the component types match
index_type(components::NTuple) = invoke(index_type, Tuple{Tuple}, (first(components),))
# Return the index type parameter as a function of the StructArray type or instance
index_type(s::StructArray) = index_type(typeof(s))
# Return the index type parameter as a function of the StructArray type
index_type(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = I

array_types(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = array_types(C)
Expand Down
28 changes: 18 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ end
s = StructArray(a=rand(10,10), b=view(rand(100,100), 1:10, 1:10))
T = typeof(s)
@test IndexStyle(T) === IndexCartesian()
@test StructArrays.index_type(s) == CartesianIndex{2}
@test StructArrays.index_type(T) == CartesianIndex{2}
@test s[100] == s[10, 10] == (a=s.a[10,10], b=s.b[10,10])
s[100] = (a=1, b=1)
@test s[100] == s[10, 10] == (a=1, b=1)
Expand All @@ -131,24 +131,32 @@ end
@inferred IndexStyle(StructArray(a=rand(10,10), b=rand(10,10)))
s = StructArray(a=rand(10,10), b=rand(10,10))
T = typeof(s)
@test StructArrays.index_type(s) == Int
@test StructArrays.index_type(T) == Int
@inferred IndexStyle(s)
@test s[100] == s[10, 10] == (a=s.a[10,10], b=s.b[10,10])
s[100] = (a=1, b=1)
@test s[100] == s[10, 10] == (a=1, b=1)
s[10, 10] = (a=0, b=0)
@test s[100] == s[10, 10] == (a=0, b=0)

# inference for "many" types, both for linear ad Cartesian indexing
@inferred StructArrays.index_type(ntuple(_ -> rand(5), 2))
@inferred StructArrays.index_type(ntuple(_ -> rand(5, 5), 3))
@inferred StructArrays.index_type(ntuple(_ -> rand(5, 5, 5), 4))
# inference for "many" types, both for linear and Cartesian indexing
s = @inferred StructArray(ntuple(_ -> rand(5), 2))
@test StructArrays.index_type(typeof(s)) === Int
s = @inferred StructArray(ntuple(_ -> rand(5, 5), 3))
@test StructArrays.index_type(typeof(s)) === Int
s = @inferred StructArray(ntuple(_ -> rand(5, 5, 5), 4))
@test StructArrays.index_type(typeof(s)) === Int

@inferred StructArrays.index_type(ntuple(_ -> view(rand(5), 1:3), 2))
@inferred StructArrays.index_type(ntuple(_ -> view(rand(5, 5), 1:3, 1:2), 3))
@inferred StructArrays.index_type(ntuple(_ -> view(rand(5, 5, 5), 1:3, 1:2, 1:4), 4))
s = @inferred StructArray(ntuple(_ -> view(rand(5), 1:3), 2))
@test StructArrays.index_type(typeof(s)) === Int
s = @inferred StructArray(ntuple(_ -> view(rand(5, 5), 1:3, 1:2), 3))
@test StructArrays.index_type(typeof(s)) === CartesianIndex{2}
s = @inferred StructArray(ntuple(_ -> view(rand(5, 5, 5), 1:3, 1:2, 1:4), 4))
@test StructArrays.index_type(typeof(s)) === CartesianIndex{3}

s = @inferred StructArray(ntuple(n -> n == 1 ? rand(2, 3) : view(rand(5, 5), 1:2, 1:3), 5))
@test StructArrays.index_type(typeof(s)) === CartesianIndex{2}

@inferred StructArrays.index_type(ntuple(n -> n == 1 ? rand(5, 5) : view(rand(5, 5), 1:2, 1:3), 5))
@inferred IndexStyle(StructArray(a=rand(10,10), b=view(rand(100,100), 1:10, 1:10)))
end

Expand Down

0 comments on commit dfcbd9d

Please sign in to comment.