Skip to content

Commit

Permalink
style edits and removing use_multiprocessing flag in args
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Dec 11, 2018
1 parent 0bab4b2 commit 508d877
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
15 changes: 10 additions & 5 deletions keras_retinanet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,8 @@ def csv_list(string):
parser.add_argument('--weighted-average', help='Compute the mAP using the weighted average of precisions among classes.', action='store_true')

#Fit generator arguments
parser.add_argument('--multiprocessing', help='Fit generator mulitprocessing for n worker processes', action='store_true')
parser.add_argument('--workers', help='Number of multiprocessing workers', default=1)
parser.add_argument('--max_queue_size', help='Queue length for multiprocessing workers in fit generator', default=10)
parser.add_argument('--workers', help='Number of multiprocessing workers. To disable multiprocessing, set workers to 0', default=1)
parser.add_argument('--max-queue-size', help='Queue length for multiprocessing workers in fit generator.', default=10)

return check_args(parser.parse_args(args))

Expand Down Expand Up @@ -483,7 +482,13 @@ def main(args=None):
validation_generator,
args,
)


#Use multiprocessing if workers > 0
if args.workers > 0:
use_multiprocessing = True
else:
use_multiprocessing = False

# start training
training_model.fit_generator(
generator=train_generator,
Expand All @@ -492,7 +497,7 @@ def main(args=None):
verbose=1,
callbacks=callbacks,
workers=args.workers,
use_multiprocessing=args.multiprocessing,
use_multiprocessing=use_multiprocessing,
max_queue_size=args.max_queue_size
)

Expand Down
10 changes: 7 additions & 3 deletions keras_retinanet/preprocessing/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from ..utils.transform import transform_aabb


class Generator(keras.utils.Sequence):
""" Abstract generator class.
"""
Expand Down Expand Up @@ -314,15 +315,18 @@ def compute_input_output(self, group):

return inputs, targets

#Keras Sequence methods
def __len__(self):
"""Number of batches for generator"""
"""
Number of batches for generator.
"""

return len(self.groups)

def __getitem__(self,index):
"""
Keras sequence method for generating batches
Keras sequence method for generating batches.
"""
group = self.groups[index]
inputs,targets=self.compute_input_output(group)

return inputs,targets
2 changes: 1 addition & 1 deletion tests/preprocessing/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_complete(self):
simple_generator = SimpleGenerator(input_bboxes_group, input_labels_group, image=input_image, num_classes=6)
# expect a UserWarning
with pytest.warns(UserWarning):
_, [_, labels_batch] = simple_generator.__getitem__(0)
_, [_, labels_batch] = simple_generator[0]

# test that only object with class 5 is present in labels_batch
labels = np.unique(np.argmax(labels_batch == 5, axis=2))
Expand Down

0 comments on commit 508d877

Please sign in to comment.