Move to xautodl

This commit is contained in:
D-X-Y
2021-05-18 14:08:00 +00:00
parent 98fadf8086
commit 94a149b33f
149 changed files with 94 additions and 21 deletions

View File

@@ -0,0 +1,36 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .starts import prepare_seed
from .starts import prepare_logger
from .starts import get_machine_info
from .starts import save_checkpoint
from .starts import copy_checkpoint
from .optimizers import get_optim_scheduler
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
from .funcs_nasbench import get_nas_bench_loaders
def get_procedures(procedure):
from .basic_main import basic_train, basic_valid
from .search_main import search_train, search_valid
from .search_main_v2 import search_train_v2
from .simple_KD_main import simple_KD_train, simple_KD_valid
train_funcs = {
"basic": basic_train,
"search": search_train,
"Simple-KD": simple_KD_train,
"search-v2": search_train_v2,
}
valid_funcs = {
"basic": basic_valid,
"search": search_valid,
"Simple-KD": simple_KD_valid,
"search-v2": search_valid,
}
train_func = train_funcs[procedure]
valid_func = valid_funcs[procedure]
return train_func, valid_func

View File

@@ -0,0 +1,100 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
#####################################################
# To be finished.
#
import os, sys, time, torch
from typing import Optional, Text, Callable
# modules in AutoDL
from log_utils import AverageMeter
from log_utils import time_string
from .eval_funcs import obtain_accuracy
def get_device(tensors):
if isinstance(tensors, (list, tuple)):
return get_device(tensors[0])
elif isinstance(tensors, dict):
for key, value in tensors.items():
return get_device(value)
else:
return tensors.device
def basic_train_fn(
xloader,
network,
criterion,
optimizer,
metric,
logger,
):
results = procedure(
xloader,
network,
criterion,
optimizer,
metric,
"train",
logger,
)
return results
def basic_eval_fn(xloader, network, metric, logger):
with torch.no_grad():
results = procedure(
xloader,
network,
None,
None,
metric,
"valid",
logger,
)
return results
def procedure(
xloader,
network,
criterion,
optimizer,
metric,
mode: Text,
logger_fn: Callable = None,
):
data_time, batch_time = AverageMeter(), AverageMeter()
if mode.lower() == "train":
network.train()
elif mode.lower() == "valid":
network.eval()
else:
raise ValueError("The mode is not right : {:}".format(mode))
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
if mode == "train":
optimizer.zero_grad()
outputs = network(inputs)
targets = targets.to(get_device(outputs))
if mode == "train":
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# record
with torch.no_grad():
results = metric(outputs, targets)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return metric.get_info()

View File

@@ -0,0 +1,155 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
# modules in AutoDL
from log_utils import AverageMeter
from log_utils import time_string
from .eval_funcs import obtain_accuracy
def basic_train(
xloader,
network,
criterion,
scheduler,
optimizer,
optim_config,
extra_info,
print_freq,
logger,
):
loss, acc1, acc5 = procedure(
xloader,
network,
criterion,
scheduler,
optimizer,
"train",
optim_config,
extra_info,
print_freq,
logger,
)
return loss, acc1, acc5
def basic_valid(
xloader, network, criterion, optim_config, extra_info, print_freq, logger
):
with torch.no_grad():
loss, acc1, acc5 = procedure(
xloader,
network,
criterion,
None,
None,
"valid",
None,
extra_info,
print_freq,
logger,
)
return loss, acc1, acc5
def procedure(
xloader,
network,
criterion,
scheduler,
optimizer,
mode,
config,
extra_info,
print_freq,
logger,
):
data_time, batch_time, losses, top1, top5 = (
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
)
if mode == "train":
network.train()
elif mode == "valid":
network.eval()
else:
raise ValueError("The mode is not right : {:}".format(mode))
# logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
logger.log(
"[{:5s}] config :: auxiliary={:}".format(
mode, config.auxiliary if hasattr(config, "auxiliary") else -1
)
)
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == "train":
scheduler.update(None, 1.0 * i / len(xloader))
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
targets = targets.cuda(non_blocking=True)
if mode == "train":
optimizer.zero_grad()
features, logits = network(inputs)
if isinstance(logits, list):
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
2, len(logits)
)
logits, logits_aux = logits
else:
logits, logits_aux = logits, None
loss = criterion(logits, targets)
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
loss_aux = criterion(logits_aux, targets)
loss += config.auxiliary * loss_aux
if mode == "train":
loss.backward()
optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0 or (i + 1) == len(xloader):
Sstr = (
" {:5s} ".format(mode.upper())
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
)
if scheduler is not None:
Sstr += " {:}".format(scheduler.get_min_info())
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
loss=losses, top1=top1, top5=top5
)
Istr = "Size={:}".format(list(inputs.size()))
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
logger.log(
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
mode=mode.upper(),
top1=top1,
top5=top5,
error1=100 - top1.avg,
error5=100 - top5.avg,
loss=losses.avg,
)
)
return losses.avg, top1.avg, top5.avg

