adversarial_VLNDUET/pretrain_src/optim/misc.py
Shizhe Chen 747cf0587b init
2021-11-24 13:29:08 +01:00

38 lines
1.1 KiB
Python

"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
Misc lr helper
"""
from torch.optim import Adam, Adamax
from .adamw import AdamW
from .rangerlars import RangerLars
def build_optimizer(model, opts):
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer
if not any(nd in n for nd in no_decay)],
'weight_decay': opts.weight_decay},
{'params': [p for n, p in param_optimizer
if any(nd in n for nd in no_decay)],
'weight_decay': 0.0}
]
# currently Adam only
if opts.optim == 'adam':
OptimCls = Adam
elif opts.optim == 'adamax':
OptimCls = Adamax
elif opts.optim == 'adamw':
OptimCls = AdamW
elif opts.optim == 'rangerlars':
OptimCls = RangerLars
else:
raise ValueError('invalid optimizer')
optimizer = OptimCls(optimizer_grouped_parameters,
lr=opts.learning_rate, betas=opts.betas)
return optimizer