Add more algorithms
This commit is contained in:
32
others/GDAS/lib/scheduler/scheduler.py
Normal file
32
others/GDAS/lib/scheduler/scheduler.py
Normal file
@@ -0,0 +1,32 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
from bisect import bisect_right
|
||||
|
||||
|
||||
class MultiStepLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, milestones, gammas, last_epoch=-1):
|
||||
if not list(milestones) == sorted(milestones):
|
||||
raise ValueError('Milestones should be a list of'
|
||||
' increasing integers. Got {:}', milestones)
|
||||
assert len(milestones) == len(gammas), '{:} vs {:}'.format(milestones, gammas)
|
||||
self.milestones = milestones
|
||||
self.gammas = gammas
|
||||
super(MultiStepLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
LR = 1
|
||||
for x in self.gammas[:bisect_right(self.milestones, self.last_epoch)]: LR = LR * x
|
||||
return [base_lr * LR for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
def obtain_scheduler(config, optimizer):
|
||||
if config.type == 'multistep':
|
||||
scheduler = MultiStepLR(optimizer, milestones=config.milestones, gammas=config.gammas)
|
||||
elif config.type == 'cosine':
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs)
|
||||
else:
|
||||
raise ValueError('Unknown learning rate scheduler type : {:}'.format(config.type))
|
||||
return scheduler
|
Reference in New Issue
Block a user