#### scheduler
if args.scheduler == 'cosine':
# here we do not set eta_min to lr_min to be backward compatible
# because in previous versions eta_min is default to 0
# rather than the default value of lr_min 1e-6
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
args.max_step, eta_min=args.eta_min) # should use eta_min arg
if args.sample_softmax > 0:
scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse,
args.max_step, eta_min=args.eta_min) # should use eta_min arg
elif args.scheduler == 'inv_sqrt':
# originally used for Transformer (in Attention is all you need)
def lr_lambda(step):
# return a multiplier instead of a learning rate
if step == 0 and args.warmup_step == 0:
return 1.
else:
return 1. / (step ** 0.5) if step > args.warmup_step \
else step / (args.warmup_step ** 1.5)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
elif args.scheduler == 'dev_perf':
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
if args.sample_softmax > 0:
scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse,
factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
elif args.scheduler == 'constant':
pass
-
step-wise
学习率退火,可以看到在warmup
阶段学习率是慢慢的上升的,而过了warmup
阶段使用相应的学习率schedule fun
进行改变
# step-wise learning rate annealing
train_step += 1
if args.scheduler in ['cosine', 'constant', 'dev_perf']:
# linear warmup stage
if train_step < args.warmup_step:
curr_lr = args.lr * train_step / args.warmup_step
optimizer.param_groups[0]['lr'] = curr_lr
if args.sample_softmax > 0:
optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
else:
if args.scheduler == 'cosine':
scheduler.step(train_step)
if args.sample_softmax > 0:
scheduler_sparse.step(train_step)
elif args.scheduler == 'inv_sqrt':
scheduler.step(train_step)
- 如果是基于开发集的学习率退火,那么就在在
evaluate
的时候吧loss放进去
# dev-performance based learning rate annealing
if args.scheduler == 'dev_perf':
scheduler.step(val_loss)
if args.sample_softmax > 0:
scheduler_sparse.step(val_loss)