Skip to content

Commit

Permalink
added speedtest and improved speed a little
Browse files Browse the repository at this point in the history
  • Loading branch information
RainerHeintzmann committed Jul 8, 2024
1 parent 48fe382 commit 5da3951
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 48 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ authors = ["RainerHeintzmann <heintzmann@gmail.com>"]
version = "0.1.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
NDTools = "98581153-e998-4eef-8d0d-5ec2c052313d"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Thunks = "490da00b-a60c-4ded-a4cf-df7cded56bfa"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ImageTransformations = "0.9, 0.10"
Expand All @@ -25,6 +26,7 @@ julia = "1"
[extras]
IndexFunArrays = "613c443e-d742-454e-bfc6-1d7f8dd76566"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "IndexFunArrays"]
6 changes: 3 additions & 3 deletions src/SeparableFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ module SeparableFunctions
using NDTools, LazyArrays
using ImageTransformations, StaticArrays
using Interpolations
# using ChainRulesCore # for adjoint definition
using ZygoteRules
using Zygote # to use rrule_via_ad
using ChainRulesCore # for adjoint definition
# using ZygoteRules
# using Zygote # to use rrule_via_ad

export calculate_separables, separable_view, separable_create
export calculate_broadcasted
Expand Down
157 changes: 122 additions & 35 deletions src/general.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int},
factor = one(real(eltype(AT))),
args...; dims = 1:N,
all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
defaults=NamedTuple(), pos=zero(real(eltype(AT))),
pos=zero(real(eltype(AT))),
kwargs...) where {AT, N}

RT = real(eltype(AT))
Expand All @@ -52,7 +52,6 @@ function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int},
factor = isnothing(factor) ? one(real(eltype(RT))) : RT.(factor)
start = 1 .- offset

idc = pick_n(dims[1], scale) .* ((start[dims[1]]:start[dims[1]]+sz[dims[1]]-1) .- pick_n(dims[1], pos))
# @show typeof(idc)
dims = [dims...]
valid_sz = sz[dims]
Expand All @@ -61,7 +60,8 @@ function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int},
# @show kwarg_n(dims[1], kwargs)
# @show arg_n(dims[1], args)
# @show idc
extra_args = kwargs_to_args(defaults, kwarg_n(dims[1], kwargs))

# @show extra_args
if isa(factor, Number)
factor = ntuple((d) -> factor, lastindex(dims))
end
Expand All @@ -71,59 +71,110 @@ function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int},
# @show kwargs
# @show collect(arg_n(dims[1], args))
# @show (idc, sz[dims[1]], extra_args..., arg_n(dims[1], args)...)
res[1][:] .= (factor[1]) .* fct.(idc, sz[dims[1]], extra_args..., arg_n(dims[1], args)...)
# res[1][:] .= (factor[1]) .* fct.(idc, sz[dims[1]], arg_n(dims[1], args)...)
# idc = pick_n(dims[1], scale) .* ((start[dims[1]]:start[dims[1]]+sz[dims[1]]-1) .- pick_n(dims[1], pos))
# res = in_place_assing!(res, 1, factor[1], fct, idc, sz[dims[1]], arg_n(dims[1], args))
#push!(res, collect(reorient(fct.(idc, sz[1], arg_n(1, args)...; kwarg_n(1, kwargs)...), 1, Val(N))))
for d = 2:lastindex(dims)
idc = pick_n(dims[d], scale) .* ((start[dims[d]]:start[dims[d]]+sz[dims[d]]-1) .- pick_n(dims[d], pos))
# for d = eachindex(dims)
ntuple((d) ->
# idc = pick_n(dims[d], scale) .* ((start[dims[d]]:start[dims[d]]+sz[dims[d]]-1) .- pick_n(dims[d], pos))
# myaxis = collect(fct.(idc,arg_n(d, args)...)) # no need to reorient
extra_args = kwargs_to_args(defaults, kwarg_n(dims[d], kwargs))
tmp = let
if isa(factor[d], Number)
factor[d]
else
@view factor[d][:]
end
end
res[d][:] = tmp .* fct.(idc, sz[dims[d]], extra_args..., arg_n(dims[d], args)...)
# res[d] .= reorient(fct.(idc, sz[d], arg_n(d, args)...; kwarg_n(d, kwargs)...), d, Val(N))
# extra_args = kwargs_to_args(defaults, kwarg_n(dims[d], kwargs))
# tmp = let
# if isa(factor[d], Number)
# factor[d]
# else
# @view factor[d][:]
# end
# end
# res[d][:] .= tmp .* fct.(idc, sz[dims[d]], arg_n(dims[d], args)...)
in_place_assing!(res, d, factor[d], fct, pick_n(dims[d], scale) .* ((start[dims[d]]:start[dims[d]]+sz[dims[d]]-1) .- pick_n(dims[d], pos)), sz[dims[d]], arg_n(dims[d], args))

# LazyArray representation of expression
# push!(res, myaxis)
end
, lastindex(dims)) # Vector{AT}()
# end
return res
end


