Skip to content

Commit

Permalink
about to remove sca in the example.
Browse files Browse the repository at this point in the history
  • Loading branch information
RainerHeintzmann committed Jul 24, 2024
1 parent 4c7e924 commit 784142c
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 109 deletions.
56 changes: 43 additions & 13 deletions performance_tests/gauss_fit.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using SeparableFunctions
using Zygote
using IndexFunArrays
using BenchmarkTools
using ComponentArrays
Expand All @@ -24,15 +23,16 @@ w = separable_create(fct, sz, sigma)
@time my_gaussian = gaussian_col(sz; sigma = sigma);
# @btime my_gaussian = gaussian_col($sz; sigma = $sigma); # 6 µs

##############
sz = (64,64)
mid = sz2 .+1
myxx = Float32.(xx((sz[1],1)))
myyy = Float32.(yy((1, sz[2])))

my_full_gaussian(vec) = vec.bg .+ vec.intensity.*exp.(.-abs2.(vec.sca[1].*(myxx .- vec.off[1])./(sqrt(2f0)*vec.sigma[1])) .- abs2.(vec.sca[2].*(myyy .- vec.off[2])./(sqrt(2f0)*vec.sigma[2])))
my_full_gaussian(vec) = vec.bg .+ vec.intensity.*exp.(.-abs2.(vec.sca[1].*(myxx .- vec.off[1])./(sqrt(2f0)*vec.args[1])) .- abs2.(vec.sca[2].*(myyy .- vec.off[2])./(sqrt(2f0)*vec.args[2])))
# my_full_gaussian(vec) = vec.bg .+ vec.intensity.*exp.(.-abs2.(vec.sca[1].*(myxx .- vec.off[1])./(sqrt(2f0)*vec.sigma[1]))).*exp.(.- abs2.(vec.sca[2].*(myyy .- vec.off[2])./(sqrt(2f0)*vec.sigma[2])))

vec_true = ComponentVector(;bg=0.2f0, intensity=1.0f0, sca= (1f0, 1f0), off = (2.2f0, 3.3f0), sigma = (2.4f0, 1.5f0))
vec_true = ComponentVector(;bg=0.2f0, intensity=1.0f0, sca= (1f0, 1f0), off = (2.2f0, 3.3f0), args = (2.4f0, 1.5f0))
# @btime my_full_gaussian($vec_true); # 42.800 μs (18 allocations: 16.70 KiB)

Random.seed!(42)
Expand All @@ -41,8 +41,8 @@ dat_copy = copy(dat)

loss = (vec) -> sum(abs2.(my_full_gaussian(vec) .- dat))

bc_mem = gaussian_nokw_sep(sz, vec_true.off.+mid, vec_true.sca, vec_true.sigma)
my_sep_gaussian(vec) = vec.bg .+ vec.intensity .* gaussian_nokw_sep(sz, vec.off .+mid, vec.sca, vec.sigma; all_axes=bc_mem)
bc_mem = gaussian_nokw_sep(sz, vec_true.off.+mid, vec_true.sca, vec_true.args)
my_sep_gaussian(vec) = vec.bg .+ vec.intensity .* gaussian_nokw_sep(sz, vec.off .+mid, vec.sca, vec.args; all_axes=bc_mem)
loss_sep = (vec) -> sum(abs2.(my_sep_gaussian(vec) .- dat))

# off_start = (2.0f0, 3.0f0)
Expand All @@ -52,9 +52,9 @@ off_start = [2.0f0, 3.0f0]
# off_start = [-32f0, -32f0]
sigma_start = [3.0f0, 2.0f0]
sca_start = [1.2f0, 1.5f0]
startvals = ComponentVector(;bg = 0.5f0, intensity=1.0f0, sca=sca_start, off=off_start, sigma=sigma_start)
startvals = ComponentVector(;bg = 0.5f0, intensity=1.0f0, sca=sca_start, off=off_start, args=sigma_start)

@vt dat my_full_gaussian(startvals) my_sep_gaussian(startvals)
# @vt dat my_full_gaussian(startvals) my_sep_gaussian(startvals)

loss(startvals)
loss_sep(startvals)
Expand Down Expand Up @@ -104,7 +104,8 @@ dat .= dat_copy

