Skip to content

Commit

Permalink
Use the modified Keras API for DenseNet.
Browse files Browse the repository at this point in the history
  • Loading branch information
hgaiser committed Oct 12, 2018
1 parent 0f5a712 commit dca2d0d
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions keras_retinanet/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
"""

import keras
from keras.applications.densenet import densenet
from keras.applications import densenet
from keras.utils import get_file

from . import retinanet
from . import Backbone
from ..utils.image import preprocess_image

allowed_backbones = {'densenet121': [6, 12, 24, 16], 'densenet169': [6, 12, 32, 32], 'densenet201': [6, 12, 48, 32]}

allowed_backbones = {
'densenet121': ([6, 12, 24, 16], densenet.DenseNet121),
'densenet169': ([6, 12, 32, 32], densenet.DenseNet169),
'densenet201': ([6, 12, 48, 32], densenet.DenseNet201),
}


class DenseNetBackbone(Backbone):
Expand Down Expand Up @@ -81,20 +86,20 @@ def densenet_retinanet(num_classes, backbone='densenet121', inputs=None, modifie
if inputs is None:
inputs = keras.layers.Input((None, None, 3))

blocks = allowed_backbones[backbone]
backbone = densenet.DenseNet(blocks=blocks, input_tensor=inputs, include_top=False, pooling=None, weights=None)
blocks, creator = allowed_backbones[backbone]
model = creator(input_tensor=inputs, include_top=False, pooling=None, weights=None)

# get last conv layer from the end of each dense block
layer_outputs = [backbone.get_layer(name='conv{}_block{}_concat'.format(idx + 2, block_num)).output for idx, block_num in enumerate(blocks)]
layer_outputs = [model.get_layer(name='conv{}_block{}_concat'.format(idx + 2, block_num)).output for idx, block_num in enumerate(blocks)]

# create the densenet backbone
backbone = keras.models.Model(inputs=inputs, outputs=layer_outputs[1:], name=backbone.name)
model = keras.models.Model(inputs=inputs, outputs=layer_outputs[1:], name=model.name)

# invoke modifier if given
if modifier:
backbone = modifier(backbone)
model = modifier(model)

# create the full model
model = retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=backbone.outputs, **kwargs)
model = retinanet.retinanet(inputs=inputs, num_classes=num_classes, backbone_layers=model.outputs, **kwargs)

return model

0 comments on commit dca2d0d

Please sign in to comment.