mirror of
https://github.com/WongKinYiu/yolov7.git
synced 2025-02-17 12:50:14 +08:00
fix training with frozen layers (#378)
This commit is contained in:
parent
1e51f564e0
commit
b8956dd5a5
7
train.py
7
train.py
@ -40,8 +40,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def train(hyp, opt, device, tb_writer=None):
|
||||
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
||||
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
||||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
||||
save_dir, epochs, batch_size, total_batch_size, weights, rank, freeze = \
|
||||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, opt.freeze
|
||||
|
||||
# Directories
|
||||
wdir = save_dir / 'weights'
|
||||
@ -99,7 +99,7 @@ def train(hyp, opt, device, tb_writer=None):
|
||||
test_path = data_dict['val']
|
||||
|
||||
# Freeze
|
||||
freeze = [] # parameter names to freeze (full or partial)
|
||||
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # parameter names to freeze (full or partial)
|
||||
for k, v in model.named_parameters():
|
||||
v.requires_grad = True # train all layers
|
||||
if any(x in k for x in freeze):
|
||||
@ -555,6 +555,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
|
||||
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
|
||||
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
|
||||
parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone of yolov7=50, first3=0 1 2')
|
||||
opt = parser.parse_args()
|
||||
|
||||
# Set DDP variables
|
||||
|
Loading…
Reference in New Issue
Block a user