Skip to content

Commit

Permalink
Add unrolled functions and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Mar 28, 2024
1 parent 78c2be0 commit 9a9bdad
Show file tree
Hide file tree
Showing 3 changed files with 369 additions and 1 deletion.
143 changes: 142 additions & 1 deletion src/UnrolledUtilities.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,146 @@
"""
UnrolledUtilities
A collection of generated functions in which all loops are unrolled.
The functions exported by this module are
- `unrolled_foreach(f, itrs...)`: similar to `foreach`
- `unrolled_any(f, itrs...)`: similar to `any`
- `unrolled_all(f, itrs...)`: similar to `all`
- `unrolled_map(f, itrs...)`: similar to `map`
- `unrolled_reduce(op, itr; [init])`: similar to `reduce`
- `unrolled_mapreduce(f, op, itrs...; [init])`: similar to `mapreduce`
- `unrolled_zip(itrs...)`: similar to `zip`
- `unrolled_in(item, itr)`: similar to `in`
- `unrolled_unique(itr)`: similar to `unique`
- `unrolled_filter(f, itr)`: similar to `filter`
- `unrolled_split(f, itr)`: similar to `(filter(f, itr), filter(!f, itr))`, but
without duplicate calls to `f`
- `unrolled_flatten(itr)`: similar to `Iterators.flatten`
- `unrolled_flatmap(f, itrs...)`: similar to `Iterators.flatmap`
- `unrolled_product(itrs...)`: similar to `Iterators.product`
- `unrolled_take(itr, ::Val{N})`: similar to `Iterators.take`, but with the
second argument wrapped in a `Val`
- `unrolled_drop(itr, ::Val{N})`: similar to `Iterators.drop`, but with the
second argument wrapped in a `Val`
These functions are guaranteed to be type-stable whenever they are given
iterators with inferrable lengths and element types, including when
- the iterators have nonuniform element types (with the exception of `map`, all
of the corresponding functions from `Base` encounter type-instabilities and
allocations when this is the case)
- the iterators have many elements (e.g., more than 32, which is the threshold
at which `map` becomes type-unstable for `Tuple`s)
- `f` and/or `op` recursively call the function to which they is passed, with an
arbitrarily large recursion depth (e.g., if `f` calls `map(f, itrs)`, it will
be type-unstable when the recursion depth exceeds 3, but this will not be the
case with `unrolled_map`)
"""
module UnrolledUtilities

# TODO: Add source code.
export unrolled_foreach,
unrolled_any,
unrolled_all,
unrolled_map,
unrolled_reduce,
unrolled_mapreduce,
unrolled_zip,
unrolled_in,
unrolled_unique,
unrolled_filter,
unrolled_split,
unrolled_flatten,
unrolled_flatmap,
unrolled_product,
unrolled_take,
unrolled_drop

# TODO: Add support for iterators that are not Tuples, e.g., StaticArrays. This
# will require adding new methods for unrolled_map, unrolled_filter, etc.
inferred_length(itr_type::Type{<:Tuple}) = length(itr_type.types)

function zipped_f_exprs(itr_types)
L = length(itr_types)
L == 0 && error("unrolled functions need at least one iterator as input")
N = minimum(inferred_length, itr_types)
return (:(f($((:(itrs[$l][$n]) for l in 1:L)...))) for n in 1:N)
end

function nested_op_expr(itr_type)
N = inferred_length(itr_type)
N == 0 && error("unrolled_reduce needs an `init` value for empty iterators")
item_exprs = (:(itr[$n]) for n in 1:N)
return reduce((expr1, expr2) -> :(op($expr1, $expr2)), item_exprs)
end

@generated unrolled_foreach(f, itrs...) = Expr(:block, zipped_f_exprs(itrs)...)
@generated unrolled_any(f, itrs...) = Expr(:||, zipped_f_exprs(itrs)...)
@generated unrolled_all(f, itrs...) = Expr(:&&, zipped_f_exprs(itrs)...)
@generated unrolled_map(f, itrs...) = Expr(:tuple, zipped_f_exprs(itrs)...)

struct NoInit end
@generated unrolled_reduce_without_kwarg(op, itr) = nested_op_expr(itr)
unrolled_reduce(op, itr; init = NoInit()) =
unrolled_reduce_without_kwarg(op, init == NoInit() ? itr : (init, itr...))

unrolled_mapreduce(f, op, itrs...; kwarg...) =
unrolled_reduce(op, unrolled_map(f, itrs...); kwarg...)

unrolled_zip(itrs...) = unrolled_map(tuple, itrs...)

unrolled_in(item, itr) = unrolled_any(Base.Fix1(===, item), itr)
# Note: Using === instead of == or isequal seems to improve type stability.