# a special in-place assignment, which gets its own differentiation rule for the reverse mode
# to avoid problems with memory-assignment and AD.
function in_place_assing!(res, d, tmp, fct, idc, sz1d, args_d)
res[d][:] .= tmp .* fct.(idc, sz1d, args_d...)
return res[d]
end

function out_of_place_assing(res, d, tmp, fct, idc, sz1d, args_d)
return tmp .* fct.(idc, sz1d, args_d...)
end

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(in_place_assing!), res, d, tmp, fct, idc, sz1d, args_d)
println("in rrule in_place_assing!")
y = in_place_assing!(res, d, tmp, fct, idc, sz1d, args_d)
# @show collect(y)
_, mypullback = rrule_via_ad(config, out_of_place_assing, res, d, tmp, fct, idc, sz1d, args_d)

# function in_place_assing_pullback(dy) # dy is a tuple of arrays.
# println("in in_place_assing_pullback")

# d_idc = mypullback(dy)
# @show size(dy[1])
# @show size(derivatives) # idc
# each_deriv = ntuple((i) -> sum(dy[i] .* derivatives[1]), length(dy))
# @show each_deriv
# # @show dy[1]
# return NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), each_deriv, NoTangent(), NoTangent()
# # return NoTangent(), each_deriv, NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()
# end
return y, mypullback # in_place_assing_pullback
end


function calculate_separables(::Type{AT}, fct, sz::NTuple{N, Int},
args...; dims = 1:N,
all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
defaults=NamedTuple(), pos=zero(real(eltype(AT))),
defaults=NamedTuple(), pos=zero(real(eltype(AT))),
offset = sz2 .+1,
scale = one(real(eltype(AT))),
factor = one(real(eltype(AT))),
kwargs...) where {AT, N}
return calculate_separables_nokw(AT, fct, sz, offset, scale, factor, args...; dims=dims, all_axes=all_axes, defaults=defaults, pos=pos, kwargs...)

extra_args = kwargs_to_args(defaults, kwargs)
return calculate_separables_nokw(AT, fct, sz, offset, scale, factor, extra_args..., args...; dims=dims, all_axes=all_axes, defaults=defaults, pos=pos, kwargs...)
end

function calculate_separables(fct, sz::NTuple{N, Int}, args...; dims=1:N,
all_axes = (similar_arr_type(DefaultArrType, eltype(DefaultArrType), Val(1)))(undef, sum(sz[[dims...]])),
pos=zero(real(eltype(DefaultArrType))),
defaults=NamedTuple(), pos=zero(real(eltype(DefaultArrType))),
offset = sz2 .+1,
scale = one(real(eltype(DefaultArrType))),
factor = one(real(eltype(DefaultArrType))),
kwargs...) where {N}
calculate_separables(DefaultArrType, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)
extra_args = kwargs_to_args(defaults, kwargs)
calculate_separables(DefaultArrType, fct, sz, extra_args..., args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, factor=factor, kwargs...)
end

# define custom adjoint for calculate_separables
# function ChainRulesCore.rrule(::typeof(calculate_separables), conv, rec, otf)
# end

#
# function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables), ::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims = 1:N,

# calculate_separables_nokw(AT, fct, sz, offset, scale, factor, args...; dims=dims, all_axes=all_axes, defaults=defaults, pos=pos, kwargs...)
# function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables_nokw), ::Type{AT}, fct, sz::NTuple{N, Int},
# function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(calculate_separables_nokw), ::Type{AT}, fct, sz::NTuple{N, Int},
# offset = sz.÷2 .+1, scale=one(real(eltype(AT))), factor = one(real(eltype(AT)),), args...; dims = 1:N,
# all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
# defaults = NamedTuple(), pos=zero(real(eltype(AT))),
# offset = sz.÷2 .+1, scale=one(real(eltype(AT))), kwargs...) where {AT, N}
# defaults = NamedTuple(), pos=zero(real(eltype(AT))), kwargs...) where {AT, N}

# println("inside rrule! $(sz), $(dims), $(offset)")
# println("inside calculate_separables_nokw rrule! $(sz), $(dims), $(offset)")
# # foward pass
# y = calculate_separables(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, kwargs...)
# y = collect(calculate_broadcasted_nokw(AT, fct, sz, offset, scale, factor, args...;
# dims=dims, all_axes=all_axes, defaults = defaults, pos=pos, kwargs...))