View File

@@ -0,0 +1,20 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
#####################################################
import abc
def obtain_accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

View File

@@ -0,0 +1,438 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
import os, time, copy, torch, pathlib
# modules in AutoDL
import datasets
from config_utils import load_config
from procedures import prepare_seed, get_optim_scheduler
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net
from utils import get_model_infos
from .eval_funcs import obtain_accuracy
__all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"]
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
latencies, device = [], torch.cuda.current_device()
network.eval()
with torch.no_grad():
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
targets = targets.cuda(device=device, non_blocking=True)
inputs = inputs.cuda(device=device, non_blocking=True)
data_time.update(time.time() - end)
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
batch_time.update(time.time() - end)
if batch is None or batch == inputs.size(0):
batch = inputs.size(0)
latencies.append(batch_time.val - data_time.val)
# record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
end = time.time()
if len(latencies) > 2:
latencies = latencies[1:]
return losses.avg, top1.avg, top5.avg, latencies
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
if mode == "train":
network.train()
elif mode == "valid":
network.eval()
else:
raise ValueError("The mode is not right : {:}".format(mode))
device = torch.cuda.current_device()
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == "train":
scheduler.update(None, 1.0 * i / len(xloader))
targets = targets.cuda(device=device, non_blocking=True)
if mode == "train":
optimizer.zero_grad()
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
# backward
if mode == "train":
loss.backward()
optimizer.step()
# record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
# count time
batch_time.update(time.time() - end)
end = time.time()
return losses.avg, top1.avg, top5.avg, batch_time.sum
def evaluate_for_seed(
arch_config, opt_config, train_loader, valid_loaders, seed: int, logger
):
prepare_seed(seed) # random seed
net = get_cell_based_tiny_net(arch_config)
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
flop, param = get_model_infos(net, opt_config.xshape)
logger.log("Network : {:}".format(net.get_message()), False)
logger.log(
"{:} Seed-------------------------- {:} --------------------------".format(
time_string(), seed
)
)
logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param))
# train and valid
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
default_device = torch.cuda.current_device()
network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(
device=default_device
)
criterion = criterion.cuda(device=default_device)
# start training
start_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
opt_config.epochs + opt_config.warmup,
)
(
train_losses,
train_acc1es,
train_acc5es,
valid_losses,
valid_acc1es,
valid_acc5es,
) = ({}, {}, {}, {}, {}, {})
train_times, valid_times, lrs = {}, {}, {}
for epoch in range(total_epoch):
scheduler.update(epoch, 0.0)
lr = min(scheduler.get_lr())
train_loss, train_acc1, train_acc5, train_tm = procedure(
train_loader, network, criterion, scheduler, optimizer, "train"
)
train_losses[epoch] = train_loss
train_acc1es[epoch] = train_acc1
train_acc5es[epoch] = train_acc5
train_times[epoch] = train_tm
lrs[epoch] = lr
with torch.no_grad():
for key, xloder in valid_loaders.items():
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(
xloder, network, criterion, None, None, "valid"
)
valid_losses["{:}@{:}".format(key, epoch)] = valid_loss
valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1
valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5
valid_times["{:}@{:}".format(key, epoch)] = valid_tm
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)
)
logger.log(
"{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format(
time_string(),
need_time,
epoch,
total_epoch,
train_loss,
train_acc1,
train_acc5,
valid_loss,
valid_acc1,
valid_acc5,
lr,
)
)
info_seed = {
"flop": flop,
"param": param,
"arch_config": arch_config._asdict(),
"opt_config": opt_config._asdict(),
"total_epoch": total_epoch,
"train_losses": train_losses,
"train_acc1es": train_acc1es,
"train_acc5es": train_acc5es,
"train_times": train_times,
"valid_losses": valid_losses,
"valid_acc1es": valid_acc1es,
"valid_acc5es": valid_acc5es,
"valid_times": valid_times,
"learning_rates": lrs,
"net_state_dict": net.state_dict(),
"net_string": "{:}".format(net),
"finish-train": True,
}
return info_seed
def get_nas_bench_loaders(workers):
torch.set_num_threads(workers)
root_dir = (pathlib.Path(__file__).parent / ".." / "..").resolve()
torch_dir = pathlib.Path(os.environ["TORCH_HOME"])
# cifar
cifar_config_path = root_dir / "configs" / "nas-benchmark" / "CIFAR.config"
cifar_config = load_config(cifar_config_path, None, None)
get_datasets = datasets.get_datasets # a function to return the dataset
break_line = "-" * 150
print("{:} Create data-loader for all datasets".format(time_string()))
print(break_line)
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets(
"cifar10", str(torch_dir / "cifar.python"), -1
)
print(
"original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num
)
)
cifar10_splits = load_config(
root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None
)
assert cifar10_splits.train[:10] == [
0,
5,
7,
11,
13,
15,
16,
17,
20,
24,
] and cifar10_splits.valid[:10] == [
1,
2,
3,
4,
6,
8,
9,
10,
12,
14,
]
temp_dataset = copy.deepcopy(TRAIN_CIFAR10)
temp_dataset.transform = VALID_CIFAR10.transform
# data loader
trainval_cifar10_loader = torch.utils.data.DataLoader(
TRAIN_CIFAR10,
batch_size=cifar_config.batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
)
train_cifar10_loader = torch.utils.data.DataLoader(
TRAIN_CIFAR10,
batch_size=cifar_config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train),
num_workers=workers,
pin_memory=True,
)
valid_cifar10_loader = torch.utils.data.DataLoader(
temp_dataset,
batch_size=cifar_config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid),
num_workers=workers,
pin_memory=True,
)
test__cifar10_loader = torch.utils.data.DataLoader(
VALID_CIFAR10,
batch_size=cifar_config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
print(
"CIFAR-10 : trval-loader has {:3d} batch with {:} per batch".format(
len(trainval_cifar10_loader), cifar_config.batch_size
)
)
print(
"CIFAR-10 : train-loader has {:3d} batch with {:} per batch".format(
len(train_cifar10_loader), cifar_config.batch_size
)
)
print(
"CIFAR-10 : valid-loader has {:3d} batch with {:} per batch".format(
len(valid_cifar10_loader), cifar_config.batch_size
)
)
print(
"CIFAR-10 : test--loader has {:3d} batch with {:} per batch".format(
len(test__cifar10_loader), cifar_config.batch_size
)
)
print(break_line)
# CIFAR-100
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets(
"cifar100", str(torch_dir / "cifar.python"), -1
)
print(
"original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num
)
)
cifar100_splits = load_config(
root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None
)
assert cifar100_splits.xvalid[:10] == [
1,
3,
4,
5,
8,
10,
13,
14,
15,
16,
] and cifar100_splits.xtest[:10] == [
0,
2,
6,
7,
9,
11,
12,
17,
20,
24,
]
train_cifar100_loader = torch.utils.data.DataLoader(
TRAIN_CIFAR100,
batch_size=cifar_config.batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
)
valid_cifar100_loader = torch.utils.data.DataLoader(
VALID_CIFAR100,
batch_size=cifar_config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid),
num_workers=workers,
pin_memory=True,
)
test__cifar100_loader = torch.utils.data.DataLoader(
VALID_CIFAR100,
batch_size=cifar_config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest),
num_workers=workers,
pin_memory=True,
)
print(
"CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader))
)
print(
"CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader))
)
print(
"CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader))
)
print(break_line)
imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config"
imagenet16_config = load_config(imagenet16_config_path, None, None)
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets(
"ImageNet16-120", str(torch_dir / "cifar.python" / "ImageNet16"), -1
)
print(
"original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num
)
)
imagenet_splits = load_config(
root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt",
None,
None,
)
assert imagenet_splits.xvalid[:10] == [
1,
2,
3,
6,
7,
8,
9,
12,
16,
18,
] and imagenet_splits.xtest[:10] == [
0,
4,
5,
10,
11,
13,
14,
15,
17,
20,
]
train_imagenet_loader = torch.utils.data.DataLoader(
TRAIN_ImageNet16_120,
batch_size=imagenet16_config.batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
)
valid_imagenet_loader = torch.utils.data.DataLoader(
VALID_ImageNet16_120,
batch_size=imagenet16_config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid),
num_workers=workers,
pin_memory=True,
)
test__imagenet_loader = torch.utils.data.DataLoader(
VALID_ImageNet16_120,
batch_size=imagenet16_config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest),
num_workers=workers,
pin_memory=True,
)
print(
"ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch".format(
len(train_imagenet_loader), imagenet16_config.batch_size
)
)
print(
"ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch".format(
len(valid_imagenet_loader), imagenet16_config.batch_size
)
)
print(
"ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch".format(
len(test__imagenet_loader), imagenet16_config.batch_size
)
)
# 'cifar10', 'cifar100', 'ImageNet16-120'
loaders = {
"cifar10@trainval": trainval_cifar10_loader,
"cifar10@train": train_cifar10_loader,
"cifar10@valid": valid_cifar10_loader,
"cifar10@test": test__cifar10_loader,
"cifar100@train": train_cifar100_loader,
"cifar100@valid": valid_cifar100_loader,
"cifar100@test": test__cifar100_loader,
"ImageNet16-120@train": train_imagenet_loader,
"ImageNet16-120@valid": valid_imagenet_loader,
"ImageNet16-120@test": test__imagenet_loader,
}
return loaders

