Skip to content

Commit

Permalink
bug fixes with gauss_fit and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RainerHeintzmann committed Aug 2, 2024
1 parent aa17441 commit 8c66617
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 26 deletions.
41 changes: 29 additions & 12 deletions examples/gauss_fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,47 +8,64 @@ using CUDA
# simulate a gaussian blob with Poisson noise and fit it with a Gaussian function
sz = (7,7) # (1600, 1600)
many_fits = true
use_cuda = true
N = 10_000
use_cuda = false
N = 1_000
hyperplanes = many_fits ? rand(Float32, (1, N)) : 0
hp_zeros = many_fits ? zeros(Float32, (1, N)) : 0

off = [3.2f0, 3.5f0].+hyperplanes
sigma = [1.4f0, 1.1f0]
off = [3.2f0, 3.5f0] .+ hyperplanes
sigma = [1.4f0, 1.1f0] .+ hp_zeros
hyperplanes = many_fits ? rand(Float32, (1, N)) : 0
intensity = [50f0] .* (1 .+ hyperplanes)
vec_true = ComponentVector(;bg=10.0f0, intensity=intensity, off = off, args = sigma)
vec_true = ComponentVector(;bg=10.0f0 .+ hp_zeros, intensity=intensity, off = off, args = sigma)
vec_true = Float64.(vec_true)

pdat = gaussian_vec(sz, vec_true)
dat = Float32.(poisson(Float64.(pdat)))

pdat = (use_cuda) ? CuArray(pdat) : pdat
dat = (use_cuda) ? CuArray(dat) : dat

qdat = copy(pdat)
qdat .= qdat[:,:,1]
# now prepare the fitting:
myfg! = get_fg!(pdat, gaussian_raw, length(sz); loss=loss_anscombe_pos, bg=7f0);
# myfg! = get_fg!(pdat, gaussian_raw, length(sz); loss=loss_anscombe_pos, bg=7f0);
myfg! = get_fg!(qdat, gaussian_raw, length(sz); loss=loss_gaussian);
shyperplanes = many_fits ? zeros(Float32, (1, size(dat)[end])) : 0
soff = [4.0f0, 4.0f0] .+ shyperplanes
bg = 0.5f0
bg = [0.5f0] .+ shyperplanes
intensity = [45f0] .+ shyperplanes
sigma = [3.0f0, 2.0f0]
sigma = [3.0f0, 2.0f0] .+ shyperplanes
if (use_cuda)
bg = CuArray([bg])
intensity = CuArray(intensity)
soff = CuArray(soff)
sigma = CuArray(sigma)
end
startvals = ComponentVector(;bg=bg, intensity=intensity, off = soff, args = sigma)
opt = Optim.Options(iterations = 499); #
startvals = Float64.(startvals)
opt = Optim.Options(iterations = 50); #
odo = OnceDifferentiable(Optim.NLSolversBase.only_fg!(myfg!), startvals);

if (false)
G = copy(startvals)
myfg!(1, G, startvals)

myfg2! = get_fg!(pdat[:,:,1], gaussian_raw, length(sz); loss=loss_anscombe_pos, bg=7f0);
sv = ComponentVector{Float32}(bg=startvals.bg[1], intensity=startvals.intensity[1], off = startvals.off[:,1], args = startvals.args[:,1])
G2 = copy(sv)
myfg2!(1, G2, sv)
G2
end

# and perform the fit
@time reso = Optim.optimize(odo, startvals, Optim.LBFGS(), opt);
# 2 sec, 5k fits/s (44.25 k allocations: 1.546 GiB, 7.35% gc time)
# with intensity variations: 26.833106 seconds (532.47 k allocations: 20.251 GiB, 5.99% gc time)
# in Cuda:
reso.f_calls # 61
reso.minimum
@vt pdat dat gaussian_vec(sz, startvals) gaussian_vec(sz, reso.minimizer)
reso.f_calls # 61 # 1766 für 10_000 fits, 155, 2.2 sec for 10_000 fits with all entries being vectors
reso.minimum #
@vt pdat gaussian_vec(sz, startvals) gaussian_vec(sz, reso.minimizer)

