Skip to content

Commit

Permalink
Merge pull request #33 from MagneticResonanceImaging/OverloadCalcFilt…
Browse files Browse the repository at this point in the history
…eredBackProjection

Overloaded calcFBP function to allow for easier non-MRF fbps
  • Loading branch information
JakobAsslaender authored Feb 16, 2024
2 parents 9f9d7dc + 53f9f62 commit 7ddbc7e
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 54 deletions.
75 changes: 66 additions & 9 deletions src/BackProjection.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,89 @@
function calculateBackProjection(data::AbstractArray{T}, trj, U, cmaps; verbose = false) where {T}
function calculateBackProjection(data::AbstractArray{T}, trj, img_shape::NTuple{N,Int}; U = N==3 ? I(size(data,2)) : I(1), density_compensation=:none, verbose=false) where {N,T}
if typeof(trj) <: AbstractMatrix
trj = [trj]
end

if ndims(data) == 2
data = reshape(data, size(data, 1), 1, size(data, 2))
end
Ncoils = size(data, 3)
Ncoef = size(U,2)

p = plan_nfft(reduce(hcat, trj), img_shape; precompute=TENSOR, blocking=true, fftflags=FFTW.MEASURE)
xbp = Array{T}(undef, img_shape..., Ncoef, Ncoils)

data_temp = similar(@view data[:, :, 1]) # size = Ncycles*Nr x Nt
img_idx = CartesianIndices(img_shape)
verbose && println("calculating backprojection..."); flush(stdout)
for icoef axes(U,2)
t = @elapsed for icoil axes(data, 3)
@simd for i CartesianIndices(data_temp)
@inbounds data_temp[i] = data[i,icoil] * conj(U[i[2],icoef])
end
applyDensityCompensation!(data_temp, trj; density_compensation)
@views mul!(xbp[img_idx, icoef, icoil], adjoint(p), vec(data_temp))
end
verbose && println("coefficient = $icoef: t = $t s"); flush(stdout)
end
return xbp
end

function calculateBackProjection(data::AbstractArray{T,N}, trj, cmaps::AbstractVector{<:AbstractArray{T}}; U = N==3 ? I(size(data,2)) : I(1), density_compensation=:none, verbose=false) where {N,T}
if typeof(trj) <: AbstractMatrix
trj = [trj]
end

if ndims(data) == 2
data = reshape(data, size(data, 1), 1, size(data, 2))
end

test_dimension(data, trj, U, cmaps)

_, Ncoef = size(U)
Ncoef = size(U,2)
img_shape = size(cmaps[1])

p = plan_nfft(reduce(hcat,trj), img_shape; precompute=TENSOR, blocking = true, fftflags = FFTW.MEASURE)
xbp = zeros(T, img_shape..., Ncoef)
xtmp = Array{T}(undef, img_shape)

dataU = similar(@view data[:,:,1]) # size = Ncycles*Nr x Nt
data_temp = similar(@view data[:,:,1]) # size = Ncycles*Nr x Nt
img_idx = CartesianIndices(img_shape)
verbose && println("calculating backprojection..."); flush(stdout)
for icoef axes(U,2)
t = @elapsed for icoil eachindex(cmaps)
@simd for i CartesianIndices(dataU)
@inbounds dataU[i] = data[i,icoil] * conj(U[i[2],icoef])
@simd for i CartesianIndices(data_temp)
@inbounds data_temp[i] = data[i,icoil] * conj(U[i[2],icoef])
end
mul!(xtmp, adjoint(p), vec(dataU))
applyDensityCompensation!(data_temp, trj; density_compensation)

mul!(xtmp, adjoint(p), vec(data_temp))
xbp[img_idx,icoef] .+= conj.(cmaps[icoil]) .* xtmp
end
verbose && println("coefficient = $icoef: t = $t s"); flush(stdout)
end
return xbp
end

function applyDensityCompensation!(data, trj; density_compensation=:radial_3D)
for it in axes(data, 2)
if density_compensation == :radial_3D
data[:, it] .*= transpose(sum(abs2, trj[it], dims=1))
elseif density_compensation == :radial_2D
data[:, it] .*= transpose(sqrt.(sum(abs2, trj[it], dims=1)))
elseif density_compensation == :none
# do nothing here
elseif isa(density_compensation, AbstractVector{<:AbstractVector})
data[:, it] .*= density_compensation[it]
else
error("`density_compensation` can only be `:radial_3D`, `:radial_2D`, `:none`, or of type `AbstractVector{<:AbstractVector}`")
end
end
end

