38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
"""
|
|
Copyright (c) Microsoft Corporation.
|
|
Licensed under the MIT license.
|
|
|
|
Misc lr helper
|
|
"""
|
|
from torch.optim import Adam, Adamax
|
|
|
|
from .adamw import AdamW
|
|
from .rangerlars import RangerLars
|
|
|
|
def build_optimizer(model, opts):
|
|
param_optimizer = list(model.named_parameters())
|
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
|
optimizer_grouped_parameters = [
|
|
{'params': [p for n, p in param_optimizer
|
|
if not any(nd in n for nd in no_decay)],
|
|
'weight_decay': opts.weight_decay},
|
|
{'params': [p for n, p in param_optimizer
|
|
if any(nd in n for nd in no_decay)],
|
|
'weight_decay': 0.0}
|
|
]
|
|
|
|
# currently Adam only
|
|
if opts.optim == 'adam':
|
|
OptimCls = Adam
|
|
elif opts.optim == 'adamax':
|
|
OptimCls = Adamax
|
|
elif opts.optim == 'adamw':
|
|
OptimCls = AdamW
|
|
elif opts.optim == 'rangerlars':
|
|
OptimCls = RangerLars
|
|
else:
|
|
raise ValueError('invalid optimizer')
|
|
optimizer = OptimCls(optimizer_grouped_parameters,
|
|
lr=opts.learning_rate, betas=opts.betas)
|
|
return optimizer
|