From 807bf74a7f9c094ce999d2229c60b2cc78d6d53f Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Tue, 26 Mar 2024 11:13:49 -0700 Subject: [PATCH] Add unrolled functions and unit tests --- Project.toml | 5 +- src/UnrolledUtilities.jl | 143 ++++++++++++++++++++++++++++++++++++++- test/runtests.jl | 1 + 3 files changed, 146 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 6ea7814..4b44887 100644 --- a/Project.toml +++ b/Project.toml @@ -8,10 +8,11 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "BenchmarkTools", "JET", "SafeTestsets", "Test"] +test = ["Aqua", "JET", "OrderedCollections", "PrettyTables", "SafeTestsets", "Test"] diff --git a/src/UnrolledUtilities.jl b/src/UnrolledUtilities.jl index 9191479..8a89b98 100644 --- a/src/UnrolledUtilities.jl +++ b/src/UnrolledUtilities.jl @@ -1,5 +1,146 @@ +""" + UnrolledUtilities + +A collection of generated functions in which all loops are unrolled. + +The functions exported by this module are +- `unrolled_foreach(f, itrs...)`: similar to `foreach` +- `unrolled_any(f, itrs...)`: similar to `any` +- `unrolled_all(f, itrs...)`: similar to `all` +- `unrolled_map(f, itrs...)`: similar to `map` +- `unrolled_reduce(op, itr; [init])`: similar to `reduce` +- `unrolled_mapreduce(f, op, itrs...; [init])`: similar to `mapreduce` +- `unrolled_zip(itrs...)`: similar to `zip` +- `unrolled_in(item, itr)`: similar to `in` +- `unrolled_unique(itr)`: similar to `unique` +- `unrolled_filter(f, itr)`: similar to `filter` +- `unrolled_split(f, itr)`: similar to `(filter(f, itr), filter(!f, itr))`, but + without duplicate calls to `f` +- `unrolled_flatten(itr)`: similar to `Iterators.flatten` +- `unrolled_flatmap(f, itrs...)`: similar to `Iterators.flatmap` +- `unrolled_product(itrs...)`: similar to `Iterators.product` +- `unrolled_take(itr, ::Val{N})`: similar to `Iterators.take`, but with the + second argument wrapped in a `Val` +- `unrolled_drop(itr, ::Val{N})`: similar to `Iterators.drop`, but with the + second argument wrapped in a `Val` + +These functions are guaranteed to be type-stable whenever they are given +iterators with inferrable lengths and element types, including when +- the iterators have nonuniform element types (with the exception of `map`, all + of the corresponding functions from `Base` encounter type-instabilities and + allocations when this is the case) +- the iterators have many elements (e.g., more than 32, which is the threshold + at which `map` becomes type-unstable for `Tuple`s) +- `f` and/or `op` recursively call the function to which they is passed, with an + arbitrarily large recursion depth (e.g., if `f` calls `map(f, itrs)`, it will + be type-unstable when the recursion depth exceeds 3, but this will not be the + case with `unrolled_map`) +""" module UnrolledUtilities -# TODO: Add source code. +export unrolled_foreach, + unrolled_any, + unrolled_all, + unrolled_map, + unrolled_reduce, + unrolled_mapreduce, + unrolled_zip, + unrolled_in, + unrolled_unique, + unrolled_filter, + unrolled_split, + unrolled_flatten, + unrolled_flatmap, + unrolled_product, + unrolled_take, + unrolled_drop + +# TODO: Add support for iterators that are not Tuples, e.g., StaticArrays. This +# will require adding new methods for unrolled_map, unrolled_filter, etc. +inferred_length(itr_type::Type{<:Tuple}) = length(itr_type.types) + +function zipped_f_exprs(itr_types) + L = length(itr_types) + L == 0 && error("unrolled functions need at least one iterator as input") + N = minimum(inferred_length, itr_types) + return (:(f($((:(itrs[$l][$n]) for l in 1:L)...))) for n in 1:N) +end + +function nested_op_expr(itr_type) + N = inferred_length(itr_type) + N == 0 && error("unrolled_reduce needs an `init` value for empty iterators") + item_exprs = (:(itr[$n]) for n in 1:N) + return reduce((expr1, expr2) -> :(op($expr1, $expr2)), item_exprs) +end + +@generated unrolled_foreach(f, itrs...) = Expr(:block, zipped_f_exprs(itrs)...) +@generated unrolled_any(f, itrs...) = Expr(:||, zipped_f_exprs(itrs)...) +@generated unrolled_all(f, itrs...) = Expr(:&&, zipped_f_exprs(itrs)...) +@generated unrolled_map(f, itrs...) = Expr(:tuple, zipped_f_exprs(itrs)...) + +struct NoInit end +@generated unrolled_reduce_without_kwarg(op, itr) = nested_op_expr(itr) +unrolled_reduce(op, itr; init = NoInit()) = + unrolled_reduce_without_kwarg(op, init == NoInit() ? itr : (init, itr...)) + +unrolled_mapreduce(f, op, itrs...; kwarg...) = + unrolled_reduce(op, unrolled_map(f, itrs...); kwarg...) + +unrolled_zip(itrs...) = unrolled_map(tuple, itrs...) + +unrolled_in(item, itr) = unrolled_any(Base.Fix1(===, item), itr) +# Note: Using === instead of == or isequal seems to improve type stability. + +unrolled_unique(itr) = + unrolled_reduce(itr; init = ()) do unique_items, item + unrolled_in(item, unique_items) ? unique_items : (unique_items..., item) + end + +unrolled_filter(f, itr) = + unrolled_reduce(itr; init = ()) do filtered_items, item + f(item) ? (filtered_items..., item) : filtered_items + end + +unrolled_split(f, itr) = + unrolled_reduce(itr; init = ((), ())) do (f_items, not_f_items), item + f(item) ? ((f_items..., item), not_f_items) : + (f_items, (not_f_items..., item)) + end + +unrolled_flatten(itr) = + unrolled_reduce((item1, item2) -> (item1..., item2...), itr; init = ()) + +unrolled_flatmap(f, itrs...) = unrolled_flatten(unrolled_map(f, itrs...)) + +unrolled_product(itrs...) = + unrolled_reduce(itrs; init = ((),)) do product_itr, itr + unrolled_flatmap(itr) do item + unrolled_map(product_tuple -> (product_tuple..., item), product_itr) + end + end + +# Note: ntuple is unrolled via Base.@ntuple when its second argument is a Val. +unrolled_take(itr, ::Val{N}) where {N} = ntuple(i -> itr[i], Val(N)) +unrolled_drop(itr, ::Val{N}) where {N} = + ntuple(i -> itr[N + i], Val(length(itr) - N)) + +# Drop the recursion limit for functions that take other functions as arguments. +@static if hasfield(Method, :recursion_relation) + const functions_without_recursion_limit = ( + unrolled_foreach, + unrolled_any, + unrolled_all, + unrolled_map, + unrolled_reduce_without_kwarg, + unrolled_reduce, + unrolled_mapreduce, + unrolled_filter, + unrolled_split, + unrolled_flatmap, + ) + for func in functions_without_recursion_limit, method in methods(func) + method.recursion_relation = (_...) -> true + end +end end diff --git a/test/runtests.jl b/test/runtests.jl index 2d1aac4..be00ae4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using SafeTestsets #! format: off +@safetestset "Test and Analyze" begin @time include("test_and_analyze.jl") end @safetestset "Aqua" begin @time include("aqua.jl") end #! format: on