Use compute_loss_ota() if there is not loss_ota param or loss_ota==1

This commit is contained in:
AlexeyAB84 2022-08-16 02:10:07 +03:00
parent 6ded32cc8d
commit 36ce6b2087

View File

@ -359,7 +359,7 @@ def train(hyp, opt, device, tb_writer=None):
# Forward
with amp.autocast(enabled=cuda):
pred = model(imgs) # forward
if hyp['loss_ota'] == 1:
if 'loss_ota' not in hyp or hyp['loss_ota'] == 1:
loss, loss_items = compute_loss_ota(pred, targets.to(device), imgs) # loss scaled by batch_size
else:
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size