Skip to content

Commit

Permalink
Use compute_loss_ota() if there is not loss_ota param or loss_ota==1
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyAB committed Aug 15, 2022
1 parent 6ded32c commit 36ce6b2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def train(hyp, opt, device, tb_writer=None):
# Forward
with amp.autocast(enabled=cuda):
pred = model(imgs) # forward
if hyp['loss_ota'] == 1:
if 'loss_ota' not in hyp or hyp['loss_ota'] == 1:
loss, loss_items = compute_loss_ota(pred, targets.to(device), imgs) # loss scaled by batch_size
else:
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
Expand Down

0 comments on commit 36ce6b2

Please sign in to comment.