odo = OnceDifferentiable(Optim.NLSolversBase.only_fg!(myfg!), startvals);
if isa(dat, CuArray)
Expand Down
15 changes: 12 additions & 3 deletions src/general.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,12 +438,19 @@ function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussia
# resid .= bg .+ intensity .* by .- data
# end
myint = get_vec_dim(intensity, 1, sz)
loss = loss_fg!(F, resid, bg .+ myint .* by)
mybg = get_vec_dim(bg, 1, sz)
loss = loss_fg!(F, resid, mybg .+ myint .* by)
# @show resid
if !isnothing(G)
# for arrays this should be .=
if hasproperty(vec, :bg)
G.bg = C*sum(resid)
if (prod(size(G.bg)) > 1)
G.bg[:] .= C*sum(resid, dims=1:length(sz))[:]
elseif isa(G.bg, Number)
G.bg = C*sum(resid)[1]
else
G.bg .= C*sum(resid)[1]
end
end

other_dims = ntuple((d)-> (ntuple((n)->n, d-1)..., ntuple((n)->d+n, prod_dims-d)...), Val(prod_dims))
Expand Down Expand Up @@ -508,8 +515,10 @@ function get_fg!(data::AbstractArray{T,N}, fct, prod_dims=N; loss = loss_gaussia

if (prod(size(G.intensity)) > 1)
G.intensity[:] .= C .* @view sum(resid, dims=1:length(sz))[:]
elseif isa(G.intensity, Number)
G.intensity = C .* sum(resid)[1]
else
G.intensity = C .* sum(resid, dims=1:length(sz))[1]
G.intensity .= C .* sum(resid)[1]
end
# G.intensity[:] .= C .* sum(resid .* by, dims=1:length(sz))[:]
# end
Expand Down
9 changes: 7 additions & 2 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,19 @@ end

# Versions that assign in place
function optional_convert_assign!(dst, ref_arg::AbstractArray{T,N}, val) where {T,N}
if (prod(size(ref_arg)) == 1)
if isa(dst, Number)
dst = sum(sum.(val))
return
elseif (prod(size(ref_arg)) == 1)
dst .= sum(sum.(val))
return
end
d=1
for v in val
dv = selectdim(dst, 1, d)
if (prod(size(dv)) == 1)
if isa(dv, Number)
dv = sum(real.(v[:]))
elseif (prod(size(dv)) == 1)
# the real cast seem dodgy here
dv .= sum(real.(v[:]))
else
Expand Down
14 changes: 5 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,13 @@ end
dat = gaussian_vec(sz, vec_true)
loss2 = (vec) -> sum(abs2.(gaussian_vec(sz, vec) .- dat))
@test loss2(vec_true) == 0
g = gradient(loss2, vec_true)
gn = grad(central_fdm(5, 1), loss2, vec_true) # 5th order method, 1st derivatives
g = gradient(loss2, vec_true)[1]
gn = grad(central_fdm(5, 1), loss2, vec_true)[1] # 5th order method, 1st derivatives
myfg! = get_fg!(dat, gaussian_raw, length(sz); loss = loss_gaussian)
G = similar(gn[1]) .* 0
G = similar(gn) .* 0
f = myfg!(1, G, vec_true)
# maximum(abs.(G))
for (mygn, myg, myfg) in zip(gn[1], g[1], G)
@test all(isapprox.(mygn, 0, atol = 5e-12))
@test all(isapprox.(myg, 0, atol = 5e-12))
@test all(isapprox.(myfg, 0, atol = 5e-12))
end
check_all(Float64, zeros(length(g)), g, gn, G; atol=5e-12)

# vec_start = ComponentVector(;bg=0.3, intensity=1.1, off = [2.3, 3.4], sca = [1.4, 1.3], args = [2.5, 1.6])
vec_start = vec_true .+ 0.2
Expand All @@ -315,7 +311,7 @@ end
# maximum(abs.(Gs[:] .- gns[1][:]))
for (mygn, myg, myfg) in zip(gns, gs, Gs)
@test all(isapprox.(mygn, myg, atol=4e-7))
@test all(isapprox.(mygn, myfg, atol=4e-7))
@test all(isapprox.(mygn, myfg, atol=4e-6))
end
end
end
Expand Down

0 comments on commit 8c66617

Please sign in to comment.