Skip to content

Commit

Permalink
Merge pull request #30 from juglab/scriptfix
Browse files Browse the repository at this point in the history
Scriptfix
  • Loading branch information
alex-krull authored Aug 8, 2019
2 parents 8791cac + bdf6d85 commit f4e7ead
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
2 changes: 1 addition & 1 deletion n2v/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.4'
__version__ = '0.1.6'
36 changes: 26 additions & 10 deletions scripts/predictN2V.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import sys
import argparse
from glob import glob
import csbdeep.io

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--baseDir", help="directory in which all your network will live", default='models')
parser.add_argument("--name", help="name of your network", default='N2V2D')
parser.add_argument("--name", help="name of your network", default='N2V')
parser.add_argument("--dataPath", help="The path to your data")
parser.add_argument("--fileName", help="name of your data file", default="*.tif")
parser.add_argument("--output", help="The path to which your data is to be saved", default='.')
Expand All @@ -20,6 +21,8 @@

args = parser.parse_args()

print(args.output)

assert (not 'T' in args.dims) or (args.dims[0]=='T')

# We import all our dependencies.
Expand All @@ -28,7 +31,6 @@
import numpy as np
from matplotlib import pyplot as plt
from tifffile import imread
from tifffile import imwrite


# A previously trained model is loaded by creating a new N2V-object without providing a 'config'.
Expand All @@ -48,31 +50,45 @@
datagen = N2V_DataGenerator()
imgs = datagen.load_imgs_from_directory(directory = args.dataPath, dims=args.dims, filter=args.fileName)


files = glob(os.path.join(args.dataPath, args.fileName))
files.sort()

for i, img in enumerate(imgs):
img_=img

if 'Z' in args.dims:
myDims='TZYXC'
else:
myDims='TYXC'

if not 'C' in args.dims :
img_=img[...,0]
myDims=myDims[:-1]

myDims_=myDims[1:]


if not 'C' in args.dims :
img_=img[...,0]

# if we have a time dimension we process the images one by one
if args.dims[0]=='T':
outDims=myDims
pred=img_.copy()
myDims=args.dims[1:]


for j in range(img_.shape[0]):
pred[j] = model.predict( img_[j], axes=myDims, n_tiles=tiles)
print('predicting slice', j, img_[j].shape, myDims_)
pred[j] = model.predict( img_[j], axes=myDims_, n_tiles=tiles)
else:
outDims=myDims_
img_=img_[0,...]
print("denoising image "+str(i+1) +" of "+str(len(imgs)))
# Denoise the image.
print(args.dims)
pred = model.predict( img_, axes=args.dims, n_tiles=tiles)

pred = model.predict( img_, axes=myDims_, n_tiles=tiles)

print(pred.shape)
outpath=args.output
filename=os.path.basename(files[i]).replace('.tif','_N2V.tif')
outpath=os.path.join(outpath,filename)
imwrite(filename,pred.astype(np.float32))
print('writing file to ',outpath, outDims, pred.shape)
csbdeep.io.save_tiff_imagej_compatible(outpath, pred.astype(np.float32), outDims)
9 changes: 5 additions & 4 deletions scripts/trainN2V.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--baseDir", help="base directory in which your network will live", default='models')
parser.add_argument("--name", help="name of your network", default='N2V3D')
parser.add_argument("--name", help="name of your network", default='N2V')
parser.add_argument("--dataPath", help="The path to your training data")
parser.add_argument("--fileName", help="name of your training data file", default="*.tif")
parser.add_argument("--validationFraction", help="Fraction of data you want to use for validation (percent)", default=10.0, type=float)
parser.add_argument("--validationFraction", help="Fraction of data you want to use for validation (percent)", default=5.0, type=float)
parser.add_argument("--dims", help="dimensions of your data, can include: X,Y,Z,C (channel), T (time)", default='YX')
parser.add_argument("--patchSizeXY", help="XY-size of your training patches", default=64, type=int)
parser.add_argument("--patchSizeZ", help="Z-size of your training patches", default=64, type=int)
parser.add_argument("--epochs", help="number of training epochs", default=100, type=int)
parser.add_argument("--stepsPerEpoch", help="number training steps per epoch", default=5, type=int)
parser.add_argument("--stepsPerEpoch", help="number training steps per epoch", default=400, type=int)
parser.add_argument("--batchSize", help="size of your training batches", default=64, type=int)
parser.add_argument("--netDepth", help="depth of your U-Net", default=2, type=int)
parser.add_argument("--netKernelSize", help="Size of conv. kernels in first layer", default=3, type=int)
parser.add_argument("--n2vPercPix", help="percentage of pixels to manipulated by N2V", default=1.6, type=float)
parser.add_argument("--learningRate", help="initial learning rate", default=0.0004, type=float)

parser.add_argument("--unet_n_first", help="number of feature channels in the first u-net layer", default=32, type=int)

if len(sys.argv)==1:
parser.print_help(sys.stderr)
Expand Down Expand Up @@ -83,6 +83,7 @@
train_batch_size=args.batchSize, n2v_perc_pix=args.n2vPercPix, n2v_patch_shape=pshape,
n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, train_learning_rate=args.learningRate,
unet_n_depth=args.netDepth,
unet_n_first=args.unet_n_first
)

# Let's look at the parameters stored in the config-object.
Expand Down

0 comments on commit f4e7ead

Please sign in to comment.