# DifferentiationInterface.withgradient

opt = Optim.Options(x_reltol=0.001)
# opt = Optim.Options(x_reltol=0.001)
opt = Optim.Options(iterations = 9) # why ?

@time res = Optim.optimize(loss, startvals, Optim.LBFGS(), opt; autodiff = :forward); #
res.f_calls
Expand All @@ -115,10 +116,12 @@ res.f_calls
@btime res = Optim.optimize($loss, $startvals, Optim.LBFGS(), opt; autodiff = :forward);
# 1.06 ms, 11.13 ms
# Hand-Separated: 14 ms
# 10 ms

@btime res = Optim.optimize($loss, $startvals, Optim.LBFGS(), opt; autodiff = :finite);
# 4.065 ms (1989 allocations: 2.52 MiB), 40.440 ms (19519 allocations: 25.53 MiB)
# Hand-separated: 60.908 ms (20067 allocations: 24.28 MiB)
# 27 ms

function fg!(F, G, vec)
val_pb = Zygote.pullback(loss, vec);
Expand All @@ -132,19 +135,46 @@ function fg!(F, G, vec)
return val_pb[1]
end
end
od = OnceDifferentiable(Optim.NLSolversBase.only_fg!(fg!), startvals)

myfg! = get_fg!(dat, gaussian_raw)

startvals_s = copy(startvals)
startvals_s.off .+= sz2 .+1
G = copy(startvals)
myfg!(1, G, startvals_s)

od = OnceDifferentiable(Optim.NLSolversBase.only_fg!(fg!), startvals);
@time res = Optim.optimize(od, startvals, Optim.LBFGS(), opt)
res.f_calls # 61
res.f_calls # 33
res.minimum # 13.9137
# res.f_calls = 0

@btime res = Optim.optimize($od, $startvals, Optim.LBFGS(), opt);
loss(startvals)
fg!(1.0, G, startvals)
G
od = OnceDifferentiable(Optim.NLSolversBase.only_fg!(fg!), startvals);
@btime res = Optim.optimize($od, $startvals, Optim.LBFGS(), $opt);
# Full: 4.804 ms (14435 allocations: 12.78 MiB)
# Sep: 4.099 ms (26269 allocations: 11.20 MiB)
# Hand-Separated: 2.397 ms (15472 allocations: 9.19 MiB)
# 2.8 ms

myfg!(1.0, G, startvals_s)
G
# opt = Optim.Options(iterations = 9); # why ?
odo = OnceDifferentiable(Optim.NLSolversBase.only_fg!(myfg!), startvals_s);
@time reso = Optim.optimize(odo, startvals_s, Optim.LBFGS(), opt);
reso.f_calls # 33
reso.minimum # 13.9138

odo = OnceDifferentiable(Optim.NLSolversBase.only_fg!(myfg!), startvals_s);
@btime reso = Optim.optimize($odo, $startvals_s, Optim.LBFGS(), $opt);
# Zygote-free: 1.5 ms

using InverseModeling

gstartvals = ComponentVector(;offset = startvals.bg, i0=startvals.intensity, µ=startvals.off, σ=startvals.sigma)
@time res1, res2, res3 = gauss_fit(dat, gstartvals; x_reltol=0.001);
gstartvals = ComponentVector(;offset = startvals.bg, i0=startvals.intensity, µ=startvals.off, σ=startvals.args)
@time res1, res2, res3 = gauss_fit(dat, gstartvals; iterations = 9);
res3.f_calls
vec_true
res1
Expand Down
4 changes: 2 additions & 2 deletions src/SeparableFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ export calculate_separables_nokw, calculate_separables, separable_view, separabl
export calculate_broadcasted
export copy_corners!
export propagator_col, propagator_col!, phase_kz_col, phase_kz_col!
export get_sep_mem, get_bc_mem
export fg!
export get_sep_mem, get_bc_mem, get_operator
export get_fg!

export calc_radial_symm!, calc_radial_symm, get_corner_ranges
export radial_speedup, radial_speedup_ifa
Expand Down
Loading

0 comments on commit 784142c

Please sign in to comment.