fix training with frozen layers (#378)

This commit is contained in:
Mohammad Khoshbin 2022-08-02 19:25:28 +04:30 committed by GitHub
parent 1e51f564e0
commit b8956dd5a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -40,8 +40,8 @@ logger = logging.getLogger(__name__)
def train(hyp, opt, device, tb_writer=None):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
save_dir, epochs, batch_size, total_batch_size, weights, rank, freeze = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, opt.freeze
# Directories
wdir = save_dir / 'weights'
@ -99,7 +99,7 @@ def train(hyp, opt, device, tb_writer=None):
test_path = data_dict['val']
# Freeze
freeze = [] # parameter names to freeze (full or partial)
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # parameter names to freeze (full or partial)
for k, v in model.named_parameters():
v.requires_grad = True # train all layers
if any(x in k for x in freeze):
@ -555,6 +555,7 @@ if __name__ == '__main__':
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone of yolov7=50, first3=0 1 2')
opt = parser.parse_args()
# Set DDP variables