diff --git a/src/UnrolledUtilities.jl b/src/UnrolledUtilities.jl index 9191479..8a89b98 100644 --- a/src/UnrolledUtilities.jl +++ b/src/UnrolledUtilities.jl @@ -1,5 +1,146 @@ +""" + UnrolledUtilities + +A collection of generated functions in which all loops are unrolled. + +The functions exported by this module are +- `unrolled_foreach(f, itrs...)`: similar to `foreach` +- `unrolled_any(f, itrs...)`: similar to `any` +- `unrolled_all(f, itrs...)`: similar to `all` +- `unrolled_map(f, itrs...)`: similar to `map` +- `unrolled_reduce(op, itr; [init])`: similar to `reduce` +- `unrolled_mapreduce(f, op, itrs...; [init])`: similar to `mapreduce` +- `unrolled_zip(itrs...)`: similar to `zip` +- `unrolled_in(item, itr)`: similar to `in` +- `unrolled_unique(itr)`: similar to `unique` +- `unrolled_filter(f, itr)`: similar to `filter` +- `unrolled_split(f, itr)`: similar to `(filter(f, itr), filter(!f, itr))`, but + without duplicate calls to `f` +- `unrolled_flatten(itr)`: similar to `Iterators.flatten` +- `unrolled_flatmap(f, itrs...)`: similar to `Iterators.flatmap` +- `unrolled_product(itrs...)`: similar to `Iterators.product` +- `unrolled_take(itr, ::Val{N})`: similar to `Iterators.take`, but with the + second argument wrapped in a `Val` +- `unrolled_drop(itr, ::Val{N})`: similar to `Iterators.drop`, but with the + second argument wrapped in a `Val` + +These functions are guaranteed to be type-stable whenever they are given +iterators with inferrable lengths and element types, including when +- the iterators have nonuniform element types (with the exception of `map`, all + of the corresponding functions from `Base` encounter type-instabilities and + allocations when this is the case) +- the iterators have many elements (e.g., more than 32, which is the threshold + at which `map` becomes type-unstable for `Tuple`s) +- `f` and/or `op` recursively call the function to which they is passed, with an + arbitrarily large recursion depth (e.g., if `f` calls `map(f, itrs)`, it will + be type-unstable when the recursion depth exceeds 3, but this will not be the + case with `unrolled_map`) +""" module UnrolledUtilities -# TODO: Add source code. +export unrolled_foreach, + unrolled_any, + unrolled_all, + unrolled_map, + unrolled_reduce, + unrolled_mapreduce, + unrolled_zip, + unrolled_in, + unrolled_unique, + unrolled_filter, + unrolled_split, + unrolled_flatten, + unrolled_flatmap, + unrolled_product, + unrolled_take, + unrolled_drop + +# TODO: Add support for iterators that are not Tuples, e.g., StaticArrays. This +# will require adding new methods for unrolled_map, unrolled_filter, etc. +inferred_length(itr_type::Type{<:Tuple}) = length(itr_type.types) + +function zipped_f_exprs(itr_types) + L = length(itr_types) + L == 0 && error("unrolled functions need at least one iterator as input") + N = minimum(inferred_length, itr_types) + return (:(f($((:(itrs[$l][$n]) for l in 1:L)...))) for n in 1:N) +end + +function nested_op_expr(itr_type) + N = inferred_length(itr_type) + N == 0 && error("unrolled_reduce needs an `init` value for empty iterators") + item_exprs = (:(itr[$n]) for n in 1:N) + return reduce((expr1, expr2) -> :(op($expr1, $expr2)), item_exprs) +end + +@generated unrolled_foreach(f, itrs...) = Expr(:block, zipped_f_exprs(itrs)...) +@generated unrolled_any(f, itrs...) = Expr(:||, zipped_f_exprs(itrs)...) +@generated unrolled_all(f, itrs...) = Expr(:&&, zipped_f_exprs(itrs)...) +@generated unrolled_map(f, itrs...) = Expr(:tuple, zipped_f_exprs(itrs)...) + +struct NoInit end +@generated unrolled_reduce_without_kwarg(op, itr) = nested_op_expr(itr) +unrolled_reduce(op, itr; init = NoInit()) = + unrolled_reduce_without_kwarg(op, init == NoInit() ? itr : (init, itr...)) + +unrolled_mapreduce(f, op, itrs...; kwarg...) = + unrolled_reduce(op, unrolled_map(f, itrs...); kwarg...) + +unrolled_zip(itrs...) = unrolled_map(tuple, itrs...) + +unrolled_in(item, itr) = unrolled_any(Base.Fix1(===, item), itr) +# Note: Using === instead of == or isequal seems to improve type stability. + +unrolled_unique(itr) = + unrolled_reduce(itr; init = ()) do unique_items, item + unrolled_in(item, unique_items) ? unique_items : (unique_items..., item) + end + +unrolled_filter(f, itr) = + unrolled_reduce(itr; init = ()) do filtered_items, item + f(item) ? (filtered_items..., item) : filtered_items + end + +unrolled_split(f, itr) = + unrolled_reduce(itr; init = ((), ())) do (f_items, not_f_items), item + f(item) ? ((f_items..., item), not_f_items) : + (f_items, (not_f_items..., item)) + end + +unrolled_flatten(itr) = + unrolled_reduce((item1, item2) -> (item1..., item2...), itr; init = ()) + +unrolled_flatmap(f, itrs...) = unrolled_flatten(unrolled_map(f, itrs...)) + +unrolled_product(itrs...) = + unrolled_reduce(itrs; init = ((),)) do product_itr, itr + unrolled_flatmap(itr) do item + unrolled_map(product_tuple -> (product_tuple..., item), product_itr) + end + end + +# Note: ntuple is unrolled via Base.@ntuple when its second argument is a Val. +unrolled_take(itr, ::Val{N}) where {N} = ntuple(i -> itr[i], Val(N)) +unrolled_drop(itr, ::Val{N}) where {N} = + ntuple(i -> itr[N + i], Val(length(itr) - N)) + +# Drop the recursion limit for functions that take other functions as arguments. +@static if hasfield(Method, :recursion_relation) + const functions_without_recursion_limit = ( + unrolled_foreach, + unrolled_any, + unrolled_all, + unrolled_map, + unrolled_reduce_without_kwarg, + unrolled_reduce, + unrolled_mapreduce, + unrolled_filter, + unrolled_split, + unrolled_flatmap, + ) + for func in functions_without_recursion_limit, method in methods(func) + method.recursion_relation = (_...) -> true + end +end end diff --git a/test/runtests.jl b/test/runtests.jl index 2d1aac4..84acf5c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using SafeTestsets #! format: off +@safetestset "Unit Tests" begin @time include("unit_tests.jl") end @safetestset "Aqua" begin @time include("aqua.jl") end #! format: on diff --git a/test/unit_tests.jl b/test/unit_tests.jl new file mode 100644 index 0000000..c2f1c4a --- /dev/null +++ b/test/unit_tests.jl @@ -0,0 +1,226 @@ +using Test, JET, UnrolledUtilities + +function code_instance(f, args...) + available_methods = methods(f, Tuple{map(typeof, args)...}) + @assert length(available_methods) == 1 + (; specializations) = available_methods[1] + specTypes = Tuple{typeof(f), map(typeof, args)...} + return if specializations isa Core.MethodInstance + @assert specializations.specTypes == specTypes + specializations.cache + else + matching_specialization_indices = + findall(specializations) do specialization + !isnothing(specialization) && + specialization.specTypes == specTypes + end + @assert length(matching_specialization_indices) == 1 + specializations[matching_specialization_indices[1]].cache + end +end + +macro test_unrolled(args_expr, unrolled_expr, reference_expr, args_str_expr) + @assert Meta.isexpr(args_expr, :tuple) + esc_args = map(esc, args_expr.args) + esc_args_str = :(lpad($(esc(args_str_expr)), 69)) + unpadded_reference_str = + replace(string(reference_expr), r"\s*#=.+=#" => "", r"\s+" => ' ') + reference_str = rpad(unpadded_reference_str, 86) + quote + unrolled_func($(args_expr.args...)) = $unrolled_expr + reference_func($(args_expr.args...)) = $reference_expr + + # Record the compilation times for later. If both of these functions + # have constant outputs, their runtimes will essentially be 0. + unrolled_time = @elapsed unrolled_func($(esc_args...)) + reference_time = @elapsed reference_func($(esc_args...)) + + # Test for correctness. + @test unrolled_func($(esc_args...)) == reference_func($(esc_args...)) + + unrolled_func_with_nothing($(args_expr.args...)) = + (unrolled_func($(args_expr.args...)); nothing) + reference_func_with_nothing($(args_expr.args...)) = + (reference_func($(args_expr.args...)); nothing) + + unrolled_func_with_nothing($(esc_args...)) # Run once to compile. + reference_func_with_nothing($(esc_args...)) + + # Test for allocations. + @test (@allocated unrolled_func_with_nothing($(esc_args...))) == 0 + is_reference_non_allocating = + (@allocated reference_func_with_nothing($(esc_args...))) == 0 + + # Test for type-stability. + @test_opt unrolled_func($(esc_args...)) + is_reference_stable = + isempty(JET.get_reports(@report_opt reference_func($(esc_args...)))) + + unrolled_instance = code_instance(unrolled_func, $(esc_args...)) + reference_instance = code_instance(reference_func, $(esc_args...)) + + # Test for constant propagation. + @test isdefined(unrolled_instance, :rettype_const) + is_reference_const = isdefined(reference_instance, :rettype_const) + + # TODO: Print this information in a table. + if !is_reference_non_allocating + @info "for $($esc_args_str), $($reference_str) is allocating" + elseif !is_reference_stable + @info "for $($esc_args_str), $($reference_str) is type-unstable" + elseif !is_reference_const + @info "for $($esc_args_str), $($reference_str) is not constant" + else + ratio = round(reference_time / unrolled_time; sigdigits = 3) + @info "for $($esc_args_str), $($reference_str) takes $ratio times \ + as long to compile" + end + end +end + +for n in (1, 10, 33), is_uniform in (n == 1 ? (true,) : (true, false)) + itr1 = ntuple(i -> is_uniform ? () : ntuple(Val, (i - 1) % 7), n) + itr2 = ntuple(i -> is_uniform ? ((),) : ntuple(Val, (i - 1) % 7 + 1), n) + if n == 1 + str1 = "a tuple of 1 (empty) singleton value" + str2 = "a tuple of 1 (nonempty) singleton value" + full_str = "tuples of 1 singleton value" + else + type_str = (is_uniform ? "" : "non-") * "uniformly typed" + str1 = "a tuple of $n (empty & nonempty) $type_str singleton values" + str2 = "a tuple of $n (nonempty) $type_str singleton values" + full_str = "tuples of $n $type_str singleton values" + end + @testset "$full_str" begin + for (itr, str) in ((itr1, str1), (itr2, str2)) + @test_unrolled( + (itr,), + unrolled_foreach(item -> (@assert length(item) <= 7), itr), + foreach(item -> (@assert length(item) <= 7), itr), + str, + ) + + @test_unrolled (itr,) unrolled_any(isempty, itr) any(isempty, itr) str + @test_unrolled (itr,) unrolled_any(!isempty, itr) any(!isempty, itr) str + + @test_unrolled (itr,) unrolled_all(isempty, itr) all(isempty, itr) str + @test_unrolled (itr,) unrolled_all(!isempty, itr) all(!isempty, itr) str + + @test_unrolled (itr,) unrolled_map(length, itr) map(length, itr) str + + @test_unrolled (itr,) unrolled_reduce(tuple, itr) reduce(tuple, itr) str + @test_unrolled( + (itr,), + unrolled_reduce(tuple, itr; init = ()), + reduce(tuple, itr; init = ()), + str, + ) + + @test_unrolled( + (itr,), + unrolled_mapreduce(length, +, itr), + mapreduce(length, +, itr), + str, + ) + @test_unrolled( + (itr,), + unrolled_mapreduce(length, +, itr; init = false), + mapreduce(length, +, itr; init = false), + str, + ) + + @test_unrolled (itr,) unrolled_zip(itr) Tuple(zip(itr)) str + + @test_unrolled (itr,) unrolled_in(nothing, itr) (nothing in itr) str + @test_unrolled (itr,) unrolled_in(itr[1], itr) (itr[1] in itr) str + @test_unrolled (itr,) unrolled_in(itr[end], itr) (itr[end] in itr) str + + @test_unrolled (itr,) unrolled_unique(itr) Tuple(unique(itr)) str + + @test_unrolled( + (itr,), + unrolled_filter(!isempty, itr), + filter(!isempty, itr), + str, + ) + + @test_unrolled( + (itr,), + unrolled_split(isempty, itr), + (filter(isempty, itr), filter(!isempty, itr)), + str, + ) + + @test_unrolled( + (itr,), + unrolled_flatten(itr), + Tuple(Iterators.flatten(itr)), + str, + ) + + @test_unrolled( + (itr,), + unrolled_flatmap(reverse, itr), + Tuple(Iterators.flatmap(reverse, itr)), + str, + ) + + @test_unrolled( + (itr,), + unrolled_product(itr), + Tuple(Iterators.product(itr)), + str, + ) + + if n > 1 + @test_unrolled( + (itr,), + unrolled_take(itr, Val(7)), + itr[1:7], + str, + ) + @test_unrolled( + (itr,), + unrolled_drop(itr, Val(7)), + itr[8:end], + str, + ) + end + end + + @test_unrolled( + (itr1, itr2), + unrolled_foreach( + (item1, item2) -> (@assert length(item1) < length(item2)), + itr1, + itr2, + ), + foreach( + (item1, item2) -> (@assert length(item1) < length(item2)), + itr1, + itr2, + ), + full_str, + ) + @test_unrolled( + (itr1, itr2), + unrolled_zip(itr1, itr2), + Tuple(zip(itr1, itr2)), + full_str, + ) + @test_unrolled( + (itr1, itr2), + unrolled_product(itr1, itr2), + Tuple(Iterators.product(itr1, itr2)), + full_str, + ) + if n <= 10 # This takes a long time to compile for large tuples. + @test_unrolled( + (itr1, itr2), + unrolled_product(itr1, itr2, itr1), + Tuple(Iterators.product(itr1, itr2, itr1)), + full_str, + ) + end + end +end