Skip to content

Commit

Permalink
fixed gradients and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RainerHeintzmann committed Aug 1, 2024
1 parent 5917f82 commit 7d23a92
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 97 deletions.
125 changes: 88 additions & 37 deletions src/general.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
calculate_separables_nokw([::Type{AT},] fct, sz::NTuple{N, Int}, offset = sz.÷2 .+1, scale = one(real(eltype(AT))),
args...;all_axes = get_sep_mem(AT, sz), kwargs...) where {AT, N}
args...;all_axes = nothing, kwargs...) where {AT, N}
creates a list of one-dimensional vectors, which can be combined to yield a separable array. In a way this can be seen as a half-way Lazy operator.
The (potentially heavy) work of calculating the one-dimensional functions is done now but the memory-heavy calculation of the array is done later.
Expand Down Expand Up @@ -36,13 +36,16 @@ function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int},
offset = nothing,
scale = nothing,
args...;
all_axes = get_sep_mem(AT, sz, get_arg_sz(sz, offset, scale, args...)),
all_axes = nothing,
kwargs...) where {AT, N}

RT = real(float(eltype(AT)))
RAT = similar_arr_type(AT, RT, Val(1))
offset = isnothing(offset) ? (sz2 .+1 ) : RT.(offset)
scale = isnothing(scale) ? RAT([one(RT)]) : RT.(scale)
if isnothing(all_axes)
all_axes = get_sep_mem(AT, sz, get_arg_sz(sz, offset, scale, args...))
end

# offset = ntuple((d) -> pick_n(d, offset), Val(N))
# scale = ntuple((d) -> pick_n(d, scale), Val(N))
Expand All @@ -60,6 +63,8 @@ function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int},
idc = get_1d_ids(d, sz, offset, scale)
args_d = arg_n(d, args, RT, sz) #
# in_place_assing!(res, 1, fct, idc, sz1d, args_d)
# @show size(res)
# @show size(idc)
res .= fct.(idc, sz1d, args_d...) # 5 allocs, 160 bytes
end
return all_axes
Expand Down Expand Up @@ -91,7 +96,7 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(in_
_, in_place_assing_pullback = rrule_via_ad(config, out_of_place_assing, res, d, fct, idc, sz1d, args_d)

function debug_dummy(dy)
println("in debug_dummy") # sz is (10, 20)
# println("in debug_dummy") # sz is (10, 20)
# @show dy # NoTangent()
# @show size(dy) # 1st calls: (1, 20) 2nd call: (10, 1)
myres = in_place_assing_pullback(dy)
Expand Down Expand Up @@ -123,7 +128,7 @@ end

"""
calculate_separables([::Type{AT},] fct, sz::NTuple{N, Int}, args...;
all_axes = get_sep_mem(AT, sz),
all_axes = nothing,
defaults=NamedTuple(),
offset = sz.÷2 .+1,
scale = one(real(eltype(AT))),
Expand Down Expand Up @@ -162,27 +167,29 @@ julia> gauss_sep = SeparableFunctions.calculate_separables_nokw(Array{Float32},
"""
function calculate_separables(::Type{AT}, fct, sz::NTuple{N, Int},
args...;
all_axes = nothing,
defaults=NamedTuple(),
offset = sz2 .+1,
scale = one(real(eltype(AT))),
kwargs...) where {AT, N}

extra_args = kwargs_to_args(defaults, kwargs)
return calculate_separables_nokw(AT, fct, sz, offset, scale, extra_args..., args...; defaults=defaults, kwargs...)
return calculate_separables_nokw(AT, fct, sz, offset, scale, extra_args..., args...; all_axes=all_axes, defaults=defaults, kwargs...)
end

function calculate_separables(fct, sz::NTuple{N, Int}, args...;
all_axes = nothing,
defaults=NamedTuple(),
offset = sz2 .+1,
scale = one(real(eltype(DefaultArrType))),
kwargs...) where {N}
extra_args = kwargs_to_args(defaults, kwargs)
calculate_separables(DefaultArrType, fct, sz, extra_args..., args...; offset=offset, scale=scale, kwargs...)
calculate_separables(DefaultArrType, fct, sz, extra_args..., args...; all_axes = all_axes, offset=offset, scale=scale, kwargs...)
end


"""
calculate_broadcasted([::Type{TA},] fct, sz::NTuple{N, Int}, args...; offset=sz.÷2 .+1, scale=one(real(eltype(DefaultArrType))), operator = *, kwargs...) where {TA, N}
calculate_broadcasted([::Type{TA},] fct, sz::NTuple{N, Int}, args...; offset=sz.÷2 .+1, scale=one(real(eltype(DefaultArrType))), operator = get_operator(fct), kwargs...) where {TA, N}
returns an instantiated broadcasted separable array, which essentially behaves almost like an array yet uses broadcasting. Test revealed maximal speed improvements for this version.
Yet, a problem is that reduce operators with specified dimensions cause an error. However this can be avoided by calling `collect`.
Expand Down Expand Up @@ -214,16 +221,16 @@ julia> collect(my_gaussian)
```
"""
function calculate_broadcasted(::Type{AT}, fct, sz::NTuple{N, Int}, args...;
operator = *, all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)),
operator = get_operator(fct), all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)),
kwargs...) where {AT, N}
# replace the sep memory inside the broadcast structure with new values:
calculate_separables(AT, fct, sz, args...; all_axes = all_axes.args, kwargs...)
return all_axes
# Broadcast.instantiate(Broadcast.broadcasted(operator, calculate_separables(AT, fct, sz, args...; all_axes=all_axes, kwargs...)...))
# return Broadcast.instantiate(Broadcast.broadcasted(operator, calculate_separables(AT, fct, sz, args...; all_axes=all_axes, kwargs...)...))
end

