diff --git a/train.py b/train.py index 2864636553..ea763995e0 100644 --- a/train.py +++ b/train.py @@ -359,7 +359,7 @@ def train(hyp, opt, device, tb_writer=None): # Forward with amp.autocast(enabled=cuda): pred = model(imgs) # forward - if hyp['loss_ota'] == 1: + if 'loss_ota' not in hyp or hyp['loss_ota'] == 1: loss, loss_items = compute_loss_ota(pred, targets.to(device), imgs) # loss scaled by batch_size else: loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size