Skip to content

Commit

Permalink
Merge pull request fizyr#843 from ei-grad/lr
Browse files Browse the repository at this point in the history
Allow to specify the learning rate
  • Loading branch information
hgaiser authored Dec 17, 2018
2 parents 0ca4d6f + f39d893 commit b6e4605
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions keras_retinanet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def model_with_weights(model, weights, skip_mismatch):
return model


def create_models(backbone_retinanet, num_classes, weights, multi_gpu=0, freeze_backbone=False, config=None):
def create_models(backbone_retinanet, num_classes, weights, multi_gpu=0,
freeze_backbone=False, lr=1e-5, config=None):
""" Creates three models (model, training_model, prediction_model).
Args
Expand Down Expand Up @@ -127,7 +128,7 @@ def create_models(backbone_retinanet, num_classes, weights, multi_gpu=0, freeze_
'regression' : losses.smooth_l1(),
'classification': losses.focal()
},
optimizer=keras.optimizers.adam(lr=1e-5, clipnorm=0.001)
optimizer=keras.optimizers.adam(lr=lr, clipnorm=0.001)
)

return model, training_model, prediction_model
Expand Down Expand Up @@ -398,6 +399,7 @@ def csv_list(string):
parser.add_argument('--multi-gpu-force', help='Extra flag needed to enable (experimental) multi-gpu support.', action='store_true')
parser.add_argument('--epochs', help='Number of epochs to train.', type=int, default=50)
parser.add_argument('--steps', help='Number of steps per epoch.', type=int, default=10000)
parser.add_argument('--lr', help='Learning rate.', type=float, default=1e-5)
parser.add_argument('--snapshot-path', help='Path to store snapshots of models during training (defaults to \'./snapshots\')', default='./snapshots')
parser.add_argument('--tensorboard-dir', help='Log directory for Tensorboard output', default='./logs')
parser.add_argument('--no-snapshots', help='Disable saving snapshots.', dest='snapshots', action='store_false')
Expand Down Expand Up @@ -462,6 +464,7 @@ def main(args=None):
weights=weights,
multi_gpu=args.multi_gpu,
freeze_backbone=args.freeze_backbone,
lr=args.lr,
config=args.config
)

Expand Down

0 comments on commit b6e4605

Please sign in to comment.