# @show size(y)
# @show sum(abs2.(y))
# # extra_args = kwargs_to_args(defaults, kwarg_n(dims[1], kwargs))
# # res[1][:] .= fct.(idc, sz[dims[1]], extra_args..., arg_n(dims[1], args)...)

Expand All @@ -133,22 +184,41 @@ end
# # fct(x, sz, args...; kwargs...), idx)[2](1)[1]
# # d_pos_fct = d_off_fct

# d_off_fct = (idx, sz, args...; kwargs...) -> rrule_via_ad(config, (x) ->
# println("calculating fct gradients")
# # d_fct_dx = (idx, sz, args...; kwargs...) -> rrule_via_ad(config, (x) ->
# # fct(x, sz, args...), idx; kwargs...)[1]
# d_offset_fct = (idx, sz, args...; kwargs...) -> rrule_via_ad(config, (x) ->
# fct(x, sz, args...), idx; kwargs...)[1]
# d_scale_fct = (idx, sz, args...; kwargs...) -> idx * rrule_via_ad(config, (x) ->
# fct(x, sz, args...), idx;kwargs...)[1]
# d_pos_fct = d_off_fct
# d_pos_fct = d_offset_fct

# all_grad_axes = copy(all_axes) # generate a different memory buffer for the gradients

# function calculate_separables_pullback(dy) # dy is the gradient of the output
# # multiply by dy?
# doffset = calculate_separables(AT, d_off_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = dy, kwargs...)
# dscale = 1 #calculate_separables(AT, d_scale_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...)
# dpos = 1 # calculate_separables(AT, d_pos_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...)
# println("calculating separables pullback")
# @show dy
# @show length(dy)
# @show sum(abs.(dy))
# @show args
# @show d_offset_fct.(13.3:0.2:15.3, (20,), args...)
# @show offset
# @show scale
# @show factor
# doffset = (factor, factor) .* sum(reshape([dy...], sz) .* collect(calculate_broadcasted_nokw(AT, d_offset_fct, sz, offset, scale, factor, args...; dims=dims, all_axes=all_grad_axes, pos=pos, defaults = defaults, kwargs...)))
# @show doffset
# dscale = dy # .* collect(calculate_broadcasted_nokw(AT, d_scale_fct, sz, offset, scale, factor, args...; dims=dims, all_axes=all_axes, pos=pos, defaults = defaults, kwargs...))[:]
# dfactor = dy
# # dpos = 1 # calculate_separables(AT, d_pos_fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, offset=offset, scale=scale, defaults = defaults, factor = 1, kwargs...)
# dargs = args;
# # It should return the gradient of the inputs
# # println(doffset)
# return (NoTangent(), NoTangent(), NoTangent(), doffset, dargs..., NoTangent(), NoTangent(), NoTangent(), dpos, doffset, dscale, NoTangent())

# # calculate_separables_nokw(AT, fct, sz, offset, scale, factor, args...; dims=dims, all_axes=all_axes, defaults=defaults, pos=pos, kwargs...)
# return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), doffset, dscale, dfactor, dargs..., NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
# end
# @show "returning from pullback"
# return y, calculate_separables_pullback
# end

Expand Down Expand Up @@ -264,12 +334,26 @@ function calculate_broadcasted(fct, sz::NTuple{N, Int}, args...; dims=1:N,
Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables(DefaultArrType, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...))
end


# function calculate_sep_nokw(::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims=1:N,
# all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
# pos=zero(real(eltype(DefaultArrType))), operation = *, defaults = nothing, kwargs...) where {AT, N}
# # defaults should be evaluated here and filled into args...
# return calculate_separables_nokw(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)
# end

# function calculate_sep_nokw(fct, sz::NTuple{N, Int}, args...; dims=1:N,
# all_axes = (similar_arr_type(DefaultArrType, eltype(DefaultArrType), Val(1)))(undef, sum(sz[[dims...]])),
# pos=zero(real(eltype(DefaultArrType))), operation = *, defaults = nothing, kwargs...) where {N}
# return calculate_separables_nokw(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)
# end

### Versions where offst and scale are without keyword arguments
function calculate_broadcasted_nokw(::Type{AT}, fct, sz::NTuple{N, Int}, args...; dims=1:N,
all_axes = (similar_arr_type(AT, eltype(AT), Val(1)))(undef, sum(sz[[dims...]])),
pos=zero(real(eltype(DefaultArrType))), operation = *, defaults = nothing, kwargs...) where {AT, N}
# defaults should be evaluated here and filled into args...
Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables_nokw(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...))
Broadcast.instantiate(Broadcast.broadcasted(operation, calculate_separables_nokw(AT, fct, sz, args...; dims=dims, all_axes=all_axes, pos=pos, kwargs...)...))
end