unrolled_unique(itr) =
unrolled_reduce(itr; init = ()) do unique_items, item
unrolled_in(item, unique_items) ? unique_items : (unique_items..., item)
end

unrolled_filter(f, itr) =
unrolled_reduce(itr; init = ()) do filtered_items, item
f(item) ? (filtered_items..., item) : filtered_items
end

unrolled_split(f, itr) =
unrolled_reduce(itr; init = ((), ())) do (f_items, not_f_items), item
f(item) ? ((f_items..., item), not_f_items) :
(f_items, (not_f_items..., item))
end

unrolled_flatten(itr) =
unrolled_reduce((item1, item2) -> (item1..., item2...), itr; init = ())

unrolled_flatmap(f, itrs...) = unrolled_flatten(unrolled_map(f, itrs...))

unrolled_product(itrs...) =
unrolled_reduce(itrs; init = ((),)) do product_itr, itr
unrolled_flatmap(itr) do item
unrolled_map(product_tuple -> (product_tuple..., item), product_itr)
end
end

# Note: ntuple is unrolled via Base.@ntuple when its second argument is a Val.
unrolled_take(itr, ::Val{N}) where {N} = ntuple(i -> itr[i], Val(N))
unrolled_drop(itr, ::Val{N}) where {N} =
ntuple(i -> itr[N + i], Val(length(itr) - N))

# Drop the recursion limit for functions that take other functions as arguments.
@static if hasfield(Method, :recursion_relation)
const functions_without_recursion_limit = (
unrolled_foreach,
unrolled_any,
unrolled_all,
unrolled_map,
unrolled_reduce_without_kwarg,
unrolled_reduce,
unrolled_mapreduce,
unrolled_filter,
unrolled_split,
unrolled_flatmap,
)
for func in functions_without_recursion_limit, method in methods(func)
method.recursion_relation = (_...) -> true
end
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using SafeTestsets

#! format: off
@safetestset "Unit Tests" begin @time include("unit_tests.jl") end
@safetestset "Aqua" begin @time include("aqua.jl") end
#! format: on
226 changes: 226 additions & 0 deletions test/unit_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
using Test, JET, UnrolledUtilities

function code_instance(f, args...)
available_methods = methods(f, Tuple{map(typeof, args)...})
@assert length(available_methods) == 1
(; specializations) = available_methods[1]
specTypes = Tuple{typeof(f), map(typeof, args)...}
return if specializations isa Core.MethodInstance
@assert specializations.specTypes == specTypes
specializations.cache
else
matching_specialization_indices =
findall(specializations) do specialization
!isnothing(specialization) &&
specialization.specTypes == specTypes
end
@assert length(matching_specialization_indices) == 1
specializations[matching_specialization_indices[1]].cache
end
end

macro test_unrolled(args_expr, unrolled_expr, reference_expr, args_str_expr)
@assert Meta.isexpr(args_expr, :tuple)
esc_args = map(esc, args_expr.args)
esc_args_str = :(lpad($(esc(args_str_expr)), 69))
unpadded_reference_str =
replace(string(reference_expr), r"\s*#=.+=#" => "", r"\s+" => ' ')
reference_str = rpad(unpadded_reference_str, 86)
quote
unrolled_func($(args_expr.args...)) = $unrolled_expr
reference_func($(args_expr.args...)) = $reference_expr

# Record the compilation times for later. If both of these functions
# have constant outputs, their runtimes will essentially be 0.
unrolled_time = @elapsed unrolled_func($(esc_args...))
reference_time = @elapsed reference_func($(esc_args...))

# Test for correctness.
@test unrolled_func($(esc_args...)) == reference_func($(esc_args...))

unrolled_func_with_nothing($(args_expr.args...)) =
(unrolled_func($(args_expr.args...)); nothing)
reference_func_with_nothing($(args_expr.args...)) =
(reference_func($(args_expr.args...)); nothing)

unrolled_func_with_nothing($(esc_args...)) # Run once to compile.
reference_func_with_nothing($(esc_args...))

# Test for allocations.
@test (@allocated unrolled_func_with_nothing($(esc_args...))) == 0
is_reference_non_allocating =
(@allocated reference_func_with_nothing($(esc_args...))) == 0

# Test for type-stability.
@test_opt unrolled_func($(esc_args...))
is_reference_stable =
isempty(JET.get_reports(@report_opt reference_func($(esc_args...))))

unrolled_instance = code_instance(unrolled_func, $(esc_args...))
reference_instance = code_instance(reference_func, $(esc_args...))

# Test for constant propagation.
@test isdefined(unrolled_instance, :rettype_const)
is_reference_const = isdefined(reference_instance, :rettype_const)

