Skip to content

Commit

Permalink
Merge pull request #31 from juglab/issue_20
Browse files Browse the repository at this point in the history
Issue 20
  • Loading branch information
tibuch authored Sep 10, 2019
2 parents 71b1769 + f479967 commit 2326bb6
Show file tree
Hide file tree
Showing 18 changed files with 658 additions and 280 deletions.
37 changes: 9 additions & 28 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,22 @@ env:
matrix:
include:
- os: linux
python: 3.5
- os: linux
python: 3.6
###########
- os: linux
python: 3.6
if: branch = master
env: TENSORFLOW='tensorflow<1.12' KERAS='keras'
- os: linux
dist: bionic
python: 3.6
if: branch = master
env: TENSORFLOW='tensorflow<1.11' KERAS='keras'
env: TENSORFLOW='tensorflow==1.14.0' KERAS='keras==2.2.4'
- os: linux
dist: bionic
python: 3.6
if: branch = master
env: TENSORFLOW='tensorflow<1.10' KERAS='keras==2.2.2'
env: TENSORFLOW='tensorflow==1.12.0' KERAS='keras==2.2.4'
- os: linux
python: 3.6
if: branch = master
# env: TENSORFLOW='tensorflow<1.9' KERAS='keras==2.2.0' # causes segmentation fault, why?
env: TENSORFLOW='tensorflow<1.8' KERAS='keras==2.2.0'
- os: linux
python: 3.6
if: branch = master
env: TENSORFLOW='tensorflow<1.8' KERAS='keras==2.1.6'
- os: linux
python: 3.6
if: branch = master
# lowest supported keras version, last tensorflow release that supports CUDA 8
env: TENSORFLOW='tensorflow==1.4.1' KERAS='keras==2.1.6'
###########
dist: bionic
python: 3.7
env: TENSORFLOW='tensorflow==1.14.0' KERAS='keras==2.2.4'

install:
- pip install $TENSORFLOW $KERAS
- pip install .

