mirror of
https://github.com/WongKinYiu/yolov7.git
synced 2025-02-17 12:50:14 +08:00
Use compute_loss_ota() if there is not loss_ota param or loss_ota==1
This commit is contained in:
parent
6ded32cc8d
commit
36ce6b2087
2
train.py
2
train.py
@ -359,7 +359,7 @@ def train(hyp, opt, device, tb_writer=None):
|
||||
# Forward
|
||||
with amp.autocast(enabled=cuda):
|
||||
pred = model(imgs) # forward
|
||||
if hyp['loss_ota'] == 1:
|
||||
if 'loss_ota' not in hyp or hyp['loss_ota'] == 1:
|
||||
loss, loss_items = compute_loss_ota(pred, targets.to(device), imgs) # loss scaled by batch_size
|
||||
else:
|
||||
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
||||
|
Loading…
Reference in New Issue
Block a user