From f39d89382c5263d9572b791f9aa1aef8da133fd0 Mon Sep 17 00:00:00 2001 From: Andrew Grigorev Date: Mon, 17 Dec 2018 02:44:36 +0300 Subject: [PATCH] Add --lr option to train --- keras_retinanet/bin/train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/keras_retinanet/bin/train.py b/keras_retinanet/bin/train.py index 849286380..8bfdd4053 100755 --- a/keras_retinanet/bin/train.py +++ b/keras_retinanet/bin/train.py @@ -81,7 +81,8 @@ def model_with_weights(model, weights, skip_mismatch): return model -def create_models(backbone_retinanet, num_classes, weights, multi_gpu=0, freeze_backbone=False, config=None): +def create_models(backbone_retinanet, num_classes, weights, multi_gpu=0, + freeze_backbone=False, lr=1e-5, config=None): """ Creates three models (model, training_model, prediction_model). Args @@ -127,7 +128,7 @@ def create_models(backbone_retinanet, num_classes, weights, multi_gpu=0, freeze_ 'regression' : losses.smooth_l1(), 'classification': losses.focal() }, - optimizer=keras.optimizers.adam(lr=1e-5, clipnorm=0.001) + optimizer=keras.optimizers.adam(lr=lr, clipnorm=0.001) ) return model, training_model, prediction_model @@ -398,6 +399,7 @@ def csv_list(string): parser.add_argument('--multi-gpu-force', help='Extra flag needed to enable (experimental) multi-gpu support.', action='store_true') parser.add_argument('--epochs', help='Number of epochs to train.', type=int, default=50) parser.add_argument('--steps', help='Number of steps per epoch.', type=int, default=10000) + parser.add_argument('--lr', help='Learning rate.', type=float, default=1e-5) parser.add_argument('--snapshot-path', help='Path to store snapshots of models during training (defaults to \'./snapshots\')', default='./snapshots') parser.add_argument('--tensorboard-dir', help='Log directory for Tensorboard output', default='./logs') parser.add_argument('--no-snapshots', help='Disable saving snapshots.', dest='snapshots', action='store_false') @@ -462,6 +464,7 @@ def main(args=None): weights=weights, multi_gpu=args.multi_gpu, freeze_backbone=args.freeze_backbone, + lr=args.lr, config=args.config )