Skip to content

Commit

Permalink
gaussfit with CUDA support
Browse files Browse the repository at this point in the history
  • Loading branch information
RainerHeintzmann committed Jul 24, 2024
1 parent 784142c commit d1bbe2d
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 36 deletions.
7 changes: 7 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
InverseModeling = "ce844058-9528-415d-a63d-06f3dd08b29f"
Noise = "81d43f40-5267-43b7-ae1c-8b967f377efa"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
SeparableFunctions = "c8c7ead4-852c-491e-a42d-3d43bc74259e"
55 changes: 55 additions & 0 deletions examples/gauss_fit.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using SeparableFunctions
using ComponentArrays
using Optim
using BenchmarkTools
using Noise
using CUDA

# simulate a gaussian blob with Poisson noise and fit it with a Gaussian function
sz = (1600, 1600)
vec_true = ComponentVector(;bg=10.0f0, intensity=50f0, off = [8.2f0, 6.5f0], args = [2.4f0, 1.5f0])

dat = Float32.(poisson(Float64.(gaussian_vec(sz, vec_true))))

# dat = CuArray(dat)
# now prepare the fitting:
myfg! = get_fg!(dat, gaussian_raw, loss=loss_anscombe_pos, bg=7f0);
startvals = ComponentVector(;bg=0.5f0, intensity=45f0, off = [9f0, 7f0], args = [3.0f0, 2.0f0])

opt = Optim.Options(iterations = 19); #
odo = OnceDifferentiable(Optim.NLSolversBase.only_fg!(myfg!), startvals);

# and perform the fit
@time reso = Optim.optimize(odo, startvals, Optim.LBFGS(), opt);
reso.f_calls # 61
reso.minimum
@vt dat gaussian_vec(sz, startvals) gaussian_vec(sz, reso.minimizer)

odo = OnceDifferentiable(Optim.NLSolversBase.only_fg!(myfg!), startvals);
if isa(dat, CuArray)
@btime CUDA.@sync reso = Optim.optimize($odo, $startvals, Optim.LBFGS(), $opt);
else
@btime reso = Optim.optimize($odo, $startvals, Optim.LBFGS(), $opt);
end
# Zygote-free CPU: 800 µs, for 1600x1600: 2.7 sec
# Zygote-free GPU: 52 ms, for 1600x1600: 0.213 sec

using InverseModeling

gstartvals = ComponentVector(;offset = startvals.bg, i0=startvals.intensity, µ=startvals.off.-sz2 .+1, σ=startvals.args)
@time res1, res2, res3 = gauss_fit(dat, gstartvals; iterations = 99);
res3.f_calls
res3.minimum
vec_true
res1

@btime res1, res2, res3 = gauss_fit($dat, $gstartvals; x_reltol=0.001);
# 4.37 ms (27575 allocations: 8.12 MiB)
@vt dat res2 gaussian_vec(sz, reso.minimizer) (res2 .- dat) (gaussian_vec(sz, reso.minimizer) .- dat)

@time res1, res2, res3 = gauss_fit(dat);

@btime res1, res2, res3 = gauss_fit($dat);
# 5 ms (39192 allocations: 4.35 MiB)

# @btime Optim.optimize($loss, $off_start, $sigma_start, LBFGS(); autodiff = :forward); # 1.000 ms (10001 allocations: 1.53 MiB)
33 changes: 17 additions & 16 deletions performance_tests/gauss_fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ mid = sz.÷2 .+1
myxx = Float32.(xx((sz[1],1)))
myyy = Float32.(yy((1, sz[2])))

my_full_gaussian(vec) = vec.bg .+ vec.intensity.*exp.(.-abs2.(vec.sca[1].*(myxx .- vec.off[1])./(sqrt(2f0)*vec.args[1])) .- abs2.(vec.sca[2].*(myyy .- vec.off[2])./(sqrt(2f0)*vec.args[2])))
my_full_gaussian(vec) = vec.bg .+ vec.intensity.*exp.(.-abs2.((myxx .- vec.off[1])./(sqrt(2f0)*vec.args[1])) .- abs2.((myyy .- vec.off[2])./(sqrt(2f0)*vec.args[2])))
# my_full_gaussian(vec) = vec.bg .+ vec.intensity.*exp.(.-abs2.(vec.sca[1].*(myxx .- vec.off[1])./(sqrt(2f0)*vec.sigma[1]))).*exp.(.- abs2.(vec.sca[2].*(myyy .- vec.off[2])./(sqrt(2f0)*vec.sigma[2])))