View File

@@ -0,0 +1,134 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
#####################################################
import abc
import numpy as np
import torch
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0.0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __repr__(self):
return "{name}(val={val}, avg={avg}, count={count})".format(
name=self.__class__.__name__, **self.__dict__
)
class Metric(abc.ABC):
"""The default meta metric class."""
def __init__(self):
self.reset()
def reset(self):
raise NotImplementedError
def __call__(self, predictions, targets):
raise NotImplementedError
def get_info(self):
raise NotImplementedError
def __repr__(self):
return "{name}({inner})".format(
name=self.__class__.__name__, inner=self.inner_repr()
)
def inner_repr(self):
return ""
class ComposeMetric(Metric):
"""The composed metric class."""
def __init__(self, *metric_list):
self.reset()
for metric in metric_list:
self.append(metric)
def reset(self):
self._metric_list = []
def append(self, metric):
if not isinstance(metric, Metric):
raise ValueError(
"The input metric is not correct: {:}".format(type(metric))
)
self._metric_list.append(metric)
def __len__(self):
return len(self._metric_list)
def __call__(self, predictions, targets):
results = list()
for metric in self._metric_list:
results.append(metric(predictions, targets))
return results
def get_info(self):
results = dict()
for metric in self._metric_list:
for key, value in metric.get_info().items():
results[key] = value
return results
def inner_repr(self):
xlist = []
for metric in self._metric_list:
xlist.append(str(metric))
return ",".join(xlist)
class MSEMetric(Metric):
"""The metric for mse."""
def reset(self):
self._mse = AverageMeter()
def __call__(self, predictions, targets):
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
batch = predictions.shape[0]
loss = torch.nn.functional.mse_loss(predictions.data, targets.data)
loss = loss.item()
self._mse.update(loss, batch)
return loss
else:
raise NotImplementedError
def get_info(self):
return {"mse": self._mse.avg}
class SaveMetric(Metric):
"""The metric for mse."""
def reset(self):
self._predicts = []
def __call__(self, predictions, targets=None):
if isinstance(predictions, torch.Tensor):
predicts = predictions.cpu().numpy()
self._predicts.append(predicts)
return predicts
else:
raise NotImplementedError
def get_info(self):
all_predicts = np.concatenate(self._predicts)
return {"predictions": all_predicts}

