From bdf6d85d7af40e02ef718632e64e0ef82f8c3f0b Mon Sep 17 00:00:00 2001 From: Alexander Krull Date: Thu, 8 Aug 2019 11:55:28 +0200 Subject: [PATCH] predict script is now working with high dimensional data and writes tiffs in imagej-conform way --- n2v/version.py | 2 +- scripts/predictN2V.py | 36 ++++++++++++++++++++++++++---------- scripts/trainN2V.py | 9 +++++---- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/n2v/version.py b/n2v/version.py index 66a87bb..2fb2513 100644 --- a/n2v/version.py +++ b/n2v/version.py @@ -1 +1 @@ -__version__ = '0.1.5' +__version__ = '0.1.6' diff --git a/scripts/predictN2V.py b/scripts/predictN2V.py index cdc735d..8bab9f4 100644 --- a/scripts/predictN2V.py +++ b/scripts/predictN2V.py @@ -4,10 +4,11 @@ import sys import argparse from glob import glob +import csbdeep.io parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--baseDir", help="directory in which all your network will live", default='models') -parser.add_argument("--name", help="name of your network", default='N2V2D') +parser.add_argument("--name", help="name of your network", default='N2V') parser.add_argument("--dataPath", help="The path to your data") parser.add_argument("--fileName", help="name of your data file", default="*.tif") parser.add_argument("--output", help="The path to which your data is to be saved", default='.') @@ -20,6 +21,8 @@ args = parser.parse_args() +print(args.output) + assert (not 'T' in args.dims) or (args.dims[0]=='T') # We import all our dependencies. @@ -28,7 +31,6 @@ import numpy as np from matplotlib import pyplot as plt from tifffile import imread -from tifffile import imwrite # A previously trained model is loaded by creating a new N2V-object without providing a 'config'. @@ -48,31 +50,45 @@ datagen = N2V_DataGenerator() imgs = datagen.load_imgs_from_directory(directory = args.dataPath, dims=args.dims, filter=args.fileName) + files = glob(os.path.join(args.dataPath, args.fileName)) files.sort() for i, img in enumerate(imgs): img_=img + + if 'Z' in args.dims: + myDims='TZYXC' + else: + myDims='TYXC' + + if not 'C' in args.dims : + img_=img[...,0] + myDims=myDims[:-1] + + myDims_=myDims[1:] + + if not 'C' in args.dims : img_=img[...,0] # if we have a time dimension we process the images one by one if args.dims[0]=='T': + outDims=myDims pred=img_.copy() - myDims=args.dims[1:] - - for j in range(img_.shape[0]): - pred[j] = model.predict( img_[j], axes=myDims, n_tiles=tiles) + print('predicting slice', j, img_[j].shape, myDims_) + pred[j] = model.predict( img_[j], axes=myDims_, n_tiles=tiles) else: + outDims=myDims_ img_=img_[0,...] print("denoising image "+str(i+1) +" of "+str(len(imgs))) # Denoise the image. - print(args.dims) - pred = model.predict( img_, axes=args.dims, n_tiles=tiles) - + pred = model.predict( img_, axes=myDims_, n_tiles=tiles) + print(pred.shape) outpath=args.output filename=os.path.basename(files[i]).replace('.tif','_N2V.tif') outpath=os.path.join(outpath,filename) - imwrite(filename,pred.astype(np.float32)) + print('writing file to ',outpath, outDims, pred.shape) + csbdeep.io.save_tiff_imagej_compatible(outpath, pred.astype(np.float32), outDims) diff --git a/scripts/trainN2V.py b/scripts/trainN2V.py index 6d1b53e..9062ded 100644 --- a/scripts/trainN2V.py +++ b/scripts/trainN2V.py @@ -6,21 +6,21 @@ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--baseDir", help="base directory in which your network will live", default='models') -parser.add_argument("--name", help="name of your network", default='N2V3D') +parser.add_argument("--name", help="name of your network", default='N2V') parser.add_argument("--dataPath", help="The path to your training data") parser.add_argument("--fileName", help="name of your training data file", default="*.tif") -parser.add_argument("--validationFraction", help="Fraction of data you want to use for validation (percent)", default=10.0, type=float) +parser.add_argument("--validationFraction", help="Fraction of data you want to use for validation (percent)", default=5.0, type=float) parser.add_argument("--dims", help="dimensions of your data, can include: X,Y,Z,C (channel), T (time)", default='YX') parser.add_argument("--patchSizeXY", help="XY-size of your training patches", default=64, type=int) parser.add_argument("--patchSizeZ", help="Z-size of your training patches", default=64, type=int) parser.add_argument("--epochs", help="number of training epochs", default=100, type=int) -parser.add_argument("--stepsPerEpoch", help="number training steps per epoch", default=5, type=int) +parser.add_argument("--stepsPerEpoch", help="number training steps per epoch", default=400, type=int) parser.add_argument("--batchSize", help="size of your training batches", default=64, type=int) parser.add_argument("--netDepth", help="depth of your U-Net", default=2, type=int) parser.add_argument("--netKernelSize", help="Size of conv. kernels in first layer", default=3, type=int) parser.add_argument("--n2vPercPix", help="percentage of pixels to manipulated by N2V", default=1.6, type=float) parser.add_argument("--learningRate", help="initial learning rate", default=0.0004, type=float) - +parser.add_argument("--unet_n_first", help="number of feature channels in the first u-net layer", default=32, type=int) if len(sys.argv)==1: parser.print_help(sys.stderr) @@ -83,6 +83,7 @@ train_batch_size=args.batchSize, n2v_perc_pix=args.n2vPercPix, n2v_patch_shape=pshape, n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate=args.learningRate, unet_n_depth=args.netDepth, + unet_n_first=args.unet_n_first ) # Let's look at the parameters stored in the config-object.