vec_true = ComponentVector(;bg=0.2f0, intensity=1.0f0, sca= (1f0, 1f0), off = (2.2f0, 3.3f0), args = (2.4f0, 1.5f0))
vec_true = ComponentVector(;bg=0.2f0, intensity=1.0f0, off = (2.2f0, 3.3f0), args = (2.4f0, 1.5f0))
# @btime my_full_gaussian($vec_true); # 42.800 μs (18 allocations: 16.70 KiB)

Random.seed!(42)
Expand All @@ -41,8 +41,8 @@ dat_copy = copy(dat)

loss = (vec) -> sum(abs2.(my_full_gaussian(vec) .- dat))

bc_mem = gaussian_nokw_sep(sz, vec_true.off.+mid, vec_true.sca, vec_true.args)
my_sep_gaussian(vec) = vec.bg .+ vec.intensity .* gaussian_nokw_sep(sz, vec.off .+mid, vec.sca, vec.args; all_axes=bc_mem)
bc_mem = gaussian_nokw_sep(sz, vec_true.off.+mid, 1f0, vec_true.args)
my_sep_gaussian(vec) = vec.bg .+ vec.intensity .* gaussian_nokw_sep(sz, vec.off .+mid, 1f0, vec.args; all_axes=bc_mem)
loss_sep = (vec) -> sum(abs2.(my_sep_gaussian(vec) .- dat))

# off_start = (2.0f0, 3.0f0)
Expand All @@ -51,8 +51,8 @@ loss_sep = (vec) -> sum(abs2.(my_sep_gaussian(vec) .- dat))
off_start = [2.0f0, 3.0f0]
# off_start = [-32f0, -32f0]
sigma_start = [3.0f0, 2.0f0]
sca_start = [1.2f0, 1.5f0]
startvals = ComponentVector(;bg = 0.5f0, intensity=1.0f0, sca=sca_start, off=off_start, args=sigma_start)
# sca_start = [1.2f0, 1.5f0]
startvals = ComponentVector(;bg = 0.5f0, intensity=1.0f0, off=off_start, args=sigma_start)

# @vt dat my_full_gaussian(startvals) my_sep_gaussian(startvals)

Expand Down Expand Up @@ -82,9 +82,9 @@ v, g = value_and_gradient(loss_sep, AutoZygote(), startvals) # (54.162834f0
# v, g = value_and_gradient(loss_sep, AutoReverseDiff(), startvals) #(54.162834f0, (off = Float32[-0.3461006, -1.4118625], sigma = Float32[-0.15418892, 0.4756395]))

mymem = get_bc_mem(typeof(dat), size(dat), *)
@time fg!(dat, gaussian_raw, startvals.bg, startvals.intensity, startvals.off .+ sz2 .+1, startvals.sca, startvals.sigma; all_axes=mymem)
@time fg!(dat, gaussian_raw, startvals.bg, startvals.intensity, startvals.off .+ sz2 .+1, startvals.sigma; all_axes=mymem)

@btime fg!($dat, $gaussian_raw, $startvals.bg, $startvals.intensity, $startvals.off .+ $sz2 .+1, startvals.sca, $startvals.sigma; all_axes=$mymem)
@btime fg!($dat, $gaussian_raw, $startvals.bg, $startvals.intensity, $startvals.off .+ $sz2 .+1, $startvals.sigma; all_axes=$mymem)
# 35.3 µs, 165 allocs, 126 kB

# broken!
Expand Down Expand Up @@ -136,17 +136,17 @@ function fg!(F, G, vec)
end
end

myfg! = get_fg!(dat, gaussian_raw)
myfg! = get_fg!(dat, gaussian_raw, loss=loss_anscombe_pos)

startvals_s = copy(startvals)
startvals_s.off .+= sz2 .+1
G = copy(startvals)
myfg!(1, G, startvals_s)

G
od = OnceDifferentiable(Optim.NLSolversBase.only_fg!(fg!), startvals);
@time res = Optim.optimize(od, startvals, Optim.LBFGS(), opt)
res.f_calls # 33
res.minimum # 13.9137
res.f_calls # 31
res.minimum # 13.927
# res.f_calls = 0

loss(startvals)
Expand All @@ -157,15 +157,15 @@ od = OnceDifferentiable(Optim.NLSolversBase.only_fg!(fg!), startvals);
# Full: 4.804 ms (14435 allocations: 12.78 MiB)
# Sep: 4.099 ms (26269 allocations: 11.20 MiB)
# Hand-Separated: 2.397 ms (15472 allocations: 9.19 MiB)
# 2.8 ms
# 2.5 ms

