Skip to content

Commit

Permalink
Merge pull request #552 from DiamondLightSource/ccpi_reg
Browse files Browse the repository at this point in the history
FISTA algorithm updates
  • Loading branch information
dkazanc authored Mar 4, 2019
2 parents 8f0bfbc + e53f35b commit 8eb68bb
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 50 deletions.
25 changes: 12 additions & 13 deletions savu/plugins/filters/ccpi_regul_toolkit_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
.. module:: Wrapper for CCPi-Regularisation Toolkit (CPU) for efficient 2D/3D denoising
:platform: Unix
:synopsis: CCPi-Regularisation Toolkit delivers a variety of variational 2D/3D denoising methods. The available methods are: 'ROF_TV','FGP_TV','SB_TV','TGV','LLT_ROF','NDF','DIFF4th'
:synopsis: CCPi-Regularisation Toolkit delivers a variety of variational 2D/3D denoising methods. The available methods are: 'ROF_TV','FGP_TV','SB_TV','TGV','LLT_ROF','NDF','Diff4th'
.. moduleauthor:: Daniil Kazantsev <scientificsoftware@diamond.ac.uk>
"""
Expand All @@ -24,7 +24,7 @@
from savu.plugins.driver.cpu_plugin import CpuPlugin
from savu.plugins.utils import register_plugin

from ccpi.filters.regularisers import ROF_TV, FGP_TV, SB_TV, TGV, LLT_ROF, NDF, DIFF4th
from ccpi.filters.regularisers import ROF_TV, FGP_TV, SB_TV, TGV, LLT_ROF, NDF, Diff4th
from savu.data.plugin_list import CitationInformation

@register_plugin
Expand All @@ -38,7 +38,7 @@ class CcpiRegulToolkitCpu(Plugin, CpuPlugin):
'NDF': Nonlinear/Linear Diffusion model (Perona-Malik, Huber or Tukey);
'DIFF4th': Fourth-order nonlinear diffusion model
:param method: Choose methods |ROF_TV|FGP_TV|SB_TV|TGV|LLT_ROF|NDF|DIFF4th. Default: 'FGP_TV'.
:param method: Choose methods |ROF_TV|FGP_TV|SB_TV|TGV|LLT_ROF|NDF|Diff4th. Default: 'FGP_TV'.
:param reg_par: Regularisation (smoothing) parameter. Default: 0.05.
:param max_iterations: Total number of iterations. Default: 200.
:param time_step: Time marching step, relevant for ROF_TV, LLT_ROF,\
Expand All @@ -49,7 +49,7 @@ class CcpiRegulToolkitCpu(Plugin, CpuPlugin):
:param reg_parLLT: LLT-ROF method, parameter to control the 2nd-order term. Default: 0.05.
:param penalty_type: NDF method, Penalty type for the duffison, choose from\
huber, perona or tukey. Default: 'huber'.
:param edge_par: NDF and DIFF4th methods, noise magnitude parameter. Default: 0.01.
:param edge_par: NDF and Diff4th methods, noise magnitude parameter. Default: 0.01.
"""

def __init__(self):
Expand All @@ -69,14 +69,14 @@ def setup(self):
out_pData[0].plugin_data_setup('VOLUME_XZ', 'single')

