From 48fe382a4c4912085537b23c0ea1fbaf4ceab542 Mon Sep 17 00:00:00 2001 From: RainerHeintzmann Date: Fri, 5 Jul 2024 23:51:33 +0200 Subject: [PATCH] added support for _nokw_ towards AD capability. --- Project.toml | 6 +- src/SeparableFunctions.jl | 5 +- src/general.jl | 212 +++++++++++++++++++++++++++++--------- src/specific.jl | 17 ++- test/runtests.jl | 89 +++++++++++----- 5 files changed, 251 insertions(+), 78 deletions(-) diff --git a/Project.toml b/Project.toml index d08c541..d3d8ae2 100644 --- a/Project.toml +++ b/Project.toml @@ -4,8 +4,6 @@ authors = ["RainerHeintzmann "] version = "0.1.1" [deps] -ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" @@ -13,6 +11,7 @@ NDTools = "98581153-e998-4eef-8d0d-5ec2c052313d" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Thunks = "490da00b-a60c-4ded-a4cf-df7cded56bfa" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] ImageTransformations = "0.9, 0.10" @@ -22,9 +21,6 @@ NDTools = "0.7" StaticArrays = "0.1, 0.3, 0.4, 0.13, 0.14, 1" Thunks = "0.3" julia = "1" -ChainRules = "1.65, 1.66, 1.67, 1.68, 1.69" -ChainRulesCore = "1.20, 1.21, 1.22, 1.23, 1.24" -Zygote = "0.5, 0.6" [extras] IndexFunArrays = "613c443e-d742-454e-bfc6-1d7f8dd76566" diff --git a/src/SeparableFunctions.jl b/src/SeparableFunctions.jl index ef874ba..65e0eec 100644 --- a/src/SeparableFunctions.jl +++ b/src/SeparableFunctions.jl @@ -28,8 +28,9 @@ module SeparableFunctions using NDTools, LazyArrays using ImageTransformations, StaticArrays using Interpolations -using ChainRulesCore # for adjoint definition -using Zygote +# using ChainRulesCore # for adjoint definition +using ZygoteRules +using Zygote # to use rrule_via_ad export calculate_separables, separable_view, separable_create export calculate_broadcasted diff --git a/src/general.jl b/src/general.jl index af26a90..7e1f64b 100644 --- a/src/general.jl +++ b/src/general.jl @@ -1,5 +1,8 @@ """ - calculate_separables([::Type{AT},] fct, sz::NTuple{N, Int}, args...; dims = 1:N, all_axes = (similar_arr_type(AT, dims=Val(1)))(undef, sum(sz[[dims...]])), pos=zero(real(eltype(AT))), offset=sz.÷2 .+1, scale=one(real(eltype(AT))), kwargs...) where {AT, N} + calculate_separables_nokw([::Type{AT},] fct, sz::NTuple{N, Int}, offset = sz.÷2 .+1, scale = one(real(eltype(AT))), + factor = one(real(eltype(AT))), args...; dims = 1:N, + all_axes = (similar_arr_type(AT, dims=Val(1)))(undef, sum(sz[[dims...]])), pos=zero(real(eltype(AT))), + kwargs...) where {AT, N} creates a list of one-dimensional vectors, which can be combined to yield a separable array. In a way this can be seen as a half-way Lazy operation. The (potentially heavy) work of calculating the one-dimensional functions is done now but the memory-heavy calculation of the array is done later. @@ -9,12 +12,13 @@ This function is used in `separable_view` and `separable_create`. + `AT`: optional type signfying the array result type. You can for example use `CuArray{Float32}` using `CUDA` to create the views on the GPU. + `fct`: the function to calculate for each axis index (no need for broadcasting!) of this iterable of seperable axes. Note that the first arguments of `fct` have to be the index of this coordinate and the size of this axis. Any further `args` and `nargs` can follow. Often the second argument is not used but it still needs to be present. + `sz`: the size of the result array (when appying the one-D axes) ++ `offset`: specifying the center (zero-position) of the result array in one-based coordinates. The default corresponds to the Fourier-center. ++ `scale`: multiplies the index before passing it to `fct` ++ `factor`: multiplies the result of `fct` before storing it in the result array. + `args`: further arguments which are passed over to the function `fct`. + `dims`: a vector `[]` of valid dimensions. Only these dimension will be calculated but they are oriented in ND. + `all_axes`: if provided, this memory is used instead of allocating a new one. This can be useful if you want to use the same memory for multiple calculations. + `pos`: a position shifting the indices passed to `fct` in relationship to the `offset`. -+ `offset`: specifying the center (zero-position) of the result array in one-based coordinates. The default corresponds to the Fourier-center. -+ `scale`: multiplies the index before passing it to `fct` #Example: ```julia @@ -33,13 +37,21 @@ julia> gauss_sep = calculate_separables(fct, (6,5), (0.5,1.0), pos = (0.1,0.2)) 6.50731f-5 0.000356206 0.000717312 0.000531398 0.000144823 ``` """ -function calculate_separables(::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims = 1:N, +function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int}, + offset = sz.÷2 .+1, + scale = one(real(eltype(AT))), + factor = one(real(eltype(AT))), + args...; dims = 1:N, all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])), defaults=NamedTuple(), pos=zero(real(eltype(AT))), - offset=sz.÷2 .+1, - scale=one(real(eltype(AT))), - factor=one(real(eltype(AT))), kwargs...) where {AT, N} + kwargs...) where {AT, N} + + RT = real(eltype(AT)) + offset = isnothing(offset) ? sz.÷2 .+1 : RT.(offset) + scale = isnothing(scale) ? one(real(eltype(RT))) : RT.(scale) + factor = isnothing(factor) ? one(real(eltype(RT))) : RT.(factor) start = 1 .- offset + idc = pick_n(dims[1], scale) .* ((start[dims[1]]:start[dims[1]]+sz[dims[1]]-1) .- pick_n(dims[1], pos)) # @show typeof(idc) dims = [dims...] @@ -53,7 +65,12 @@ function calculate_separables(::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims if isa(factor, Number) factor = ntuple((d) -> factor, lastindex(dims)) end - # @show factor + # @show offset + # @show extra_args + # @show args + # @show kwargs + # @show collect(arg_n(dims[1], args)) + # @show (idc, sz[dims[1]], extra_args..., arg_n(dims[1], args)...) res[1][:] .= (factor[1]) .* fct.(idc, sz[dims[1]], extra_args..., arg_n(dims[1], args)...) #push!(res, collect(reorient(fct.(idc, sz[1], arg_n(1, args)...; kwarg_n(1, kwargs)...), 1, Val(N)))) for d = 2:lastindex(dims) @@ -75,52 +92,130 @@ function calculate_separables(::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims return res end -function calculate_separables(fct, sz::NTuple{N, Int}, args...; dims=1:N, +function calculate_separables(::Type{AT}, fct, sz::NTuple{N, Int}, + args...; dims = 1:N, + all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])), + defaults=NamedTuple(), pos=zero(real(eltype(AT))), + offset = sz.÷2 .+1, + scale = one(real(eltype(AT))), + factor = one(real(eltype(AT))), + kwargs...) where {AT, N} + return calculate_separables_nokw(AT, fct, sz, offset, scale, factor, args...; dims=dims, all_axes=all_axes, defaults=defaults, pos=pos, kwargs...) +end + +function calculate_separables(fct, sz::NTuple{N, Int}, args...; dims=1:N, all_axes = (similar_arr_type(DefaultArrType, eltype(DefaultArrType), Val(1)))(undef, sum(sz[[dims...]])), - pos=zero(real(eltype(DefaultArrType))), offset=sz.÷2 .+1, scale=one(real(eltype(DefaultArrType))), kwargs...) where {N} - calculate_separables(DefaultArrType, fct, sz, args...; all_axes=all_axes, dims=dims, pos=pos, offset=offset, scale=scale, kwargs...) + pos=zero(real(eltype(DefaultArrType))), + kwargs...) where {N} + calculate_separables(DefaultArrType, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...) end # define custom adjoint for calculate_separables # function ChainRulesCore.rrule(::typeof(calculate_separables), conv, rec, otf) # end -function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables), ::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims = 1:N, - all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])), - defaults = NamedTuple(), pos=zero(real(eltype(AT))), - offset = sz.÷2 .+1, scale=one(real(eltype(AT))), kwargs...) where {AT, N} +# +# function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables), ::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims = 1:N, +# all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])), +# defaults = NamedTuple(), pos=zero(real(eltype(AT))), +# offset = sz.÷2 .+1, scale=one(real(eltype(AT))), kwargs...) where {AT, N} + +# println("inside rrule! $(sz), $(dims), $(offset)") +# # foward pass +# y = calculate_separables(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, kwargs...) - println("inside rrule! $(sz), $(dims), $(offset)") - # foward pass - y = calculate_separables(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, kwargs...) +# # extra_args = kwargs_to_args(defaults, kwarg_n(dims[1], kwargs)) +# # res[1][:] .= fct.(idc, sz[dims[1]], extra_args..., arg_n(dims[1], args)...) - # extra_args = kwargs_to_args(defaults, kwarg_n(dims[1], kwargs)) - # res[1][:] .= fct.(idc, sz[dims[1]], extra_args..., arg_n(dims[1], args)...) +# # d_off_fct = (idx, sz, args...; kwargs...) -> rrule((x) -> +# # fct(x, sz, args...; kwargs...), idx)[2](1)[1] +# # d_scale_fct = (idx, sz, args...; kwargs...) -> idx * rrule((x, sz, args...; kwargs...) -> +# # fct(x, sz, args...; kwargs...), idx)[2](1)[1] +# # d_pos_fct = d_off_fct -# d_off_fct = (idx, sz, args...; kwargs...) -> rrule((x) -> -# fct(x, sz, args...; kwargs...), idx)[2](1)[1] -# d_scale_fct = (idx, sz, args...; kwargs...) -> idx * rrule((x, sz, args...; kwargs...) -> -# fct(x, sz, args...; kwargs...), idx)[2](1)[1] +# d_off_fct = (idx, sz, args...; kwargs...) -> rrule_via_ad(config, (x) -> +# fct(x, sz, args...), idx; kwargs...)[1] +# d_scale_fct = (idx, sz, args...; kwargs...) -> idx * rrule_via_ad(config, (x) -> +# fct(x, sz, args...), idx;kwargs...)[1] # d_pos_fct = d_off_fct - d_off_fct = (idx, sz, args...; kwargs...) -> rrule_via_ad(config, (x) -> - fct(x, sz, args...), idx; kwargs...)[1] - d_scale_fct = (idx, sz, args...; kwargs...) -> idx * rrule_via_ad(config, (x) -> - fct(x, sz, args...), idx;kwargs...)[1] - d_pos_fct = d_off_fct - - function calculate_separables_pullback(dy) # dy is the gradient of the output - # multiply by dy? - doffset = calculate_separables(AT, d_off_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = dy, kwargs...) - dscale = 1 #calculate_separables(AT, d_scale_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...) - dpos = 1 # calculate_separables(AT, d_pos_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...) - dargs = args; - # It should return the gradient of the inputs - # println(doffset) - return (NoTangent(), NoTangent(), NoTangent(), doffset, dargs..., NoTangent(), NoTangent(), NoTangent(), dpos, doffset, dscale, NoTangent()) - end - return y, calculate_separables_pullback -end +# function calculate_separables_pullback(dy) # dy is the gradient of the output +# # multiply by dy? +# doffset = calculate_separables(AT, d_off_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = dy, kwargs...) +# dscale = 1 #calculate_separables(AT, d_scale_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...) +# dpos = 1 # calculate_separables(AT, d_pos_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...) +# dargs = args; +# # It should return the gradient of the inputs +# # println(doffset) +# return (NoTangent(), NoTangent(), NoTangent(), doffset, dargs..., NoTangent(), NoTangent(), NoTangent(), dpos, doffset, dscale, NoTangent()) +# end +# return y, calculate_separables_pullback +# end + +# foo(a,b;c=42.0)=a*b*c +# function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables), ::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims = 1:N, +# function ZygoteRules._pullback(::ZygoteRules.AContext, A::typeof(Core.kwcall), kwargs, ::typeof(calculate_separables), ::Type{AT}, fct, sz::NTuple{N, Int}, args...) where {AT, N} +# # function ZygoteRules._pullback(::ZygoteRules.AContext, A::typeof(Core.kwcall), kwargs, ::typeof(foo), a, b) +# # retrieve the kwargs +# dims = get(kwargs, :dims, 1:N) # insert default as we need that in pullback. +# all_axes = get(kwargs, :all_axes, (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]]))) # insert default as we need that in pullback. +# defaults = get(kwargs, :defaults, NamedTuple()) +# pos = get(kwargs, :pos, zero(real(eltype(AT)))) +# offset = get(kwargs, :offset, sz.÷2 .+1) +# scale = get(kwargs, :scale, one(real(eltype(AT)))) + +# println("kwargs calculate_separables in zygote rule sz:$(sz), dims:$(dims), offset:$(offset) args:$(args)") +# # foward pass +# y = calculate_separables(AT, fct, sz, args...; kwargs...) +# println("calculated y $(typeof(y))") + +# # Use Zygote's _pullback to obtain the gradient function for `fct` in the separable one-d case +# d_off_fct = (idx, sz, args...; kwargs...) -> Zygote.pullback((x) -> fct(x, sz, args...; kwargs...), idx)[1] +# d_scale_fct = (idx, sz, args...; kwargs...) -> idx * Zygote.pullback((x) -> fct(x, sz, args...; kwargs...), idx)[1] +# # config = Type{acontext} # Zygote.ZygoteRuleConfig + +# # d_off_fct = (idx, sz, args...; kwargs...) -> rrule_via_ad(config, (x) -> +# # fct(x, sz, args...), idx; kwargs...)[1] +# # d_scale_fct = (idx, sz, args...; kwargs...) -> idx * rrule_via_ad(config, (x) -> +# # fct(x, sz, args...), idx;kwargs...)[1] +# d_pos_fct = d_off_fct + +# # c = get(kwargs, :c, 42.0) # insert default as we need that in pullback. +# println("functions defined") +# function calculate_separables_pullback(dy) # dy is the gradient of the output +# println("kwargs calculate_separables pulling back") +# # multiply by dy? +# # @show d_off_fct(13.3, 20, args...) + +# doffset = haskey(kwargs, :offset) ? calculate_separables(AT, d_off_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = dy, kwargs...) : Zygote.NoTangent() +# # doffset = collect(doffset) +# # @show size(.*(doffset...)) +# # @show eltype(.*(doffset...)) +# println("calculated doffset $(typeof(doffset))") +# dscale = 1 #calculate_separables(AT, d_scale_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...) +# dpos = 1 # calculate_separables(AT, d_pos_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...) +# dargs = args; +# dkwargs = (;offset = doffset, scale = dscale, pos = dpos); +# # It should return the gradient of the inputs +# # println(doffset) +# # return (NoTangent(), NoTangent(), NoTangent(), doffset, dargs..., NoTangent(), NoTangent(), NoTangent(), dpos, doffset, dscale, NoTangent()) +# println("done calculating $(typeof(dkwargs[:offset]))") +# return nothing, dkwargs, nothing, nothing, nothing, nothing +# end +# println("pullback defined") +# return y, calculate_separables_pullback + +# # function foo_pullback(dy) +# # println("kwargs foo pulling back") +# # da = dy*b*c +# # db = dy*a*c +# # dc = haskey(kwargs, :c) ? dy*a*b : NoTangent() +# # dkwargs = (;c=dc) +# # return nothing, dkwargs, nothing, da, db +# # end +# # return y, foo_pullback +# end + """ calculate_broadcasted([::Type{TA},] fct, sz::NTuple{N, Int}, args...; dims=1:N, pos=zero(real(eltype(DefaultArrType))), offset=sz.÷2 .+1, scale=one(real(eltype(DefaultArrType))), operation = *, kwargs...) where {TA, N} @@ -157,16 +252,41 @@ julia> collect(my_gaussian) """ function calculate_broadcasted(::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims=1:N, all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])), - pos=zero(real(eltype(DefaultArrType))), offset=sz.÷2 .+1, scale=one(real(eltype(DefaultArrType))), operation = *, kwargs...) where {AT, N} - Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, kwargs...)...)) + pos=zero(real(eltype(DefaultArrType))), + operation = *, kwargs...) where {AT, N} + Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...)) end function calculate_broadcasted(fct, sz::NTuple{N, Int}, args...; dims=1:N, all_axes = (similar_arr_type(DefaultArrType, eltype(DefaultArrType), Val(1)))(undef, sum(sz[[dims...]])), - pos=zero(real(eltype(DefaultArrType))), offset=sz.÷2 .+1, scale=one(real(eltype(DefaultArrType))), operation = *, kwargs...) where {N} - Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables(DefaultArrType, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, kwargs...)...)) + pos=zero(real(eltype(DefaultArrType))), + operation = *, kwargs...) where {N} + Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables(DefaultArrType, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...)) +end + +### Versions where offst and scale are without keyword arguments +function calculate_broadcasted_nokw(::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims=1:N, + all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])), + pos=zero(real(eltype(DefaultArrType))), operation = *, defaults = nothing, kwargs...) where {AT, N} + # defaults should be evaluated here and filled into args... +Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables_nokw(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...)) end +function calculate_broadcasted_nokw(fct, sz::NTuple{N, Int}, args...; dims=1:N, + all_axes = (similar_arr_type(DefaultArrType, eltype(DefaultArrType), Val(1)))(undef, sum(sz[[dims...]])), + pos=zero(real(eltype(DefaultArrType))), operation = *, defaults = nothing, kwargs...) where {N} + # defaults should be evaluated here and filled into args... +Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables_nokw(DefaultArrType, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...)) +end + +# towards a Gaussian that can also be rotated: +# mulitply with exp(-(x-x0)*(y-y0)/(sigma_xy)) = exp(-(x-x0)/(sigma_xy)) ^ (y-y0) +# g = gaussian_sep((200, 200), sigma=2.2) +# vg = Base.broadcasted((x,y)->x, collect(g), ones(1,1,10)); +# res = similar(g.args[1], size(vg)); +# @time gg = accumulate!(*, res, vg, dims=3); #, init=collect(g) + + """ separable_view{N}(fct, sz, args...; pos=zero(real(eltype(AT))), offset = sz.÷2 .+1, scale = one(real(eltype(AT))), operation = .*) diff --git a/src/specific.jl b/src/specific.jl index dc6bf30..5960b32 100644 --- a/src/specific.jl +++ b/src/specific.jl @@ -30,7 +30,7 @@ function generate_functions_expr() # Rules: the calculation function has no kwargs but the last N arguments are the kwargs of the wrapper function # FunctionName, kwarg_names, no_kwargs_function_definition, default_return_type, default_separamble_operator (:(gaussian),(sigma=1.0,), :((x,sz, sigma) -> exp(- x^2/(2 .* sigma^2))), Float32, *), - (:(normal), (sigma=1.0,), :((x,sz,sigma) -> exp(- x^2/(2 .* sigma^2)) / (sqrt(typeof(x)(2pi))*sigma)), Float32, *), + (:(normal), (sigma=1.0,), :((x,sz, sigma) -> exp(- x^2/(2 .* sigma^2)) / (sqrt(typeof(x)(2pi))*sigma)), Float32, *), (:(sinc), NamedTuple(), :((x,sz) -> sinc(x)), Float32, *), # the value "nothing" means that this default argument will not be handed over. But this works only for the last argument! (:(exp_ikx), (shift_by=nothing,), :((x,sz, shift_by=sz÷2) -> cis(x*(-typeof(x)(2pi)*shift_by/sz))), ComplexF32, *), @@ -73,6 +73,20 @@ for F in generate_functions_expr() calculate_broadcasted(Array{$(F[4])}, fct, sz, args...; defaults=$(F[2]), operation=$(F[5]), kwargs...) end + @eval function $(Symbol(F[1], :_nokw_sep))(::Type{TA}, sz::NTuple{N, Int}, args...; + all_axes = (similar_arr_type(TA, eltype(TA), Val(1)))(undef, sum(sz[[(1:N)...]])) + ) where {TA, N} + fct = $(F[3]) # to assign the function to a symbol + calculate_broadcasted_nokw(TA, fct, sz, args...; defaults=$(F[2]), operation=$(F[5]), all_axes=all_axes) + end + + @eval function $(Symbol(F[1], :_nokw_sep))(sz::NTuple{N, Int}, args...; + all_axes = (similar_arr_type(Array{$(F[4])}, eltype(Array{$(F[4])}), Val(1)))(undef, sum(sz[[(1:N)...]])) + ) where {N} + fct = $(F[3]) # to assign the function to a symbol + calculate_broadcasted_nokw(Array{$(F[4])}, fct, sz, args...; defaults=$(F[2]), operation=$(F[5]), all_axes=all_axes) + end + @eval function $(Symbol(F[1], :_lz))(::Type{TA}, sz::NTuple{N, Int}, args...; kwargs...) where {TA, N} fct = $(F[3]) # to assign the function to a symbol separable_view(TA, fct, sz, args...; defaults=$(F[2]), operation=$(F[5]), kwargs...) @@ -88,6 +102,7 @@ for F in generate_functions_expr() # separated: a vector of separated contributions is returned and the user has to combine them @eval export $(Symbol(F[1], :_sep)) # lazy: A LazyArray representation is returned + @eval export $(Symbol(F[1], :_nokw_sep)) # @eval export $(Symbol(F[1], :_lz)) end diff --git a/test/runtests.jl b/test/runtests.jl index ec16fed..0c3afff 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,7 @@ using IndexFunArrays using SeparableFunctions function test_fct(T, fcts, sz, args...; kwargs...) - ifa, col, lz, sep, op = fcts + ifa, fct = fcts a = let if typeof(ifa) <: AbstractArray ifa @@ -12,26 +12,23 @@ function test_fct(T, fcts, sz, args...; kwargs...) end end - b = col(Array{T}, sz, args...; kwargs...) - c = lz(Array{T}, sz, args...; kwargs...) - res = sep(Array{T}, sz, args...; kwargs...) - @test a≈b - @test eltype(b)==T - @test a≈c - @test eltype(c)==T - @test a≈collect(res) - @test eltype(collect(res))==T + res = fct(Array{T}, sz, args...; kwargs...) + # @test (typeof(res) <: AbstractArray) == false + res = collect(res) + @test (typeof(res) <: AbstractArray) == true + @test a≈res + @test eltype(res)==T all_axes = zeros(T, prod(sz)) - b2 = col(Array{T}, sz, args...; all_axes = all_axes, kwargs...) - @test b≈b2 - c2 = lz(Array{T}, sz, args...; all_axes = all_axes, kwargs...) - @test c≈c2 - res2 = sep(Array{T}, sz, args...; all_axes = all_axes, kwargs...) - @test collect(res)≈collect(res2) + res2 = fct(Array{T}, sz, args...; all_axes = all_axes, kwargs...) + # @test (typeof(res2) <: AbstractArray) == false + res2 = collect(res2) + @test (typeof(res2) <: AbstractArray) == true + @test res≈res2 @test sum(abs.(all_axes)) > 0 end + function test_fct_t(fcts, sz, args...; kwargs...) test_fct(Float32, fcts, sz, args...;kwargs...) test_fct(Float64, fcts, sz, args...;kwargs...) @@ -40,7 +37,8 @@ end @testset "calculate_separables" begin sz = (13,15) fct = (r, sz, sigma)-> exp(-r^2/(2*sigma^2)) - @time gauss_sep = calculate_separables(fct, sz, (0.5,1.0), pos = (0.1,0.2)) + offset = (2.2, -2.2) ; scale = (1.1, 1.2); factor = 1.0; + @time gauss_sep = calculate_separables(fct, sz, (0.5,1.0), pos = (0.1,0.2), offset=offset, scale=scale, factor=factor) @test size(.*(gauss_sep...)) == sz # test with preallocated array all_axes = zeros(Float32, prod(sz)) @@ -51,7 +49,15 @@ end @testset "gaussian" begin sz = (11,22) - test_fct_t((gaussian, gaussian_col, SeparableFunctions.gaussian_lz, gaussian_sep, *), sz; sigma=(11.2, 5.5)); + sigma = (11.2, 5.5) + mygaussian = gaussian(sz, sigma=sigma) + test_fct_t((mygaussian, gaussian_col), sz; sigma=sigma); + test_fct_t((mygaussian, SeparableFunctions.gaussian_lz), sz; sigma=sigma); + test_fct_t((mygaussian, gaussian_sep), sz; sigma=sigma); + offset = sz.÷2 .+1 ; scale = (1.0, 1.0); factor = 1.0; + test_fct_t((mygaussian, gaussian_nokw_sep), sz, offset, scale, factor, sigma); + + # test_fct_t((gaussian, gaussian_col, SeparableFunctions.gaussian_lz, gaussian_sep, *), sz; sigma=(11.2, 5.5)); # # test with preallocated array # all_axes = zeros(Float32, prod(sz)) # test_fct_t((gaussian, gaussian_col, SeparableFunctions.gaussian_lz, gaussian_sep, *), sz; all_axes = all_axes, sigma=(11.2, 5.5)); @@ -59,25 +65,52 @@ end @testset "rr2" begin sz = (11,22, 3) - offset = (2,3,1) - test_fct_t((rr2, rr2_col, SeparableFunctions.rr2_lz, rr2_sep, +), sz; scale=(2.2, 3.3, 1.0), offset=offset); + offset = (2,3,1) # try some offset not in the center + scale = (2.2, 3.3, 1.0) # and a non-unity scale + myrr2 = rr2(sz; offset=offset, scale=scale) + test_fct_t((myrr2, rr2_col), sz; scale=scale, offset=offset); + test_fct_t((myrr2, SeparableFunctions.rr2_lz), sz; scale=scale, offset=offset); + test_fct_t((myrr2, rr2_sep), sz; scale=scale, offset=offset); + factor = 2.0; + test_fct_t((2 .*myrr2, rr2_nokw_sep), sz, offset, scale, factor); + + offset = sz .÷ 2 .+1 # try some offset not in the center + scale = (1.0, 1.0, 1.0) # and a non-unity scale + myrr2 = rr2(sz; offset=offset, scale=scale) + test_fct_t((myrr2, rr2_nokw_sep), sz); # should be the same as the default end @testset "box" begin sz = (11,22, 3) offset = (2,3,1) - test_fct_t((box, box_col, SeparableFunctions.box_lz, box_sep, *), sz; scale=(2.2, 3.3, 1.0), offset=offset); + scale = (2.2, 3.3, 1.0) + mybox = box(sz; offset=offset, scale=scale) + test_fct_t((mybox, box_col), sz; scale=scale, offset=offset); + test_fct_t((box, SeparableFunctions.box_lz), sz; scale=scale, offset=offset); + test_fct_t((box, box_sep), sz; scale=scale, offset=offset); + test_fct_t((mybox, box_nokw_sep), sz, offset, scale); end @testset "ramp" begin sz = (11,22) - test_fct_t((xx(sz) .+ yy(sz), ramp_col, SeparableFunctions.ramp_lz, ramp_sep, +), sz; slope=(1.0,1.0)); + slope = (1.0, 2.2) + myxy = slope[1].*xx(sz) .+ slope[2].*yy(sz) + test_fct_t((myxy, ramp_col,), sz; slope=slope); + test_fct_t((myxy, SeparableFunctions.ramp_lz), sz; slope=slope); + test_fct_t((myxy, ramp_sep), sz; slope=slope); + test_fct_t((myxy, ramp_nokw_sep), sz, nothing, nothing, nothing, slope); end @testset "exp_ikx" begin sz = (11, 22, 4) + shift_by = (1.1, 0.2, 2.2) + myexp_ikx = exp_ikx(sz; shift_by = shift_by) # scale leads to problems! Since exp_ikx(sz) ≈ exp_ikx(sz, scale=(1.0,1.0,1.0)) -> false - test_fct(ComplexF32, (exp_ikx, exp_ikx_col, SeparableFunctions.exp_ikx_lz, exp_ikx_sep, *), sz; shift_by=(1.1,0.2,2.2)); + test_fct(ComplexF32, (myexp_ikx, exp_ikx_col), sz; shift_by=shift_by); + test_fct(ComplexF32, (myexp_ikx, SeparableFunctions.exp_ikx_lz), sz; shift_by=shift_by); + test_fct(ComplexF32, (myexp_ikx, exp_ikx_sep), sz; shift_by=shift_by); + test_fct(ComplexF32, (myexp_ikx, exp_ikx_nokw_sep), sz, nothing, nothing, nothing, shift_by); + myshift = (0.1,0.2,0.3) a = ones(ComplexF64,sz) SeparableFunctions.mul_exp_ikx!(a; shift_by=myshift) @@ -88,7 +121,10 @@ end sz = (12, 23) scale = (1.1, 2.2) mysinc = sinc.(xx(sz; scale=scale)) .* sinc.(yy(sz; scale=scale)); - test_fct(Float32, (mysinc, sinc_col, SeparableFunctions.sinc_lz, sinc_sep, *), sz; scale=scale); + test_fct(Float32, (mysinc, sinc_col), sz; scale=scale); + test_fct(Float32, (mysinc, SeparableFunctions.sinc_lz), sz; scale=scale); + test_fct(Float32, (mysinc, sinc_sep), sz; scale=scale); + test_fct(Float32, (mysinc, sinc_nokw_sep), sz, nothing, scale); end function test_copy_corners(sz) @@ -123,4 +159,9 @@ end @test maximum(abs.(res4 .- res5)) < 1e-6 end +@testset "gradients" begin + sz = (10,10) + gradient((x) -> gaussian_nokw_sep(sz, x, 1.0, 1.0, 1.0), (2.2,3.3)) +end + return