myfg!(1.0, G, startvals_s)
G
# opt = Optim.Options(iterations = 9); # why ?
odo = OnceDifferentiable(Optim.NLSolversBase.only_fg!(myfg!), startvals_s);
@time reso = Optim.optimize(odo, startvals_s, Optim.LBFGS(), opt);
reso.f_calls # 33
reso.minimum # 13.9138
reso.f_calls # 31
reso.minimum # 13.927

odo = OnceDifferentiable(Optim.NLSolversBase.only_fg!(myfg!), startvals_s);
@btime reso = Optim.optimize($odo, $startvals_s, Optim.LBFGS(), $opt);
Expand All @@ -176,11 +176,12 @@ using InverseModeling
gstartvals = ComponentVector(;offset = startvals.bg, i0=startvals.intensity, µ=startvals.off, σ=startvals.args)
@time res1, res2, res3 = gauss_fit(dat, gstartvals; iterations = 9);
res3.f_calls
res3.minimum
vec_true
res1

@btime res1, res2, res3 = gauss_fit($dat, $gstartvals; x_reltol=0.001);
# 4.214 ms (27575 allocations: 8.12 MiB)
# 4.37 ms (27575 allocations: 8.12 MiB)
@vt dat res2 (res2 .- dat)

# @btime Optim.optimize($loss, $off_start, $sigma_start, LBFGS(); autodiff = :forward); # 1.000 ms (10001 allocations: 1.53 MiB)
2 changes: 2 additions & 0 deletions src/SeparableFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ export calc_radial_symm!, calc_radial_symm, get_corner_ranges
export radial_speedup, radial_speedup_ifa
export kwargs_to_args

export loss_anscombe, loss_anscombe_pos, loss_gaussian, loss_poisson, loss_poisson_pos

