Skip to content

Commit

Permalink
Merge pull request fizyr#705 from fizyr/fix-config-parsing
Browse files Browse the repository at this point in the history
Fix config parsing bug.
  • Loading branch information
hgaiser authored Oct 2, 2018
2 parents 965bf85 + a2f8b4e commit 078e12d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
14 changes: 9 additions & 5 deletions keras_retinanet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,25 @@ def create_models(backbone_retinanet, num_classes, weights, multi_gpu=0, freeze_

modifier = freeze_model if freeze_backbone else None

# load anchor parameters, or pass None (so that defaults will be used)
anchor_params = None
num_anchors = None
if config and 'anchor_parameters' in config:
anchor_params = parse_anchor_parameters(config)
num_anchors = anchor_params.num_anchors()

# Keras recommends initialising a multi-gpu model on the CPU to ease weight sharing, and to prevent OOM errors.
# optionally wrap in a parallel model
if multi_gpu > 1:
from keras.utils import multi_gpu_model
with tf.device('/cpu:0'):
model = model_with_weights(backbone_retinanet(num_classes, modifier=modifier), weights=weights, skip_mismatch=True)
model = model_with_weights(backbone_retinanet(num_classes, num_anchors=num_anchors, modifier=modifier), weights=weights, skip_mismatch=True)
training_model = multi_gpu_model(model, gpus=multi_gpu)
else:
model = model_with_weights(backbone_retinanet(num_classes, modifier=modifier), weights=weights, skip_mismatch=True)
model = model_with_weights(backbone_retinanet(num_classes, num_anchors=num_anchors, modifier=modifier), weights=weights, skip_mismatch=True)
training_model = model

# make prediction model
anchor_params = None
if config and 'anchor_parameters' in config:
anchor_params = parse_anchor_parameters(config)
prediction_model = retinanet_bbox(model=model, anchor_params=anchor_params)

# compile model
Expand Down
8 changes: 4 additions & 4 deletions keras_retinanet/layers/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@ def __init__(self, size, stride, ratios=None, scales=None, *args, **kwargs):
Args
size: The base size of the anchors to generate.
stride: The stride of the anchors to generate.
ratios: The ratios of the anchors to generate (defaults to [0.5, 1, 2]).
scales: The scales of the anchors to generate (defaults to [2^0, 2^(1/3), 2^(2/3)]).
ratios: The ratios of the anchors to generate (defaults to AnchorParameters.default.ratios).
scales: The scales of the anchors to generate (defaults to AnchorParameters.default.scales).
"""
self.size = size
self.stride = stride
self.ratios = ratios
self.scales = scales

if ratios is None:
self.ratios = np.array([0.5, 1, 2], keras.backend.floatx()),
self.ratios = utils_anchors.AnchorParameters.default.ratios
elif isinstance(ratios, list):
self.ratios = np.array(ratios)
if scales is None:
self.scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)], keras.backend.floatx()),
self.scales = utils_anchors.AnchorParameters.default.scales
elif isinstance(scales, list):
self.scales = np.array(scales)

Expand Down
6 changes: 5 additions & 1 deletion keras_retinanet/models/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def retinanet(
inputs,
backbone_layers,
num_classes,
num_anchors = 9,
num_anchors = None,
create_pyramid_features = __create_pyramid_features,
submodels = None,
name = 'retinanet'
Expand All @@ -264,6 +264,10 @@ def retinanet(
]
```
"""

if num_anchors is None:
num_anchors = AnchorParameters.default.num_anchors()

if submodels is None:
submodels = default_submodels(num_classes, num_anchors)

Expand Down

0 comments on commit 078e12d

Please sign in to comment.