View File

@@ -0,0 +1,263 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import math, torch
import torch.nn as nn
from bisect import bisect_right
from torch.optim import Optimizer
class _LRScheduler(object):
def __init__(self, optimizer, warmup_epochs, epochs):
if not isinstance(optimizer, Optimizer):
raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__))
self.optimizer = optimizer
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
self.base_lrs = list(
map(lambda group: group["initial_lr"], optimizer.param_groups)
)
self.max_epochs = epochs
self.warmup_epochs = warmup_epochs
self.current_epoch = 0
self.current_iter = 0
def extra_repr(self):
return ""
def __repr__(self):
return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format(
name=self.__class__.__name__, **self.__dict__
) + ", {:})".format(
self.extra_repr()
)
def state_dict(self):
return {
key: value for key, value in self.__dict__.items() if key != "optimizer"
}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_lr(self):
raise NotImplementedError
def get_min_info(self):
lrs = self.get_lr()
return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format(
min(lrs), max(lrs), self.current_epoch, self.current_iter
)
def get_min_lr(self):
return min(self.get_lr())
def update(self, cur_epoch, cur_iter):
if cur_epoch is not None:
assert (
isinstance(cur_epoch, int) and cur_epoch >= 0
), "invalid cur-epoch : {:}".format(cur_epoch)
self.current_epoch = cur_epoch
if cur_iter is not None:
assert (
isinstance(cur_iter, float) and cur_iter >= 0
), "invalid cur-iter : {:}".format(cur_iter)
self.current_iter = cur_iter
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group["lr"] = lr
class CosineAnnealingLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
self.T_max = T_max
self.eta_min = eta_min
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return "type={:}, T-max={:}, eta-min={:}".format(
"cosine", self.T_max, self.eta_min
)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if (
self.current_epoch >= self.warmup_epochs
and self.current_epoch < self.max_epochs
):
last_epoch = self.current_epoch - self.warmup_epochs
# if last_epoch < self.T_max:
# if last_epoch < self.max_epochs:
lr = (
self.eta_min
+ (base_lr - self.eta_min)
* (1 + math.cos(math.pi * last_epoch / self.T_max))
/ 2
)
# else:
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
elif self.current_epoch >= self.max_epochs:
lr = self.eta_min
else:
lr = (
self.current_epoch / self.warmup_epochs
+ self.current_iter / self.warmup_epochs
) * base_lr
lrs.append(lr)
return lrs
class MultiStepLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(
len(milestones), len(gammas)
)
self.milestones = milestones
self.gammas = gammas
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format(
"multistep", self.milestones, self.gammas, self.base_lrs
)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
idx = bisect_right(self.milestones, last_epoch)
lr = base_lr
for x in self.gammas[:idx]:
lr *= x
else:
lr = (
self.current_epoch / self.warmup_epochs
+ self.current_iter / self.warmup_epochs
) * base_lr
lrs.append(lr)
return lrs
class ExponentialLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
self.gamma = gamma
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return "type={:}, gamma={:}, base-lrs={:}".format(
"exponential", self.gamma, self.base_lrs
)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
lr = base_lr * (self.gamma ** last_epoch)
else:
lr = (
self.current_epoch / self.warmup_epochs
+ self.current_iter / self.warmup_epochs
) * base_lr
lrs.append(lr)
return lrs
class LinearLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
self.max_LR = max_LR
self.min_LR = min_LR
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format(
"LinearLR", self.max_LR, self.min_LR, self.base_lrs
)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
ratio = (
(self.max_LR - self.min_LR)
* last_epoch
/ self.max_epochs
/ self.max_LR
)
lr = base_lr * (1 - ratio)
else:
lr = (
self.current_epoch / self.warmup_epochs
+ self.current_iter / self.warmup_epochs
) * base_lr
lrs.append(lr)
return lrs
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def get_optim_scheduler(parameters, config):
assert (
hasattr(config, "optim")
and hasattr(config, "scheduler")
and hasattr(config, "criterion")
), "config must have optim / scheduler / criterion keys instead of {:}".format(
config
)
if config.optim == "SGD":
optim = torch.optim.SGD(
parameters,
config.LR,
momentum=config.momentum,
weight_decay=config.decay,
nesterov=config.nesterov,
)
elif config.optim == "RMSprop":
optim = torch.optim.RMSprop(
parameters, config.LR, momentum=config.momentum, weight_decay=config.decay
)
else:
raise ValueError("invalid optim : {:}".format(config.optim))
if config.scheduler == "cos":
T_max = getattr(config, "T_max", config.epochs)
scheduler = CosineAnnealingLR(
optim, config.warmup, config.epochs, T_max, config.eta_min
)
elif config.scheduler == "multistep":
scheduler = MultiStepLR(
optim, config.warmup, config.epochs, config.milestones, config.gammas
)
elif config.scheduler == "exponential":
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
elif config.scheduler == "linear":
scheduler = LinearLR(
optim, config.warmup, config.epochs, config.LR, config.LR_min
)
else:
raise ValueError("invalid scheduler : {:}".format(config.scheduler))
if config.criterion == "Softmax":
criterion = torch.nn.CrossEntropyLoss()
elif config.criterion == "SmoothSoftmax":
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
else:
raise ValueError("invalid criterion : {:}".format(config.criterion))
return optim, scheduler, criterion