DefaultResElType = Float32
DefaultArrType = Array{DefaultResElType}
DefaultComplexArrType = Array{complex(DefaultResElType)}
Expand Down
101 changes: 82 additions & 19 deletions src/general.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,24 +359,86 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(cal
return y, calculate_separables_nokw_hook_pullback
end

function get_fg!(data::AbstractArray{T,N}, fct) where {T,N}
# from InverseModeling.jl
loss_gaussian(data::AbstractArray{T}, fwd, bg=zero(T)) where {T} = mapreduce(abs2, +, fwd .- data ; init=zero(T))
function get_fgC!(data::AbstractArray{T,N}, ::typeof(loss_gaussian), bg=zero(T)) where {T,N}
function fg!(F, G, fwd)
if !isnothing(G)
G .= (fwd .- data)
return (!isnothing(F)) ? mapreduce(abs2, +, G; init=zero(T)) : T(0);
end
return (!isnothing(F)) ? mapreduce(abs2, +, (fwd .- data); init=zero(T)) : T(0);
end
return fg!, 2
end

loss_poisson(data::AbstractArray{T}, fwd, bg=zero(T)) where {T} = sum((fwd.+bg) .- (data.+bg).*log.(fwd.+bg))
function get_fgC!(data::AbstractArray{T,N}, ::typeof(loss_poisson), bg=zero(T)) where {T,N}
function fg!(F, G, fwd)
if !isnothing(G)
G .= one(T) .- (data .+ bg)./max.(fwd .+ bg, max(b, T(1f-10)))
end
return (!isnothing(F)) ? reduce(+, (fwd.+bg) .- (data.+bg).*log.(fwd.+bg); init=zero(T)) : T(0);
end
return fg!, 1
end

loss_poisson_pos(data::AbstractArray{T}, fwd, bg=zero(T)) where {T} = sum(max.(eltype(fwd)(0),fwd.+bg) .- max.(eltype(fwd)(0),data.+bg).*log.(max.(eltype(fwd)(0),fwd.+bg)))
function get_fgC!(data::AbstractArray{T,N}, ::typeof(loss_poisson_pos), bg=zero(T)) where {T,N}
function fg!(F, G, fwd)
if !isnothing(G)
G .= one(T) .- (data .+ bg)./max.(fwd .+ bg, T(1f-10))
end
return (!isnothing(F)) ? reduce(+, max.(eltype(fwd)(0),(fwd.+bg)) .- max.(eltype(fwd)(0),(data.+bg)).*log.(max.(eltype(fwd)(0),fwd.+bg)); init=zero(T)) : T(0);
end
return fg!, 1
end

loss_anscombe(data::AbstractArray{T}, fwd, bg=zero(T)) where {T} = sum(abs2.(sqrt.(data.+bg) .- sqrt.(fwd.+bg)))
function get_fgC!(data::AbstractArray{T,N}, ::typeof(loss_anscombe), bg=zero(T)) where {T,N}
function fg!(F, G, fwd)
if !isnothing(G)
G .= one(T) .- sqrt.(data.+bg)./sqrt.(max.(fwd .+ bg, max(bg, T(1f-10))))
end
return (!isnothing(F)) ? mapreduce(abs2, +, sqrt.(fwd.+bg) .- sqrt.(data.+bg); init=zero(T)) : T(0);
end
return fg!, 1
end

loss_anscombe_pos(data::AbstractArray{T}, fwd, bg=zero(T)) where {T} = sum(abs2.(sqrt.(max.(eltype(fwd)(0),fwd.+bg)) .- sqrt.(max.(eltype(fwd)(0),data.+bg))))

function get_fgC!(data::AbstractArray{T,N}, ::typeof(loss_anscombe_pos), bg=zero(T)) where {T,N}
function fg!(F, G, fwd)
if !isnothing(G)
G .= one(T) .- sqrt.(max.(eltype(fwd)(0), data.+bg))./sqrt.(max.(fwd .+ bg, max(bg, T(1f-10))))
end
return (!isnothing(F)) ? mapreduce(abs2, +, sqrt.(max.(eltype(fwd)(0), fwd.+bg)) .- sqrt.(max.(eltype(fwd)(0),data.+bg)); init=zero(T)) : T(0);
end
return fg!, 1
end


function get_fg!(data::AbstractArray{T,N}, fct; loss = loss_gaussian, bg=zero(real(T))) where {T,N}
RT = real(eltype(data))
operator = get_operator(fct)
by = get_bc_mem(typeof(data), size(data), operator)
y = by.args
yv = ntuple((d) -> (@view y[d][:]), Val(N))

resid = similar(data)
dy = get_sep_mem(typeof(data), size(data))
dyv = ntuple((d) -> (@view dy[d][:]), Val(N))

resid = similar(data) # checkpoint

loss_fg!, C = get_fgC!(data, loss, bg)

# this function returns the forward value and mutates the gradient G
function fg!(F, G, vec)
bg = hasfield(vec, :bg) ? vec.bg : zero(RT);
intensity = hasfield(vec, :intensity) ? vec.intensity : one(RT)
off = hasfield(vec, :off) ? vec.off : RT.(sz2 .+1)
sca = hasfield(vec, :sca) ? vec.sca : one(RT);
args = hasfield(vec, :args) ? (vec.args,) : one(RT)
bg = hasproperty(vec, :bg) ? vec.bg : zero(RT);
intensity = hasproperty(vec, :intensity) ? vec.intensity : one(RT)
off = hasproperty(vec, :off) ? vec.off : RT.(sz2 .+1)
sca = hasproperty(vec, :sca) ? vec.sca : one(RT);
args = hasproperty(vec, :args) ? (vec.args,) : one(RT)
sz = size(data)
# mid = sz .÷ 2 .+ 1
# off = off .+ mid
Expand All @@ -386,11 +448,12 @@ function get_fg!(data::AbstractArray{T,N}, fct) where {T,N}
# 5kB, result is in by
calculate_broadcasted_nokw(typeof(data), fct, sz, off, sca, args...; operator=operator, all_axes=by)

if !isnothing(F) || !isnothing(G)
resid .= bg .+ intensity .* by .- data
end
# if !isnothing(F) || !isnothing(G)
# resid .= bg .+ intensity .* by .- data
# end
loss = loss_fg!(F, resid, bg .+ intensity .* by)
if !isnothing(G)
G.bg = 2*sum(resid)
G.bg = C*sum(resid)

other_dims = ntuple((d)-> (ntuple((n)->n, d-1)..., ntuple((n)->d+n, N-d)...), Val(N))
# @show other_dims
Expand All @@ -411,38 +474,38 @@ function get_fg!(data::AbstractArray{T,N}, fct) where {T,N}

args_1d = ntuple((d)-> pick_n.(d, args), Val(N))

if hasfield(vec, :off)
G.off = optional_convert(off, ntuple((d) -> (-2*intensity*pick_n(d, sca)) .*
if hasproperty(vec, :off)
G.off = optional_convert(off, ntuple((d) -> (-C*intensity*pick_n(d, sca)) .*
get_idx_gradient(fct, yv[d], ids[d], sz[d], dyv[d], args_1d[d]...), Val(N))) # ids @ offset the -1 is since the argument of fct is idx-offset
end

if hasfield(vec, :args)
if hasproperty(vec, :args)
dargs = ntuple((argno) -> optional_convert(args[argno],
(2*intensity).*ntuple((d) -> get_arg_gradient(fct, yv[d], ids[d], sz[d], dyv[d], args_1d[d]...), Val(N))), length(args)) # ids @ offset the -1 is since the argument of fct is idx-offset
(C*intensity).*ntuple((d) -> get_arg_gradient(fct, yv[d], ids[d], sz[d], dyv[d], args_1d[d]...), Val(N))), length(args)) # ids @ offset the -1 is since the argument of fct is idx-offset
G.args = dargs[1]
end
# dargs = (0f0,0f0)

if hasfield(vec, :sca)
if hasproperty(vec, :sca)
# missuse the dy memory
for d = 1:N
dyv[d] .= dyv[d].* ids_offset_only[d]
end
G.sca = optional_convert(sca, ntuple((d) ->
(2*intensity).*get_idx_gradient(fct, yv[d], ids[d], sz[d], dyv[d], args_1d[d]...), Val(N))) # ids @ offset the -1 is since the argument of fct is idx-offset
(C*intensity).*get_idx_gradient(fct, yv[d], ids[d], sz[d], dyv[d], args_1d[d]...), Val(N))) # ids @ offset the -1 is since the argument of fct is idx-offset
end
# 1.5 kB:
# (2*intensity).*get_idx_gradient(fct, yv[d], ids[d], sz[d], dyv[d].* ids_offset_only[d], args_1d[d]...), Val(N))) # ids @ offset the -1 is since the argument of fct is idx-offset
# dscale = (0f0,0f0)
end