function calculate_broadcasted_nokw(fct, sz::NTuple{N, Int}, args...; dims=1:N,
Expand Down Expand Up @@ -363,8 +447,11 @@ julia> my_gaussian = separable_create(fct, (6,5), (0.5,1.0); pos=(0.1,0.2))
```
"""
function separable_create(::Type{TA}, fct, sz::NTuple{N, Int}, args...; operation::Function = *, kwargs...)::similar_arr_type(TA, T, Val(N)) where {T, N, TA <: AbstractArray{T}}
res = calculate_separables(TA, fct, sz, args...; kwargs...)
operation.(res...)
# res = calculate_separables(TA, fct, sz, args...; kwargs...)
# operation.(res...)
res = similar(TA, sz)
res .= calculate_broadcasted(TA, fct, sz, args...; operation=operation, kwargs...)
return res
end

## the code below seems not type-stable but the code above is. Why?
Expand Down
15 changes: 11 additions & 4 deletions src/specific.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,21 @@ for F in generate_functions_expr()
all_axes = (similar_arr_type(TA, eltype(TA), Val(1)))(undef, sum(sz[[(1:N)...]]))
) where {TA, N}
fct = $(F[3]) # to assign the function to a symbol
calculate_broadcasted_nokw(TA, fct, sz, args...; defaults=$(F[2]), operation=$(F[5]), all_axes=all_axes)

# return calculate_broadcasted_nokw(TA, fct, sz, args...; defaults=$(F[2]), operation=$(F[5]), all_axes=all_axes)
pos=zero(real(eltype(TA)))
operation=$(F[5])
return calculate_separables_nokw(TA, fct, sz, args...; pos=pos, all_axes=all_axes), operation
end

@eval function $(Symbol(F[1], :_nokw_sep))(sz::NTuple{N, Int}, args...;
all_axes = (similar_arr_type(Array{$(F[4])}, eltype(Array{$(F[4])}), Val(1)))(undef, sum(sz[[(1:N)...]]))
) where {N}
fct = $(F[3]) # to assign the function to a symbol
calculate_broadcasted_nokw(Array{$(F[4])}, fct, sz, args...; defaults=$(F[2]), operation=$(F[5]), all_axes=all_axes)
fct = $(F[3]) # to assign the function to a symbol
# return calculate_broadcasted_nokw(Array{$(F[4])}, fct, sz, args...; defaults=$(F[2]), operation=$(F[5]), all_axes=all_axes)
pos=zero(real(eltype(DefaultArrType)))
operation=$(F[5])
return calculate_separables_nokw(Array{$(F[4])}, fct, sz, args...; pos=pos, all_axes=all_axes), operation
end

@eval function $(Symbol(F[1], :_lz))(::Type{TA}, sz::NTuple{N, Int}, args...; kwargs...) where {TA, N}
Expand Down Expand Up @@ -152,7 +159,7 @@ function propagator_col!(arr::AbstractArray{T,N}; Δz=one(eltype(arr)), k_max=0.
# fac = eltype(arr)(4im * pi * Δz)
# f(r2) = cispi(sqrt(max(zero(real(eltype(TA))),k2_max - r2)) * (4 * Δz))
# f(r2) = exp(sqrt(max(zero(real(eltype(arr))),k2_max - r2)) * fac)
fac = eltype(arr)(4pi * Δz)
fac = real(eltype(arr))(4pi * Δz)
f(r2) = cis(sqrt(max(zero(real(eltype(arr))),k2_max - r2)) * fac)
if length(size(arr)) < 3 || sz[3] == 1
return calc_radial2_symm!(arr, f; scale=scale);
Expand Down
12 changes: 7 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@ function test_fct(T, fcts, sz, args...; kwargs...)
end

res = fct(Array{T}, sz, args...; kwargs...)
if typeof(res) <: Tuple
res = res[2].(res[1]...)
end
# @test (typeof(res) <: AbstractArray) == false
res = collect(res)
@test (typeof(res) <: AbstractArray) == true

@test ares
@test eltype(res)==T

all_axes = zeros(T, prod(sz))
res2 = fct(Array{T}, sz, args...; all_axes = all_axes, kwargs...)
if typeof(res2) <: Tuple
res2 = res2[2].(res2[1]...)
end
# @test (typeof(res2) <: AbstractArray) == false
res2 = collect(res2)
@test (typeof(res2) <: AbstractArray) == true
Expand Down Expand Up @@ -159,9 +166,4 @@ end
@test maximum(abs.(res4 .- res5)) < 1e-6
end

@testset "gradients" begin
sz = (10,10)
gradient((x) -> gaussian_nokw_sep(sz, x, 1.0, 1.0, 1.0), (2.2,3.3))
end

return
Loading

0 comments on commit 5da3951

Please sign in to comment.