Skip to content

Commit

Permalink
Update training script
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Meudec committed May 31, 2020
1 parent ad35558 commit 2129a5d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
35 changes: 24 additions & 11 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
YOLOV4_ANCHORS_MASKS = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]

INPUT_SHAPE = (608, 608, 3)
BATCH_SIZE = 1
BOUNDING_BOXES_FIXED_NUMBER = 10
BATCH_SIZE = 16
BOUNDING_BOXES_FIXED_NUMBER = 50
PASCAL_VOC_NUM_CLASSES = 20


def broadcast_iou(box_1, box_2):
Expand Down Expand Up @@ -184,7 +185,7 @@ def pad_bounding_boxes_to_fixed_number_of_bounding_boxes(bounding_boxes, pad_num
return tf.pad(bounding_boxes, paddings, constant_values=0.0)


def prepare_dataset(dataset):
def prepare_dataset(dataset, shuffle=True):
dataset = dataset.map(lambda el: (el["image"], el["objects"]))
dataset = dataset.map(
lambda image, object: (
Expand All @@ -196,16 +197,27 @@ def prepare_dataset(dataset):
],
axis=-1,
),
)
),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
dataset = dataset.map(
lambda image, bounding_boxes: (
image,
pad_bounding_boxes_to_fixed_number_of_bounding_boxes(
bounding_boxes, pad_number=BOUNDING_BOXES_FIXED_NUMBER
),
)
),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
dataset = dataset.map(
lambda image, bounding_box: (
tf.image.resize(image, INPUT_SHAPE[:2]) / 255.0,
bounding_box,
),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.map(
lambda image, bounding_box_with_class: (
Expand All @@ -218,19 +230,20 @@ def prepare_dataset(dataset):
YOLOV4_ANCHORS_MASKS,
INPUT_SHAPE[0], # Assumes square input
),
)
),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)

return dataset


voc_dataset = tfds.load("voc", shuffle_files=True)
ds_train, ds_test = voc_dataset["train"], voc_dataset["test"]
ds_train = prepare_dataset(ds_train)
ds_test = prepare_dataset(ds_test)
ds_train = prepare_dataset(ds_train, shuffle=True)
ds_test = prepare_dataset(ds_test, shuffle=False)

model = YOLOv4(
input_shape=INPUT_SHAPE, anchors=YOLOV4_ANCHORS, num_classes=80, training=True
input_shape=INPUT_SHAPE, anchors=YOLOV4_ANCHORS, num_classes=PASCAL_VOC_NUM_CLASSES, training=True
)

optimizer = tf.keras.optimizers.Adam(lr=1e-4)
Expand All @@ -244,8 +257,8 @@ def prepare_dataset(dataset):
history = model.fit(
ds_train,
validation_data=ds_test,
validation_steps=100,
epochs=2,
validation_steps=10,
epochs=100,
callbacks=[
tf.keras.callbacks.TensorBoard(log_dir="./logs"),
tf.keras.callbacks.ModelCheckpoint(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_model_should_predict_valid_shapes_at_inference(

@pytest.mark.parametrize("input_shape", [(32, 33, 3), (33, 32, 3)])
def test_model_instanciation_should_fail_with_input_shapes_not_multiple_of_32(
input_shape
input_shape,
):
with pytest.raises(ValueError):
YOLOv4(input_shape, 80, [])

0 comments on commit 2129a5d

Please sign in to comment.