# return loss, dfdbg, dfdI, doffset, dscale, dargs...
loss = (!isnothing(F)) ? mapreduce(abs2, +, resid; init=zero(T)) : T(0);
# loss = (!isnothing(F)) ? mapreduce(abs2, +, resid; init=zero(T)) : T(0);

if !isnothing(G) # forward needs to be calculated
# resid .*= by # is slower!
resid .= resid .* by
G.intensity = 2*sum(resid)
G.intensity = C*sum(resid)
# G.intensity = 2*sum(resid.*by)
end

Expand Down
20 changes: 19 additions & 1 deletion src/specific.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,23 @@ for F in generate_functions_expr()
# operator=$(F[5])
# return calculate_separables_nokw(Array{$(F[4])}, fct, sz, args...; all_axes=all_axes), operator
end


@eval function $(Symbol(F[1], :_vec))(::Type{TA}, sz::NTuple{N, Int}, vec;
all_axes = get_bc_mem(Array{$(F[4])}, sz, $(F[5])) ) where {TA, N}
RT = real(eltype(TA))
intensity = (hasproperty(vec, :intensity)) ? vec.intensity : one(RT)
bg = (hasproperty(vec, :bg)) ? vec.bg : zero(RT)
off = (hasproperty(vec, :off)) ? vec.off : sz 2 .+ 1
sca = (hasproperty(vec, :sca)) ? vec.sca : ones(RT, length(sz))
args = (hasproperty(vec, :args)) ? (vec.args,) : Tuple([])
return bg .+ intensity .* ($(Symbol(F[1], :_nokw_sep))(sz, off, sca, args...; all_axes=all_axes))
end

@eval function $(Symbol(F[1], :_vec))(sz::NTuple{N, Int}, vec;
all_axes = get_bc_mem(Array{$(F[4])}, sz, $(F[5])) ) where {N}
return $(Symbol(F[1], :_vec))(Array{$(F[4])}, sz, vec; all_axes=all_axes)
end

@eval function $(Symbol(F[1], :_lz))(::Type{TA}, sz::NTuple{N, Int}, args...; kwargs...) where {TA, N}
fct = $(F[3]) # to assign the function to a symbol
separable_view(TA, fct, sz, args...; defaults=$(F[2]), operator=$(F[5]), kwargs...)
Expand All @@ -191,6 +207,8 @@ for F in generate_functions_expr()
@eval export $(Symbol(F[1], :_col))
# separated: a vector of separated contributions is returned and the user has to combine them
@eval export $(Symbol(F[1], :_sep))
# a broadcasted version which accepts a ComponentArray as an input
@eval export $(Symbol(F[1], :_vec))
# lazy: A LazyArray representation is returned
@eval export $(Symbol(F[1], :_nokw_sep))
# @eval export $(Symbol(F[1], :_lz))
Expand Down

0 comments on commit d1bbe2d

Please sign in to comment.