function calculate_broadcasted(fct, sz::NTuple{N, Int}, args...;
operator = *, all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)),
operator = get_operator(fct), all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)),
kwargs...) where {N}
calculate_separables(DefaultArrType, fct, sz, args...; all_axes=all_axes.args, kwargs...)
return all_axes
Expand All @@ -245,13 +252,13 @@ end

### Versions where offst and scale are without keyword arguments
function calculate_broadcasted_nokw(::Type{AT}, fct, sz::NTuple{N, Int}, args...;
operator = *, all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)),
operator = get_operator(fct), all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)),
defaults = nothing, kwargs...) where {AT, N}
# defaults should be evaluated here and filled into args...
calculate_separables_nokw_hook(AT, fct, sz, args...; all_axes=all_axes.args, kwargs...)
# res = Broadcast.instantiate(Broadcast.broadcasted(operator, calculate_separables_nokw_hook(AT, fct, sz, args...; all_axes=all_axes, kwargs...)...))
# @show eltype(collect(res))
return all_axes
# return Broadcast.instantiate(Broadcast.broadcasted(operator, calculate_separables_nokw_hook(AT, fct, sz, args...; all_axes=all_axes, kwargs...)...))
end

function calculate_separables_nokw_hook(::Type{AT}, fct, sz::NTuple{N, Int}, args...; kwargs...) where {AT, N}
Expand All @@ -265,7 +272,7 @@ end
# this code only works for the multiplicative version and actually only saves very few allocations.
# it's not worth the specialization:
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_broadcasted_nokw), ::Type{AT}, fct, sz::NTuple{N, Int}, args...;
operator = *, all_axes = get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)),
operator = get_operator(fct), all_axes = nothing, # get_bc_mem(AT, sz, operator, get_arg_sz(sz, args...)),
defaults = nothing, kwargs...) where {AT, N}

