Skip to content

Commit

Permalink
predict script is now working with high dimensional data and writes t…
Browse files Browse the repository at this point in the history
…iffs in imagej-conform way
  • Loading branch information
Alexander Krull committed Aug 8, 2019
1 parent fc4b9e3 commit bdf6d85
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
2 changes: 1 addition & 1 deletion n2v/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.5'
__version__ = '0.1.6'
36 changes: 26 additions & 10 deletions scripts/predictN2V.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import sys
import argparse
from glob import glob
import csbdeep.io

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--baseDir", help="directory in which all your network will live", default='models')
parser.add_argument("--name", help="name of your network", default='N2V2D')
parser.add_argument("--name", help="name of your network", default='N2V')
parser.add_argument("--dataPath", help="The path to your data")
parser.add_argument("--fileName", help="name of your data file", default="*.tif")
parser.add_argument("--output", help="The path to which your data is to be saved", default='.')
Expand All @@ -20,6 +21,8 @@

args = parser.parse_args()

print(args.output)

assert (not 'T' in args.dims) or (args.dims[0]=='T')

# We import all our dependencies.
Expand All @@ -28,7 +31,6 @@
import numpy as np
from matplotlib import pyplot as plt
from tifffile import imread
from tifffile import imwrite


# A previously trained model is loaded by creating a new N2V-object without providing a 'config'.
Expand All @@ -48,31 +50,45 @@
datagen = N2V_DataGenerator()
imgs = datagen.load_imgs_from_directory(directory = args.dataPath, dims=args.dims, filter=args.fileName)


files = glob(os.path.join(args.dataPath, args.fileName))
files.sort()

for i, img in enumerate(imgs):
img_=img

if 'Z' in args.dims:
myDims='TZYXC'
else:
myDims='TYXC'

if not 'C' in args.dims :
img_=img[...,0]
myDims=myDims[:-1]

myDims_=myDims[1:]


if not 'C' in args.dims :
img_=img[...,0]

# if we have a time dimension we process the images one by one
if args.dims[0]=='T':
outDims=myDims
pred=img_.copy()
myDims=args.dims[1:]


for j in range(img_.shape[0]):
pred[j] = model.predict( img_[j], axes=myDims, n_tiles=tiles)
print('predicting slice', j, img_[j].shape, myDims_)
pred[j] = model.predict( img_[j], axes=myDims_, n_tiles=tiles)
else:
outDims=myDims_
img_=img_[0,...]
print("denoising image "+str(i+1) +" of "+str(len(imgs)))
# Denoise the image.
print(args.dims)
pred = model.predict( img_, axes=args.dims, n_tiles=tiles)

pred = model.predict( img_, axes=myDims_, n_tiles=tiles)

print(pred.shape)
outpath=args.output
filename=os.path.basename(files[i]).replace('.tif','_N2V.tif')
outpath=os.path.join(outpath,filename)
imwrite(filename,pred.astype(np.float32))
print('writing file to ',outpath, outDims, pred.shape)
csbdeep.io.save_tiff_imagej_compatible(outpath, pred.astype(np.float32), outDims)
9 changes: 5 additions & 4 deletions scripts/trainN2V.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--baseDir", help="base directory in which your network will live", default='models')
parser.add_argument("--name", help="name of your network", default='N2V3D')
parser.add_argument("--name", help="name of your network", default='N2V')
parser.add_argument("--dataPath", help="The path to your training data")
parser.add_argument("--fileName", help="name of your training data file", default="*.tif")
parser.add_argument("--validationFraction", help="Fraction of data you want to use for validation (percent)", default=10.0, type=float)
parser.add_argument("--validationFraction", help="Fraction of data you want to use for validation (percent)", default=5.0, type=float)
parser.add_argument("--dims", help="dimensions of your data, can include: X,Y,Z,C (channel), T (time)", default='YX')
parser.add_argument("--patchSizeXY", help="XY-size of your training patches", default=64, type=int)
parser.add_argument("--patchSizeZ", help="Z-size of your training patches", default=64, type=int)
parser.add_argument("--epochs", help="number of training epochs", default=100, type=int)
parser.add_argument("--stepsPerEpoch", help="number training steps per epoch", default=5, type=int)
parser.add_argument("--stepsPerEpoch", help="number training steps per epoch", default=400, type=int)
parser.add_argument("--batchSize", help="size of your training batches", default=64, type=int)
parser.add_argument("--netDepth", help="depth of your U-Net", default=2, type=int)
parser.add_argument("--netKernelSize", help="Size of conv. kernels in first layer", default=3, type=int)
parser.add_argument("--n2vPercPix", help="percentage of pixels to manipulated by N2V", default=1.6, type=float)
parser.add_argument("--learningRate", help="initial learning rate", default=0.0004, type=float)

parser.add_argument("--unet_n_first", help="number of feature channels in the first u-net layer", default=32, type=int)

if len(sys.argv)==1:
parser.print_help(sys.stderr)
Expand Down Expand Up @@ -83,6 +83,7 @@
train_batch_size=args.batchSize, n2v_perc_pix=args.n2vPercPix, n2v_patch_shape=pshape,
n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate=args.learningRate,
unet_n_depth=args.netDepth,
unet_n_first=args.unet_n_first
)

# Let's look at the parameters stored in the config-object.
Expand Down

0 comments on commit bdf6d85

Please sign in to comment.