View File

@@ -0,0 +1,151 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
#####################################################
import inspect
import os
import pprint
import logging
from copy import deepcopy
from log_utils import pickle_load
import qlib
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.utils import flatten_dict
from qlib.log import get_module_logger
def set_log_basic_config(filename=None, format=None, level=None):
"""
Set the basic configuration for the logging system.
See details at https://docs.python.org/3/library/logging.html#logging.basicConfig
:param filename: str or None
The path to save the logs.
:param format: the logging format
:param level: int
:return: Logger
Logger object.
"""
from qlib.config import C
if level is None:
level = C.logging_level
if format is None:
format = C.logging_config["formatters"]["logger_format"]["format"]
# Remove all handlers associated with the root logger object.
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.basicConfig(filename=filename, format=format, level=level)
def update_gpu(config, gpu):
config = deepcopy(config)
if "task" in config and "model" in config["task"]:
if "GPU" in config["task"]["model"]:
config["task"]["model"]["GPU"] = gpu
elif (
"kwargs" in config["task"]["model"]
and "GPU" in config["task"]["model"]["kwargs"]
):
config["task"]["model"]["kwargs"]["GPU"] = gpu
elif "model" in config:
if "GPU" in config["model"]:
config["model"]["GPU"] = gpu
elif "kwargs" in config["model"] and "GPU" in config["model"]["kwargs"]:
config["model"]["kwargs"]["GPU"] = gpu
elif "kwargs" in config and "GPU" in config["kwargs"]:
config["kwargs"]["GPU"] = gpu
elif "GPU" in config:
config["GPU"] = gpu
return config
def update_market(config, market):
config = deepcopy(config.copy())
config["market"] = market
config["data_handler_config"]["instruments"] = market
return config
def run_exp(
task_config,
dataset,
experiment_name,
recorder_name,
uri,
model_obj_name="model.pkl",
):
model = init_instance_by_config(task_config["model"])
model_fit_kwargs = dict(dataset=dataset)
# Let's start the experiment.
with R.start(
experiment_name=experiment_name,
recorder_name=recorder_name,
uri=uri,
resume=True,
):
# Setup log
recorder_root_dir = R.get_recorder().get_local_dir()
log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name))
set_log_basic_config(log_file)
logger = get_module_logger("q.run_exp")
logger.info("task_config::\n{:}".format(pprint.pformat(task_config, indent=2)))
logger.info("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri))
logger.info("dataset={:}".format(dataset))
# Train model
try:
if hasattr(model, "to"): # Recoverable model
ori_device = model.device
model = R.load_object(model_obj_name)
model.to(ori_device)
else:
model = R.load_object(model_obj_name)
logger.info("[Find existing object from {:}]".format(model_obj_name))
except OSError:
R.log_params(**flatten_dict(update_gpu(task_config, None)))
if "save_path" in inspect.getfullargspec(model.fit).args:
model_fit_kwargs["save_path"] = os.path.join(
recorder_root_dir, "model.ckp"
)
elif "save_dir" in inspect.getfullargspec(model.fit).args:
model_fit_kwargs["save_dir"] = os.path.join(
recorder_root_dir, "model-ckps"
)
model.fit(**model_fit_kwargs)
# remove model to CPU for saving
if hasattr(model, "to"):
old_device = model.device
model.to("cpu")
R.save_objects(**{model_obj_name: model})
model.to(old_device)
else:
R.save_objects(**{model_obj_name: model})
except Exception as e:
raise ValueError("Something wrong: {:}".format(e))
# Get the recorder
recorder = R.get_recorder()
# Generate records: prediction, backtest, and analysis
for record in task_config["record"]:
record = deepcopy(record)
if record["class"] == "MultiSegRecord":
record["kwargs"] = dict(model=model, dataset=dataset, recorder=recorder)
sr = init_instance_by_config(record)
sr.generate(**record["generate_kwargs"])
elif record["class"] == "SignalRecord":
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
record["kwargs"].update(srconf)
sr = init_instance_by_config(record)
sr.generate()
else:
rconf = {"recorder": recorder}
record["kwargs"].update(rconf)
ar = init_instance_by_config(record)
ar.generate()

View File

@@ -0,0 +1,198 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
from log_utils import AverageMeter, time_string
from models import change_key
from .eval_funcs import obtain_accuracy
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
expected_flop = torch.mean(expected_flop)
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
loss = -torch.log(expected_flop)
# elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
elif flop_cur > flop_need: # Too Large FLOP
loss = torch.log(expected_flop)
else: # Required FLOP
loss = None
if loss is None:
return 0, 0
else:
return loss, loss.item()
def search_train(
search_loader,
network,
criterion,
scheduler,
base_optimizer,
arch_optimizer,
optim_config,
extra_info,
print_freq,
logger,
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = (
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
)
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
epoch_str, flop_need, flop_weight, flop_tolerant = (
extra_info["epoch-str"],
extra_info["FLOP-exp"],
extra_info["FLOP-weight"],
extra_info["FLOP-tolerant"],
)
network.train()
logger.log(
"[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
epoch_str, flop_need, flop_weight
)
)
end = time.time()
network.apply(change_key("search_mode", "search"))
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
search_loader
):
scheduler.update(None, 1.0 * step / len(search_loader))
# calculate prediction and loss
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
base_optimizer.zero_grad()
logits, expected_flop = network(base_inputs)
# network.apply( change_key('search_mode', 'basic') )
# features, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
base_optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
top1.update(prec1.item(), base_inputs.size(0))
top5.update(prec5.item(), base_inputs.size(0))
# update the architecture
arch_optimizer.zero_grad()
logits, expected_flop = network(arch_inputs)
flop_cur = network.module.get_flop("genotype", None, None)
flop_loss, flop_loss_scale = get_flop_loss(
expected_flop, flop_cur, flop_need, flop_tolerant
)
acls_loss = criterion(logits, arch_targets)
arch_loss = acls_loss + flop_loss * flop_weight
arch_loss.backward()
arch_optimizer.step()
# record
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or (step + 1) == len(search_loader):
Sstr = (
"**TRAIN** "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
loss=base_losses, top1=top1, top5=top5
)
Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format(
aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses
)
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr)
# Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
# print(network.module.get_arch_info())
# print(network.module.width_attentions[0])
# print(network.module.width_attentions[1])
logger.log(
" **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format(
top1=top1,
top5=top5,
error1=100 - top1.avg,
error5=100 - top5.avg,
baseloss=base_losses.avg,
archloss=arch_losses.avg,
)
)
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
data_time, batch_time, losses, top1, top5 = (
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
)
network.eval()
network.apply(change_key("search_mode", "search"))
end = time.time()
# logger.log('Starting evaluating {:}'.format(epoch_info))
with torch.no_grad():
for i, (inputs, targets) in enumerate(xloader):
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
targets = targets.cuda(non_blocking=True)
logits, expected_flop = network(inputs)
loss = criterion(logits, targets)
# record
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0 or (i + 1) == len(xloader):
Sstr = (
"**VALID** "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
loss=losses, top1=top1, top5=top5
)
Istr = "Size={:}".format(list(inputs.size()))
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
logger.log(
" **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
top1=top1,
top5=top5,
error1=100 - top1.avg,
error5=100 - top5.avg,
loss=losses.avg,
)
)
return losses.avg, top1.avg, top5.avg

