-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unrolled functions and unit tests
- Loading branch information
1 parent
78c2be0
commit 807bf74
Showing
3 changed files
with
146 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,146 @@ | ||
""" | ||
UnrolledUtilities | ||
A collection of generated functions in which all loops are unrolled. | ||
The functions exported by this module are | ||
- `unrolled_foreach(f, itrs...)`: similar to `foreach` | ||
- `unrolled_any(f, itrs...)`: similar to `any` | ||
- `unrolled_all(f, itrs...)`: similar to `all` | ||
- `unrolled_map(f, itrs...)`: similar to `map` | ||
- `unrolled_reduce(op, itr; [init])`: similar to `reduce` | ||
- `unrolled_mapreduce(f, op, itrs...; [init])`: similar to `mapreduce` | ||
- `unrolled_zip(itrs...)`: similar to `zip` | ||
- `unrolled_in(item, itr)`: similar to `in` | ||
- `unrolled_unique(itr)`: similar to `unique` | ||
- `unrolled_filter(f, itr)`: similar to `filter` | ||
- `unrolled_split(f, itr)`: similar to `(filter(f, itr), filter(!f, itr))`, but | ||
without duplicate calls to `f` | ||
- `unrolled_flatten(itr)`: similar to `Iterators.flatten` | ||
- `unrolled_flatmap(f, itrs...)`: similar to `Iterators.flatmap` | ||
- `unrolled_product(itrs...)`: similar to `Iterators.product` | ||
- `unrolled_take(itr, ::Val{N})`: similar to `Iterators.take`, but with the | ||
second argument wrapped in a `Val` | ||
- `unrolled_drop(itr, ::Val{N})`: similar to `Iterators.drop`, but with the | ||
second argument wrapped in a `Val` | ||
These functions are guaranteed to be type-stable whenever they are given | ||
iterators with inferrable lengths and element types, including when | ||
- the iterators have nonuniform element types (with the exception of `map`, all | ||
of the corresponding functions from `Base` encounter type-instabilities and | ||
allocations when this is the case) | ||
- the iterators have many elements (e.g., more than 32, which is the threshold | ||
at which `map` becomes type-unstable for `Tuple`s) | ||
- `f` and/or `op` recursively call the function to which they is passed, with an | ||
arbitrarily large recursion depth (e.g., if `f` calls `map(f, itrs)`, it will | ||
be type-unstable when the recursion depth exceeds 3, but this will not be the | ||
case with `unrolled_map`) | ||
""" | ||
module UnrolledUtilities | ||
|
||
# TODO: Add source code. | ||
export unrolled_foreach, | ||
unrolled_any, | ||
unrolled_all, | ||
unrolled_map, | ||
unrolled_reduce, | ||
unrolled_mapreduce, | ||
unrolled_zip, | ||
unrolled_in, | ||
unrolled_unique, | ||
unrolled_filter, | ||
unrolled_split, | ||
unrolled_flatten, | ||
unrolled_flatmap, | ||
unrolled_product, | ||
unrolled_take, | ||
unrolled_drop | ||
|
||
# TODO: Add support for iterators that are not Tuples, e.g., StaticArrays. This | ||
# will require adding new methods for unrolled_map, unrolled_filter, etc. | ||
inferred_length(itr_type::Type{<:Tuple}) = length(itr_type.types) | ||
|
||
function zipped_f_exprs(itr_types) | ||
L = length(itr_types) | ||
L == 0 && error("unrolled functions need at least one iterator as input") | ||
N = minimum(inferred_length, itr_types) | ||
return (:(f($((:(itrs[$l][$n]) for l in 1:L)...))) for n in 1:N) | ||
end | ||
|
||
function nested_op_expr(itr_type) | ||
N = inferred_length(itr_type) | ||
N == 0 && error("unrolled_reduce needs an `init` value for empty iterators") | ||
item_exprs = (:(itr[$n]) for n in 1:N) | ||
return reduce((expr1, expr2) -> :(op($expr1, $expr2)), item_exprs) | ||
end | ||
|
||
@generated unrolled_foreach(f, itrs...) = Expr(:block, zipped_f_exprs(itrs)...) | ||
@generated unrolled_any(f, itrs...) = Expr(:||, zipped_f_exprs(itrs)...) | ||
@generated unrolled_all(f, itrs...) = Expr(:&&, zipped_f_exprs(itrs)...) | ||
@generated unrolled_map(f, itrs...) = Expr(:tuple, zipped_f_exprs(itrs)...) | ||
|
||
struct NoInit end | ||
@generated unrolled_reduce_without_kwarg(op, itr) = nested_op_expr(itr) | ||
unrolled_reduce(op, itr; init = NoInit()) = | ||
unrolled_reduce_without_kwarg(op, init == NoInit() ? itr : (init, itr...)) | ||
|
||
unrolled_mapreduce(f, op, itrs...; kwarg...) = | ||
unrolled_reduce(op, unrolled_map(f, itrs...); kwarg...) | ||
|
||
unrolled_zip(itrs...) = unrolled_map(tuple, itrs...) | ||
|
||
unrolled_in(item, itr) = unrolled_any(Base.Fix1(===, item), itr) | ||
# Note: Using === instead of == or isequal seems to improve type stability. | ||
|
||
unrolled_unique(itr) = | ||
unrolled_reduce(itr; init = ()) do unique_items, item | ||
unrolled_in(item, unique_items) ? unique_items : (unique_items..., item) | ||
end | ||
|
||
unrolled_filter(f, itr) = | ||
unrolled_reduce(itr; init = ()) do filtered_items, item | ||
f(item) ? (filtered_items..., item) : filtered_items | ||
end | ||
|
||
unrolled_split(f, itr) = | ||
unrolled_reduce(itr; init = ((), ())) do (f_items, not_f_items), item | ||
f(item) ? ((f_items..., item), not_f_items) : | ||
(f_items, (not_f_items..., item)) | ||
end | ||
|
||
unrolled_flatten(itr) = | ||
unrolled_reduce((item1, item2) -> (item1..., item2...), itr; init = ()) | ||
|
||
unrolled_flatmap(f, itrs...) = unrolled_flatten(unrolled_map(f, itrs...)) | ||
|
||
unrolled_product(itrs...) = | ||
unrolled_reduce(itrs; init = ((),)) do product_itr, itr | ||
unrolled_flatmap(itr) do item | ||
unrolled_map(product_tuple -> (product_tuple..., item), product_itr) | ||
end | ||
end | ||
|
||
# Note: ntuple is unrolled via Base.@ntuple when its second argument is a Val. | ||
unrolled_take(itr, ::Val{N}) where {N} = ntuple(i -> itr[i], Val(N)) | ||
unrolled_drop(itr, ::Val{N}) where {N} = | ||
ntuple(i -> itr[N + i], Val(length(itr) - N)) | ||
|
||
# Drop the recursion limit for functions that take other functions as arguments. | ||
@static if hasfield(Method, :recursion_relation) | ||
const functions_without_recursion_limit = ( | ||
unrolled_foreach, | ||
unrolled_any, | ||
unrolled_all, | ||
unrolled_map, | ||
unrolled_reduce_without_kwarg, | ||
unrolled_reduce, | ||
unrolled_mapreduce, | ||
unrolled_filter, | ||
unrolled_split, | ||
unrolled_flatmap, | ||
) | ||
for func in functions_without_recursion_limit, method in methods(func) | ||
method.recursion_relation = (_...) -> true | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
using SafeTestsets | ||
|
||
#! format: off | ||
@safetestset "Test and Analyze" begin @time include("test_and_analyze.jl") end | ||
@safetestset "Aqua" begin @time include("aqua.jl") end | ||
#! format: on |