Move to xautodl
This commit is contained in:
36
xautodl/procedures/__init__.py
Normal file
36
xautodl/procedures/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .starts import prepare_seed
|
||||
from .starts import prepare_logger
|
||||
from .starts import get_machine_info
|
||||
from .starts import save_checkpoint
|
||||
from .starts import copy_checkpoint
|
||||
from .optimizers import get_optim_scheduler
|
||||
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
|
||||
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
|
||||
from .funcs_nasbench import get_nas_bench_loaders
|
||||
|
||||
|
||||
def get_procedures(procedure):
|
||||
from .basic_main import basic_train, basic_valid
|
||||
from .search_main import search_train, search_valid
|
||||
from .search_main_v2 import search_train_v2
|
||||
from .simple_KD_main import simple_KD_train, simple_KD_valid
|
||||
|
||||
train_funcs = {
|
||||
"basic": basic_train,
|
||||
"search": search_train,
|
||||
"Simple-KD": simple_KD_train,
|
||||
"search-v2": search_train_v2,
|
||||
}
|
||||
valid_funcs = {
|
||||
"basic": basic_valid,
|
||||
"search": search_valid,
|
||||
"Simple-KD": simple_KD_valid,
|
||||
"search-v2": search_valid,
|
||||
}
|
||||
|
||||
train_func = train_funcs[procedure]
|
||||
valid_func = valid_funcs[procedure]
|
||||
return train_func, valid_func
|
100
xautodl/procedures/advanced_main.py
Normal file
100
xautodl/procedures/advanced_main.py
Normal file
@@ -0,0 +1,100 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
||||
#####################################################
|
||||
# To be finished.
|
||||
#
|
||||
import os, sys, time, torch
|
||||
from typing import Optional, Text, Callable
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter
|
||||
from log_utils import time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def get_device(tensors):
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return get_device(tensors[0])
|
||||
elif isinstance(tensors, dict):
|
||||
for key, value in tensors.items():
|
||||
return get_device(value)
|
||||
else:
|
||||
return tensors.device
|
||||
|
||||
|
||||
def basic_train_fn(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
optimizer,
|
||||
metric,
|
||||
logger,
|
||||
):
|
||||
results = procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
optimizer,
|
||||
metric,
|
||||
"train",
|
||||
logger,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def basic_eval_fn(xloader, network, metric, logger):
|
||||
with torch.no_grad():
|
||||
results = procedure(
|
||||
xloader,
|
||||
network,
|
||||
None,
|
||||
None,
|
||||
metric,
|
||||
"valid",
|
||||
logger,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
optimizer,
|
||||
metric,
|
||||
mode: Text,
|
||||
logger_fn: Callable = None,
|
||||
):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
if mode.lower() == "train":
|
||||
network.train()
|
||||
elif mode.lower() == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = network(inputs)
|
||||
targets = targets.to(get_device(outputs))
|
||||
|
||||
if mode == "train":
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
with torch.no_grad():
|
||||
results = metric(outputs, targets)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return metric.get_info()
|
155
xautodl/procedures/basic_main.py
Normal file
155
xautodl/procedures/basic_main.py
Normal file
@@ -0,0 +1,155 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter
|
||||
from log_utils import time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def basic_train(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
"train",
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def basic_valid(
|
||||
xloader, network, criterion, optim_config, extra_info, print_freq, logger
|
||||
):
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
None,
|
||||
None,
|
||||
"valid",
|
||||
None,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
mode,
|
||||
config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time, losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
if mode == "train":
|
||||
network.train()
|
||||
elif mode == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
# logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
|
||||
logger.log(
|
||||
"[{:5s}] config :: auxiliary={:}".format(
|
||||
mode, config.auxiliary if hasattr(config, "auxiliary") else -1
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == "train":
|
||||
scheduler.update(None, 1.0 * i / len(xloader))
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
|
||||
features, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
|
||||
2, len(logits)
|
||||
)
|
||||
logits, logits_aux = logits
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
loss = criterion(logits, targets)
|
||||
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary * loss_aux
|
||||
|
||||
if mode == "train":
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i + 1) == len(xloader):
|
||||
Sstr = (
|
||||
" {:5s} ".format(mode.upper())
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
||||
)
|
||||
if scheduler is not None:
|
||||
Sstr += " {:}".format(scheduler.get_min_info())
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=losses, top1=top1, top5=top5
|
||||
)
|
||||
Istr = "Size={:}".format(list(inputs.size()))
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
|
||||
|
||||
logger.log(
|
||||
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||
mode=mode.upper(),
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
loss=losses.avg,
|
||||
)
|
||||
)
|
||||
return losses.avg, top1.avg, top5.avg
|
20
xautodl/procedures/eval_funcs.py
Normal file
20
xautodl/procedures/eval_funcs.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
||||
#####################################################
|
||||
import abc
|
||||
|
||||
|
||||
def obtain_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
438
xautodl/procedures/funcs_nasbench.py
Normal file
438
xautodl/procedures/funcs_nasbench.py
Normal file
@@ -0,0 +1,438 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
import os, time, copy, torch, pathlib
|
||||
|
||||
# modules in AutoDL
|
||||
import datasets
|
||||
from config_utils import load_config
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net
|
||||
from utils import get_model_infos
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
__all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"]
|
||||
|
||||
|
||||
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
latencies, device = [], torch.cuda.current_device()
|
||||
network.eval()
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
targets = targets.cuda(device=device, non_blocking=True)
|
||||
inputs = inputs.cuda(device=device, non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
batch_time.update(time.time() - end)
|
||||
if batch is None or batch == inputs.size(0):
|
||||
batch = inputs.size(0)
|
||||
latencies.append(batch_time.val - data_time.val)
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
end = time.time()
|
||||
if len(latencies) > 2:
|
||||
latencies = latencies[1:]
|
||||
return losses.avg, top1.avg, top5.avg, latencies
|
||||
|
||||
|
||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == "train":
|
||||
network.train()
|
||||
elif mode == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
device = torch.cuda.current_device()
|
||||
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == "train":
|
||||
scheduler.update(None, 1.0 * i / len(xloader))
|
||||
|
||||
targets = targets.cuda(device=device, non_blocking=True)
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# backward
|
||||
if mode == "train":
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
# count time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||
|
||||
|
||||
def evaluate_for_seed(
|
||||
arch_config, opt_config, train_loader, valid_loaders, seed: int, logger
|
||||
):
|
||||
|
||||
prepare_seed(seed) # random seed
|
||||
net = get_cell_based_tiny_net(arch_config)
|
||||
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||
flop, param = get_model_infos(net, opt_config.xshape)
|
||||
logger.log("Network : {:}".format(net.get_message()), False)
|
||||
logger.log(
|
||||
"{:} Seed-------------------------- {:} --------------------------".format(
|
||||
time_string(), seed
|
||||
)
|
||||
)
|
||||
logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param))
|
||||
# train and valid
|
||||
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
|
||||
default_device = torch.cuda.current_device()
|
||||
network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(
|
||||
device=default_device
|
||||
)
|
||||
criterion = criterion.cuda(device=default_device)
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = (
|
||||
time.time(),
|
||||
AverageMeter(),
|
||||
opt_config.epochs + opt_config.warmup,
|
||||
)
|
||||
(
|
||||
train_losses,
|
||||
train_acc1es,
|
||||
train_acc5es,
|
||||
valid_losses,
|
||||
valid_acc1es,
|
||||
valid_acc5es,
|
||||
) = ({}, {}, {}, {}, {}, {})
|
||||
train_times, valid_times, lrs = {}, {}, {}
|
||||
for epoch in range(total_epoch):
|
||||
scheduler.update(epoch, 0.0)
|
||||
lr = min(scheduler.get_lr())
|
||||
train_loss, train_acc1, train_acc5, train_tm = procedure(
|
||||
train_loader, network, criterion, scheduler, optimizer, "train"
|
||||
)
|
||||
train_losses[epoch] = train_loss
|
||||
train_acc1es[epoch] = train_acc1
|
||||
train_acc5es[epoch] = train_acc5
|
||||
train_times[epoch] = train_tm
|
||||
lrs[epoch] = lr
|
||||
with torch.no_grad():
|
||||
for key, xloder in valid_loaders.items():
|
||||
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(
|
||||
xloder, network, criterion, None, None, "valid"
|
||||
)
|
||||
valid_losses["{:}@{:}".format(key, epoch)] = valid_loss
|
||||
valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1
|
||||
valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5
|
||||
valid_times["{:}@{:}".format(key, epoch)] = valid_tm
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = "Time Left: {:}".format(
|
||||
convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)
|
||||
)
|
||||
logger.log(
|
||||
"{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format(
|
||||
time_string(),
|
||||
need_time,
|
||||
epoch,
|
||||
total_epoch,
|
||||
train_loss,
|
||||
train_acc1,
|
||||
train_acc5,
|
||||
valid_loss,
|
||||
valid_acc1,
|
||||
valid_acc5,
|
||||
lr,
|
||||
)
|
||||
)
|
||||
info_seed = {
|
||||
"flop": flop,
|
||||
"param": param,
|
||||
"arch_config": arch_config._asdict(),
|
||||
"opt_config": opt_config._asdict(),
|
||||
"total_epoch": total_epoch,
|
||||
"train_losses": train_losses,
|
||||
"train_acc1es": train_acc1es,
|
||||
"train_acc5es": train_acc5es,
|
||||
"train_times": train_times,
|
||||
"valid_losses": valid_losses,
|
||||
"valid_acc1es": valid_acc1es,
|
||||
"valid_acc5es": valid_acc5es,
|
||||
"valid_times": valid_times,
|
||||
"learning_rates": lrs,
|
||||
"net_state_dict": net.state_dict(),
|
||||
"net_string": "{:}".format(net),
|
||||
"finish-train": True,
|
||||
}
|
||||
return info_seed
|
||||
|
||||
|
||||
def get_nas_bench_loaders(workers):
|
||||
|
||||
torch.set_num_threads(workers)
|
||||
|
||||
root_dir = (pathlib.Path(__file__).parent / ".." / "..").resolve()
|
||||
torch_dir = pathlib.Path(os.environ["TORCH_HOME"])
|
||||
# cifar
|
||||
cifar_config_path = root_dir / "configs" / "nas-benchmark" / "CIFAR.config"
|
||||
cifar_config = load_config(cifar_config_path, None, None)
|
||||
get_datasets = datasets.get_datasets # a function to return the dataset
|
||||
break_line = "-" * 150
|
||||
print("{:} Create data-loader for all datasets".format(time_string()))
|
||||
print(break_line)
|
||||
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets(
|
||||
"cifar10", str(torch_dir / "cifar.python"), -1
|
||||
)
|
||||
print(
|
||||
"original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||
len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num
|
||||
)
|
||||
)
|
||||
cifar10_splits = load_config(
|
||||
root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None
|
||||
)
|
||||
assert cifar10_splits.train[:10] == [
|
||||
0,
|
||||
5,
|
||||
7,
|
||||
11,
|
||||
13,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
20,
|
||||
24,
|
||||
] and cifar10_splits.valid[:10] == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
6,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
12,
|
||||
14,
|
||||
]
|
||||
temp_dataset = copy.deepcopy(TRAIN_CIFAR10)
|
||||
temp_dataset.transform = VALID_CIFAR10.transform
|
||||
# data loader
|
||||
trainval_cifar10_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_CIFAR10,
|
||||
batch_size=cifar_config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
train_cifar10_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_CIFAR10,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_cifar10_loader = torch.utils.data.DataLoader(
|
||||
temp_dataset,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
test__cifar10_loader = torch.utils.data.DataLoader(
|
||||
VALID_CIFAR10,
|
||||
batch_size=cifar_config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : trval-loader has {:3d} batch with {:} per batch".format(
|
||||
len(trainval_cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : train-loader has {:3d} batch with {:} per batch".format(
|
||||
len(train_cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : valid-loader has {:3d} batch with {:} per batch".format(
|
||||
len(valid_cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : test--loader has {:3d} batch with {:} per batch".format(
|
||||
len(test__cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(break_line)
|
||||
# CIFAR-100
|
||||
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets(
|
||||
"cifar100", str(torch_dir / "cifar.python"), -1
|
||||
)
|
||||
print(
|
||||
"original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||
len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num
|
||||
)
|
||||
)
|
||||
cifar100_splits = load_config(
|
||||
root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None
|
||||
)
|
||||
assert cifar100_splits.xvalid[:10] == [
|
||||
1,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
8,
|
||||
10,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
] and cifar100_splits.xtest[:10] == [
|
||||
0,
|
||||
2,
|
||||
6,
|
||||
7,
|
||||
9,
|
||||
11,
|
||||
12,
|
||||
17,
|
||||
20,
|
||||
24,
|
||||
]
|
||||
train_cifar100_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_CIFAR100,
|
||||
batch_size=cifar_config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_cifar100_loader = torch.utils.data.DataLoader(
|
||||
VALID_CIFAR100,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
test__cifar100_loader = torch.utils.data.DataLoader(
|
||||
VALID_CIFAR100,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print(
|
||||
"CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader))
|
||||
)
|
||||
print(
|
||||
"CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader))
|
||||
)
|
||||
print(
|
||||
"CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader))
|
||||
)
|
||||
print(break_line)
|
||||
|
||||
imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config"
|
||||
imagenet16_config = load_config(imagenet16_config_path, None, None)
|
||||
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets(
|
||||
"ImageNet16-120", str(torch_dir / "cifar.python" / "ImageNet16"), -1
|
||||
)
|
||||
print(
|
||||
"original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||
len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num
|
||||
)
|
||||
)
|
||||
imagenet_splits = load_config(
|
||||
root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt",
|
||||
None,
|
||||
None,
|
||||
)
|
||||
assert imagenet_splits.xvalid[:10] == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
12,
|
||||
16,
|
||||
18,
|
||||
] and imagenet_splits.xtest[:10] == [
|
||||
0,
|
||||
4,
|
||||
5,
|
||||
10,
|
||||
11,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
17,
|
||||
20,
|
||||
]
|
||||
train_imagenet_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_ImageNet16_120,
|
||||
batch_size=imagenet16_config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_imagenet_loader = torch.utils.data.DataLoader(
|
||||
VALID_ImageNet16_120,
|
||||
batch_size=imagenet16_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
test__imagenet_loader = torch.utils.data.DataLoader(
|
||||
VALID_ImageNet16_120,
|
||||
batch_size=imagenet16_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print(
|
||||
"ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch".format(
|
||||
len(train_imagenet_loader), imagenet16_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch".format(
|
||||
len(valid_imagenet_loader), imagenet16_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch".format(
|
||||
len(test__imagenet_loader), imagenet16_config.batch_size
|
||||
)
|
||||
)
|
||||
|
||||
# 'cifar10', 'cifar100', 'ImageNet16-120'
|
||||
loaders = {
|
||||
"cifar10@trainval": trainval_cifar10_loader,
|
||||
"cifar10@train": train_cifar10_loader,
|
||||
"cifar10@valid": valid_cifar10_loader,
|
||||
"cifar10@test": test__cifar10_loader,
|
||||
"cifar100@train": train_cifar100_loader,
|
||||
"cifar100@valid": valid_cifar100_loader,
|
||||
"cifar100@test": test__cifar100_loader,
|
||||
"ImageNet16-120@train": train_imagenet_loader,
|
||||
"ImageNet16-120@valid": valid_imagenet_loader,
|
||||
"ImageNet16-120@test": test__imagenet_loader,
|
||||
}
|
||||
return loaders
|
134
xautodl/procedures/metric_utils.py
Normal file
134
xautodl/procedures/metric_utils.py
Normal file
@@ -0,0 +1,134 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
||||
#####################################################
|
||||
import abc
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0.0
|
||||
self.avg = 0.0
|
||||
self.sum = 0.0
|
||||
self.count = 0.0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(val={val}, avg={avg}, count={count})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
|
||||
class Metric(abc.ABC):
|
||||
"""The default meta metric class."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({inner})".format(
|
||||
name=self.__class__.__name__, inner=self.inner_repr()
|
||||
)
|
||||
|
||||
def inner_repr(self):
|
||||
return ""
|
||||
|
||||
|
||||
class ComposeMetric(Metric):
|
||||
"""The composed metric class."""
|
||||
|
||||
def __init__(self, *metric_list):
|
||||
self.reset()
|
||||
for metric in metric_list:
|
||||
self.append(metric)
|
||||
|
||||
def reset(self):
|
||||
self._metric_list = []
|
||||
|
||||
def append(self, metric):
|
||||
if not isinstance(metric, Metric):
|
||||
raise ValueError(
|
||||
"The input metric is not correct: {:}".format(type(metric))
|
||||
)
|
||||
self._metric_list.append(metric)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._metric_list)
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
results = list()
|
||||
for metric in self._metric_list:
|
||||
results.append(metric(predictions, targets))
|
||||
return results
|
||||
|
||||
def get_info(self):
|
||||
results = dict()
|
||||
for metric in self._metric_list:
|
||||
for key, value in metric.get_info().items():
|
||||
results[key] = value
|
||||
return results
|
||||
|
||||
def inner_repr(self):
|
||||
xlist = []
|
||||
for metric in self._metric_list:
|
||||
xlist.append(str(metric))
|
||||
return ",".join(xlist)
|
||||
|
||||
|
||||
class MSEMetric(Metric):
|
||||
"""The metric for mse."""
|
||||
|
||||
def reset(self):
|
||||
self._mse = AverageMeter()
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
|
||||
batch = predictions.shape[0]
|
||||
loss = torch.nn.functional.mse_loss(predictions.data, targets.data)
|
||||
loss = loss.item()
|
||||
self._mse.update(loss, batch)
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
return {"mse": self._mse.avg}
|
||||
|
||||
|
||||
class SaveMetric(Metric):
|
||||
"""The metric for mse."""
|
||||
|
||||
def reset(self):
|
||||
self._predicts = []
|
||||
|
||||
def __call__(self, predictions, targets=None):
|
||||
if isinstance(predictions, torch.Tensor):
|
||||
predicts = predictions.cpu().numpy()
|
||||
self._predicts.append(predicts)
|
||||
return predicts
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
all_predicts = np.concatenate(self._predicts)
|
||||
return {"predictions": all_predicts}
|
263
xautodl/procedures/optimizers.py
Normal file
263
xautodl/procedures/optimizers.py
Normal file
@@ -0,0 +1,263 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
from bisect import bisect_right
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class _LRScheduler(object):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs):
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__))
|
||||
self.optimizer = optimizer
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("initial_lr", group["lr"])
|
||||
self.base_lrs = list(
|
||||
map(lambda group: group["initial_lr"], optimizer.param_groups)
|
||||
)
|
||||
self.max_epochs = epochs
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.current_epoch = 0
|
||||
self.current_iter = 0
|
||||
|
||||
def extra_repr(self):
|
||||
return ""
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
) + ", {:})".format(
|
||||
self.extra_repr()
|
||||
)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_lr(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_min_info(self):
|
||||
lrs = self.get_lr()
|
||||
return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format(
|
||||
min(lrs), max(lrs), self.current_epoch, self.current_iter
|
||||
)
|
||||
|
||||
def get_min_lr(self):
|
||||
return min(self.get_lr())
|
||||
|
||||
def update(self, cur_epoch, cur_iter):
|
||||
if cur_epoch is not None:
|
||||
assert (
|
||||
isinstance(cur_epoch, int) and cur_epoch >= 0
|
||||
), "invalid cur-epoch : {:}".format(cur_epoch)
|
||||
self.current_epoch = cur_epoch
|
||||
if cur_iter is not None:
|
||||
assert (
|
||||
isinstance(cur_iter, float) and cur_iter >= 0
|
||||
), "invalid cur-iter : {:}".format(cur_iter)
|
||||
self.current_iter = cur_iter
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
class CosineAnnealingLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, T-max={:}, eta-min={:}".format(
|
||||
"cosine", self.T_max, self.eta_min
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if (
|
||||
self.current_epoch >= self.warmup_epochs
|
||||
and self.current_epoch < self.max_epochs
|
||||
):
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
# if last_epoch < self.T_max:
|
||||
# if last_epoch < self.max_epochs:
|
||||
lr = (
|
||||
self.eta_min
|
||||
+ (base_lr - self.eta_min)
|
||||
* (1 + math.cos(math.pi * last_epoch / self.T_max))
|
||||
/ 2
|
||||
)
|
||||
# else:
|
||||
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
|
||||
elif self.current_epoch >= self.max_epochs:
|
||||
lr = self.eta_min
|
||||
else:
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class MultiStepLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
|
||||
assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(
|
||||
len(milestones), len(gammas)
|
||||
)
|
||||
self.milestones = milestones
|
||||
self.gammas = gammas
|
||||
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format(
|
||||
"multistep", self.milestones, self.gammas, self.base_lrs
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
idx = bisect_right(self.milestones, last_epoch)
|
||||
lr = base_lr
|
||||
for x in self.gammas[:idx]:
|
||||
lr *= x
|
||||
else:
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class ExponentialLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
|
||||
self.gamma = gamma
|
||||
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, gamma={:}, base-lrs={:}".format(
|
||||
"exponential", self.gamma, self.base_lrs
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
||||
lr = base_lr * (self.gamma ** last_epoch)
|
||||
else:
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class LinearLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
|
||||
self.max_LR = max_LR
|
||||
self.min_LR = min_LR
|
||||
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format(
|
||||
"LinearLR", self.max_LR, self.min_LR, self.base_lrs
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
||||
ratio = (
|
||||
(self.max_LR - self.min_LR)
|
||||
* last_epoch
|
||||
/ self.max_epochs
|
||||
/ self.max_LR
|
||||
)
|
||||
lr = base_lr * (1 - ratio)
|
||||
else:
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
def __init__(self, num_classes, epsilon):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
|
||||
def get_optim_scheduler(parameters, config):
|
||||
assert (
|
||||
hasattr(config, "optim")
|
||||
and hasattr(config, "scheduler")
|
||||
and hasattr(config, "criterion")
|
||||
), "config must have optim / scheduler / criterion keys instead of {:}".format(
|
||||
config
|
||||
)
|
||||
if config.optim == "SGD":
|
||||
optim = torch.optim.SGD(
|
||||
parameters,
|
||||
config.LR,
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.decay,
|
||||
nesterov=config.nesterov,
|
||||
)
|
||||
elif config.optim == "RMSprop":
|
||||
optim = torch.optim.RMSprop(
|
||||
parameters, config.LR, momentum=config.momentum, weight_decay=config.decay
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid optim : {:}".format(config.optim))
|
||||
|
||||
if config.scheduler == "cos":
|
||||
T_max = getattr(config, "T_max", config.epochs)
|
||||
scheduler = CosineAnnealingLR(
|
||||
optim, config.warmup, config.epochs, T_max, config.eta_min
|
||||
)
|
||||
elif config.scheduler == "multistep":
|
||||
scheduler = MultiStepLR(
|
||||
optim, config.warmup, config.epochs, config.milestones, config.gammas
|
||||
)
|
||||
elif config.scheduler == "exponential":
|
||||
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
|
||||
elif config.scheduler == "linear":
|
||||
scheduler = LinearLR(
|
||||
optim, config.warmup, config.epochs, config.LR, config.LR_min
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid scheduler : {:}".format(config.scheduler))
|
||||
|
||||
if config.criterion == "Softmax":
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
elif config.criterion == "SmoothSoftmax":
|
||||
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
|
||||
else:
|
||||
raise ValueError("invalid criterion : {:}".format(config.criterion))
|
||||
return optim, scheduler, criterion
|
151
xautodl/procedures/q_exps.py
Normal file
151
xautodl/procedures/q_exps.py
Normal file
@@ -0,0 +1,151 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
|
||||
#####################################################
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import pprint
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
from log_utils import pickle_load
|
||||
import qlib
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
def set_log_basic_config(filename=None, format=None, level=None):
|
||||
"""
|
||||
Set the basic configuration for the logging system.
|
||||
See details at https://docs.python.org/3/library/logging.html#logging.basicConfig
|
||||
:param filename: str or None
|
||||
The path to save the logs.
|
||||
:param format: the logging format
|
||||
:param level: int
|
||||
:return: Logger
|
||||
Logger object.
|
||||
"""
|
||||
from qlib.config import C
|
||||
|
||||
if level is None:
|
||||
level = C.logging_level
|
||||
|
||||
if format is None:
|
||||
format = C.logging_config["formatters"]["logger_format"]["format"]
|
||||
|
||||
# Remove all handlers associated with the root logger object.
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
logging.basicConfig(filename=filename, format=format, level=level)
|
||||
|
||||
|
||||
def update_gpu(config, gpu):
|
||||
config = deepcopy(config)
|
||||
if "task" in config and "model" in config["task"]:
|
||||
if "GPU" in config["task"]["model"]:
|
||||
config["task"]["model"]["GPU"] = gpu
|
||||
elif (
|
||||
"kwargs" in config["task"]["model"]
|
||||
and "GPU" in config["task"]["model"]["kwargs"]
|
||||
):
|
||||
config["task"]["model"]["kwargs"]["GPU"] = gpu
|
||||
elif "model" in config:
|
||||
if "GPU" in config["model"]:
|
||||
config["model"]["GPU"] = gpu
|
||||
elif "kwargs" in config["model"] and "GPU" in config["model"]["kwargs"]:
|
||||
config["model"]["kwargs"]["GPU"] = gpu
|
||||
elif "kwargs" in config and "GPU" in config["kwargs"]:
|
||||
config["kwargs"]["GPU"] = gpu
|
||||
elif "GPU" in config:
|
||||
config["GPU"] = gpu
|
||||
return config
|
||||
|
||||
|
||||
def update_market(config, market):
|
||||
config = deepcopy(config.copy())
|
||||
config["market"] = market
|
||||
config["data_handler_config"]["instruments"] = market
|
||||
return config
|
||||
|
||||
|
||||
def run_exp(
|
||||
task_config,
|
||||
dataset,
|
||||
experiment_name,
|
||||
recorder_name,
|
||||
uri,
|
||||
model_obj_name="model.pkl",
|
||||
):
|
||||
|
||||
model = init_instance_by_config(task_config["model"])
|
||||
model_fit_kwargs = dict(dataset=dataset)
|
||||
|
||||
# Let's start the experiment.
|
||||
with R.start(
|
||||
experiment_name=experiment_name,
|
||||
recorder_name=recorder_name,
|
||||
uri=uri,
|
||||
resume=True,
|
||||
):
|
||||
# Setup log
|
||||
recorder_root_dir = R.get_recorder().get_local_dir()
|
||||
log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name))
|
||||
|
||||
set_log_basic_config(log_file)
|
||||
logger = get_module_logger("q.run_exp")
|
||||
logger.info("task_config::\n{:}".format(pprint.pformat(task_config, indent=2)))
|
||||
logger.info("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri))
|
||||
logger.info("dataset={:}".format(dataset))
|
||||
|
||||
# Train model
|
||||
try:
|
||||
if hasattr(model, "to"): # Recoverable model
|
||||
ori_device = model.device
|
||||
model = R.load_object(model_obj_name)
|
||||
model.to(ori_device)
|
||||
else:
|
||||
model = R.load_object(model_obj_name)
|
||||
logger.info("[Find existing object from {:}]".format(model_obj_name))
|
||||
except OSError:
|
||||
R.log_params(**flatten_dict(update_gpu(task_config, None)))
|
||||
if "save_path" in inspect.getfullargspec(model.fit).args:
|
||||
model_fit_kwargs["save_path"] = os.path.join(
|
||||
recorder_root_dir, "model.ckp"
|
||||
)
|
||||
elif "save_dir" in inspect.getfullargspec(model.fit).args:
|
||||
model_fit_kwargs["save_dir"] = os.path.join(
|
||||
recorder_root_dir, "model-ckps"
|
||||
)
|
||||
model.fit(**model_fit_kwargs)
|
||||
# remove model to CPU for saving
|
||||
if hasattr(model, "to"):
|
||||
old_device = model.device
|
||||
model.to("cpu")
|
||||
R.save_objects(**{model_obj_name: model})
|
||||
model.to(old_device)
|
||||
else:
|
||||
R.save_objects(**{model_obj_name: model})
|
||||
except Exception as e:
|
||||
raise ValueError("Something wrong: {:}".format(e))
|
||||
# Get the recorder
|
||||
recorder = R.get_recorder()
|
||||
|
||||
# Generate records: prediction, backtest, and analysis
|
||||
for record in task_config["record"]:
|
||||
record = deepcopy(record)
|
||||
if record["class"] == "MultiSegRecord":
|
||||
record["kwargs"] = dict(model=model, dataset=dataset, recorder=recorder)
|
||||
sr = init_instance_by_config(record)
|
||||
sr.generate(**record["generate_kwargs"])
|
||||
elif record["class"] == "SignalRecord":
|
||||
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||
record["kwargs"].update(srconf)
|
||||
sr = init_instance_by_config(record)
|
||||
sr.generate()
|
||||
else:
|
||||
rconf = {"recorder": recorder}
|
||||
record["kwargs"].update(rconf)
|
||||
ar = init_instance_by_config(record)
|
||||
ar.generate()
|
198
xautodl/procedures/search_main.py
Normal file
198
xautodl/procedures/search_main.py
Normal file
@@ -0,0 +1,198 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from log_utils import AverageMeter, time_string
|
||||
from models import change_key
|
||||
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||
expected_flop = torch.mean(expected_flop)
|
||||
|
||||
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
|
||||
loss = -torch.log(expected_flop)
|
||||
# elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
|
||||
elif flop_cur > flop_need: # Too Large FLOP
|
||||
loss = torch.log(expected_flop)
|
||||
else: # Required FLOP
|
||||
loss = None
|
||||
if loss is None:
|
||||
return 0, 0
|
||||
else:
|
||||
return loss, loss.item()
|
||||
|
||||
|
||||
def search_train(
|
||||
search_loader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
base_optimizer,
|
||||
arch_optimizer,
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = (
|
||||
extra_info["epoch-str"],
|
||||
extra_info["FLOP-exp"],
|
||||
extra_info["FLOP-weight"],
|
||||
extra_info["FLOP-tolerant"],
|
||||
)
|
||||
|
||||
network.train()
|
||||
logger.log(
|
||||
"[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
|
||||
epoch_str, flop_need, flop_weight
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
network.apply(change_key("search_mode", "search"))
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
|
||||
search_loader
|
||||
):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
base_optimizer.zero_grad()
|
||||
logits, expected_flop = network(base_inputs)
|
||||
# network.apply( change_key('search_mode', 'basic') )
|
||||
# features, logits = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
base_optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
top1.update(prec1.item(), base_inputs.size(0))
|
||||
top5.update(prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop("genotype", None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(
|
||||
expected_flop, flop_cur, flop_need, flop_tolerant
|
||||
)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
|
||||
arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step + 1) == len(search_loader):
|
||||
Sstr = (
|
||||
"**TRAIN** "
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
||||
)
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=base_losses, top1=top1, top5=top5
|
||||
)
|
||||
Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format(
|
||||
aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses
|
||||
)
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr)
|
||||
# Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
|
||||
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
|
||||
# print(network.module.get_arch_info())
|
||||
# print(network.module.width_attentions[0])
|
||||
# print(network.module.width_attentions[1])
|
||||
|
||||
logger.log(
|
||||
" **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format(
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
baseloss=base_losses.avg,
|
||||
archloss=arch_losses.avg,
|
||||
)
|
||||
)
|
||||
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
|
||||
|
||||
|
||||
def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
|
||||
data_time, batch_time, losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
|
||||
network.eval()
|
||||
network.apply(change_key("search_mode", "search"))
|
||||
end = time.time()
|
||||
# logger.log('Starting evaluating {:}'.format(epoch_info))
|
||||
with torch.no_grad():
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
logits, expected_flop = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i + 1) == len(xloader):
|
||||
Sstr = (
|
||||
"**VALID** "
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
||||
)
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=losses, top1=top1, top5=top5
|
||||
)
|
||||
Istr = "Size={:}".format(list(inputs.size()))
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
|
||||
|
||||
logger.log(
|
||||
" **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
loss=losses.avg,
|
||||
)
|
||||
)
|
||||
|
||||
return losses.avg, top1.avg, top5.avg
|
139
xautodl/procedures/search_main_v2.py
Normal file
139
xautodl/procedures/search_main_v2.py
Normal file
@@ -0,0 +1,139 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter, time_string
|
||||
from models import change_key
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||
expected_flop = torch.mean(expected_flop)
|
||||
|
||||
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
|
||||
loss = -torch.log(expected_flop)
|
||||
# elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
|
||||
elif flop_cur > flop_need: # Too Large FLOP
|
||||
loss = torch.log(expected_flop)
|
||||
else: # Required FLOP
|
||||
loss = None
|
||||
if loss is None:
|
||||
return 0, 0
|
||||
else:
|
||||
return loss, loss.item()
|
||||
|
||||
|
||||
def search_train_v2(
|
||||
search_loader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
base_optimizer,
|
||||
arch_optimizer,
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = (
|
||||
extra_info["epoch-str"],
|
||||
extra_info["FLOP-exp"],
|
||||
extra_info["FLOP-weight"],
|
||||
extra_info["FLOP-tolerant"],
|
||||
)
|
||||
|
||||
network.train()
|
||||
logger.log(
|
||||
"[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
|
||||
epoch_str, flop_need, flop_weight
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
network.apply(change_key("search_mode", "search"))
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
|
||||
search_loader
|
||||
):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
base_optimizer.zero_grad()
|
||||
logits, expected_flop = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
base_optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
top1.update(prec1.item(), base_inputs.size(0))
|
||||
top5.update(prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop("genotype", None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(
|
||||
expected_flop, flop_cur, flop_need, flop_tolerant
|
||||
)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
|
||||
arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step + 1) == len(search_loader):
|
||||
Sstr = (
|
||||
"**TRAIN** "
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
||||
)
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=base_losses, top1=top1, top5=top5
|
||||
)
|
||||
Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format(
|
||||
aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses
|
||||
)
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr)
|
||||
# num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0
|
||||
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6))
|
||||
# Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
|
||||
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
|
||||
# print(network.module.get_arch_info())
|
||||
# print(network.module.width_attentions[0])
|
||||
# print(network.module.width_attentions[1])
|
||||
|
||||
logger.log(
|
||||
" **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format(
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
baseloss=base_losses.avg,
|
||||
archloss=arch_losses.avg,
|
||||
)
|
||||
)
|
||||
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
|
204
xautodl/procedures/simple_KD_main.py
Normal file
204
xautodl/procedures/simple_KD_main.py
Normal file
@@ -0,0 +1,204 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import os, sys, time, torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter, time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def simple_KD_train(
|
||||
xloader,
|
||||
teacher,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader,
|
||||
teacher,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
"train",
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def simple_KD_valid(
|
||||
xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger
|
||||
):
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader,
|
||||
teacher,
|
||||
network,
|
||||
criterion,
|
||||
None,
|
||||
None,
|
||||
"valid",
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def loss_KD_fn(
|
||||
criterion,
|
||||
student_logits,
|
||||
teacher_logits,
|
||||
studentFeatures,
|
||||
teacherFeatures,
|
||||
targets,
|
||||
alpha,
|
||||
temperature,
|
||||
):
|
||||
basic_loss = criterion(student_logits, targets) * (1.0 - alpha)
|
||||
log_student = F.log_softmax(student_logits / temperature, dim=1)
|
||||
sof_teacher = F.softmax(teacher_logits / temperature, dim=1)
|
||||
KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (
|
||||
alpha * temperature * temperature
|
||||
)
|
||||
return basic_loss + KD_loss
|
||||
|
||||
|
||||
def procedure(
|
||||
xloader,
|
||||
teacher,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
mode,
|
||||
config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time, losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
Ttop1, Ttop5 = AverageMeter(), AverageMeter()
|
||||
if mode == "train":
|
||||
network.train()
|
||||
elif mode == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
teacher.eval()
|
||||
|
||||
logger.log(
|
||||
"[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format(
|
||||
mode,
|
||||
config.auxiliary if hasattr(config, "auxiliary") else -1,
|
||||
config.KD_alpha,
|
||||
config.KD_temperature,
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == "train":
|
||||
scheduler.update(None, 1.0 * i / len(xloader))
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
|
||||
student_f, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
|
||||
2, len(logits)
|
||||
)
|
||||
logits, logits_aux = logits
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
with torch.no_grad():
|
||||
teacher_f, teacher_logits = teacher(inputs)
|
||||
|
||||
loss = loss_KD_fn(
|
||||
criterion,
|
||||
logits,
|
||||
teacher_logits,
|
||||
student_f,
|
||||
teacher_f,
|
||||
targets,
|
||||
config.KD_alpha,
|
||||
config.KD_temperature,
|
||||
)
|
||||
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary * loss_aux
|
||||
|
||||
if mode == "train":
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(sprec1.item(), inputs.size(0))
|
||||
top5.update(sprec5.item(), inputs.size(0))
|
||||
# teacher
|
||||
tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5))
|
||||
Ttop1.update(tprec1.item(), inputs.size(0))
|
||||
Ttop5.update(tprec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i + 1) == len(xloader):
|
||||
Sstr = (
|
||||
" {:5s} ".format(mode.upper())
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
||||
)
|
||||
if scheduler is not None:
|
||||
Sstr += " {:}".format(scheduler.get_min_info())
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=losses, top1=top1, top5=top5
|
||||
)
|
||||
Lstr += " Teacher : acc@1={:.2f}, acc@5={:.2f}".format(Ttop1.avg, Ttop5.avg)
|
||||
Istr = "Size={:}".format(list(inputs.size()))
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
|
||||
|
||||
logger.log(
|
||||
" **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}".format(
|
||||
mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg
|
||||
)
|
||||
)
|
||||
logger.log(
|
||||
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||
mode=mode.upper(),
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
loss=losses.avg,
|
||||
)
|
||||
)
|
||||
return losses.avg, top1.avg, top5.avg
|
79
xautodl/procedures/starts.py
Normal file
79
xautodl/procedures/starts.py
Normal file
@@ -0,0 +1,79 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, torch, random, PIL, copy, numpy as np
|
||||
from os import path as osp
|
||||
from shutil import copyfile
|
||||
|
||||
|
||||
def prepare_seed(rand_seed):
|
||||
random.seed(rand_seed)
|
||||
np.random.seed(rand_seed)
|
||||
torch.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed_all(rand_seed)
|
||||
|
||||
|
||||
def prepare_logger(xargs):
|
||||
args = copy.deepcopy(xargs)
|
||||
from log_utils import Logger
|
||||
|
||||
logger = Logger(args.save_dir, args.rand_seed)
|
||||
logger.log("Main Function with logger : {:}".format(logger))
|
||||
logger.log("Arguments : -------------------------------")
|
||||
for name, value in args._get_kwargs():
|
||||
logger.log("{:16} : {:}".format(name, value))
|
||||
logger.log("Python Version : {:}".format(sys.version.replace("\n", " ")))
|
||||
logger.log("Pillow Version : {:}".format(PIL.__version__))
|
||||
logger.log("PyTorch Version : {:}".format(torch.__version__))
|
||||
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
|
||||
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
|
||||
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
|
||||
logger.log(
|
||||
"CUDA_VISIBLE_DEVICES : {:}".format(
|
||||
os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ
|
||||
else "None"
|
||||
)
|
||||
)
|
||||
return logger
|
||||
|
||||
|
||||
def get_machine_info():
|
||||
info = "Python Version : {:}".format(sys.version.replace("\n", " "))
|
||||
info += "\nPillow Version : {:}".format(PIL.__version__)
|
||||
info += "\nPyTorch Version : {:}".format(torch.__version__)
|
||||
info += "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
|
||||
info += "\nCUDA available : {:}".format(torch.cuda.is_available())
|
||||
info += "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
info += "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ["CUDA_VISIBLE_DEVICES"])
|
||||
else:
|
||||
info += "\nDoes not set CUDA_VISIBLE_DEVICES"
|
||||
return info
|
||||
|
||||
|
||||
def save_checkpoint(state, filename, logger):
|
||||
if osp.isfile(filename):
|
||||
if hasattr(logger, "log"):
|
||||
logger.log(
|
||||
"Find {:} exist, delete is at first before saving".format(filename)
|
||||
)
|
||||
os.remove(filename)
|
||||
torch.save(state, filename)
|
||||
assert osp.isfile(
|
||||
filename
|
||||
), "save filename : {:} failed, which is not found.".format(filename)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("save checkpoint into {:}".format(filename))
|
||||
return filename
|
||||
|
||||
|
||||
def copy_checkpoint(src, dst, logger):
|
||||
if osp.isfile(dst):
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("Find {:} exist, delete is at first before saving".format(dst))
|
||||
os.remove(dst)
|
||||
copyfile(src, dst)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("copy the file from {:} into {:}".format(src, dst))
|
Reference in New Issue
Block a user