# @show typeof(all_axes)
Expand All @@ -285,25 +292,46 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(cal
# @show size(y_sep[2])
# @show size(y)

function calculate_broadcasted_nokw_pullback(dy)
# println("in calculate_broadcasted_nokw_pullback") # sz is (10, 20)
function calculate_broadcasted_nokw_pullback_mul(dy)
# println("in calculate_broadcasted_nokw_pullback_mul") # sz is (10, 20)
# @show dy
# @show y_sep
projected = (N<=1) ? ntuple((d) -> dy, Val(N)) : ntuple((d) -> begin
dims=((1:d-1)...,(d+1:N)...)
reduce(+, operator.(y_sep[[dims...]]...) .* dy, dims=dims)
other_dims=((1:d-1)...,(d+1:N)...)
reduce(+, operator.(conj.(y_sep[[other_dims...]])...) .* dy, dims=other_dims)
end, Val(N)) # 7 Mb
# @show typeof(projected[2])
# @show projected
myres = calculate_sep_nokw_pullback(projected) # 8 kB
return myres
end
return y, calculate_broadcasted_nokw_pullback # in_place_assing_pullback # in_place_assing_pullback
function calculate_broadcasted_nokw_pullback_add(dy)
# println("in calculate_broadcasted_nokw_pullbac_add") # sz is (10, 20)
projected = (N<=1) ? ntuple((d) -> dy, Val(N)) : ntuple((d) -> begin
other_dims=((1:d-1)...,(d+1:N)...)
reduce(+, dy, dims=other_dims)
end, Val(N)) # 7 Mb
myres = calculate_sep_nokw_pullback(projected) # 8 kB
return myres
end
mypullback = let
if (operator == *)
calculate_broadcasted_nokw_pullback_mul
elseif (operator == +)
calculate_broadcasted_nokw_pullback_add # in_place_assing_pullback # in_place_assing_pullback
else
error("SeparableFunctions operator not supported")
end
end
return y, mypullback
end

# function calculate_separables_nokw_hook2()
# end

# Needs revision
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables_nokw_hook), ::Type{AT}, fct, sz::NTuple{N, Int}, args...;
all_axes = get_sep_mem(AT, sz, get_arg_sz(sz, args...)), kwargs...) where {AT, N}
all_axes = nothing, kwargs...) where {AT, N} # get_sep_mem(AT, sz, get_arg_sz(sz, args...)),
# @show "in rrule sep hook"
# ids = ntuple((d) -> reorient(get_1d_ids(d, sz, args[1], args[2]), d, Val(N)), Val(N)) # offset==args[1] and scale==args[2]
RT = real(float(eltype(AT)))
Expand Down Expand Up @@ -348,15 +376,27 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(cal
get_arg_gradient(fct, red_dims, y[d], ids[d], sz[d], dy[d], args_1d[d]...)
end, Val(N))), length(args)-2)

# @show doffset
# @show dargs

return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), doffset, dscale, dargs...)
end
return y, calculate_separables_nokw_hook_pullback
end

"""
get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussian, bg=zero(real(T))) where {T,N}
This function returns an `fg!(F, G, vec)` function, which calculates the forward and gradient of a separable function.
It can directly be used for the `Optimize` package.
The function `fct` should be a separable function of the "_vec" type.
The first arguments of `fct` have to be the index of this coordinate and the size of this axis.
# Arguments
+ `data`: the data to fit to
+ `fct`: the separable function to fit to the data
+ `prod_dims`: the number of dimensions to be multiplied together. This is the number of dimensions which are not separated.
+ `loss`: the loss function to use. Default is the Gaussian loss.
+ `bg`: the background value (only for the loss function. Not the forward model of the data). Default is zero.
"""
function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussian, bg=zero(real(T))) where {T,N}
RT = real(eltype(data))
AT = typeof(data)
Expand All @@ -372,7 +412,7 @@ function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussia
y = by.args
# yv = ntuple((d) -> (@view y[d]), Val(prod_dims))

dy = get_sep_mem(typeof(data), size(data)[1:prod_dims], hyper_sz)
dy = get_sep_mem(AT, size(data)[1:prod_dims], hyper_sz)
# dyv = ntuple((d) -> (@view dy[d]), Val(prod_dims))

resid = similar(data) # checkpoint
Expand All @@ -397,10 +437,14 @@ function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussia
# if !isnothing(F) || !isnothing(G)
# resid .= bg .+ intensity .* by .- data
# end
loss = loss_fg!(F, resid, bg .+ get_vec_dim(intensity, 1, sz) .* by)
myint = get_vec_dim(intensity, 1, sz)
loss = loss_fg!(F, resid, bg .+ myint .* by)
# @show resid
if !isnothing(G)
# for arrays this should be .=
G.bg = C*sum(resid)
if hasproperty(vec, :bg)
G.bg = C*sum(resid)
end