def pre_process(self):
# accessing Ccpi-RGLTK modules
# accessing Ccpi-RGL toolkit modules
self.device = 'cpu'
if (self.parameters['method'] == 'ROF_TV'):
# set parameters for the ROF-TV method
self.pars = {'algorithm': self.parameters['method'], \
'regularisation_parameter':self.parameters['reg_par'],\
'number_of_iterations': self.parameters['max_iterations'],\
'time_marching_parameter': self.parameters['time_step']}
'time_marching_parameter': self.parameters['time_step']}
if (self.parameters['method'] == 'FGP_TV'):
# set parameters for the FGP-TV method
self.pars = {'algorithm': self.parameters['method'], \
Expand All @@ -85,7 +85,7 @@ def pre_process(self):
'tolerance_constant':1e-06,\
'methodTV': 0 ,\
'nonneg': 0 ,\
'printingOut': 0}
'printingOut': 0}
if (self.parameters['method'] == 'SB_TV'):
# set parameters for the SB-TV method
self.pars = {'algorithm': self.parameters['method'], \
Expand Down Expand Up @@ -125,8 +125,8 @@ def pre_process(self):
'edge_parameter':self.parameters['edge_par'],\
'number_of_iterations': self.parameters['max_iterations'],\
'time_marching_parameter': self.parameters['time_step'],\
'penalty_type': penaltyNDF}
if (self.parameters['method'] == 'DIFF4th'):
'penalty_type': penaltyNDF}
if (self.parameters['method'] == 'Diff4th'):
# set parameters for the DIFF4th method
self.pars = {'algorithm': self.parameters['method'], \
'regularisation_parameter':self.parameters['reg_par'],\
Expand Down Expand Up @@ -165,7 +165,7 @@ def process_frames(self, data):
im_res = TGV(self.pars['input'],
self.pars['regularisation_parameter'],
self.pars['alpha1'],
self.pars['alpha0'],
self.pars['alpha0'],
self.pars['number_of_iterations'],
self.pars['LipshitzConstant'],self.device)
if (self.parameters['method'] == 'LLT_ROF'):
Expand All @@ -182,7 +182,7 @@ def process_frames(self, data):
self.pars['time_marching_parameter'],
self.pars['penalty_type'],self.device)
if (self.parameters['method'] == 'DIFF4th'):
im_res = DIFF4th(self.pars['input'],
im_res = Diff4th(self.pars['input'],
self.pars['regularisation_parameter'],
self.pars['edge_parameter'],
self.pars['number_of_iterations'],
Expand All @@ -192,7 +192,6 @@ def process_frames(self, data):
def post_process(self):
pass


def get_citation_information(self):
cite_info1 = CitationInformation()
cite_info1.name = 'citation1'
Expand Down Expand Up @@ -426,4 +425,4 @@ def get_citation_information(self):
"%I Springer\n")
cite_info2.doi = "doi: 10.1007/s11263-010-0330-1"

return [cite_info1, cite_info2]
return [cite_info1, cite_info2]
34 changes: 17 additions & 17 deletions savu/plugins/filters/ccpi_regul_toolkit_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from savu.plugins.driver.gpu_plugin import GpuPlugin
from savu.plugins.utils import register_plugin

from ccpi.filters.regularisers import ROF_TV, FGP_TV, SB_TV, TGV, LLT_ROF, NDF, DIFF4th
from ccpi.filters.regularisers import ROF_TV, FGP_TV, SB_TV, TGV, LLT_ROF, NDF, Diff4th
from savu.data.plugin_list import CitationInformation

@register_plugin
Expand All @@ -38,14 +38,14 @@ class CcpiRegulToolkitGpu(Plugin, GpuPlugin):
:param reg_par: Regularisation (smoothing) parameter. Default: 0.05.
:param max_iterations: Total number of iterations. Default: 200.
:param time_step: Time marching step, relevant for ROF_TV, LLT_ROF,\
NDF, DIFF4th methods. Default: 0.001.
NDF, Diff4th methods. Default: 0.001.
:param lipshitz_constant: TGV method, Lipshitz constant. Default: 12.
:param alpha1: TGV method, parameter to control the 1st-order term. Default: 1.0.
:param alpha0: TGV method, parameter to control the 2nd-order term. Default: 0.8.
:param reg_parLLT: LLT-ROF method, parameter to control the 2nd-order term. Default: 0.05.
:param penalty_type: NDF method, Penalty type for the duffison, choose from\
huber, perona or tukey. Default: 'huber'.
:param edge_par: NDF and DIFF4th methods, noise magnitude parameter. Default: 0.01.
:param edge_par: NDF and Diff4th methods, noise magnitude parameter. Default: 0.01.
"""

def __init__(self):
Expand Down Expand Up @@ -74,7 +74,7 @@ def pre_process(self):
self.pars = {'algorithm': self.parameters['method'], \
'regularisation_parameter':self.parameters['reg_par'],\
'number_of_iterations': self.parameters['max_iterations'],\
'time_marching_parameter': self.parameters['time_step']}
'time_marching_parameter': self.parameters['time_step']}
if (self.parameters['method'] == 'FGP_TV'):
# set parameters for the FGP-TV method
self.pars = {'algorithm': self.parameters['method'], \
Expand All @@ -83,30 +83,30 @@ def pre_process(self):
'tolerance_constant':1e-06,\
'methodTV': 0 ,\
'nonneg': 0 ,\
'printingOut': 0}
'printingOut': 0}
if (self.parameters['method'] == 'SB_TV'):
# set parameters for the SB-TV method
self.pars = {'algorithm': self.parameters['method'], \
'regularisation_parameter':self.parameters['reg_par'],\
'number_of_iterations': self.parameters['max_iterations'],\
'tolerance_constant':1e-06,\
'methodTV': 0 ,\
'printingOut': 0}
'printingOut': 0}
if (self.parameters['method'] == 'TGV'):
# set parameters for the TGV method
self.pars = {'algorithm': self.parameters['method'], \
'regularisation_parameter' : self.parameters['reg_par'],\
'alpha1' : self.parameters['alpha1'],\
'alpha0': self.parameters['alpha0'],\
'number_of_iterations': self.parameters['max_iterations'],\
'LipshitzConstant' :self.parameters['lipshitz_constant']}
'LipshitzConstant' :self.parameters['lipshitz_constant']}
if (self.parameters['method'] == 'LLT_ROF'):
# set parameters for the LLT-ROF method
self.pars = {'algorithm': self.parameters['method'], \
'regularisation_parameter':self.parameters['reg_par'],\
'regularisation_parameterLLT':self.parameters['reg_parLLT'], \
'number_of_iterations': self.parameters['max_iterations'],\
'time_marching_parameter': self.parameters['time_step']}
'time_marching_parameter': self.parameters['time_step']}
if (self.parameters['method'] == 'NDF'):
# set parameters for the NDF method
if (self.parameters['penalty_type'] == 'huber'):
Expand All @@ -124,8 +124,8 @@ def pre_process(self):
'number_of_iterations': self.parameters['max_iterations'],\
'time_marching_parameter': self.parameters['time_step'],\
'penalty_type': penaltyNDF}
if (self.parameters['method'] == 'DIFF4th'):
# set parameters for the DIFF4th method
if (self.parameters['method'] == 'Diff4th'):
# set parameters for the Diff4th method
self.pars = {'algorithm': self.parameters['method'], \
'regularisation_parameter':self.parameters['reg_par'],\
'edge_parameter':self.parameters['edge_par'],\
Expand All @@ -143,44 +143,44 @@ def process_frames(self, data):
im_res = ROF_TV(self.pars['input'],
self.pars['regularisation_parameter'],
self.pars['number_of_iterations'],
self.pars['time_marching_parameter'],self.device)
self.pars['time_marching_parameter'],self.device)
if (self.parameters['method'] == 'FGP_TV'):
im_res = FGP_TV(self.pars['input'],
self.pars['regularisation_parameter'],
self.pars['number_of_iterations'],
self.pars['tolerance_constant'],
self.pars['methodTV'],
self.pars['nonneg'],
self.pars['printingOut'],self.device )
self.pars['printingOut'],self.device )
if (self.parameters['method'] == 'SB_TV'):
im_res = SB_TV(self.pars['input'],
self.pars['regularisation_parameter'],
self.pars['number_of_iterations'],
self.pars['tolerance_constant'],
self.pars['methodTV'],
self.pars['printingOut'],self.device)
self.pars['printingOut'],self.device)
if (self.parameters['method'] == 'TGV'):
im_res = TGV(self.pars['input'],
self.pars['regularisation_parameter'],
self.pars['alpha1'],
self.pars['alpha0'],
self.pars['number_of_iterations'],
self.pars['LipshitzConstant'],self.device)
self.pars['LipshitzConstant'],self.device)
if (self.parameters['method'] == 'LLT_ROF'):
im_res = LLT_ROF(self.pars['input'],
self.pars['regularisation_parameter'],
self.pars['regularisation_parameterLLT'],
self.pars['number_of_iterations'],
self.pars['time_marching_parameter'],self.device)
self.pars['time_marching_parameter'],self.device)
if (self.parameters['method'] == 'NDF'):
im_res = NDF(self.pars['input'],
self.pars['regularisation_parameter'],
self.pars['edge_parameter'],
self.pars['number_of_iterations'],
self.pars['time_marching_parameter'],
self.pars['penalty_type'],self.device)
self.pars['penalty_type'],self.device)
if (self.parameters['method'] == 'DIFF4th'):
im_res = DIFF4th(self.pars['input'],
im_res = Diff4th(self.pars['input'],
self.pars['regularisation_parameter'],
self.pars['edge_parameter'],
self.pars['number_of_iterations'],
Expand Down
45 changes: 25 additions & 20 deletions savu/plugins/reconstructions/fista_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from savu.plugins.driver.gpu_plugin import GpuPlugin

import numpy as np
# install FISTA-tomo with: conda install -c dkazanc fista-tomo
# or from https://github.com/dkazanc/FISTA-tomo
from fista.tomo.recModIter import RecTools
# FISTA algorithm is a part of TomoRec software
# https://github.com/dkazanc/TomoRec
from tomorec.methodsIR import RecToolsIR

from savu.plugins.utils import register_plugin
from scipy import ndimage
Expand All @@ -37,23 +37,24 @@
class FistaRecon(BaseRecon, GpuPlugin):
"""
A Plugin to reconstruct data by using FISTA iterative algorithm implemented \
in FISTA-tomo package. Dependencies on FISTA-tomo, ASTRA toolbox and CCPi RGL toolkit: \
in TomoRec package. Dependencies on TomoRec, ASTRA toolbox and CCPi RGL toolkit: \
https://github.com/vais-ral/CCPi-Regularisation-Toolkit.
:param iterationsFISTA: Number of FISTA iterations. Default: 200.
:param iterationsFISTA: Number of FISTA iterations. Default: 20.
:param datafidelity: Data fidelity, Least Squares only at the moment. Default: 'LS'.
:param ordersubsets: The number of ordered-subsets to accelerate reconstruction. Default: None.
:param nonnegativity: Nonnegativity constraint, choose on or None. Default: 'ENABLE'.
:param ordersubsets: The number of ordered-subsets to accelerate reconstruction. Default: 6.
:param converg_const: Lipschitz constant, can be set to a value. Default: 'power'.
:param regularisation: To regularise choose ROF_TV, FGP_TV, SB_TV, LLT_ROF,\
TGV, NDF, DIFF4th. Default: 'FGP_TV'.
NDF, Diff4th. Default: 'FGP_TV'.
:param regularisation_parameter: Regularisation (smoothing) value, higher \
the value stronger the smoothing effect. Default: 0.005.
:param regularisation_iterations: The number of regularisation iterations. Default: 50.
the value stronger the smoothing effect. Default: 0.001.
:param regularisation_iterations: The number of regularisation iterations. Default: 170.
:param time_marching_parameter: Time marching parameter, relevant for \
(ROF_TV, LLT_ROF, NDF, DIFF4th) penalties. Default: 0.0025.
:param edge_param: Edge (noise) related parameter, relevant for NDF and DIFF4th. Default: 0.01.
:param NDF_penalty: NDF specific penalty type: 1 - Huber, 2 - Perona-Malik, 3 - Tukey Biweight. Default: 1.
(ROF_TV, LLT_ROF, NDF, Diff4th) penalties. Default: 0.0025.
:param edge_param: Edge (noise) related parameter, relevant for NDF and Diff4th. Default: 0.01.
:param regularisation_parameter2: Regularisation (smoothing) value for LLT_ROF method. Default: 0.005.
:param NDF_penalty: NDF specific penalty type Huber, Perona, Tukey. Default: 'Huber'.
"""

def __init__(self):
Expand All @@ -68,6 +69,7 @@ def pre_process(self):
# extract given parameters
self.iterationsFISTA = self.parameters['iterationsFISTA']
self.datafidelity = self.parameters['datafidelity']
self.nonnegativity = self.parameters['nonnegativity']
self.ordersubsets = self.parameters['ordersubsets']
self.converg_const = self.parameters['converg_const']
self.regularisation = self.parameters['regularisation']
Expand All @@ -76,7 +78,8 @@ def pre_process(self):
self.time_marching_parameter = self.parameters['time_marching_parameter']
self.edge_param = self.parameters['edge_param']
self.NDF_penalty = self.parameters['NDF_penalty']
self.Rectools = None

self.RecToolsIR = None
if (self.ordersubsets > 1):
self.regularisation_iterations = (int)(self.parameters['regularisation_iterations']/self.ordersubsets) + 1
else:
Expand All @@ -92,31 +95,33 @@ def process_frames(self, data):
# Lipschitz constant if not given explicitly
self.setup_Lipschitz_constant()

# Run FISTA reconstrucion algorithm
# Run FISTA reconstrucion algorithm here
recon = self.Rectools.FISTA(sino,\
iterationsFISTA = self.iterationsFISTA,\
regularisation = self.regularisation,\
regularisation_parameter = self.regularisation_parameter,\
regularisation_iterations = self.regularisation_iterations,\
regularisation_parameter2 = self.regularisation_parameter2,\
time_marching_parameter = self.time_marching_parameter,\
edge_param = self.edge_param,\
NDF_penalty = self.NDF_penalty,\
tolerance_regul = 1e-10,\
edge_param = self.edge_param,\
lipschitz_const = self.Lipschitz_const)
return recon

def setup_Lipschitz_constant(self):
if self.Rectools is not None:
if self.RecToolsIR is not None:
return

# set parameters and initiate a fista-omo class object
self.Rectools = RecTools(DetectorsDimH = self.DetectorsDimH, # DetectorsDimH # detector dimension (horizontal)
# set parameters and initiate a TomoRec class object
self.Rectools = RecToolsIR(DetectorsDimH = self.DetectorsDimH, # DetectorsDimH # detector dimension (horizontal)
DetectorsDimV = None, # DetectorsDimV # detector dimension (vertical) for 3D case only
AnglesVec = self.angles, # array of angles in radians
ObjSize = self.vol_shape[0] , # a scalar to define reconstructed object dimensions
datafidelity=self.datafidelity,# data fidelity, choose LS, PWLS (wip), GH (wip), Student (wip)
nonnegativity=self.nonnegativity, # enable nonnegativity constraint (set to 'on')
OS_number = self.ordersubsets, # the number of subsets, NONE/(or > 1) ~ classical / ordered subsets
tolerance = 1e-08, # tolerance to stop outer iterations earlier
tolerance = 1e-9, # tolerance to stop outer iterations earlier
device='gpu')
if (self.parameters['converg_const'] == 'power'):
self.Lipschitz_const = self.Rectools.powermethod() # calculate Lipschitz constant
Expand Down Expand Up @@ -156,4 +161,4 @@ def get_citation_information(self):
"%D 2009\n" +
"%I SIAM\n")
cite_info1.doi = "doi: "
return cite_info1
return cite_info1
Binary file modified test_data/test_process_lists/fista_test.nxs
Binary file not shown.

0 comments on commit 8eb68bb

Please sign in to comment.