Skip to content

Commit

Permalink
fix training with frozen layers (WongKinYiu#378)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkhoshbin72 authored Aug 2, 2022
1 parent 1e51f56 commit b8956dd
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@

def train(hyp, opt, device, tb_writer=None):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
save_dir, epochs, batch_size, total_batch_size, weights, rank, freeze = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, opt.freeze

# Directories
wdir = save_dir / 'weights'
Expand Down Expand Up @@ -99,7 +99,7 @@ def train(hyp, opt, device, tb_writer=None):
test_path = data_dict['val']

# Freeze
freeze = [] # parameter names to freeze (full or partial)
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # parameter names to freeze (full or partial)
for k, v in model.named_parameters():
v.requires_grad = True # train all layers
if any(x in k for x in freeze):
Expand Down Expand Up @@ -555,6 +555,7 @@ def train(hyp, opt, device, tb_writer=None):
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone of yolov7=50, first3=0 1 2')
opt = parser.parse_args()

# Set DDP variables
Expand Down

0 comments on commit b8956dd

Please sign in to comment.