Skip to content

Commit

Permalink
added support for _nokw_ towards AD capability.
Browse files Browse the repository at this point in the history
  • Loading branch information
RainerHeintzmann committed Jul 5, 2024
1 parent be1296b commit 48fe382
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 78 deletions.
6 changes: 1 addition & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ authors = ["RainerHeintzmann <heintzmann@gmail.com>"]
version = "0.1.1"

[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
NDTools = "98581153-e998-4eef-8d0d-5ec2c052313d"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Thunks = "490da00b-a60c-4ded-a4cf-df7cded56bfa"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ImageTransformations = "0.9, 0.10"
Expand All @@ -22,9 +21,6 @@ NDTools = "0.7"
StaticArrays = "0.1, 0.3, 0.4, 0.13, 0.14, 1"
Thunks = "0.3"
julia = "1"
ChainRules = "1.65, 1.66, 1.67, 1.68, 1.69"
ChainRulesCore = "1.20, 1.21, 1.22, 1.23, 1.24"
Zygote = "0.5, 0.6"

[extras]
IndexFunArrays = "613c443e-d742-454e-bfc6-1d7f8dd76566"
Expand Down
5 changes: 3 additions & 2 deletions src/SeparableFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ module SeparableFunctions
using NDTools, LazyArrays
using ImageTransformations, StaticArrays
using Interpolations
using ChainRulesCore # for adjoint definition
using Zygote
# using ChainRulesCore # for adjoint definition
using ZygoteRules
using Zygote # to use rrule_via_ad

export calculate_separables, separable_view, separable_create
export calculate_broadcasted
Expand Down
212 changes: 166 additions & 46 deletions src/general.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""
calculate_separables([::Type{AT},] fct, sz::NTuple{N, Int}, args...; dims = 1:N, all_axes = (similar_arr_type(AT, dims=Val(1)))(undef, sum(sz[[dims...]])), pos=zero(real(eltype(AT))), offset=sz.÷2 .+1, scale=one(real(eltype(AT))), kwargs...) where {AT, N}
calculate_separables_nokw([::Type{AT},] fct, sz::NTuple{N, Int}, offset = sz.÷2 .+1, scale = one(real(eltype(AT))),
factor = one(real(eltype(AT))), args...; dims = 1:N,
all_axes = (similar_arr_type(AT, dims=Val(1)))(undef, sum(sz[[dims...]])), pos=zero(real(eltype(AT))),
kwargs...) where {AT, N}
creates a list of one-dimensional vectors, which can be combined to yield a separable array. In a way this can be seen as a half-way Lazy operation.
The (potentially heavy) work of calculating the one-dimensional functions is done now but the memory-heavy calculation of the array is done later.
Expand All @@ -9,12 +12,13 @@ This function is used in `separable_view` and `separable_create`.
+ `AT`: optional type signfying the array result type. You can for example use `CuArray{Float32}` using `CUDA` to create the views on the GPU.
+ `fct`: the function to calculate for each axis index (no need for broadcasting!) of this iterable of seperable axes. Note that the first arguments of `fct` have to be the index of this coordinate and the size of this axis. Any further `args` and `nargs` can follow. Often the second argument is not used but it still needs to be present.
+ `sz`: the size of the result array (when appying the one-D axes)
+ `offset`: specifying the center (zero-position) of the result array in one-based coordinates. The default corresponds to the Fourier-center.
+ `scale`: multiplies the index before passing it to `fct`
+ `factor`: multiplies the result of `fct` before storing it in the result array.
+ `args`: further arguments which are passed over to the function `fct`.
+ `dims`: a vector `[]` of valid dimensions. Only these dimension will be calculated but they are oriented in ND.
+ `all_axes`: if provided, this memory is used instead of allocating a new one. This can be useful if you want to use the same memory for multiple calculations.
+ `pos`: a position shifting the indices passed to `fct` in relationship to the `offset`.
+ `offset`: specifying the center (zero-position) of the result array in one-based coordinates. The default corresponds to the Fourier-center.
+ `scale`: multiplies the index before passing it to `fct`
#Example:
```julia
Expand All @@ -33,13 +37,21 @@ julia> gauss_sep = calculate_separables(fct, (6,5), (0.5,1.0), pos = (0.1,0.2))
6.50731f-5 0.000356206 0.000717312 0.000531398 0.000144823
```
"""
function calculate_separables(::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims = 1:N,
function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int},
offset = sz2 .+1,
scale = one(real(eltype(AT))),
factor = one(real(eltype(AT))),
args...; dims = 1:N,
all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
defaults=NamedTuple(), pos=zero(real(eltype(AT))),
offset=sz2 .+1,
scale=one(real(eltype(AT))),
factor=one(real(eltype(AT))), kwargs...) where {AT, N}
kwargs...) where {AT, N}

RT = real(eltype(AT))
offset = isnothing(offset) ? sz2 .+1 : RT.(offset)
scale = isnothing(scale) ? one(real(eltype(RT))) : RT.(scale)
factor = isnothing(factor) ? one(real(eltype(RT))) : RT.(factor)
start = 1 .- offset

idc = pick_n(dims[1], scale) .* ((start[dims[1]]:start[dims[1]]+sz[dims[1]]-1) .- pick_n(dims[1], pos))
# @show typeof(idc)
dims = [dims...]
Expand All @@ -53,7 +65,12 @@ function calculate_separables(::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims
if isa(factor, Number)
factor = ntuple((d) -> factor, lastindex(dims))
end
# @show factor
# @show offset
# @show extra_args
# @show args
# @show kwargs
# @show collect(arg_n(dims[1], args))
# @show (idc, sz[dims[1]], extra_args..., arg_n(dims[1], args)...)
res[1][:] .= (factor[1]) .* fct.(idc, sz[dims[1]], extra_args..., arg_n(dims[1], args)...)
#push!(res, collect(reorient(fct.(idc, sz[1], arg_n(1, args)...; kwarg_n(1, kwargs)...), 1, Val(N))))
for d = 2:lastindex(dims)
Expand All @@ -75,52 +92,130 @@ function calculate_separables(::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims
return res
end

function calculate_separables(fct, sz::NTuple{N, Int}, args...; dims=1:N,
function calculate_separables(::Type{AT}, fct, sz::NTuple{N, Int},
args...; dims = 1:N,
all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
defaults=NamedTuple(), pos=zero(real(eltype(AT))),
offset = sz2 .+1,
scale = one(real(eltype(AT))),
factor = one(real(eltype(AT))),
kwargs...) where {AT, N}
return calculate_separables_nokw(AT, fct, sz, offset, scale, factor, args...; dims=dims, all_axes=all_axes, defaults=defaults, pos=pos, kwargs...)
end

function calculate_separables(fct, sz::NTuple{N, Int}, args...; dims=1:N,
all_axes = (similar_arr_type(DefaultArrType, eltype(DefaultArrType), Val(1)))(undef, sum(sz[[dims...]])),
pos=zero(real(eltype(DefaultArrType))), offset=sz2 .+1, scale=one(real(eltype(DefaultArrType))), kwargs...) where {N}
calculate_separables(DefaultArrType, fct, sz, args...; all_axes=all_axes, dims=dims, pos=pos, offset=offset, scale=scale, kwargs...)
pos=zero(real(eltype(DefaultArrType))),
kwargs...) where {N}
calculate_separables(DefaultArrType, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)
end

# define custom adjoint for calculate_separables
# function ChainRulesCore.rrule(::typeof(calculate_separables), conv, rec, otf)
# end

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables), ::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims = 1:N,
all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
defaults = NamedTuple(), pos=zero(real(eltype(AT))),
offset = sz2 .+1, scale=one(real(eltype(AT))), kwargs...) where {AT, N}
#
# function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables), ::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims = 1:N,
# all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
# defaults = NamedTuple(), pos=zero(real(eltype(AT))),
# offset = sz.÷2 .+1, scale=one(real(eltype(AT))), kwargs...) where {AT, N}

# println("inside rrule! $(sz), $(dims), $(offset)")
# # foward pass
# y = calculate_separables(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, kwargs...)

println("inside rrule! $(sz), $(dims), $(offset)")
# foward pass
y = calculate_separables(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, kwargs...)
# # extra_args = kwargs_to_args(defaults, kwarg_n(dims[1], kwargs))
# # res[1][:] .= fct.(idc, sz[dims[1]], extra_args..., arg_n(dims[1], args)...)

# extra_args = kwargs_to_args(defaults, kwarg_n(dims[1], kwargs))
# res[1][:] .= fct.(idc, sz[dims[1]], extra_args..., arg_n(dims[1], args)...)
# # d_off_fct = (idx, sz, args...; kwargs...) -> rrule((x) ->
# # fct(x, sz, args...; kwargs...), idx)[2](1)[1]
# # d_scale_fct = (idx, sz, args...; kwargs...) -> idx * rrule((x, sz, args...; kwargs...) ->
# # fct(x, sz, args...; kwargs...), idx)[2](1)[1]
# # d_pos_fct = d_off_fct

# d_off_fct = (idx, sz, args...; kwargs...) -> rrule((x) ->
# fct(x, sz, args...; kwargs...), idx)[2](1)[1]
# d_scale_fct = (idx, sz, args...; kwargs...) -> idx * rrule((x, sz, args...; kwargs...) ->
# fct(x, sz, args...; kwargs...), idx)[2](1)[1]
# d_off_fct = (idx, sz, args...; kwargs...) -> rrule_via_ad(config, (x) ->
# fct(x, sz, args...), idx; kwargs...)[1]
# d_scale_fct = (idx, sz, args...; kwargs...) -> idx * rrule_via_ad(config, (x) ->
# fct(x, sz, args...), idx;kwargs...)[1]
# d_pos_fct = d_off_fct

d_off_fct = (idx, sz, args...; kwargs...) -> rrule_via_ad(config, (x) ->
fct(x, sz, args...), idx; kwargs...)[1]
d_scale_fct = (idx, sz, args...; kwargs...) -> idx * rrule_via_ad(config, (x) ->
fct(x, sz, args...), idx;kwargs...)[1]
d_pos_fct = d_off_fct

function calculate_separables_pullback(dy) # dy is the gradient of the output
# multiply by dy?
doffset = calculate_separables(AT, d_off_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = dy, kwargs...)
dscale = 1 #calculate_separables(AT, d_scale_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...)
dpos = 1 # calculate_separables(AT, d_pos_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...)
dargs = args;
# It should return the gradient of the inputs
# println(doffset)
return (NoTangent(), NoTangent(), NoTangent(), doffset, dargs..., NoTangent(), NoTangent(), NoTangent(), dpos, doffset, dscale, NoTangent())
end
return y, calculate_separables_pullback
end
# function calculate_separables_pullback(dy) # dy is the gradient of the output
# # multiply by dy?
# doffset = calculate_separables(AT, d_off_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = dy, kwargs...)
# dscale = 1 #calculate_separables(AT, d_scale_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...)
# dpos = 1 # calculate_separables(AT, d_pos_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...)
# dargs = args;
# # It should return the gradient of the inputs
# # println(doffset)
# return (NoTangent(), NoTangent(), NoTangent(), doffset, dargs..., NoTangent(), NoTangent(), NoTangent(), dpos, doffset, dscale, NoTangent())
# end
# return y, calculate_separables_pullback
# end

# foo(a,b;c=42.0)=a*b*c
# function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables), ::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims = 1:N,
# function ZygoteRules._pullback(::ZygoteRules.AContext, A::typeof(Core.kwcall), kwargs, ::typeof(calculate_separables), ::Type{AT}, fct, sz::NTuple{N, Int}, args...) where {AT, N}
# # function ZygoteRules._pullback(::ZygoteRules.AContext, A::typeof(Core.kwcall), kwargs, ::typeof(foo), a, b)
# # retrieve the kwargs
# dims = get(kwargs, :dims, 1:N) # insert default as we need that in pullback.
# all_axes = get(kwargs, :all_axes, (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]]))) # insert default as we need that in pullback.
# defaults = get(kwargs, :defaults, NamedTuple())
# pos = get(kwargs, :pos, zero(real(eltype(AT))))
# offset = get(kwargs, :offset, sz.÷2 .+1)
# scale = get(kwargs, :scale, one(real(eltype(AT))))

# println("kwargs calculate_separables in zygote rule sz:$(sz), dims:$(dims), offset:$(offset) args:$(args)")
# # foward pass
# y = calculate_separables(AT, fct, sz, args...; kwargs...)
# println("calculated y $(typeof(y))")

# # Use Zygote's _pullback to obtain the gradient function for `fct` in the separable one-d case
# d_off_fct = (idx, sz, args...; kwargs...) -> Zygote.pullback((x) -> fct(x, sz, args...; kwargs...), idx)[1]
# d_scale_fct = (idx, sz, args...; kwargs...) -> idx * Zygote.pullback((x) -> fct(x, sz, args...; kwargs...), idx)[1]
# # config = Type{acontext} # Zygote.ZygoteRuleConfig

# # d_off_fct = (idx, sz, args...; kwargs...) -> rrule_via_ad(config, (x) ->
# # fct(x, sz, args...), idx; kwargs...)[1]
# # d_scale_fct = (idx, sz, args...; kwargs...) -> idx * rrule_via_ad(config, (x) ->
# # fct(x, sz, args...), idx;kwargs...)[1]
# d_pos_fct = d_off_fct

# # c = get(kwargs, :c, 42.0) # insert default as we need that in pullback.
# println("functions defined")
# function calculate_separables_pullback(dy) # dy is the gradient of the output
# println("kwargs calculate_separables pulling back")
# # multiply by dy?
# # @show d_off_fct(13.3, 20, args...)

# doffset = haskey(kwargs, :offset) ? calculate_separables(AT, d_off_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = dy, kwargs...) : Zygote.NoTangent()
# # doffset = collect(doffset)
# # @show size(.*(doffset...))
# # @show eltype(.*(doffset...))
# println("calculated doffset $(typeof(doffset))")
# dscale = 1 #calculate_separables(AT, d_scale_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...)
# dpos = 1 # calculate_separables(AT, d_pos_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...)
# dargs = args;
# dkwargs = (;offset = doffset, scale = dscale, pos = dpos);
# # It should return the gradient of the inputs
# # println(doffset)
# # return (NoTangent(), NoTangent(), NoTangent(), doffset, dargs..., NoTangent(), NoTangent(), NoTangent(), dpos, doffset, dscale, NoTangent())
# println("done calculating $(typeof(dkwargs[:offset]))")
# return nothing, dkwargs, nothing, nothing, nothing, nothing
# end
# println("pullback defined")
# return y, calculate_separables_pullback

# # function foo_pullback(dy)
# # println("kwargs foo pulling back")
# # da = dy*b*c
# # db = dy*a*c
# # dc = haskey(kwargs, :c) ? dy*a*b : NoTangent()
# # dkwargs = (;c=dc)
# # return nothing, dkwargs, nothing, da, db
# # end
# # return y, foo_pullback
# end


"""
calculate_broadcasted([::Type{TA},] fct, sz::NTuple{N, Int}, args...; dims=1:N, pos=zero(real(eltype(DefaultArrType))), offset=sz.÷2 .+1, scale=one(real(eltype(DefaultArrType))), operation = *, kwargs...) where {TA, N}
Expand Down Expand Up @@ -157,16 +252,41 @@ julia> collect(my_gaussian)
"""
function calculate_broadcasted(::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims=1:N,
all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
pos=zero(real(eltype(DefaultArrType))), offset=sz2 .+1, scale=one(real(eltype(DefaultArrType))), operation = *, kwargs...) where {AT, N}
Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, kwargs...)...))
pos=zero(real(eltype(DefaultArrType))),
operation = *, kwargs...) where {AT, N}
Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...))
end

function calculate_broadcasted(fct, sz::NTuple{N, Int}, args...; dims=1:N,
all_axes = (similar_arr_type(DefaultArrType, eltype(DefaultArrType), Val(1)))(undef, sum(sz[[dims...]])),
pos=zero(real(eltype(DefaultArrType))), offset=sz2 .+1, scale=one(real(eltype(DefaultArrType))), operation = *, kwargs...) where {N}
Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables(DefaultArrType, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, kwargs...)...))
pos=zero(real(eltype(DefaultArrType))),
operation = *, kwargs...) where {N}
Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables(DefaultArrType, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...))
end

### Versions where offst and scale are without keyword arguments
function calculate_broadcasted_nokw(::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims=1:N,
all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
pos=zero(real(eltype(DefaultArrType))), operation = *, defaults = nothing, kwargs...) where {AT, N}
# defaults should be evaluated here and filled into args...
Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables_nokw(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...))
end

function calculate_broadcasted_nokw(fct, sz::NTuple{N, Int}, args...; dims=1:N,
all_axes = (similar_arr_type(DefaultArrType, eltype(DefaultArrType), Val(1)))(undef, sum(sz[[dims...]])),
pos=zero(real(eltype(DefaultArrType))), operation = *, defaults = nothing, kwargs...) where {N}
# defaults should be evaluated here and filled into args...
Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables_nokw(DefaultArrType, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...))
end

# towards a Gaussian that can also be rotated:
# mulitply with exp(-(x-x0)*(y-y0)/(sigma_xy)) = exp(-(x-x0)/(sigma_xy)) ^ (y-y0)
# g = gaussian_sep((200, 200), sigma=2.2)
# vg = Base.broadcasted((x,y)->x, collect(g), ones(1,1,10));
# res = similar(g.args[1], size(vg));
# @time gg = accumulate!(*, res, vg, dims=3); #, init=collect(g)


"""
separable_view{N}(fct, sz, args...; pos=zero(real(eltype(AT))), offset = sz.÷2 .+1, scale = one(real(eltype(AT))), operation = .*)
Expand Down
Loading

0 comments on commit 48fe382

Please sign in to comment.