Skip to content

Commit

Permalink
Merge pull request #29 from juglab/scriptfix
Browse files Browse the repository at this point in the history
Creating entry points for scripts
  • Loading branch information
alex-krull authored Aug 7, 2019
2 parents 6e5c1a2 + 771274e commit 8791cac
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
2 changes: 1 addition & 1 deletion n2v/internals/N2V_DataGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def load_imgs_from_directory(self, directory, filter='*.tif', dims='YX'):
"""

files = glob(join(directory, filter))
files.sort()
files.sort()
return self.load_imgs(files, dims=dims)


Expand Down
21 changes: 12 additions & 9 deletions scripts/predictN2V.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
#!/usr/bin/env python3

import os
import sys
import argparse
from glob import glob

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--baseDir", help="directory in which all your network will live", default='models')
parser.add_argument("--name", help="name of your network", default='N2V3D')
parser.add_argument("--name", help="name of your network", default='N2V2D')
parser.add_argument("--dataPath", help="The path to your data")
parser.add_argument("--fileName", help="name of your data file", default="*.tif")
parser.add_argument("--output", help="The path to your data to be saved", default='predictions.tif')
parser.add_argument("--output", help="The path to which your data is to be saved", default='.')
parser.add_argument("--dims", help="dimensions of your data", default='YX')
parser.add_argument("--tile", help="will cut your image [TILE] times in every dimension to make it fit GPU memory", default=1, type=int)

Expand Down Expand Up @@ -45,9 +48,10 @@
datagen = N2V_DataGenerator()
imgs = datagen.load_imgs_from_directory(directory = args.dataPath, dims=args.dims, filter=args.fileName)

for i, img in enumerate(imgs):
print("img.shape",img.shape)
files = glob(os.path.join(args.dataPath, args.fileName))
files.sort()

for i, img in enumerate(imgs):
img_=img
if not 'C' in args.dims :
img_=img[...,0]
Expand All @@ -59,17 +63,16 @@


for j in range(img_.shape[0]):
print("img_[j].shape", img_[j].shape)
pred[j] = model.predict( img_[j], axes=myDims, n_tiles=tiles)
else:
img_=img_[0,...]
print("denoising image "+str(i) +" of "+str(len(imgs)))
print("denoising image "+str(i+1) +" of "+str(len(imgs)))
# Denoise the image.
print(args.dims)
pred = model.predict( img_, axes=args.dims, n_tiles=tiles)

print(pred.shape)
filename=args.output
if len(imgs) > 1:
filename=filename+'_'+str(i).zfill(4) +'.tif'
outpath=args.output
filename=os.path.basename(files[i]).replace('.tif','_N2V.tif')
outpath=os.path.join(outpath,filename)
imwrite(filename,pred.astype(np.float32))
2 changes: 2 additions & 0 deletions scripts/trainN2V.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

import os
import sys
import argparse
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
'Programming Language :: Python :: 3.6',
],

scripts=['scripts/trainN2V.py',
'scripts/predictN2V.py'
],

install_requires=[
"numpy",
"scipy",
Expand Down

0 comments on commit 8791cac

Please sign in to comment.