function test_dimension(data, trj, U, cmaps)
Nt, _ = size(U)
img_shape = size(cmaps[1])
Ncoils = length(cmaps)
Nt = size(U,1)
img_shape = size(cmaps)[1:end-1]
Ncoils = size(cmaps)[end]

Nt != size(data, 2) && ArgumentError(
"The second dimension of data ($(size(data, 2))) and the first one of U ($Nt) do not match. Both should be number of time points.",
Expand Down
48 changes: 7 additions & 41 deletions src/CoilMaps.jl
Original file line number Diff line number Diff line change
@@ -1,58 +1,24 @@
function calcFilteredBackProjection(data::AbstractArray{Complex{T},3}, trj::AbstractVector{<:AbstractMatrix{T}}, U::AbstractMatrix{Complex{T}}, img_shape::NTuple{N,Int}, Ncoils::Int; density_compensation::Union{Symbol, <:AbstractVector{<:AbstractVector{T}}}=:radial_3D, verbose = false) where {N,T}
p = plan_nfft(reduce(hcat,trj), img_shape; precompute=TENSOR, blocking = true, fftflags = FFTW.MEASURE)
xbp = Array{Complex{T}}(undef, img_shape..., Ncoils)

dataU = similar(@view data[:,:,1]) # size = Ncycles*Nr x Nt
img_idx = CartesianIndices(img_shape)
t = @elapsed for icoil axes(data,3)
if density_compensation == :radial_3D
@simd for i CartesianIndices(dataU)
dataU[i] = data[i,icoil] * conj(U[i[2],1]) * sum(abs2, @view trj[i[2]][:,i[1]])
end
elseif density_compensation == :radial_2D
@simd for i CartesianIndices(dataU)
dataU[i] = data[i,icoil] * conj(U[i[2],1]) * sqrt(sum(abs2, @view trj[i[2]][:,i[1]]))
end
elseif density_compensation == :none
# no density compensation; premultiply data with inverse of sampling density before calling function
@simd for i CartesianIndices(dataU)
dataU[i] = data[i,icoil] * conj(U[i[2],1])
end
elseif isa(density_compensation, Symbol)
error("`density_compensation` can only be `:radial_3D`, `:radial_2D`, `:none`, or of type `AbstractVector{<:AbstractVector{T}}`")
else
@simd for i CartesianIndices(dataU)
dataU[i] = data[i,icoil] * conj(U[i[2],1]) * density_compensation[i[2]][i[1]]
end
end

@views mul!(xbp[img_idx,icoil], adjoint(p), vec(dataU))
end
verbose && println("BP for coils maps: $t s")
return xbp
end


function calcCoilMaps(data::AbstractArray{Complex{T},3}, trj::AbstractVector{<:AbstractMatrix{T}}, U::AbstractMatrix{Complex{T}}, img_shape::NTuple{N,Int}; density_compensation::Union{Symbol, <:AbstractVector{<:AbstractVector{T}}}=:radial_3D, kernel_size = ntuple(_->6, N), calib_size = ntuple(_->24, N), eigThresh_1=0.01, eigThresh_2=0.9, nmaps=1, verbose = false) where {N,T}
Ncoils = size(data,3)
function calcCoilMaps(data::AbstractArray{Complex{T},3}, trj::AbstractVector{<:AbstractMatrix{T}}, img_shape::NTuple{N,Int}; U = N==3 ? I(size(data,2)) : I(1), density_compensation=:radial_3D, kernel_size=ntuple(_ -> 6, N), calib_size=ntuple(_ -> 24, N), eigThresh_1=0.01, eigThresh_2=0.9, nmaps=1, verbose=false) where {N,T}
Ncoils = size(data, 3)
Ndims = length(img_shape)
imdims = ntuple(i->i, Ndims)
imdims = ntuple(i -> i, Ndims)

xbp = calcFilteredBackProjection(data, trj, U, img_shape, Ncoils; density_compensation, verbose)
xbp = calculateBackProjection(data, trj, img_shape; U=U[:,1], density_compensation, verbose)
xbp = dropdims(xbp, dims=ndims(xbp)-1)

img_idx = CartesianIndices(img_shape)
kbp = fftshift(xbp, imdims)
fft!(kbp, imdims)
kbp = fftshift(kbp, imdims)

m = CartesianIndices(calib_size) .+ CartesianIndex((img_shape .- calib_size) 2)
kbp = kbp[m,:]
kbp = kbp[m, :]

t = @elapsed begin
cmaps = espirit(kbp, img_shape, kernel_size, eigThresh_1=eigThresh_1, eigThresh_2=eigThresh_2, nmaps=nmaps)
end
verbose && println("espirit: $t s")

cmaps = [cmaps[img_idx,ic,1] for ic=1:Ncoils]
cmaps = [cmaps[img_idx, ic, 1] for ic = 1:Ncoils]
return cmaps
end
1 change: 1 addition & 0 deletions src/MRFingerprintingRecon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ include("NFFTNormalOpBasisFunc.jl")
include("CoilMaps.jl")
include("BackProjection.jl")
include("Trajectories.jl")
include("deprecated.jl")

end # module
6 changes: 3 additions & 3 deletions src/NFFTNormalOpBasisFunc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function NFFTNormalOpBasisFunc(
img_shape,
trj::AbstractVector{<:AbstractMatrix{T}},
U::AbstractMatrix{Tc};
cmaps = (1,),
cmaps=[ones(T, img_shape)],
verbose = false,
Λ_kmask_indcs = calculateToeplitzKernelBasis(2 .* img_shape, trj, U; verbose = verbose),
num_fft_threads = round(Int, Threads.nthreads()/size(U, 2))
Expand Down Expand Up @@ -154,13 +154,13 @@ function NFFTNormalOpBasisFuncLO(
img_shape,
trj::AbstractVector{<:AbstractMatrix{T}},
U::AbstractMatrix{Tc};
cmaps = (1,),
cmaps=[ones(T, img_shape)],
verbose = false,
Λ_kmask_indcs = calculateToeplitzKernelBasis(2 .* img_shape, trj, U; verbose = verbose),
num_fft_threads = round(Int, Threads.nthreads()/size(U, 2))
) where {T, Tc <: Union{T, Complex{T}}}

S = NFFTNormalOpBasisFunc(img_shape, trj, U; cmaps = cmaps, Λ_kmask_indcs = Λ_kmask_indcs, num_fft_threads = num_fft_threads)
S = NFFTNormalOpBasisFunc(img_shape, trj, U; cmaps, Λ_kmask_indcs, num_fft_threads)
return NFFTNormalOpBasisFuncLO(S)
end

Expand Down
10 changes: 10 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
function calculateBackProjection(data::AbstractArray{T}, trj, U, cmaps::AbstractVector{<:AbstractArray{T}}; density_compensation=:none, verbose=false) where T
@warn "calculateBackProjection(data, trj, U, cmaps) has been deprecated – call calculateBackProjection(data, trj, cmaps; U=U) with U as a keyword argument instead." maxlog=1
return calculateBackProjection(data, trj, cmaps; U, density_compensation, verbose)
end


function calcCoilMaps(data::AbstractArray{Complex{T},3}, trj::AbstractVector{<:AbstractMatrix{T}}, U::AbstractMatrix{Complex{T}}, img_shape::NTuple{N,Int}; density_compensation=:radial_3D, kernel_size=ntuple(_ -> 6, N), calib_size=ntuple(_ -> 24, N), eigThresh_1=0.01, eigThresh_2=0.9, nmaps=1, verbose=false) where {N,T}
@warn "calcCoilMaps(data, trj, U, img_shape) has been deprecated – call calcCoilMaps(data, trj, img_shape; U=U) with U as a keyword argument instead." maxlog=1
return calcCoilMaps(data, trj, img_shape; U, density_compensation, kernel_size, calib_size, eigThresh_1, eigThresh_2, nmaps, verbose)
end
4 changes: 3 additions & 1 deletion test/reconstruct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ for it ∈ axes(data,2)
end

## BackProjection
b = vec(calculateBackProjection(data, trj, U, [ones(T, Nx,Nx)], verbose = true))
b = vec(calculateBackProjection(data, trj, (Nx, Nx); U, verbose = true))
bcoil = vec(calculateBackProjection(data, trj, [ones(Complex{T}, Nx,Nx)]; U, verbose = true))
@test bcoil b

## construct forward operator
A = NFFTNormalOpBasisFuncLO((Nx,Nx), trj, U; verbose = true)
Expand Down

0 comments on commit 7ddbc7e

Please sign in to comment.