diff --git a/keras_retinanet/bin/train.py b/keras_retinanet/bin/train.py index 71be859ab..1c3a7f05d 100755 --- a/keras_retinanet/bin/train.py +++ b/keras_retinanet/bin/train.py @@ -403,6 +403,7 @@ def csv_list(string): parser.add_argument('--snapshot-path', help='Path to store snapshots of models during training (defaults to \'./snapshots\')', default='./snapshots') parser.add_argument('--tensorboard-dir', help='Log directory for Tensorboard output', default='./logs') parser.add_argument('--no-snapshots', help='Disable saving snapshots.', dest='snapshots', action='store_false') + parser.add_argument('--initial-epoch', help='Last epoch from resumed snapshot', type=int, default=0) parser.add_argument('--no-evaluation', help='Disable per epoch evaluation.', dest='evaluation', action='store_false') parser.add_argument('--freeze-backbone', help='Freeze training of backbone layers.', action='store_true') parser.add_argument('--random-transform', help='Randomly transform image and annotations.', action='store_true') @@ -506,7 +507,8 @@ def main(args=None): workers=args.workers, use_multiprocessing=use_multiprocessing, max_queue_size=args.max_queue_size, - validation_data=validation_generator + validation_data=validation_generator, + initial_epoch=args.initial_epoch )