From 7d23a92be117569af7926d75e1d8aa198f873364 Mon Sep 17 00:00:00 2001 From: RainerHeintzmann Date: Thu, 1 Aug 2024 17:35:44 +0200 Subject: [PATCH] fixed gradients and tests --- src/general.jl | 125 +++++++++++++++++++++++++++++++++-------------- src/specific.jl | 49 +++++++++++++------ src/utilities.jl | 32 +++++++----- test/runtests.jl | 117 +++++++++++++++++++++++++++++++------------- 4 files changed, 226 insertions(+), 97 deletions(-) diff --git a/src/general.jl b/src/general.jl index 4a4ce3c..b3f5759 100644 --- a/src/general.jl +++ b/src/general.jl @@ -1,6 +1,6 @@ """ calculate_separables_nokw([::Type{AT},] fct, sz::NTuple{N, Int}, offset = sz.÷2 .+1, scale = one(real(eltype(AT))), - args...;all_axes = get_sep_mem(AT, sz), kwargs...) where {AT, N} + args...;all_axes = nothing, kwargs...) where {AT, N} creates a list of one-dimensional vectors, which can be combined to yield a separable array. In a way this can be seen as a half-way Lazy operator. The (potentially heavy) work of calculating the one-dimensional functions is done now but the memory-heavy calculation of the array is done later. @@ -36,13 +36,16 @@ function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int}, offset = nothing, scale = nothing, args...; - all_axes = get_sep_mem(AT, sz, get_arg_sz(sz, offset, scale, args...)), + all_axes = nothing, kwargs...) where {AT, N} RT = real(float(eltype(AT))) RAT = similar_arr_type(AT, RT, Val(1)) offset = isnothing(offset) ? (sz.÷2 .+1 ) : RT.(offset) scale = isnothing(scale) ? RAT([one(RT)]) : RT.(scale) + if isnothing(all_axes) + all_axes = get_sep_mem(AT, sz, get_arg_sz(sz, offset, scale, args...)) + end # offset = ntuple((d) -> pick_n(d, offset), Val(N)) # scale = ntuple((d) -> pick_n(d, scale), Val(N)) @@ -60,6 +63,8 @@ function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int}, idc = get_1d_ids(d, sz, offset, scale) args_d = arg_n(d, args, RT, sz) # # in_place_assing!(res, 1, fct, idc, sz1d, args_d) + # @show size(res) + # @show size(idc) res .= fct.(idc, sz1d, args_d...) # 5 allocs, 160 bytes end return all_axes @@ -91,7 +96,7 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(in_ _, in_place_assing_pullback = rrule_via_ad(config, out_of_place_assing, res, d, fct, idc, sz1d, args_d) function debug_dummy(dy) - println("in debug_dummy") # sz is (10, 20) + # println("in debug_dummy") # sz is (10, 20) # @show dy # NoTangent() # @show size(dy) # 1st calls: (1, 20) 2nd call: (10, 1) myres = in_place_assing_pullback(dy) @@ -123,7 +128,7 @@ end """ calculate_separables([::Type{AT},] fct, sz::NTuple{N, Int}, args...; - all_axes = get_sep_mem(AT, sz), + all_axes = nothing, defaults=NamedTuple(), offset = sz.÷2 .+1, scale = one(real(eltype(AT))), @@ -162,27 +167,29 @@ julia> gauss_sep = SeparableFunctions.calculate_separables_nokw(Array{Float32}, """ function calculate_separables(::Type{AT}, fct, sz::NTuple{N, Int}, args...; + all_axes = nothing, defaults=NamedTuple(), offset = sz.÷2 .+1, scale = one(real(eltype(AT))), kwargs...) where {AT, N} extra_args = kwargs_to_args(defaults, kwargs) - return calculate_separables_nokw(AT, fct, sz, offset, scale, extra_args..., args...; defaults=defaults, kwargs...) + return calculate_separables_nokw(AT, fct, sz, offset, scale, extra_args..., args...; all_axes=all_axes, defaults=defaults, kwargs...) end function calculate_separables(fct, sz::NTuple{N, Int}, args...; + all_axes = nothing, defaults=NamedTuple(), offset = sz.÷2 .+1, scale = one(real(eltype(DefaultArrType))), kwargs...) where {N} extra_args = kwargs_to_args(defaults, kwargs) - calculate_separables(DefaultArrType, fct, sz, extra_args..., args...; offset=offset, scale=scale, kwargs...) + calculate_separables(DefaultArrType, fct, sz, extra_args..., args...; all_axes = all_axes, offset=offset, scale=scale, kwargs...) end """ - calculate_broadcasted([::Type{TA},] fct, sz::NTuple{N, Int}, args...; offset=sz.÷2 .+1, scale=one(real(eltype(DefaultArrType))), operator = *, kwargs...) where {TA, N} + calculate_broadcasted([::Type{TA},] fct, sz::NTuple{N, Int}, args...; offset=sz.÷2 .+1, scale=one(real(eltype(DefaultArrType))), operator = get_operator(fct), kwargs...) where {TA, N} returns an instantiated broadcasted separable array, which essentially behaves almost like an array yet uses broadcasting. Test revealed maximal speed improvements for this version. Yet, a problem is that reduce operators with specified dimensions cause an error. However this can be avoided by calling `collect`. @@ -214,16 +221,16 @@ julia> collect(my_gaussian) ``` """ function calculate_broadcasted(::Type{AT}, fct, sz::NTuple{N, Int}, args...; - operator = *, all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)), + operator = get_operator(fct), all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)), kwargs...) where {AT, N} # replace the sep memory inside the broadcast structure with new values: calculate_separables(AT, fct, sz, args...; all_axes = all_axes.args, kwargs...) return all_axes - # Broadcast.instantiate(Broadcast.broadcasted(operator, calculate_separables(AT, fct, sz, args...; all_axes=all_axes, kwargs...)...)) + # return Broadcast.instantiate(Broadcast.broadcasted(operator, calculate_separables(AT, fct, sz, args...; all_axes=all_axes, kwargs...)...)) end function calculate_broadcasted(fct, sz::NTuple{N, Int}, args...; - operator = *, all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)), + operator = get_operator(fct), all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)), kwargs...) where {N} calculate_separables(DefaultArrType, fct, sz, args...; all_axes=all_axes.args, kwargs...) return all_axes @@ -245,13 +252,13 @@ end ### Versions where offst and scale are without keyword arguments function calculate_broadcasted_nokw(::Type{AT}, fct, sz::NTuple{N, Int}, args...; - operator = *, all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)), + operator = get_operator(fct), all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)), defaults = nothing, kwargs...) where {AT, N} # defaults should be evaluated here and filled into args... calculate_separables_nokw_hook(AT, fct, sz, args...; all_axes=all_axes.args, kwargs...) - # res = Broadcast.instantiate(Broadcast.broadcasted(operator, calculate_separables_nokw_hook(AT, fct, sz, args...; all_axes=all_axes, kwargs...)...)) # @show eltype(collect(res)) return all_axes + # return Broadcast.instantiate(Broadcast.broadcasted(operator, calculate_separables_nokw_hook(AT, fct, sz, args...; all_axes=all_axes, kwargs...)...)) end function calculate_separables_nokw_hook(::Type{AT}, fct, sz::NTuple{N, Int}, args...; kwargs...) where {AT, N} @@ -265,7 +272,7 @@ end # this code only works for the multiplicative version and actually only saves very few allocations. # it's not worth the specialization: function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_broadcasted_nokw), ::Type{AT}, fct, sz::NTuple{N, Int}, args...; - operator = *, all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)), + operator = get_operator(fct), all_axes = nothing, # get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)), defaults = nothing, kwargs...) where {AT, N} # @show typeof(all_axes) @@ -285,17 +292,38 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(cal # @show size(y_sep[2]) # @show size(y) - function calculate_broadcasted_nokw_pullback(dy) - # println("in calculate_broadcasted_nokw_pullback") # sz is (10, 20) + function calculate_broadcasted_nokw_pullback_mul(dy) + # println("in calculate_broadcasted_nokw_pullback_mul") # sz is (10, 20) + # @show dy + # @show y_sep projected = (N<=1) ? ntuple((d) -> dy, Val(N)) : ntuple((d) -> begin - dims=((1:d-1)...,(d+1:N)...) - reduce(+, operator.(y_sep[[dims...]]...) .* dy, dims=dims) + other_dims=((1:d-1)...,(d+1:N)...) + reduce(+, operator.(conj.(y_sep[[other_dims...]])...) .* dy, dims=other_dims) end, Val(N)) # 7 Mb # @show typeof(projected[2]) + # @show projected myres = calculate_sep_nokw_pullback(projected) # 8 kB return myres end - return y, calculate_broadcasted_nokw_pullback # in_place_assing_pullback # in_place_assing_pullback + function calculate_broadcasted_nokw_pullback_add(dy) + # println("in calculate_broadcasted_nokw_pullbac_add") # sz is (10, 20) + projected = (N<=1) ? ntuple((d) -> dy, Val(N)) : ntuple((d) -> begin + other_dims=((1:d-1)...,(d+1:N)...) + reduce(+, dy, dims=other_dims) + end, Val(N)) # 7 Mb + myres = calculate_sep_nokw_pullback(projected) # 8 kB + return myres + end + mypullback = let + if (operator == *) + calculate_broadcasted_nokw_pullback_mul + elseif (operator == +) + calculate_broadcasted_nokw_pullback_add # in_place_assing_pullback # in_place_assing_pullback + else + error("SeparableFunctions operator not supported") + end + end + return y, mypullback end # function calculate_separables_nokw_hook2() @@ -303,7 +331,7 @@ end # Needs revision function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables_nokw_hook), ::Type{AT}, fct, sz::NTuple{N, Int}, args...; - all_axes = get_sep_mem(AT, sz, get_arg_sz(sz, args...)), kwargs...) where {AT, N} + all_axes = nothing, kwargs...) where {AT, N} # get_sep_mem(AT, sz, get_arg_sz(sz, args...)), # @show "in rrule sep hook" # ids = ntuple((d) -> reorient(get_1d_ids(d, sz, args[1], args[2]), d, Val(N)), Val(N)) # offset==args[1] and scale==args[2] RT = real(float(eltype(AT))) @@ -348,15 +376,27 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(cal get_arg_gradient(fct, red_dims, y[d], ids[d], sz[d], dy[d], args_1d[d]...) end, Val(N))), length(args)-2) - # @show doffset - # @show dargs - return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), doffset, dscale, dargs...) end return y, calculate_separables_nokw_hook_pullback end +""" + get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussian, bg=zero(real(T))) where {T,N} + +This function returns an `fg!(F, G, vec)` function, which calculates the forward and gradient of a separable function. +It can directly be used for the `Optimize` package. +The function `fct` should be a separable function of the "_vec" type. +The first arguments of `fct` have to be the index of this coordinate and the size of this axis. +# Arguments ++ `data`: the data to fit to ++ `fct`: the separable function to fit to the data ++ `prod_dims`: the number of dimensions to be multiplied together. This is the number of dimensions which are not separated. ++ `loss`: the loss function to use. Default is the Gaussian loss. ++ `bg`: the background value (only for the loss function. Not the forward model of the data). Default is zero. + +""" function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussian, bg=zero(real(T))) where {T,N} RT = real(eltype(data)) AT = typeof(data) @@ -372,7 +412,7 @@ function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussia y = by.args # yv = ntuple((d) -> (@view y[d]), Val(prod_dims)) - dy = get_sep_mem(typeof(data), size(data)[1:prod_dims], hyper_sz) + dy = get_sep_mem(AT, size(data)[1:prod_dims], hyper_sz) # dyv = ntuple((d) -> (@view dy[d]), Val(prod_dims)) resid = similar(data) # checkpoint @@ -397,10 +437,14 @@ function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussia # if !isnothing(F) || !isnothing(G) # resid .= bg .+ intensity .* by .- data # end - loss = loss_fg!(F, resid, bg .+ get_vec_dim(intensity, 1, sz) .* by) + myint = get_vec_dim(intensity, 1, sz) + loss = loss_fg!(F, resid, bg .+ myint .* by) + # @show resid if !isnothing(G) # for arrays this should be .= - G.bg = C*sum(resid) + if hasproperty(vec, :bg) + G.bg = C*sum(resid) + end other_dims = ntuple((d)-> (ntuple((n)->n, d-1)..., ntuple((n)->d+n, prod_dims-d)...), Val(prod_dims)) # @show other_dims @@ -412,8 +456,15 @@ function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussia for d in 1:prod_dims # This costs 33 kB allocation, sometimes it can be faster? # dy[d] .= sum(resid.* (.*(other_ys[d]...)), dims=other_dims[d]) - # this costs 2 kB allocations but is slower - dy[d] .= broadcast_reduce(*, +, resid, other_ys[d]..., dims=other_dims[d], init=zero(T)) + # this costs 2 kB allocations but is slower reduce(+, dy, dims=dims) + if (operator == *) + dy[d] .= conj.(broadcast_reduce(*, +, conj.(resid), other_ys[d]..., dims=other_dims[d], init=zero(T))) + # dy[d] .= broadcast_reduce(*, +, resid, other_ys[d]..., dims=other_dims[d], init=zero(T)) + elseif (operator == +) + dy[d] .= reduce(+, resid, dims=other_dims[d], init=zero(T)) + else + error("SeparableFunctions operator in fg! not supported") + end end # less memory, but a little slower: # dy = ntuple((d) -> broadcast_reduce(*, +, resid, other_ys[d]..., dims=other_dims[d], init=zero(T)), Val(N)) @@ -450,7 +501,7 @@ function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussia # return loss, dfdbg, dfdI, doffset, dscale, dargs... # loss = (!isnothing(F)) ? mapreduce(abs2, +, resid; init=zero(T)) : T(0); - if !isnothing(G) # forward needs to be calculated + if (!isnothing(G) && hasproperty(vec, :intensity)) # forward needs to be calculated # resid .*= by # is slower! resid .= conj.(by) .* resid # optional_convert_assign!(G.intensity, G.intensity, (sum(resid, dims=1:length(sz)),)) @@ -471,13 +522,13 @@ function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussia end function calculate_broadcasted_nokw(fct, sz::NTuple{N, Int}, args...; - operator = *, all_axes = get_bc_mem(AT, sz, operator), + operator = get_operator(fct), all_axes = nothing, # get_bc_mem(AT, sz, operator), defaults = nothing, kwargs...) where {N} # defaults should be evaluated here and filled into args... - calculate_separables_nokw(DefaultArrType, fct, sz, args...; all_axes=all_axes.args, kwargs...) - # res = Broadcast.instantiate(Broadcast.broadcasted(operator, calculate_separables_nokw(DefaultArrType, fct, sz, args...; all_axes=all_axes, kwargs...)...)) + # calculate_separables_nokw(DefaultArrType, fct, sz, args...; all_axes=all_axes.args, kwargs...) + # return all_axes + return Broadcast.instantiate(Broadcast.broadcasted(operator, calculate_separables_nokw(DefaultArrType, fct, sz, args...; all_axes=all_axes, kwargs...)...)) # @show eltype(collect(res)) - return all_axes end # towards a Gaussian that can also be rotated: @@ -489,7 +540,7 @@ end """ - separable_view{N}(fct, sz, args...; offset = sz.÷2 .+1, scale = one(real(eltype(AT))), operator = .*) + separable_view{N}(fct, sz, args...; offset = sz.÷2 .+1, scale = one(real(eltype(AT))), operator = *) creates an `LazyArray` view of an N-dimensional separable function. Note that this view consumes much less memory than a full allocation of the collected result. @@ -520,18 +571,18 @@ julia> my_gaussian = separable_view(fct, (6,5), (0.1,0.2), (0.5,1.0)) ``` """ function separable_view(::Type{AT}, fct, sz::NTuple{N, Int}, args...; - operator = *, kwargs...) where {AT, N} - res = calculate_separables(AT, fct, sz, args...; operator=operator, kwargs...) + operator = get_operator(fct), kwargs...) where {AT, N} + res = calculate_separables(AT, fct, sz, args...; kwargs...) return LazyArray(@~ operator.(res...)) # to prevent premature evaluation end function separable_view(fct, sz::NTuple{N, Int}, args...; - operator = *, kwargs...) where {N} + operator = get_operator(fct), kwargs...) where {N} separable_view(DefaultArrType, fct, sz::NTuple{N, Int}, args...; operator=operator, kwargs...) end """ - separable_create([::Type{TA},] fct, sz::NTuple{N, Int}, args...; offset = sz.÷2 .+1, scale = one(real(eltype(AT))), operator = *, kwargs...) where {TA, N} + separable_create([::Type{TA},] fct, sz::NTuple{N, Int}, args...; offset = sz.÷2 .+1, scale = one(real(eltype(AT))), operator = get_operator(fct), kwargs...) where {TA, N} creates an array view of an N-dimensional separable function including memory allocation and collection. See the example below. diff --git a/src/specific.jl b/src/specific.jl index 821c953..7ab5046 100644 --- a/src/specific.jl +++ b/src/specific.jl @@ -19,6 +19,8 @@ # end using ChainRulesCore +export mem_fct, raw_fct + function generate_functions_expr() # offset and scale is already wrapped in the generator function # x_expr = :(scale .* (x .- offset)) @@ -83,7 +85,6 @@ for F in generate_functions_expr() # @show "creating rrule for $(Symbol(F[1], :_raw)) @eval function get_idx_gradient(::typeof($(Symbol(F[1], :_raw))), prod_dims, y, x, sz, dy) # println("in set_idx_gradient") - # return dot(dy, $(F[6]).(y, x, sz)) return mapreduce(*, +, conj.(dy), $(F[6]).(y, x, sz); dims=1:prod_dims) end @@ -102,18 +103,16 @@ for F in generate_functions_expr() # @show "creating rrule for $(Symbol(F[1], :_raw)) " @eval function get_idx_gradient(::typeof($(Symbol(F[1], :_raw))), prod_dims, y, x, sz, dy, args...) # println("in set_idx_gradient") - # return dot(dy, $(F[6]).(y, x, sz, args...)) # includes all dimensions! - return mapreduce(*, +, conj.(dy), $(F[6]).(y, x, sz, args...), dims=1:prod_dims) + return mapreduce(*, +, conj.(dy), $(F[6]).(y, x, sz, args...); dims=1:prod_dims) end @eval function get_arg_gradient(::typeof($(Symbol(F[1], :_raw))), prod_dims, y, x, sz, dy, args...) # println("in set_arg_gradient") - # return dot(dy, $(F[7]).(y, x, sz, args...)) # includes all dimensions! return mapreduce(*, +, conj.(dy), $(F[7]).(y, x, sz, args...), dims=1:prod_dims) end @eval function ChainRulesCore.rrule(::typeof($(Symbol(F[1], :_raw))), x, sz, args...; kwargs...) - # println("in rrule raw") + # println("in rrule2 raw") y = $(Symbol(F[1], :_raw))(x, sz, args...; kwargs...) # to assign the function to a symbol function mypullback(dy) # println("pb") @@ -201,21 +200,30 @@ for F in generate_functions_expr() if any(isa.(args, Tuple)) error("use vectors rather than tuples in component arrays, since Zygote has trouble with tuples.") end - all_axes = isnothing(all_axes) ? get_bc_mem(similar_arr_type(TA, $(F[4]), Val(N)), sz, $(F[5]), get_arg_sz(sz, off, sca, bg, intensity, args...)) : all_axes; + all_axes = isnothing(all_axes) ? get_bc_mem(TA, sz, $(F[5]), get_arg_sz(sz, off, sca, bg, intensity, args...)) : all_axes; + # use the return value instead of all_axes directly, since only this triggers the gradient calculation correctly return bg .+ intensity .* ($(Symbol(F[1], :_nokw_sep))(TA, sz, off, sca, args...; all_axes=all_axes)) end @eval function $(Symbol(F[1], :_vec))(sz::NTuple{N, Int}, vec; all_axes = nothing) where {N} - TA = Array{$(F[4])} + T = $(F[4]) + TA = Array{T} if hasproperty(vec, :off) && isa(vec.off, AbstractArray) - TA = similar_arr_type(typeof(vec.off), $(F[4]), Val(N)) - elseif hasproperty(vec, :intensity) && isa(vec.intensity, AbstractArray) - TA = similar_arr_type(typeof(vec.intensity), $(F[4]), Val(N)) - elseif hasproperty(vec, :sca) && isa(vec.sca, AbstractArray) - TA = similar_arr_type(typeof(vec.sca), $(F[4]), Val(N)) - elseif hasproperty(vec, :args) && isa(vec.args, AbstractArray) - TA = similar_arr_type(typeof(vec.args[1]), $(F[4]), Val(N)) + T = promote_type(T, eltype(vec.off)) + TA = similar_arr_type(typeof(vec.off), T, Val(N)) + end + if hasproperty(vec, :intensity) && isa(vec.intensity, AbstractArray) + T = promote_type(T, eltype(vec.intensity)) + TA = similar_arr_type(typeof(vec.intensity), T, Val(N)) + end + if hasproperty(vec, :sca) && isa(vec.sca, AbstractArray) + T = promote_type(T, eltype(vec.sca)) + TA = similar_arr_type(typeof(vec.sca), T, Val(N)) + end + if hasproperty(vec, :args) && isa(vec.args, AbstractArray) + T = promote_type(T, eltype(vec.args)) + TA = similar_arr_type(typeof(vec.args), T, Val(N)) end return $(Symbol(F[1], :_vec))(TA, sz, vec; all_axes=all_axes) end @@ -230,6 +238,18 @@ for F in generate_functions_expr() separable_view(Array{$(F[4])}, fct, sz, args...; defaults=$(F[2]), operator=$(F[5]), kwargs...) end + @eval mem_fct(::typeof($(Symbol(F[1], :_col))), ::Type{AT}, sz, hyper_sz=()) where AT = get_bc_mem(AT, sz, $(F[5]), hyper_sz) + @eval mem_fct(::typeof($(Symbol(F[1], :_sep))), ::Type{AT}, sz, hyper_sz=()) where AT = get_bc_mem(AT, sz, $(F[5]), hyper_sz) + @eval mem_fct(::typeof($(Symbol(F[1], :_nokw_sep))), ::Type{AT}, sz, hyper_sz=()) where AT = get_bc_mem(AT, sz, $(F[5]), hyper_sz) + @eval mem_fct(::typeof($(Symbol(F[1], :_vec))), ::Type{AT}, sz, hyper_sz=()) where AT = get_bc_mem(AT, sz, $(F[5]), hyper_sz) + @eval mem_fct(::typeof($(Symbol(F[1], :_lz))), ::Type{AT}, sz, hyper_sz=()) where AT = get_sep_mem(AT, sz, hyper_sz) + + @eval raw_fct(::typeof($(Symbol(F[1], :_raw)))) = $(Symbol(F[1], :_raw)) + @eval raw_fct(::typeof($(Symbol(F[1], :_col)))) = $(Symbol(F[1], :_raw)) + @eval raw_fct(::typeof($(Symbol(F[1], :_sep)))) = $(Symbol(F[1], :_raw)) + @eval raw_fct(::typeof($(Symbol(F[1], :_nokw_sep)))) = $(Symbol(F[1], :_raw)) + @eval raw_fct(::typeof($(Symbol(F[1], :_vec)))) = $(Symbol(F[1], :_raw)) + @eval raw_fct(::typeof($(Symbol(F[1], :_lz)))) = $(Symbol(F[1], :_raw)) # collected: fast separable calculation but resulting in an ND array @eval export $(Symbol(F[1], :_col)) # separated: a vector of separated contributions is returned and the user has to combine them @@ -241,7 +261,6 @@ for F in generate_functions_expr() # @eval export $(Symbol(F[1], :_lz)) end - ## Here some individual versions based on copy_corners! stuff. They only exist in the _cor version as they are not separable in X and Y. """ propagator_col([]::Type{TA},] sz::NTuple{N, Int}; Δz=one(eltype(TA)), k_max=0.5f0, scale=0.5f0 ./ (max.(sz ./ 2, 1))) where{TA, N} diff --git a/src/utilities.jl b/src/utilities.jl index 449627d..8fc017c 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -51,9 +51,12 @@ end # end function optional_convert(ref_arg::AbstractArray{T,N}, val) where {T,N} if (prod(size(ref_arg)) == 1) - return sum(sum.(val)) + # return a vector of a single number since the input is also a vector of a single number + result = similar(ref_arg, eltype(val[1]), 1) + result[1] = sum(sum.(val)) + return result end - res = similar(ref_arg) + res = similar(ref_arg, eltype(val[1])) d=1 for v in val dim = size(ref_arg, 1) > 1 ? d : 1 @@ -86,9 +89,10 @@ function optional_convert_assign!(dst, ref_arg::AbstractArray{T,N}, val) where { for v in val dv = selectdim(dst, 1, d) if (prod(size(dv)) == 1) - dv .= sum(v[:]) + # the real cast seem dodgy here + dv .= sum(real.(v[:])) else - dv[:] .= v[:] + dv[:] .= real.(v[:]) end d += 1 end @@ -108,15 +112,18 @@ end estimates the size of the arguments. This is useful for the memory allocation function, which requires the extra size of the arguments. - +returned is a tuple of extra dimensions. This can be used directly in the memory allocation function. """ function get_arg_sz(sz, args...) max(size.(get_vec_dim.(args, 1, Ref(sz)))...)[length(sz)+1:end] end +function get_arg_sz(sz) + return () +end """ - get_sep_mem(::Type{AT}, sz::NTuple{N, Int}, hyper_sz=(1,)) where {AT, N} + get_sep_mem(::Type{AT}, sz::NTuple{N, Int}, hyper_sz=()) where {AT, N} allocates a contingous memory for the separable functions. This is useful if you want to use the same memory for multiple calculations. It should be passed to the `calculate_separables_nokw` function via the all_axes argument. @@ -128,7 +135,7 @@ Parameters: These will automatically be applied differently to each such hyperplane. E.g.: offfset=[reshape([1.0,2.0,3.0],(1,1,3)),reshape([-1.0,1.0,-2.0],(1,1,3))] for a 2D sz=(512,512) and 3 hyperplanes. """ -function get_sep_mem(::Type{AT}, sz::NTuple{N, Int}, hyper_sz=(1,)) where {AT, N} +function get_sep_mem(::Type{AT}, sz::NTuple{N, Int}, hyper_sz=()) where {AT, N} hyperplanes = prod(hyper_sz) all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, hyperplanes*sum(sz)) D = length(sz)+ length(hyper_sz) @@ -139,17 +146,20 @@ function get_sep_mem(::Type{AT}, sz::NTuple{N, Int}, hyper_sz=(1,)) where {AT, N end """ - get_bc_mem(::Type{AT}, sz::NTuple{N, Int}, operator, hyper_sz=(1,)) where {AT, N} + get_bc_mem(::Type{AT}, sz::NTuple{N, Int}, operator, hyper_sz=()) where {AT, N} allocates a contigous memory block for the separable functions and wraps it into an instantiate broadcast (bc) structure including the bc-`operator`. This structure is also returned by functions like `gaussian_sep` and can be reused by supplying it via the keyword argument `all_axes`. To obtain the bc-operator for predefined functions use `get_operator(fct)` with `fct` being the `raw_` version of the function, e.g. `get_operator(gassian_raw)` """ -function get_bc_mem(::Type{AT}, sz::NTuple{N, Int}, operator, hyper_sz=(1,)) where {AT, N} +function get_bc_mem(::Type{AT}, sz::NTuple{N, Int}, operator, hyper_sz=()) where {AT, N} return Broadcast.instantiate(Broadcast.broadcasted(operator, get_sep_mem(AT, sz, hyper_sz)...)) end +function get_mem(::Type{AT}, sz::NTuple{N, Int}, fct, hyper_sz=()) where {AT, N} + return mem_fct(fct, AT, sz, hyper_sz) +end """ @@ -171,7 +181,7 @@ This is useful for calling separable functions with their scalar arguments diffe # Example ```jdoctest -julia> args = (1,(4,5)) +julia> args = (1, (4,5)) (1, (4, 5)) julia> collect(SeparableFunctions.arg_n(2, args)) 2-element Vector{Int64}: @@ -200,7 +210,7 @@ This is useful for calling separable functions with their scalar arguments diffe # Example ```jdoctest -julia> kw = (a=1,b=(4,5)) +julia> kw = (a=1, b=(4,5)) (a = 1, b = (4, 5)) julia> SeparableFunctions.kwarg_n(2, kw) (a = 1, b = 5) diff --git a/test/runtests.jl b/test/runtests.jl index 556840c..c2e2431 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using Random using ComponentArrays function test_fct(T, fcts, sz, args...; kwargs...) + AT = Array{T} ifa, fct = fcts a = let if typeof(ifa) <: AbstractArray @@ -16,7 +17,7 @@ function test_fct(T, fcts, sz, args...; kwargs...) end end - res = fct.(Array{T}, sz, args...; kwargs...) + res = fct(AT, sz, args...; kwargs...) if typeof(res) <: Tuple res = res[2].(res[1]...) end @@ -27,8 +28,8 @@ function test_fct(T, fcts, sz, args...; kwargs...) @test a≈res @test eltype(res)==T - all_axes = zeros(T, prod(sz)) - res2 = fct.(Array{T}, sz, args...; all_axes = all_axes, kwargs...) + all_axes = mem_fct(fct, AT, sz)# get_bc_mem(AT, sz, *) # zeros(T, prod(sz)) + res2 = fct(AT, sz, args...; all_axes = all_axes, kwargs...) if typeof(res2) <: Tuple res2 = res2[2].(res2[1]...) end @@ -48,14 +49,14 @@ end @testset "calculate_separables" begin sz = (13,15) fct = (r, sz, sigma)-> exp(-r^2/(2*sigma^2)) - offset = (2.2, -2.2) ; scale = (1.1, 1.2); + offset = (2.2, -2.2); scale = (1.1, 1.2); @time gauss_sep = calculate_separables(fct, sz, (0.5,1.0), offset = offset .+ (0.1,0.2), scale=scale) @test size(.*(gauss_sep...)) == sz # test with preallocated array - all_axes = zeros(Float32, prod(sz)) - @time gauss_sep = calculate_separables(fct, sz, (0.5,1.0), all_axes = all_axes) - # @test all_axes[7] ≈ 1.0 - # @test all_axes[13+8] ≈ 1.0 + all_axes = get_sep_mem(Array{Float32}, sz) # zeros(Float32, prod(sz)) + @time gauss_sep = calculate_separables(fct, sz, (0.5, 1.0), all_axes = all_axes) + @test all_axes[1][7] ≈ 1.0 + @test all_axes[2][8] ≈ 1.0 end @testset "gaussian" begin @@ -169,24 +170,42 @@ end @test maximum(abs.(res4 .- res5)) < 1e-6 end +function check_all(T, args...; kwargs...) + for a in args[2:end] + for r in 1:length(a) + @test eltype(a[r]) == T + @test all(isapprox.(a[r], args[1][r]; kwargs...)) + end + end +end + function test_gradient(T, fct, sz, args...; kwargs...) RT = real(T) Random.seed!(1234) dat = rand(T, sz...) off0 = rand(RT, length(sz)) sca0 = rand(RT, length(sz)) - args = ntuple((d)->RT.(args[d]), length(args)) + argsc = ntuple((d)->RT.(args[d]), length(args)) loss = (off, sca, args...) -> sum(abs2.(fct(sz, off, sca, args..., kwargs...) .- dat)) - # @show loss(off0, sca0, args...) - g = gradient(loss, off0, sca0, args...) - gn = grad(central_fdm(5, 1), loss, off0, sca0, args...) # 5th order method, 1st derivative - - for r in 1:length(gn) - @test eltype(g[r]) == RT - # @show g[r] - # @show gn[r] - @test all(isapprox.(g[r], gn[r], rtol=1e-1)) + # @show loss(off0, sca0, argsc...) + gn = grad(central_fdm(5, 1), loss, off0, sca0, argsc...) # 5th order method, 1st derivative + g = gradient(loss, off0, sca0, argsc...) + + fg! = get_fg!(dat, raw_fct(fct), length(sz); loss = loss_gaussian) + if length(argsc)>0 + vec = ComponentVector(;off=off0, sca=sca0, args=[argsc[1]...]) + else + vec = ComponentVector(;off=off0, sca=sca0) + end + G = copy(vec) + f = fg!(1, G, vec) + if length(argsc)>0 + G = (G.off, G.sca, G.args) + else + G = (G.off, G.sca) end + + check_all(RT, gn, g, G; rtol=1e-1) end @testset "gradient tests" begin @@ -200,6 +219,34 @@ end gn = grad(central_fdm(5, 1), loss, rng, sigma0) # 5th order method, 1st derivative @test g[1] ≈ gn[1] @test g[2] ≈ gn[2] + # here some detailed tests for the complex-valued functions: + + sz = (2,2) + sz = (22, 11) + rfun(off, sca, arg) = real(sum(exp_ikx_nokw_sep(sz, off, sca, arg))) + ifun(off, sca, arg) = imag(sum(exp_ikx_nokw_sep(sz, off, sca, arg))) + lfun(off, sca, arg) = sum(abs2.(exp_ikx_nokw_sep(sz, off, sca, arg) .- 1.0 .- 1.0im)) + gn = grad(central_fdm(5, 1), ifun, (0.0,0.0), (1.0, 1.0), (0.0,0.0)) # 5th order method, 1st derivative + g = gradient(ifun, (0.0,0.0), (1.0, 1.0), (0.0,0.0)) + check_all(Float64, gn, g; atol=0.001) + + gn = grad(central_fdm(5, 1), ifun, (0.0,0.0), (2.0, 3.0), (2.0, 1.0)) # 5th order method, 1st derivative + g = gradient(ifun, (0.0,0.0), (2.0, 3.0), (2.0, 1.0)) + check_all(Float64, gn, g; atol=0.01) + + # problem only when the shift vector is diagonal and offset is non-zero + gn = grad(central_fdm(5, 1), ifun, (0.3, 1.2), (0.6, 1.0), (1.0, 1.0)) # 5th order method, 1st derivative + g = gradient(ifun, (0.3, 1.2), (0.6, 1.0), (1.0, 1.0)) + check_all(Float64, gn, g; atol=0.05) + + gn = grad(central_fdm(5, 1), rfun, (1.3, 1.2), (0.6, 1.5), (1.0, 2.0)) # 5th order method, 1st derivative + g = gradient(rfun, (1.3, 1.2), (0.6, 1.5), (1.0, 2.0)) + check_all(Float64, gn, g; atol=0.02) + + off0 = (1.3f0, 1.2f0); sca0 = (0.6f0, 1.5f0); args0 = (1.0f0, 2.0f0) + gn = grad(central_fdm(5, 1), lfun, off0 , sca0, args0) # 5th order method, 1st derivative + g = gradient(lfun, off0 , sca0, args0) + check_all(Float32, gn, g; atol=0.5) test_gradient(Float32, gaussian_nokw_sep, (11,22), (2.2, -0.8)) test_gradient(Float64, gaussian_nokw_sep, (6, 22, 7), 2.0) @@ -237,36 +284,38 @@ end # now the vec version with intensity, scale, offset and bg (just a single replica): for use_hyper_dims in (true, false) - N_hyper = 4 - hyperint = (use_hyper_dims) ? 5.0 .+ rand(1,1,N_hyper) : 1 - hyperoff = (use_hyper_dims) ? 1 .+ 0.2 .* rand(1,1,N_hyper) : 1 - hyperarg = (use_hyper_dims) ? 1 .+ 0.2 .* rand(1,1,N_hyper) : 1 + N_hyper = 2 + hyperint = (use_hyper_dims) ? 1.0 .+ zeros(1,N_hyper) : 1.0 + hyperoff = (use_hyper_dims) ? 1 .+ 0.2 .* zeros(1,N_hyper) : 1 + hyperarg = (use_hyper_dims) ? 1 .+ 0.2 .* zeros(1,N_hyper) : 1 sz = (11, 22) - vec_true = ComponentVector(;bg=0.2, intensity=1.0 .*hyperint, off = [2.2, 3.3].*hyperoff, sca = [1.3, 1.2], args = [2.4, 1.5].*hyperarg) + # sz = (3, 3) + vec_true = ComponentVector(;bg=0.0, intensity=1.0 .*hyperint, off = [2.2, 3.3].*hyperoff, sca = [1.3, 1.2], args = [2.4, 1.5].*hyperarg) dat = gaussian_vec(sz, vec_true) loss2 = (vec) -> sum(abs2.(gaussian_vec(sz, vec) .- dat)) @test loss2(vec_true) == 0 g = gradient(loss2, vec_true) gn = grad(central_fdm(5, 1), loss2, vec_true) # 5th order method, 1st derivatives myfg! = get_fg!(dat, gaussian_raw, length(sz); loss = loss_gaussian) - G = similar(gn[1]) + G = similar(gn[1]) .* 0 f = myfg!(1, G, vec_true) # maximum(abs.(G)) for (mygn, myg, myfg) in zip(gn[1], g[1], G) - @test all(isapprox.(mygn, 0, atol=5e-7)) - @test all(isapprox.(myg, 0, atol=5e-7)) - @test all(isapprox.(myfg, 0, atol=5e-7)) + @test all(isapprox.(mygn, 0, atol = 5e-12)) + @test all(isapprox.(myg, 0, atol = 5e-12)) + @test all(isapprox.(myfg, 0, atol = 5e-12)) end # vec_start = ComponentVector(;bg=0.3, intensity=1.1, off = [2.3, 3.4], sca = [1.4, 1.3], args = [2.5, 1.6]) vec_start = vec_true .+ 0.2 - g = gradient(loss2, vec_start) - gn = grad(central_fdm(5, 1), loss2, vec_start) # 5th order method, 1st derivatives - f = myfg!(1, G, vec_start) - # maximum(abs.(G[:] .- gn[1][:])) - for (mygn, myg, myfg) in zip(gn[1], g[1], G) - @test all(isapprox.(mygn, myg, atol=4e-2)) - @test all(isapprox.(mygn, myfg, atol=4e-2)) + gs = gradient(loss2, vec_start) + gns = grad(central_fdm(5, 1), loss2, vec_start) # 5th order method, 1st derivatives + Gs = similar(gns[1]) .* 0 + fs = myfg!(1, Gs, vec_start) + # maximum(abs.(Gs[:] .- gns[1][:])) + for (mygn, myg, myfg) in zip(gns[1], gs[1], Gs) + @test all(isapprox.(mygn, myg, atol=4e-7)) + @test all(isapprox.(mygn, myfg, atol=4e-12)) end end end