script:
- pytest -v --durations=50
- cd tests; pytest -v -s test*.py
- cd functional; ./test_training2D_RGB.py; ./test_prediction2D_RGB.py; ./test_training3D.py; ./test_prediction3D.py
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ Our implementation is based on [CSBDEEP](http://csbdeep.bioimagecomputing.com) (

## Installation
This implementation requires [Tensorflow](https://www.tensorflow.org/install/).
We have tested Noise2Void on LinuxMint 18.3 using python 3.6 and tensorflow-gpu 1.12.0.
We have tested Noise2Void on LinuxMint 19 and Ubuntu 18.0 using python 3.6 and 3.7 and tensorflow-gpu 1.12.0 and 1.14.0.

#### If you start from scratch...
We recommend using [conda](https://docs.conda.io/en/latest/miniconda.html).
We recommend using [miniconda](https://docs.conda.io/en/latest/miniconda.html).
If you do not yet have a strong opinion, just use it too!

After installing Miniconda, the following lines might are likely the easiest way to get Tensorflow and CuDNN installed on your machine (_Note:_ Macs are not supported, and if you sit on a Windows machine all this might also require some modifications.):

```
$ conda create -n 'n2v' python=3.6
$ source activate n2v
$ conda install tensorflow-gpu==1.12
$ conda install tensorflow-gpu==1.14
$ pip install jupyter
```

Expand All @@ -35,11 +35,13 @@ $ pip install n2v
```

#### Option 2: Git-Clone and install from sources
Or clone the repository:
This option is ideal if you want to edit the code. Clone the repository:

```
$ git clone https://github.com/juglab/n2v.git
```
Change into its directory and install it:

```
$ cd n2v
$ pip install -e .
Expand Down
93 changes: 41 additions & 52 deletions examples/2D/denoising2D_RGB/01_training.ipynb

Large diffs are not rendered by default.

11 changes: 4 additions & 7 deletions examples/2D/denoising2D_RGB/02_prediction.ipynb

Large diffs are not rendered by default.

58 changes: 34 additions & 24 deletions examples/2D/denoising2D_SEM/01_training.ipynb

Large diffs are not rendered by default.

30 changes: 19 additions & 11 deletions examples/2D/denoising2D_SEM/02_prediction.ipynb

Large diffs are not rendered by default.

96 changes: 42 additions & 54 deletions examples/3D/01_training.ipynb

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions examples/3D/02_prediction.ipynb

Large diffs are not rendered by default.

169 changes: 94 additions & 75 deletions n2v/models/n2v_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import keras.backend as K

from csbdeep.models import BaseConfig
from csbdeep.utils import _raise, axes_check_and_normalize, axes_dict, backend_channels_last

from six import string_types

import numpy as np
from logging.config import BaseConfigurator

# This class is a adapted version of csbdeep.models.config.py.
class N2VConfig(argparse.Namespace):
Expand Down Expand Up @@ -68,83 +70,87 @@ class N2VConfig(argparse.Namespace):
"""

def __init__(self, X, **kwargs):
"""See class docstring."""

assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

n_dim = len(X.shape) - 2
n_channel_in = X.shape[-1]
n_channel_out = n_channel_in
mean = np.mean(X)
std = np.std(X)

if n_dim == 2:
axes = 'SYXC'
elif n_dim == 3:
axes = 'SZYXC'

# parse and check axes
axes = axes_check_and_normalize(axes)
ax = axes_dict(axes)
ax = {a: (ax[a] is not None) for a in ax}

(ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))

axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
axes = axes.replace('S','') # remove sample axis if it exists

if backend_channels_last():
if ax['C']:
axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))

# X is empty if config is None
if (X.size != 0):

assert len(X.shape) == 4 or len(X.shape) == 5, "Only 'SZYXC' or 'SYXC' as dimensions is supported."

n_dim = len(X.shape) - 2
n_channel_in = X.shape[-1]
n_channel_out = n_channel_in
mean = np.mean(X)
std = np.std(X)

if n_dim == 2:
axes = 'SYXC'
elif n_dim == 3:
axes = 'SZYXC'

# parse and check axes
axes = axes_check_and_normalize(axes)
ax = axes_dict(axes)
ax = {a: (ax[a] is not None) for a in ax}

(ax['X'] and ax['Y']) or _raise(ValueError('lateral axes X and Y must be present.'))
not (ax['Z'] and ax['T']) or _raise(ValueError('using Z and T axes together not supported.'))

axes.startswith('S') or (not ax['S']) or _raise(ValueError('sample axis S must be first.'))
axes = axes.replace('S','') # remove sample axis if it exists

if backend_channels_last():
if ax['C']:
axes[-1] == 'C' or _raise(ValueError('channel axis must be last for backend (%s).' % K.backend()))
else:
axes += 'C'
else:
axes += 'C'
else:
if ax['C']:
axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
if ax['C']:
axes[0] == 'C' or _raise(ValueError('channel axis must be first for backend (%s).' % K.backend()))
else:
axes = 'C'+axes

# normalization parameters
self.mean = str(mean)
self.std = str(std)
# directly set by parameters
self.n_dim = n_dim
self.axes = axes
self.n_channel_in = int(n_channel_in)
self.n_channel_out = int(n_channel_out)

# default config (can be overwritten by kwargs below)
self.unet_residual = False
self.unet_n_depth = 2
self.unet_kern_size = 5 if self.n_dim==2 else 3
self.unet_n_first = 32
self.unet_last_activation = 'linear'
if backend_channels_last():
self.unet_input_shape = self.n_dim*(None,) + (self.n_channel_in,)
else:
axes = 'C'+axes

# normalization parameters
self.mean = str(mean)
self.std = str(std)
# directly set by parameters
self.n_dim = n_dim
self.axes = axes
self.n_channel_in = int(n_channel_in)
self.n_channel_out = int(n_channel_out)

# default config (can be overwritten by kwargs below)
self.unet_residual = False
self.unet_n_depth = 2
self.unet_kern_size = 5 if self.n_dim==2 else 3
self.unet_n_first = 32
self.unet_last_activation = 'linear'
if backend_channels_last():
self.unet_input_shape = self.n_dim*(None,) + (self.n_channel_in,)
else:
self.unet_input_shape = (self.n_channel_in,) + self.n_dim*(None,)

self.train_loss = 'mae'
self.train_epochs = 100
self.train_steps_per_epoch = 400
self.train_learning_rate = 0.0004
self.train_batch_size = 16
self.train_tensorboard = True
self.train_checkpoint = 'weights_best.h5'
self.train_reduce_lr = {'factor': 0.5, 'patience': 10}
self.batch_norm = True
self.n2v_perc_pix = 1.5
self.n2v_patch_shape = (64, 64) if self.n_dim==2 else (64, 64, 64)
self.n2v_manipulator = 'uniform_withCP'
self.n2v_neighborhood_radius = 5

# disallow setting 'n_dim' manually
try:
del kwargs['n_dim']
# warnings.warn("ignoring parameter 'n_dim'")
except:
pass
self.unet_input_shape = (self.n_channel_in,) + self.n_dim*(None,)

self.train_loss = 'mae'
self.train_epochs = 100
self.train_steps_per_epoch = 400
self.train_learning_rate = 0.0004
self.train_batch_size = 16
self.train_tensorboard = True
self.train_checkpoint = 'weights_best.h5'
self.train_reduce_lr = {'factor': 0.5, 'patience': 10}
self.batch_norm = True
self.n2v_perc_pix = 1.5
self.n2v_patch_shape = (64, 64) if self.n_dim==2 else (64, 64, 64)
self.n2v_manipulator = 'uniform_withCP'
self.n2v_neighborhood_radius = 5

# disallow setting 'n_dim' manually
try:
del kwargs['n_dim']
# warnings.warn("ignoring parameter 'n_dim'")
except:
pass

self.probabilistic = False

for k in kwargs:
setattr(self, k, kwargs[k])
Expand Down Expand Up @@ -215,3 +221,16 @@ def _is_int(v,low=None,high=None):
return all(ok.values()), tuple(k for (k,v) in ok.items() if not v)
else:
return all(ok.values())

def update_parameters(self, allow_new=True, **kwargs):
if not allow_new:
attr_new = []
for k in kwargs:
try:
getattr(self, k)
except AttributeError:
attr_new.append(k)
if len(attr_new) > 0:
raise AttributeError("Not allowed to add new parameters (%s)" % ', '.join(attr_new))
for k in kwargs:
setattr(self, k, kwargs[k])
61 changes: 47 additions & 14 deletions n2v/models/n2v_standard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from csbdeep.models import CARE
from csbdeep.utils import _raise, axes_check_and_normalize, axes_dict
from csbdeep.utils import _raise, axes_check_and_normalize, axes_dict, load_json, save_json
from csbdeep.internals import nets
from csbdeep.internals.predict import Progress

Expand Down Expand Up @@ -64,25 +64,32 @@ class N2V(CARE):
"""

def __init__(self, config, name=None, basedir='.'):
"""See class docstring"""
config is None or isinstance(config, N2VConfig) or _raise(ValueError('Invalid configuration: %s' % str(config)))
"""See class docstring."""

config is None or isinstance(config,self._config_class) or _raise (
ValueError("Invalid configuration of type '%s', was expecting type '%s'." % (type(config).__name__, self._config_class.__name__))
)
if config is not None and not config.is_valid():
invalid_attr = config.is_valid(True)[1]
raise ValueError('Invalid configuration attributes: ' + ', '.join(invalid_attr))
(not (config is None and basedir is None)) or _raise(ValueError())
(not (config is None and basedir is None)) or _raise(ValueError("No config provided and cannot be loaded from disk since basedir=None."))

name is None or isinstance(name, string_types) or _raise(ValueError())
basedir is None or isinstance(basedir, (string_types, Path)) or _raise(ValueError())
name is None or (isinstance(name,string_types) and len(name)>0) or _raise(ValueError("No valid name: '%s'" % str(name)))
basedir is None or isinstance(basedir,(string_types,Path)) or _raise(ValueError("No valid basedir: '%s'" % str(basedir)))
self.config = config
self.name = name if name is not None else datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S.%f")
self.basedir = Path(basedir) if basedir is not None else None
if config is not None:
# config was provided -> update before it is saved to disk
self._update_and_check_config()
self._set_logdir()
if config is None:
# config was loaded from disk -> update it after loading
self._update_and_check_config()
self._model_prepared = False
self.keras_model = self._build()
if config is None:
self._find_and_load_weights()
else:
config.probabilistic = False


def _build(self):
Expand Down Expand Up @@ -205,8 +212,8 @@ def train(self, X, validation_X, epochs=None, steps_per_epoch=None):
self.config.train_batch_size, int(train_num_pix/100 * self.config.n2v_perc_pix),
self.config.n2v_patch_shape, manipulator)

# validation_Y is also validation_X plus a concatinated masking channel.
# To speed things up, we precomupte the masking vo the validation data.
# validation_Y is also validation_X plus a concatenated masking channel.
# To speed things up, we precompute the masking vo the validation data.
validation_Y = np.concatenate((validation_X, np.zeros(validation_X.shape, dtype=validation_X.dtype)), axis=axes.index('C'))
n2v_utils.manipulate_val_data(validation_X, validation_Y,
num_pix=int(val_num_pix/100 * self.config.n2v_perc_pix),
Expand Down Expand Up @@ -272,13 +279,16 @@ def on_epoch_end(self, epoch, logs=None):
if epoch % self.freq == 0:
# TODO: implement batched calls to sess.run
# (current call will likely go OOM on GPU)
tensors = self.model.inputs + self.gt_outputs + self.model.sample_weights
if self.model.uses_learning_phase:
cut_v_data = len(self.model.inputs)
val_data = [self.validation_data[0][:self.n_images]] + [0]
tensors = self.model.inputs + [K.learning_phase()]
tensors += [K.learning_phase()]
val_data = list(v[:self.n_images] for v in self.validation_data[:-1])
val_data += self.validation_data[-1:]
else:
val_data = list(v[:self.n_images] for v in self.validation_data)
tensors = self.model.inputs
# GIT issue 20: We need to remove the masking component from the validation data to prevent crash.
end_index = (val_data[1].shape)[-1]//2
val_data[1] = val_data[1][...,:end_index]
feed_dict = dict(zip(tensors, val_data))
result = self.sess.run([self.merged], feed_dict=feed_dict)
summary_str = result[0]
Expand Down Expand Up @@ -365,3 +375,26 @@ def predict(self, img, axes, resizer=PadAndCropResizer(), n_tiles=None):
pred = self._predict_mean_and_scale(normalized, axes=axes, normalizer=None, resizer=resizer, n_tiles=n_tiles)[0]

return self.__denormalize__(pred, mean, std)

def _set_logdir(self):
self.logdir = self.basedir / self.name

config_file = self.logdir / 'config.json'
if self.config is None:
if config_file.exists():
config_dict = load_json(str(config_file))
self.config = self._config_class(np.array([]), **config_dict)
if not self.config.is_valid():
invalid_attr = self.config.is_valid(True)[1]
raise ValueError('Invalid attributes in loaded config: ' + ', '.join(invalid_attr))
else:
raise FileNotFoundError("config file doesn't exist: %s" % str(config_file.resolve()))
else:
if self.logdir.exists():
warnings.warn('output path for model already exists, files may be overwritten: %s' % str(self.logdir.resolve()))
self.logdir.mkdir(parents=True, exist_ok=True)
save_json(vars(self.config), str(config_file))

@property
def _config_class(self):
return N2VConfig
2 changes: 1 addition & 1 deletion n2v/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.6'
__version__ = '0.1.7'
Loading

0 comments on commit 2326bb6

Please sign in to comment.