From b8956dd5a5bcbb81c92944545ca03390c22a695f Mon Sep 17 00:00:00 2001 From: Mohammad Khoshbin Date: Tue, 2 Aug 2022 19:25:28 +0430 Subject: [PATCH] fix training with frozen layers (#378) --- train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 0c5deaa..c6db018 100644 --- a/train.py +++ b/train.py @@ -40,8 +40,8 @@ logger = logging.getLogger(__name__) def train(hyp, opt, device, tb_writer=None): logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) - save_dir, epochs, batch_size, total_batch_size, weights, rank = \ - Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank + save_dir, epochs, batch_size, total_batch_size, weights, rank, freeze = \ + Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, opt.freeze # Directories wdir = save_dir / 'weights' @@ -99,7 +99,7 @@ def train(hyp, opt, device, tb_writer=None): test_path = data_dict['val'] # Freeze - freeze = [] # parameter names to freeze (full or partial) + freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # parameter names to freeze (full or partial) for k, v in model.named_parameters(): v.requires_grad = True # train all layers if any(x in k for x in freeze): @@ -555,6 +555,7 @@ if __name__ == '__main__': parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B') parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch') parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') + parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone of yolov7=50, first3=0 1 2') opt = parser.parse_args() # Set DDP variables