Skip to content

Commit

Permalink
Add unrolled functions and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Mar 28, 2024
1 parent 78c2be0 commit 807bf74
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 3 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "BenchmarkTools", "JET", "SafeTestsets", "Test"]
test = ["Aqua", "JET", "OrderedCollections", "PrettyTables", "SafeTestsets", "Test"]
143 changes: 142 additions & 1 deletion src/UnrolledUtilities.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,146 @@
"""
UnrolledUtilities
A collection of generated functions in which all loops are unrolled.
The functions exported by this module are
- `unrolled_foreach(f, itrs...)`: similar to `foreach`
- `unrolled_any(f, itrs...)`: similar to `any`
- `unrolled_all(f, itrs...)`: similar to `all`
- `unrolled_map(f, itrs...)`: similar to `map`
- `unrolled_reduce(op, itr; [init])`: similar to `reduce`
- `unrolled_mapreduce(f, op, itrs...; [init])`: similar to `mapreduce`
- `unrolled_zip(itrs...)`: similar to `zip`
- `unrolled_in(item, itr)`: similar to `in`
- `unrolled_unique(itr)`: similar to `unique`
- `unrolled_filter(f, itr)`: similar to `filter`
- `unrolled_split(f, itr)`: similar to `(filter(f, itr), filter(!f, itr))`, but
without duplicate calls to `f`
- `unrolled_flatten(itr)`: similar to `Iterators.flatten`
- `unrolled_flatmap(f, itrs...)`: similar to `Iterators.flatmap`
- `unrolled_product(itrs...)`: similar to `Iterators.product`
- `unrolled_take(itr, ::Val{N})`: similar to `Iterators.take`, but with the
second argument wrapped in a `Val`
- `unrolled_drop(itr, ::Val{N})`: similar to `Iterators.drop`, but with the
second argument wrapped in a `Val`
These functions are guaranteed to be type-stable whenever they are given
iterators with inferrable lengths and element types, including when
- the iterators have nonuniform element types (with the exception of `map`, all
of the corresponding functions from `Base` encounter type-instabilities and
allocations when this is the case)
- the iterators have many elements (e.g., more than 32, which is the threshold
at which `map` becomes type-unstable for `Tuple`s)
- `f` and/or `op` recursively call the function to which they is passed, with an
arbitrarily large recursion depth (e.g., if `f` calls `map(f, itrs)`, it will
be type-unstable when the recursion depth exceeds 3, but this will not be the
case with `unrolled_map`)
"""
module UnrolledUtilities

# TODO: Add source code.
export unrolled_foreach,
unrolled_any,
unrolled_all,
unrolled_map,
unrolled_reduce,
unrolled_mapreduce,
unrolled_zip,
unrolled_in,
unrolled_unique,
unrolled_filter,
unrolled_split,
unrolled_flatten,
unrolled_flatmap,
unrolled_product,
unrolled_take,
unrolled_drop

# TODO: Add support for iterators that are not Tuples, e.g., StaticArrays. This
# will require adding new methods for unrolled_map, unrolled_filter, etc.
inferred_length(itr_type::Type{<:Tuple}) = length(itr_type.types)

function zipped_f_exprs(itr_types)
L = length(itr_types)
L == 0 && error("unrolled functions need at least one iterator as input")
N = minimum(inferred_length, itr_types)
return (:(f($((:(itrs[$l][$n]) for l in 1:L)...))) for n in 1:N)
end

function nested_op_expr(itr_type)
N = inferred_length(itr_type)
N == 0 && error("unrolled_reduce needs an `init` value for empty iterators")
item_exprs = (:(itr[$n]) for n in 1:N)
return reduce((expr1, expr2) -> :(op($expr1, $expr2)), item_exprs)
end

@generated unrolled_foreach(f, itrs...) = Expr(:block, zipped_f_exprs(itrs)...)
@generated unrolled_any(f, itrs...) = Expr(:||, zipped_f_exprs(itrs)...)
@generated unrolled_all(f, itrs...) = Expr(:&&, zipped_f_exprs(itrs)...)
@generated unrolled_map(f, itrs...) = Expr(:tuple, zipped_f_exprs(itrs)...)

struct NoInit end
@generated unrolled_reduce_without_kwarg(op, itr) = nested_op_expr(itr)
unrolled_reduce(op, itr; init = NoInit()) =
unrolled_reduce_without_kwarg(op, init == NoInit() ? itr : (init, itr...))

unrolled_mapreduce(f, op, itrs...; kwarg...) =
unrolled_reduce(op, unrolled_map(f, itrs...); kwarg...)

unrolled_zip(itrs...) = unrolled_map(tuple, itrs...)

unrolled_in(item, itr) = unrolled_any(Base.Fix1(===, item), itr)
# Note: Using === instead of == or isequal seems to improve type stability.

unrolled_unique(itr) =
unrolled_reduce(itr; init = ()) do unique_items, item
unrolled_in(item, unique_items) ? unique_items : (unique_items..., item)
end

unrolled_filter(f, itr) =
unrolled_reduce(itr; init = ()) do filtered_items, item
f(item) ? (filtered_items..., item) : filtered_items
end

unrolled_split(f, itr) =
unrolled_reduce(itr; init = ((), ())) do (f_items, not_f_items), item
f(item) ? ((f_items..., item), not_f_items) :
(f_items, (not_f_items..., item))
end

unrolled_flatten(itr) =
unrolled_reduce((item1, item2) -> (item1..., item2...), itr; init = ())

unrolled_flatmap(f, itrs...) = unrolled_flatten(unrolled_map(f, itrs...))

unrolled_product(itrs...) =
unrolled_reduce(itrs; init = ((),)) do product_itr, itr
unrolled_flatmap(itr) do item
unrolled_map(product_tuple -> (product_tuple..., item), product_itr)
end
end

# Note: ntuple is unrolled via Base.@ntuple when its second argument is a Val.
unrolled_take(itr, ::Val{N}) where {N} = ntuple(i -> itr[i], Val(N))
unrolled_drop(itr, ::Val{N}) where {N} =
ntuple(i -> itr[N + i], Val(length(itr) - N))

# Drop the recursion limit for functions that take other functions as arguments.
@static if hasfield(Method, :recursion_relation)
const functions_without_recursion_limit = (
unrolled_foreach,
unrolled_any,
unrolled_all,
unrolled_map,
unrolled_reduce_without_kwarg,
unrolled_reduce,
unrolled_mapreduce,
unrolled_filter,
unrolled_split,
unrolled_flatmap,
)
for func in functions_without_recursion_limit, method in methods(func)
method.recursion_relation = (_...) -> true
end
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using SafeTestsets

#! format: off
@safetestset "Test and Analyze" begin @time include("test_and_analyze.jl") end
@safetestset "Aqua" begin @time include("aqua.jl") end
#! format: on

0 comments on commit 807bf74

Please sign in to comment.