other_dims = ntuple((d)-> (ntuple((n)->n, d-1)..., ntuple((n)->d+n, prod_dims-d)...), Val(prod_dims))
# @show other_dims
Expand All @@ -412,8 +456,15 @@ function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussia
for d in 1:prod_dims
# This costs 33 kB allocation, sometimes it can be faster?
# dy[d] .= sum(resid.* (.*(other_ys[d]...)), dims=other_dims[d])
# this costs 2 kB allocations but is slower
dy[d] .= broadcast_reduce(*, +, resid, other_ys[d]..., dims=other_dims[d], init=zero(T))
# this costs 2 kB allocations but is slower reduce(+, dy, dims=dims)
if (operator == *)
dy[d] .= conj.(broadcast_reduce(*, +, conj.(resid), other_ys[d]..., dims=other_dims[d], init=zero(T)))
# dy[d] .= broadcast_reduce(*, +, resid, other_ys[d]..., dims=other_dims[d], init=zero(T))
elseif (operator == +)
dy[d] .= reduce(+, resid, dims=other_dims[d], init=zero(T))
else
error("SeparableFunctions operator in fg! not supported")
end
end
# less memory, but a little slower:
# dy = ntuple((d) -> broadcast_reduce(*, +, resid, other_ys[d]..., dims=other_dims[d], init=zero(T)), Val(N))
Expand Down Expand Up @@ -450,7 +501,7 @@ function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussia
# return loss, dfdbg, dfdI, doffset, dscale, dargs...
# loss = (!isnothing(F)) ? mapreduce(abs2, +, resid; init=zero(T)) : T(0);

if !isnothing(G) # forward needs to be calculated
if (!isnothing(G) && hasproperty(vec, :intensity)) # forward needs to be calculated
# resid .*= by # is slower!
resid .= conj.(by) .* resid
# optional_convert_assign!(G.intensity, G.intensity, (sum(resid, dims=1:length(sz)),))
Expand All @@ -471,13 +522,13 @@ function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussia
end

function calculate_broadcasted_nokw(fct, sz::NTuple{N, Int}, args...;
operator = *, all_axes = get_bc_mem(AT, sz, operator),
operator = get_operator(fct), all_axes = nothing, # get_bc_mem(AT, sz, operator),
defaults = nothing, kwargs...) where {N}
# defaults should be evaluated here and filled into args...
calculate_separables_nokw(DefaultArrType, fct, sz, args...; all_axes=all_axes.args, kwargs...)
# res = Broadcast.instantiate(Broadcast.broadcasted(operator, calculate_separables_nokw(DefaultArrType, fct, sz, args...; all_axes=all_axes, kwargs...)...))
# calculate_separables_nokw(DefaultArrType, fct, sz, args...; all_axes=all_axes.args, kwargs...)
# return all_axes
return Broadcast.instantiate(Broadcast.broadcasted(operator, calculate_separables_nokw(DefaultArrType, fct, sz, args...; all_axes=all_axes, kwargs...)...))
# @show eltype(collect(res))
return all_axes
end

# towards a Gaussian that can also be rotated:
Expand All @@ -489,7 +540,7 @@ end


"""
separable_view{N}(fct, sz, args...; offset = sz.÷2 .+1, scale = one(real(eltype(AT))), operator = .*)
separable_view{N}(fct, sz, args...; offset = sz.÷2 .+1, scale = one(real(eltype(AT))), operator = *)
creates an `LazyArray` view of an N-dimensional separable function.
Note that this view consumes much less memory than a full allocation of the collected result.
Expand Down Expand Up @@ -520,18 +571,18 @@ julia> my_gaussian = separable_view(fct, (6,5), (0.1,0.2), (0.5,1.0))
```
"""
function separable_view(::Type{AT}, fct, sz::NTuple{N, Int}, args...;
operator = *, kwargs...) where {AT, N}
res = calculate_separables(AT, fct, sz, args...; operator=operator, kwargs...)
operator = get_operator(fct), kwargs...) where {AT, N}
res = calculate_separables(AT, fct, sz, args...; kwargs...)
return LazyArray(@~ operator.(res...)) # to prevent premature evaluation
end

function separable_view(fct, sz::NTuple{N, Int}, args...;
operator = *, kwargs...) where {N}
operator = get_operator(fct), kwargs...) where {N}
separable_view(DefaultArrType, fct, sz::NTuple{N, Int}, args...; operator=operator, kwargs...)
end

"""
separable_create([::Type{TA},] fct, sz::NTuple{N, Int}, args...; offset = sz.÷2 .+1, scale = one(real(eltype(AT))), operator = *, kwargs...) where {TA, N}
separable_create([::Type{TA},] fct, sz::NTuple{N, Int}, args...; offset = sz.÷2 .+1, scale = one(real(eltype(AT))), operator = get_operator(fct), kwargs...) where {TA, N}
creates an array view of an N-dimensional separable function including memory allocation and collection.
See the example below.
Expand Down
Loading

0 comments on commit 7d23a92

Please sign in to comment.