diff --git a/src/utils.jl b/src/utils.jl index 4ca569b7..4a311e99 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -28,25 +28,40 @@ else const _getproperty = getproperty end -function _foreachfield(names, L) +array_names_types(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = array_names_types(C) +array_names_types(::Type{NamedTuple{names, types}}) where {names, types} = zip(names, types.parameters) +array_names_types(::Type{T}) where {T<:Tuple} = enumerate(T.parameters) + +function apply_f_to_vars_fields(names_types, vars) + exprs = Expr[] + for (name, type) in names_types + sym = QuoteNode(name) + args = [Expr(:call, :_getproperty, var, sym) for var in vars] + expr = if type <: StructArray + apply_f_to_vars_fields(array_names_types(type), args) + else + Expr(:call, :f, args...) + end + push!(exprs, expr) + end + return Expr(:block, exprs...) +end + +function _foreachfield(names_types, L) vars = ntuple(i -> gensym(), L) exprs = Expr[] for (i, v) in enumerate(vars) push!(exprs, Expr(:(=), v, Expr(:call, :getfield, :xs, i))) end - for field in names - sym = QuoteNode(field) - args = [Expr(:call, :_getproperty, var, sym) for var in vars] - push!(exprs, Expr(:call, :f, args...)) - end + push!(exprs, apply_f_to_vars_fields(names_types, vars)) push!(exprs, :(return nothing)) return Expr(:block, exprs...) end -@generated foreachfield_gen(::NamedTuple{names}, f, xs::Vararg{Any, L}) where {names, L} = - _foreachfield(names, L) -@generated foreachfield_gen(::NTuple{N, Any}, f, xs::Vararg{Any, L}) where {N, L} = - _foreachfield(Base.OneTo(N), L) +@generated foreachfield_gen(::NT, f, xs::Vararg{Any, L}) where {NT<:NamedTuple, L} = + _foreachfield(array_names_types(NT), L) +@generated foreachfield_gen(::T, f, xs::Vararg{Any, L}) where {T<:Tuple, L} = + _foreachfield(array_names_types(T), L) foreachfield(f, x::StructArray, xs...) = foreachfield_gen(fieldarrays(x), f, x, xs...)