Update yaml configs
This commit is contained in:
@@ -6,3 +6,7 @@ from .module_utils import call_by_yaml
|
||||
from .module_utils import nested_call_by_dict
|
||||
from .module_utils import nested_call_by_yaml
|
||||
from .yaml_utils import load_yaml
|
||||
|
||||
from .torch_utils import count_parameters
|
||||
|
||||
from .logger_utils import Logger
|
||||
|
49
xautodl/xmisc/logger_utils.py
Normal file
49
xautodl/xmisc/logger_utils.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||
#####################################################
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from .time_utils import time_for_file, time_string
|
||||
|
||||
|
||||
class Logger:
|
||||
"""A logger used in xautodl."""
|
||||
|
||||
def __init__(self, root_dir, prefix="", log_time=True):
|
||||
"""Create a summary writer logging to log_dir."""
|
||||
self.root_dir = Path(root_dir)
|
||||
self.log_dir = self.root_dir / "logs"
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._prefix = prefix
|
||||
self._log_time = log_time
|
||||
self.logger_path = self.log_dir / "{:}{:}.log".format(
|
||||
self._prefix, time_for_file()
|
||||
)
|
||||
self._logger_file = open(self.logger_path, "w")
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
return self._logger_file
|
||||
|
||||
def log(self, string, save=True, stdout=False):
|
||||
string = "{:} {:}".format(time_string(), string) if self._log_time else string
|
||||
if stdout:
|
||||
sys.stdout.write(string)
|
||||
sys.stdout.flush()
|
||||
else:
|
||||
print(string)
|
||||
if save:
|
||||
self._logger_file.write("{:}\n".format(string))
|
||||
self._logger_file.flush()
|
||||
|
||||
def close(self):
|
||||
self._logger_file.close()
|
||||
if self.writer is not None:
|
||||
self.writer.close()
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(dir={log_dir}, prefix={_prefix}, log_time={_log_time})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
@@ -62,18 +62,25 @@ def call_by_yaml(path, *args, **kwargs) -> object:
|
||||
|
||||
def nested_call_by_dict(config: Union[Dict[Text, Any], Any], *args, **kwargs) -> object:
|
||||
"""Similar to `call_by_dict`, but differently, the args may contain another dict needs to be called."""
|
||||
if not has_key_words(config):
|
||||
if isinstance(config, list):
|
||||
return [nested_call_by_dict(x) for x in config]
|
||||
elif isinstance(config, tuple):
|
||||
return (nested_call_by_dict(x) for x in config)
|
||||
elif not isinstance(config, dict):
|
||||
return config
|
||||
module = get_module_by_module_path(config["module_path"])
|
||||
cls_or_func = getattr(module, config[CLS_FUNC_KEY])
|
||||
args = tuple(list(config["args"]) + list(args))
|
||||
kwargs = {**config["kwargs"], **kwargs}
|
||||
# check whether there are nested special dict
|
||||
new_args = [nested_call_by_dict(x) for x in args]
|
||||
new_kwargs = {}
|
||||
for key, x in kwargs.items():
|
||||
new_kwargs[key] = nested_call_by_dict(x)
|
||||
return cls_or_func(*new_args, **new_kwargs)
|
||||
elif not has_key_words(config):
|
||||
return {key: nested_call_by_dict(x) for x, key in config.items()}
|
||||
else:
|
||||
module = get_module_by_module_path(config["module_path"])
|
||||
cls_or_func = getattr(module, config[CLS_FUNC_KEY])
|
||||
args = tuple(list(config["args"]) + list(args))
|
||||
kwargs = {**config["kwargs"], **kwargs}
|
||||
# check whether there are nested special dict
|
||||
new_args = [nested_call_by_dict(x) for x in args]
|
||||
new_kwargs = {}
|
||||
for key, x in kwargs.items():
|
||||
new_kwargs[key] = nested_call_by_dict(x)
|
||||
return cls_or_func(*new_args, **new_kwargs)
|
||||
|
||||
|
||||
def nested_call_by_yaml(path, *args, **kwargs) -> object:
|
||||
|
136
xautodl/xmisc/scheduler_utils.py
Normal file
136
xautodl/xmisc/scheduler_utils.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||
#####################################################
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class CosineDecayWithWarmup(_LRScheduler):
|
||||
r"""Set the learning rate of each parameter group using a cosine annealing
|
||||
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
|
||||
is the number of epochs since the last restart and :math:`T_{i}` is the number
|
||||
of epochs between two warm restarts in SGDR:
|
||||
.. math::
|
||||
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
||||
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
|
||||
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
|
||||
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
|
||||
It has been proposed in
|
||||
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
T_0 (int): Number of iterations for the first restart.
|
||||
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
|
||||
eta_min (float, optional): Minimum learning rate. Default: 0.
|
||||
last_epoch (int, optional): The index of last epoch. Default: -1.
|
||||
verbose (bool): If ``True``, prints a message to stdout for
|
||||
each update. Default: ``False``.
|
||||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||
https://arxiv.org/abs/1608.03983
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False
|
||||
):
|
||||
if T_0 <= 0 or not isinstance(T_0, int):
|
||||
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
||||
if T_mult < 1 or not isinstance(T_mult, int):
|
||||
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
||||
self.T_0 = T_0
|
||||
self.T_i = T_0
|
||||
self.T_mult = T_mult
|
||||
self.eta_min = eta_min
|
||||
|
||||
super(CosineDecayWithWarmup, self).__init__(optimizer, last_epoch, verbose)
|
||||
|
||||
self.T_cur = self.last_epoch
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed by the scheduler, "
|
||||
"please use `get_last_lr()`.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
return [
|
||||
self.eta_min
|
||||
+ (base_lr - self.eta_min)
|
||||
* (1 + math.cos(math.pi * self.T_cur / self.T_i))
|
||||
/ 2
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
def step(self, epoch=None):
|
||||
"""Step could be called after every batch update
|
||||
Example:
|
||||
>>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult)
|
||||
>>> iters = len(dataloader)
|
||||
>>> for epoch in range(20):
|
||||
>>> for i, sample in enumerate(dataloader):
|
||||
>>> inputs, labels = sample['inputs'], sample['labels']
|
||||
>>> optimizer.zero_grad()
|
||||
>>> outputs = net(inputs)
|
||||
>>> loss = criterion(outputs, labels)
|
||||
>>> loss.backward()
|
||||
>>> optimizer.step()
|
||||
>>> scheduler.step(epoch + i / iters)
|
||||
This function can be called in an interleaved way.
|
||||
Example:
|
||||
>>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult)
|
||||
>>> for epoch in range(20):
|
||||
>>> scheduler.step()
|
||||
>>> scheduler.step(26)
|
||||
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
|
||||
"""
|
||||
|
||||
if epoch is None and self.last_epoch < 0:
|
||||
epoch = 0
|
||||
|
||||
if epoch is None:
|
||||
epoch = self.last_epoch + 1
|
||||
self.T_cur = self.T_cur + 1
|
||||
if self.T_cur >= self.T_i:
|
||||
self.T_cur = self.T_cur - self.T_i
|
||||
self.T_i = self.T_i * self.T_mult
|
||||
else:
|
||||
if epoch < 0:
|
||||
raise ValueError(
|
||||
"Expected non-negative epoch, but got {}".format(epoch)
|
||||
)
|
||||
if epoch >= self.T_0:
|
||||
if self.T_mult == 1:
|
||||
self.T_cur = epoch % self.T_0
|
||||
else:
|
||||
n = int(
|
||||
math.log(
|
||||
(epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult
|
||||
)
|
||||
)
|
||||
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (
|
||||
self.T_mult - 1
|
||||
)
|
||||
self.T_i = self.T_0 * self.T_mult ** (n)
|
||||
else:
|
||||
self.T_i = self.T_0
|
||||
self.T_cur = epoch
|
||||
self.last_epoch = math.floor(epoch)
|
||||
|
||||
class _enable_get_lr_call:
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
|
||||
def __enter__(self):
|
||||
self.o._get_lr_called_within_step = True
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.o._get_lr_called_within_step = False
|
||||
return self
|
||||
|
||||
with _enable_get_lr_call(self):
|
||||
for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
|
||||
param_group, lr = data
|
||||
param_group["lr"] = lr
|
||||
self.print_lr(self.verbose, i, lr, epoch)
|
||||
|
||||
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
26
xautodl/xmisc/time_utils.py
Normal file
26
xautodl/xmisc/time_utils.py
Normal file
@@ -0,0 +1,26 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||
#####################################################
|
||||
import time
|
||||
|
||||
|
||||
def time_for_file():
|
||||
ISOTIMEFORMAT = "%d-%h-at-%H-%M-%S"
|
||||
return "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
|
||||
|
||||
|
||||
def time_string():
|
||||
ISOTIMEFORMAT = "%Y-%m-%d %X"
|
||||
string = "[{:}]".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
|
||||
return string
|
||||
|
||||
|
||||
def convert_secs2time(epoch_time, return_str=False):
|
||||
need_hour = int(epoch_time / 3600)
|
||||
need_mins = int((epoch_time - 3600 * need_hour) / 60)
|
||||
need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins)
|
||||
if return_str:
|
||||
str = "[{:02d}:{:02d}:{:02d}]".format(need_hour, need_mins, need_secs)
|
||||
return str
|
||||
else:
|
||||
return need_hour, need_mins, need_secs
|
26
xautodl/xmisc/torch_utils.py
Normal file
26
xautodl/xmisc/torch_utils.py
Normal file
@@ -0,0 +1,26 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||
#####################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
def count_parameters(model_or_parameters, unit="mb"):
|
||||
if isinstance(model_or_parameters, nn.Module):
|
||||
counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters())
|
||||
elif isinstance(model_or_parameters, nn.Parameter):
|
||||
counts = models_or_parameters.numel()
|
||||
elif isinstance(model_or_parameters, (list, tuple)):
|
||||
counts = sum(count_parameters(x, None) for x in models_or_parameters)
|
||||
else:
|
||||
counts = sum(np.prod(v.size()) for v in model_or_parameters)
|
||||
if unit.lower() == "kb" or unit.lower() == "k":
|
||||
counts /= 1e3
|
||||
elif unit.lower() == "mb" or unit.lower() == "m":
|
||||
counts /= 1e6
|
||||
elif unit.lower() == "gb" or unit.lower() == "g":
|
||||
counts /= 1e9
|
||||
elif unit is not None:
|
||||
raise ValueError("Unknow unit: {:}".format(unit))
|
||||
return counts
|
Reference in New Issue
Block a user