Update SuperMLP
This commit is contained in:
@@ -14,7 +14,9 @@ class _LRScheduler(object):
|
||||
self.optimizer = optimizer
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("initial_lr", group["lr"])
|
||||
self.base_lrs = list(map(lambda group: group["initial_lr"], optimizer.param_groups))
|
||||
self.base_lrs = list(
|
||||
map(lambda group: group["initial_lr"], optimizer.param_groups)
|
||||
)
|
||||
self.max_epochs = epochs
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.current_epoch = 0
|
||||
@@ -31,7 +33,9 @@ class _LRScheduler(object):
|
||||
)
|
||||
|
||||
def state_dict(self):
|
||||
return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
|
||||
return {
|
||||
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
@@ -50,10 +54,14 @@ class _LRScheduler(object):
|
||||
|
||||
def update(self, cur_epoch, cur_iter):
|
||||
if cur_epoch is not None:
|
||||
assert isinstance(cur_epoch, int) and cur_epoch >= 0, "invalid cur-epoch : {:}".format(cur_epoch)
|
||||
assert (
|
||||
isinstance(cur_epoch, int) and cur_epoch >= 0
|
||||
), "invalid cur-epoch : {:}".format(cur_epoch)
|
||||
self.current_epoch = cur_epoch
|
||||
if cur_iter is not None:
|
||||
assert isinstance(cur_iter, float) and cur_iter >= 0, "invalid cur-iter : {:}".format(cur_iter)
|
||||
assert (
|
||||
isinstance(cur_iter, float) and cur_iter >= 0
|
||||
), "invalid cur-iter : {:}".format(cur_iter)
|
||||
self.current_iter = cur_iter
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group["lr"] = lr
|
||||
@@ -66,29 +74,44 @@ class CosineAnnealingLR(_LRScheduler):
|
||||
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, T-max={:}, eta-min={:}".format("cosine", self.T_max, self.eta_min)
|
||||
return "type={:}, T-max={:}, eta-min={:}".format(
|
||||
"cosine", self.T_max, self.eta_min
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
|
||||
if (
|
||||
self.current_epoch >= self.warmup_epochs
|
||||
and self.current_epoch < self.max_epochs
|
||||
):
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
# if last_epoch < self.T_max:
|
||||
# if last_epoch < self.max_epochs:
|
||||
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
|
||||
lr = (
|
||||
self.eta_min
|
||||
+ (base_lr - self.eta_min)
|
||||
* (1 + math.cos(math.pi * last_epoch / self.T_max))
|
||||
/ 2
|
||||
)
|
||||
# else:
|
||||
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
|
||||
elif self.current_epoch >= self.max_epochs:
|
||||
lr = self.eta_min
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class MultiStepLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
|
||||
assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(len(milestones), len(gammas))
|
||||
assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(
|
||||
len(milestones), len(gammas)
|
||||
)
|
||||
self.milestones = milestones
|
||||
self.gammas = gammas
|
||||
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
@@ -108,7 +131,10 @@ class MultiStepLR(_LRScheduler):
|
||||
for x in self.gammas[:idx]:
|
||||
lr *= x
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
@@ -119,7 +145,9 @@ class ExponentialLR(_LRScheduler):
|
||||
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, gamma={:}, base-lrs={:}".format("exponential", self.gamma, self.base_lrs)
|
||||
return "type={:}, gamma={:}, base-lrs={:}".format(
|
||||
"exponential", self.gamma, self.base_lrs
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
@@ -129,7 +157,10 @@ class ExponentialLR(_LRScheduler):
|
||||
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
||||
lr = base_lr * (self.gamma ** last_epoch)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
@@ -151,10 +182,18 @@ class LinearLR(_LRScheduler):
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
||||
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
|
||||
ratio = (
|
||||
(self.max_LR - self.min_LR)
|
||||
* last_epoch
|
||||
/ self.max_epochs
|
||||
/ self.max_LR
|
||||
)
|
||||
lr = base_lr * (1 - ratio)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
@@ -176,26 +215,42 @@ class CrossEntropyLabelSmooth(nn.Module):
|
||||
|
||||
def get_optim_scheduler(parameters, config):
|
||||
assert (
|
||||
hasattr(config, "optim") and hasattr(config, "scheduler") and hasattr(config, "criterion")
|
||||
), "config must have optim / scheduler / criterion keys instead of {:}".format(config)
|
||||
hasattr(config, "optim")
|
||||
and hasattr(config, "scheduler")
|
||||
and hasattr(config, "criterion")
|
||||
), "config must have optim / scheduler / criterion keys instead of {:}".format(
|
||||
config
|
||||
)
|
||||
if config.optim == "SGD":
|
||||
optim = torch.optim.SGD(
|
||||
parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov
|
||||
parameters,
|
||||
config.LR,
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.decay,
|
||||
nesterov=config.nesterov,
|
||||
)
|
||||
elif config.optim == "RMSprop":
|
||||
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
|
||||
optim = torch.optim.RMSprop(
|
||||
parameters, config.LR, momentum=config.momentum, weight_decay=config.decay
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid optim : {:}".format(config.optim))
|
||||
|
||||
if config.scheduler == "cos":
|
||||
T_max = getattr(config, "T_max", config.epochs)
|
||||
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
|
||||
scheduler = CosineAnnealingLR(
|
||||
optim, config.warmup, config.epochs, T_max, config.eta_min
|
||||
)
|
||||
elif config.scheduler == "multistep":
|
||||
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
|
||||
scheduler = MultiStepLR(
|
||||
optim, config.warmup, config.epochs, config.milestones, config.gammas
|
||||
)
|
||||
elif config.scheduler == "exponential":
|
||||
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
|
||||
elif config.scheduler == "linear":
|
||||
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
|
||||
scheduler = LinearLR(
|
||||
optim, config.warmup, config.epochs, config.LR, config.LR_min
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid scheduler : {:}".format(config.scheduler))
|
||||
|
||||
|
Reference in New Issue
Block a user