From 36ce6b2087a16bbc2a44bb586d6d802ca6428acb Mon Sep 17 00:00:00 2001 From: AlexeyAB84 Date: Tue, 16 Aug 2022 02:10:07 +0300 Subject: [PATCH] Use compute_loss_ota() if there is not loss_ota param or loss_ota==1 --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 2864636553..ea763995e0 100644 --- a/train.py +++ b/train.py @@ -359,7 +359,7 @@ def train(hyp, opt, device, tb_writer=None): # Forward with amp.autocast(enabled=cuda): pred = model(imgs) # forward - if hyp['loss_ota'] == 1: + if 'loss_ota' not in hyp or hyp['loss_ota'] == 1: loss, loss_items = compute_loss_ota(pred, targets.to(device), imgs) # loss scaled by batch_size else: loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size