# TODO: Print this information in a table.
if !is_reference_non_allocating
@info "for $($esc_args_str), $($reference_str) is allocating"
elseif !is_reference_stable
@info "for $($esc_args_str), $($reference_str) is type-unstable"
elseif !is_reference_const
@info "for $($esc_args_str), $($reference_str) is not constant"
else
ratio = round(reference_time / unrolled_time; sigdigits = 3)
@info "for $($esc_args_str), $($reference_str) takes $ratio times \
as long to compile"
end
end
end

for n in (1, 10, 33), is_uniform in (n == 1 ? (true,) : (true, false))
itr1 = ntuple(i -> is_uniform ? () : ntuple(Val, (i - 1) % 7), n)
itr2 = ntuple(i -> is_uniform ? ((),) : ntuple(Val, (i - 1) % 7 + 1), n)
if n == 1
str1 = "a tuple of 1 (empty) singleton value"
str2 = "a tuple of 1 (nonempty) singleton value"
full_str = "tuples of 1 singleton value"
else
type_str = (is_uniform ? "" : "non-") * "uniformly typed"
str1 = "a tuple of $n (empty & nonempty) $type_str singleton values"
str2 = "a tuple of $n (nonempty) $type_str singleton values"
full_str = "tuples of $n $type_str singleton values"
end
@testset "$full_str" begin
for (itr, str) in ((itr1, str1), (itr2, str2))
@test_unrolled(
(itr,),
unrolled_foreach(item -> (@assert length(item) <= 7), itr),
foreach(item -> (@assert length(item) <= 7), itr),
str,
)

@test_unrolled (itr,) unrolled_any(isempty, itr) any(isempty, itr) str
@test_unrolled (itr,) unrolled_any(!isempty, itr) any(!isempty, itr) str

@test_unrolled (itr,) unrolled_all(isempty, itr) all(isempty, itr) str
@test_unrolled (itr,) unrolled_all(!isempty, itr) all(!isempty, itr) str

@test_unrolled (itr,) unrolled_map(length, itr) map(length, itr) str

@test_unrolled (itr,) unrolled_reduce(tuple, itr) reduce(tuple, itr) str
@test_unrolled(
(itr,),
unrolled_reduce(tuple, itr; init = ()),
reduce(tuple, itr; init = ()),
str,
)

@test_unrolled(
(itr,),
unrolled_mapreduce(length, +, itr),
mapreduce(length, +, itr),
str,
)
@test_unrolled(
(itr,),
unrolled_mapreduce(length, +, itr; init = false),
mapreduce(length, +, itr; init = false),
str,
)

@test_unrolled (itr,) unrolled_zip(itr) Tuple(zip(itr)) str

@test_unrolled (itr,) unrolled_in(nothing, itr) (nothing in itr) str
@test_unrolled (itr,) unrolled_in(itr[1], itr) (itr[1] in itr) str
@test_unrolled (itr,) unrolled_in(itr[end], itr) (itr[end] in itr) str

@test_unrolled (itr,) unrolled_unique(itr) Tuple(unique(itr)) str

@test_unrolled(
(itr,),
unrolled_filter(!isempty, itr),
filter(!isempty, itr),
str,
)

@test_unrolled(
(itr,),
unrolled_split(isempty, itr),
(filter(isempty, itr), filter(!isempty, itr)),
str,
)

@test_unrolled(
(itr,),
unrolled_flatten(itr),
Tuple(Iterators.flatten(itr)),
str,
)

@test_unrolled(
(itr,),
unrolled_flatmap(reverse, itr),
Tuple(Iterators.flatmap(reverse, itr)),
str,
)

@test_unrolled(
(itr,),
unrolled_product(itr),
Tuple(Iterators.product(itr)),
str,
)

if n > 1
@test_unrolled(
(itr,),
unrolled_take(itr, Val(7)),
itr[1:7],
str,
)
@test_unrolled(
(itr,),
unrolled_drop(itr, Val(7)),
itr[8:end],
str,
)
end
end

@test_unrolled(
(itr1, itr2),
unrolled_foreach(
(item1, item2) -> (@assert length(item1) < length(item2)),
itr1,
itr2,
),
foreach(
(item1, item2) -> (@assert length(item1) < length(item2)),
itr1,
itr2,
),
full_str,
)
@test_unrolled(
(itr1, itr2),
unrolled_zip(itr1, itr2),
Tuple(zip(itr1, itr2)),
full_str,
)
@test_unrolled(
(itr1, itr2),
unrolled_product(itr1, itr2),
Tuple(Iterators.product(itr1, itr2)),
full_str,
)
if n <= 10 # This takes a long time to compile for large tuples.
@test_unrolled(
(itr1, itr2),
unrolled_product(itr1, itr2, itr1),
Tuple(Iterators.product(itr1, itr2, itr1)),
full_str,
)
end
end
end

0 comments on commit 9a9bdad

Please sign in to comment.