View File

@@ -0,0 +1,139 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, time, torch
# modules in AutoDL
from log_utils import AverageMeter, time_string
from models import change_key
from .eval_funcs import obtain_accuracy
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
expected_flop = torch.mean(expected_flop)
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
loss = -torch.log(expected_flop)
# elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
elif flop_cur > flop_need: # Too Large FLOP
loss = torch.log(expected_flop)
else: # Required FLOP
loss = None
if loss is None:
return 0, 0
else:
return loss, loss.item()
def search_train_v2(
search_loader,
network,
criterion,
scheduler,
base_optimizer,
arch_optimizer,
optim_config,
extra_info,
print_freq,
logger,
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = (
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
)
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
epoch_str, flop_need, flop_weight, flop_tolerant = (
extra_info["epoch-str"],
extra_info["FLOP-exp"],
extra_info["FLOP-weight"],
extra_info["FLOP-tolerant"],
)
network.train()
logger.log(
"[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
epoch_str, flop_need, flop_weight
)
)
end = time.time()
network.apply(change_key("search_mode", "search"))
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
search_loader
):
scheduler.update(None, 1.0 * step / len(search_loader))
# calculate prediction and loss
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
base_optimizer.zero_grad()
logits, expected_flop = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
base_optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
top1.update(prec1.item(), base_inputs.size(0))
top5.update(prec5.item(), base_inputs.size(0))
# update the architecture
arch_optimizer.zero_grad()
logits, expected_flop = network(arch_inputs)
flop_cur = network.module.get_flop("genotype", None, None)
flop_loss, flop_loss_scale = get_flop_loss(
expected_flop, flop_cur, flop_need, flop_tolerant
)
acls_loss = criterion(logits, arch_targets)
arch_loss = acls_loss + flop_loss * flop_weight
arch_loss.backward()
arch_optimizer.step()
# record
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or (step + 1) == len(search_loader):
Sstr = (
"**TRAIN** "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
loss=base_losses, top1=top1, top5=top5
)
Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format(
aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses
)
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr)
# num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6))
# Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
# print(network.module.get_arch_info())
# print(network.module.width_attentions[0])
# print(network.module.width_attentions[1])
logger.log(
" **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format(
top1=top1,
top5=top5,
error1=100 - top1.avg,
error5=100 - top5.avg,
baseloss=base_losses.avg,
archloss=arch_losses.avg,
)
)
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg

View File

@@ -0,0 +1,204 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import os, sys, time, torch
import torch.nn.functional as F
# modules in AutoDL
from log_utils import AverageMeter, time_string
from .eval_funcs import obtain_accuracy
def simple_KD_train(
xloader,
teacher,
network,
criterion,
scheduler,
optimizer,
optim_config,
extra_info,
print_freq,
logger,
):
loss, acc1, acc5 = procedure(
xloader,
teacher,
network,
criterion,
scheduler,
optimizer,
"train",
optim_config,
extra_info,
print_freq,
logger,
)
return loss, acc1, acc5
def simple_KD_valid(
xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger
):
with torch.no_grad():
loss, acc1, acc5 = procedure(
xloader,
teacher,
network,
criterion,
None,
None,
"valid",
optim_config,
extra_info,
print_freq,
logger,
)
return loss, acc1, acc5
def loss_KD_fn(
criterion,
student_logits,
teacher_logits,
studentFeatures,
teacherFeatures,
targets,
alpha,
temperature,
):
basic_loss = criterion(student_logits, targets) * (1.0 - alpha)
log_student = F.log_softmax(student_logits / temperature, dim=1)
sof_teacher = F.softmax(teacher_logits / temperature, dim=1)
KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (
alpha * temperature * temperature
)
return basic_loss + KD_loss
def procedure(
xloader,
teacher,
network,
criterion,
scheduler,
optimizer,
mode,
config,
extra_info,
print_freq,
logger,
):
data_time, batch_time, losses, top1, top5 = (
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
)
Ttop1, Ttop5 = AverageMeter(), AverageMeter()
if mode == "train":
network.train()
elif mode == "valid":
network.eval()
else:
raise ValueError("The mode is not right : {:}".format(mode))
teacher.eval()
logger.log(
"[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format(
mode,
config.auxiliary if hasattr(config, "auxiliary") else -1,
config.KD_alpha,
config.KD_temperature,
)
)
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == "train":
scheduler.update(None, 1.0 * i / len(xloader))
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
targets = targets.cuda(non_blocking=True)
if mode == "train":
optimizer.zero_grad()
student_f, logits = network(inputs)
if isinstance(logits, list):
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
2, len(logits)
)
logits, logits_aux = logits
else:
logits, logits_aux = logits, None
with torch.no_grad():
teacher_f, teacher_logits = teacher(inputs)
loss = loss_KD_fn(
criterion,
logits,
teacher_logits,
student_f,
teacher_f,
targets,
config.KD_alpha,
config.KD_temperature,
)
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
loss_aux = criterion(logits_aux, targets)
loss += config.auxiliary * loss_aux
if mode == "train":
loss.backward()
optimizer.step()
# record
sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(sprec1.item(), inputs.size(0))
top5.update(sprec5.item(), inputs.size(0))
# teacher
tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5))
Ttop1.update(tprec1.item(), inputs.size(0))
Ttop5.update(tprec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0 or (i + 1) == len(xloader):
Sstr = (
" {:5s} ".format(mode.upper())
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
)
if scheduler is not None:
Sstr += " {:}".format(scheduler.get_min_info())
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
loss=losses, top1=top1, top5=top5
)
Lstr += " Teacher : acc@1={:.2f}, acc@5={:.2f}".format(Ttop1.avg, Ttop5.avg)
Istr = "Size={:}".format(list(inputs.size()))
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
logger.log(
" **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}".format(
mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg
)
)
logger.log(
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
mode=mode.upper(),
top1=top1,
top5=top5,
error1=100 - top1.avg,
error5=100 - top5.avg,
loss=losses.avg,
)
)
return losses.avg, top1.avg, top5.avg

View File

@@ -0,0 +1,79 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, torch, random, PIL, copy, numpy as np
from os import path as osp
from shutil import copyfile
def prepare_seed(rand_seed):
random.seed(rand_seed)
np.random.seed(rand_seed)
torch.manual_seed(rand_seed)
torch.cuda.manual_seed(rand_seed)
torch.cuda.manual_seed_all(rand_seed)
def prepare_logger(xargs):
args = copy.deepcopy(xargs)
from log_utils import Logger
logger = Logger(args.save_dir, args.rand_seed)
logger.log("Main Function with logger : {:}".format(logger))
logger.log("Arguments : -------------------------------")
for name, value in args._get_kwargs():
logger.log("{:16} : {:}".format(name, value))
logger.log("Python Version : {:}".format(sys.version.replace("\n", " ")))
logger.log("Pillow Version : {:}".format(PIL.__version__))
logger.log("PyTorch Version : {:}".format(torch.__version__))
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
logger.log(
"CUDA_VISIBLE_DEVICES : {:}".format(
os.environ["CUDA_VISIBLE_DEVICES"]
if "CUDA_VISIBLE_DEVICES" in os.environ
else "None"
)
)
return logger
def get_machine_info():
info = "Python Version : {:}".format(sys.version.replace("\n", " "))
info += "\nPillow Version : {:}".format(PIL.__version__)
info += "\nPyTorch Version : {:}".format(torch.__version__)
info += "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
info += "\nCUDA available : {:}".format(torch.cuda.is_available())
info += "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
if "CUDA_VISIBLE_DEVICES" in os.environ:
info += "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ["CUDA_VISIBLE_DEVICES"])
else:
info += "\nDoes not set CUDA_VISIBLE_DEVICES"
return info
def save_checkpoint(state, filename, logger):
if osp.isfile(filename):
if hasattr(logger, "log"):
logger.log(
"Find {:} exist, delete is at first before saving".format(filename)
)
os.remove(filename)
torch.save(state, filename)
assert osp.isfile(
filename
), "save filename : {:} failed, which is not found.".format(filename)
if hasattr(logger, "log"):
logger.log("save checkpoint into {:}".format(filename))
return filename
def copy_checkpoint(src, dst, logger):
if osp.isfile(dst):
if hasattr(logger, "log"):
logger.log("Find {:} exist, delete is at first before saving".format(dst))
os.remove(dst)
copyfile(src, dst)
if hasattr(logger, "log"):
logger.log("copy the file from {:} into {:}".format(src, dst))