Update q-config and black for procedures/utils
This commit is contained in:
@@ -1,25 +1,36 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .starts import prepare_seed, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint
|
||||
from .starts import prepare_seed
|
||||
from .starts import prepare_logger
|
||||
from .starts import get_machine_info
|
||||
from .starts import save_checkpoint
|
||||
from .starts import copy_checkpoint
|
||||
from .optimizers import get_optim_scheduler
|
||||
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
|
||||
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
|
||||
from .funcs_nasbench import get_nas_bench_loaders
|
||||
|
||||
def get_procedures(procedure):
|
||||
from .basic_main import basic_train, basic_valid
|
||||
from .search_main import search_train, search_valid
|
||||
from .search_main_v2 import search_train_v2
|
||||
from .simple_KD_main import simple_KD_train, simple_KD_valid
|
||||
|
||||
train_funcs = {'basic' : basic_train, \
|
||||
'search': search_train,'Simple-KD': simple_KD_train, \
|
||||
'search-v2': search_train_v2}
|
||||
valid_funcs = {'basic' : basic_valid, \
|
||||
'search': search_valid,'Simple-KD': simple_KD_valid, \
|
||||
'search-v2': search_valid}
|
||||
|
||||
train_func = train_funcs[procedure]
|
||||
valid_func = valid_funcs[procedure]
|
||||
return train_func, valid_func
|
||||
def get_procedures(procedure):
|
||||
from .basic_main import basic_train, basic_valid
|
||||
from .search_main import search_train, search_valid
|
||||
from .search_main_v2 import search_train_v2
|
||||
from .simple_KD_main import simple_KD_train, simple_KD_valid
|
||||
|
||||
train_funcs = {
|
||||
"basic": basic_train,
|
||||
"search": search_train,
|
||||
"Simple-KD": simple_KD_train,
|
||||
"search-v2": search_train_v2,
|
||||
}
|
||||
valid_funcs = {
|
||||
"basic": basic_valid,
|
||||
"search": search_valid,
|
||||
"Simple-KD": simple_KD_valid,
|
||||
"search-v2": search_valid,
|
||||
}
|
||||
|
||||
train_func = train_funcs[procedure]
|
||||
valid_func = valid_funcs[procedure]
|
||||
return train_func, valid_func
|
||||
|
@@ -3,73 +3,100 @@
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
from utils import obtain_accuracy
|
||||
|
||||
|
||||
def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
|
||||
loss, acc1, acc5 = procedure(xloader, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger)
|
||||
return loss, acc1, acc5
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader, network, criterion, scheduler, optimizer, "train", optim_config, extra_info, print_freq, logger
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger):
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(xloader, network, criterion, None, None, 'valid', None, extra_info, print_freq, logger)
|
||||
return loss, acc1, acc5
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader, network, criterion, None, None, "valid", None, extra_info, print_freq, logger
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
|
||||
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == 'train':
|
||||
network.train()
|
||||
elif mode == 'valid':
|
||||
network.eval()
|
||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
#logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
|
||||
logger.log('[{:5s}] config :: auxiliary={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1))
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mode == 'train': optimizer.zero_grad()
|
||||
|
||||
features, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits))
|
||||
logits, logits_aux = logits
|
||||
data_time, batch_time, losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
if mode == "train":
|
||||
network.train()
|
||||
elif mode == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
loss = criterion(logits, targets)
|
||||
if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary * loss_aux
|
||||
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
# logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
|
||||
logger.log(
|
||||
"[{:5s}] config :: auxiliary={:}".format(mode, config.auxiliary if hasattr(config, "auxiliary") else -1)
|
||||
)
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == "train":
|
||||
scheduler.update(None, 1.0 * i / len(xloader))
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if i % print_freq == 0 or (i+1) == len(xloader):
|
||||
Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
|
||||
if scheduler is not None:
|
||||
Sstr += ' {:}'.format(scheduler.get_min_info())
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
|
||||
Istr = 'Size={:}'.format(list(inputs.size()))
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
|
||||
logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
features, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits))
|
||||
logits, logits_aux = logits
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
loss = criterion(logits, targets)
|
||||
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary * loss_aux
|
||||
|
||||
if mode == "train":
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i + 1) == len(xloader):
|
||||
Sstr = (
|
||||
" {:5s} ".format(mode.upper())
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
||||
)
|
||||
if scheduler is not None:
|
||||
Sstr += " {:}".format(scheduler.get_min_info())
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=losses, top1=top1, top5=top5
|
||||
)
|
||||
Istr = "Size={:}".format(list(inputs.size()))
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
|
||||
|
||||
logger.log(
|
||||
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||
mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
|
||||
)
|
||||
)
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
|
@@ -5,199 +5,348 @@ import os, time, copy, torch, pathlib
|
||||
|
||||
import datasets
|
||||
from config_utils import load_config
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net
|
||||
|
||||
|
||||
__all__ = ['evaluate_for_seed', 'pure_evaluate', 'get_nas_bench_loaders']
|
||||
__all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"]
|
||||
|
||||
|
||||
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
latencies, device = [], torch.cuda.current_device()
|
||||
network.eval()
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
targets = targets.cuda(device=device, non_blocking=True)
|
||||
inputs = inputs.cuda(device=device, non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
batch_time.update(time.time() - end)
|
||||
if batch is None or batch == inputs.size(0):
|
||||
batch = inputs.size(0)
|
||||
latencies.append( batch_time.val - data_time.val )
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
end = time.time()
|
||||
if len(latencies) > 2: latencies = latencies[1:]
|
||||
return losses.avg, top1.avg, top5.avg, latencies
|
||||
|
||||
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
latencies, device = [], torch.cuda.current_device()
|
||||
network.eval()
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
targets = targets.cuda(device=device, non_blocking=True)
|
||||
inputs = inputs.cuda(device=device, non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
batch_time.update(time.time() - end)
|
||||
if batch is None or batch == inputs.size(0):
|
||||
batch = inputs.size(0)
|
||||
latencies.append(batch_time.val - data_time.val)
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
end = time.time()
|
||||
if len(latencies) > 2:
|
||||
latencies = latencies[1:]
|
||||
return losses.avg, top1.avg, top5.avg, latencies
|
||||
|
||||
|
||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == 'train' : network.train()
|
||||
elif mode == 'valid': network.eval()
|
||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||
device = torch.cuda.current_device()
|
||||
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == "train":
|
||||
network.train()
|
||||
elif mode == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
device = torch.cuda.current_device()
|
||||
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == "train":
|
||||
scheduler.update(None, 1.0 * i / len(xloader))
|
||||
|
||||
targets = targets.cuda(device=device, non_blocking=True)
|
||||
if mode == 'train': optimizer.zero_grad()
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# backward
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
# count time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||
targets = targets.cuda(device=device, non_blocking=True)
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# backward
|
||||
if mode == "train":
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
# count time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||
|
||||
|
||||
def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger):
|
||||
|
||||
prepare_seed(seed) # random seed
|
||||
net = get_cell_based_tiny_net(arch_config)
|
||||
#net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||
flop, param = get_model_infos(net, opt_config.xshape)
|
||||
logger.log('Network : {:}'.format(net.get_message()), False)
|
||||
logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed))
|
||||
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
|
||||
# train and valid
|
||||
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
|
||||
default_device = torch.cuda.current_device()
|
||||
network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device)
|
||||
criterion = criterion.cuda(device=default_device)
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup
|
||||
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
|
||||
train_times , valid_times, lrs = {}, {}, {}
|
||||
for epoch in range(total_epoch):
|
||||
scheduler.update(epoch, 0.0)
|
||||
lr = min(scheduler.get_lr())
|
||||
train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
|
||||
train_losses[epoch] = train_loss
|
||||
train_acc1es[epoch] = train_acc1
|
||||
train_acc5es[epoch] = train_acc5
|
||||
train_times [epoch] = train_tm
|
||||
lrs[epoch] = lr
|
||||
with torch.no_grad():
|
||||
for key, xloder in valid_loaders.items():
|
||||
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder , network, criterion, None, None, 'valid')
|
||||
valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss
|
||||
valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1
|
||||
valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5
|
||||
valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm
|
||||
prepare_seed(seed) # random seed
|
||||
net = get_cell_based_tiny_net(arch_config)
|
||||
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||
flop, param = get_model_infos(net, opt_config.xshape)
|
||||
logger.log("Network : {:}".format(net.get_message()), False)
|
||||
logger.log("{:} Seed-------------------------- {:} --------------------------".format(time_string(), seed))
|
||||
logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param))
|
||||
# train and valid
|
||||
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
|
||||
default_device = torch.cuda.current_device()
|
||||
network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device)
|
||||
criterion = criterion.cuda(device=default_device)
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup
|
||||
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
|
||||
train_times, valid_times, lrs = {}, {}, {}
|
||||
for epoch in range(total_epoch):
|
||||
scheduler.update(epoch, 0.0)
|
||||
lr = min(scheduler.get_lr())
|
||||
train_loss, train_acc1, train_acc5, train_tm = procedure(
|
||||
train_loader, network, criterion, scheduler, optimizer, "train"
|
||||
)
|
||||
train_losses[epoch] = train_loss
|
||||
train_acc1es[epoch] = train_acc1
|
||||
train_acc5es[epoch] = train_acc5
|
||||
train_times[epoch] = train_tm
|
||||
lrs[epoch] = lr
|
||||
with torch.no_grad():
|
||||
for key, xloder in valid_loaders.items():
|
||||
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(
|
||||
xloder, network, criterion, None, None, "valid"
|
||||
)
|
||||
valid_losses["{:}@{:}".format(key, epoch)] = valid_loss
|
||||
valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1
|
||||
valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5
|
||||
valid_times["{:}@{:}".format(key, epoch)] = valid_tm
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) )
|
||||
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5, lr))
|
||||
info_seed = {'flop' : flop,
|
||||
'param': param,
|
||||
'arch_config' : arch_config._asdict(),
|
||||
'opt_config' : opt_config._asdict(),
|
||||
'total_epoch' : total_epoch ,
|
||||
'train_losses': train_losses,
|
||||
'train_acc1es': train_acc1es,
|
||||
'train_acc5es': train_acc5es,
|
||||
'train_times' : train_times,
|
||||
'valid_losses': valid_losses,
|
||||
'valid_acc1es': valid_acc1es,
|
||||
'valid_acc5es': valid_acc5es,
|
||||
'valid_times' : valid_times,
|
||||
'learning_rates': lrs,
|
||||
'net_state_dict': net.state_dict(),
|
||||
'net_string' : '{:}'.format(net),
|
||||
'finish-train': True
|
||||
}
|
||||
return info_seed
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True))
|
||||
logger.log(
|
||||
"{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format(
|
||||
time_string(),
|
||||
need_time,
|
||||
epoch,
|
||||
total_epoch,
|
||||
train_loss,
|
||||
train_acc1,
|
||||
train_acc5,
|
||||
valid_loss,
|
||||
valid_acc1,
|
||||
valid_acc5,
|
||||
lr,
|
||||
)
|
||||
)
|
||||
info_seed = {
|
||||
"flop": flop,
|
||||
"param": param,
|
||||
"arch_config": arch_config._asdict(),
|
||||
"opt_config": opt_config._asdict(),
|
||||
"total_epoch": total_epoch,
|
||||
"train_losses": train_losses,
|
||||
"train_acc1es": train_acc1es,
|
||||
"train_acc5es": train_acc5es,
|
||||
"train_times": train_times,
|
||||
"valid_losses": valid_losses,
|
||||
"valid_acc1es": valid_acc1es,
|
||||
"valid_acc5es": valid_acc5es,
|
||||
"valid_times": valid_times,
|
||||
"learning_rates": lrs,
|
||||
"net_state_dict": net.state_dict(),
|
||||
"net_string": "{:}".format(net),
|
||||
"finish-train": True,
|
||||
}
|
||||
return info_seed
|
||||
|
||||
|
||||
def get_nas_bench_loaders(workers):
|
||||
|
||||
torch.set_num_threads(workers)
|
||||
torch.set_num_threads(workers)
|
||||
|
||||
root_dir = (pathlib.Path(__file__).parent / '..' / '..').resolve()
|
||||
torch_dir = pathlib.Path(os.environ['TORCH_HOME'])
|
||||
# cifar
|
||||
cifar_config_path = root_dir / 'configs' / 'nas-benchmark' / 'CIFAR.config'
|
||||
cifar_config = load_config(cifar_config_path, None, None)
|
||||
get_datasets = datasets.get_datasets # a function to return the dataset
|
||||
break_line = '-' * 150
|
||||
print ('{:} Create data-loader for all datasets'.format(time_string()))
|
||||
print (break_line)
|
||||
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets('cifar10', str(torch_dir/'cifar.python'), -1)
|
||||
print ('original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num))
|
||||
cifar10_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar-split.txt', None, None)
|
||||
assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [1, 2, 3, 4, 6, 8, 9, 10, 12, 14]
|
||||
temp_dataset = copy.deepcopy(TRAIN_CIFAR10)
|
||||
temp_dataset.transform = VALID_CIFAR10.transform
|
||||
# data loader
|
||||
trainval_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
|
||||
train_cifar10_loader = torch.utils.data.DataLoader(TRAIN_CIFAR10, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train), num_workers=workers, pin_memory=True)
|
||||
valid_cifar10_loader = torch.utils.data.DataLoader(temp_dataset , batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid), num_workers=workers, pin_memory=True)
|
||||
test__cifar10_loader = torch.utils.data.DataLoader(VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
|
||||
print ('CIFAR-10 : trval-loader has {:3d} batch with {:} per batch'.format(len(trainval_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : train-loader has {:3d} batch with {:} per batch'.format(len(train_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_cifar10_loader), cifar_config.batch_size))
|
||||
print ('CIFAR-10 : test--loader has {:3d} batch with {:} per batch'.format(len(test__cifar10_loader), cifar_config.batch_size))
|
||||
print (break_line)
|
||||
# CIFAR-100
|
||||
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets('cifar100', str(torch_dir/'cifar.python'), -1)
|
||||
print ('original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num))
|
||||
cifar100_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'cifar100-test-split.txt', None, None)
|
||||
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [0, 2, 6, 7, 9, 11, 12, 17, 20, 24]
|
||||
train_cifar100_loader = torch.utils.data.DataLoader(TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
|
||||
valid_cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True)
|
||||
test__cifar100_loader = torch.utils.data.DataLoader(VALID_CIFAR100, batch_size=cifar_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest) , num_workers=workers, pin_memory=True)
|
||||
print ('CIFAR-100 : train-loader has {:3d} batch'.format(len(train_cifar100_loader)))
|
||||
print ('CIFAR-100 : valid-loader has {:3d} batch'.format(len(valid_cifar100_loader)))
|
||||
print ('CIFAR-100 : test--loader has {:3d} batch'.format(len(test__cifar100_loader)))
|
||||
print (break_line)
|
||||
root_dir = (pathlib.Path(__file__).parent / ".." / "..").resolve()
|
||||
torch_dir = pathlib.Path(os.environ["TORCH_HOME"])
|
||||
# cifar
|
||||
cifar_config_path = root_dir / "configs" / "nas-benchmark" / "CIFAR.config"
|
||||
cifar_config = load_config(cifar_config_path, None, None)
|
||||
get_datasets = datasets.get_datasets # a function to return the dataset
|
||||
break_line = "-" * 150
|
||||
print("{:} Create data-loader for all datasets".format(time_string()))
|
||||
print(break_line)
|
||||
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets("cifar10", str(torch_dir / "cifar.python"), -1)
|
||||
print(
|
||||
"original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||
len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num
|
||||
)
|
||||
)
|
||||
cifar10_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None)
|
||||
assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
6,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
12,
|
||||
14,
|
||||
]
|
||||
temp_dataset = copy.deepcopy(TRAIN_CIFAR10)
|
||||
temp_dataset.transform = VALID_CIFAR10.transform
|
||||
# data loader
|
||||
trainval_cifar10_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True
|
||||
)
|
||||
train_cifar10_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_CIFAR10,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_cifar10_loader = torch.utils.data.DataLoader(
|
||||
temp_dataset,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
test__cifar10_loader = torch.utils.data.DataLoader(
|
||||
VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : trval-loader has {:3d} batch with {:} per batch".format(
|
||||
len(trainval_cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : train-loader has {:3d} batch with {:} per batch".format(
|
||||
len(train_cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : valid-loader has {:3d} batch with {:} per batch".format(
|
||||
len(valid_cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : test--loader has {:3d} batch with {:} per batch".format(
|
||||
len(test__cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(break_line)
|
||||
# CIFAR-100
|
||||
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets("cifar100", str(torch_dir / "cifar.python"), -1)
|
||||
print(
|
||||
"original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||
len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num
|
||||
)
|
||||
)
|
||||
cifar100_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None)
|
||||
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [
|
||||
0,
|
||||
2,
|
||||
6,
|
||||
7,
|
||||
9,
|
||||
11,
|
||||
12,
|
||||
17,
|
||||
20,
|
||||
24,
|
||||
]
|
||||
train_cifar100_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True
|
||||
)
|
||||
valid_cifar100_loader = torch.utils.data.DataLoader(
|
||||
VALID_CIFAR100,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
test__cifar100_loader = torch.utils.data.DataLoader(
|
||||
VALID_CIFAR100,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print("CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader)))
|
||||
print("CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader)))
|
||||
print("CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader)))
|
||||
print(break_line)
|
||||
|
||||
imagenet16_config_path = 'configs/nas-benchmark/ImageNet-16.config'
|
||||
imagenet16_config = load_config(imagenet16_config_path, None, None)
|
||||
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets('ImageNet16-120', str(torch_dir/'cifar.python'/'ImageNet16'), -1)
|
||||
print ('original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes'.format(len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num))
|
||||
imagenet_splits = load_config(root_dir / 'configs' / 'nas-benchmark' / 'imagenet-16-120-test-split.txt', None, None)
|
||||
assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [0, 4, 5, 10, 11, 13, 14, 15, 17, 20]
|
||||
train_imagenet_loader = torch.utils.data.DataLoader(TRAIN_ImageNet16_120, batch_size=imagenet16_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True)
|
||||
valid_imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid), num_workers=workers, pin_memory=True)
|
||||
test__imagenet_loader = torch.utils.data.DataLoader(VALID_ImageNet16_120, batch_size=imagenet16_config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest) , num_workers=workers, pin_memory=True)
|
||||
print ('ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch'.format(len(train_imagenet_loader), imagenet16_config.batch_size))
|
||||
print ('ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch'.format(len(valid_imagenet_loader), imagenet16_config.batch_size))
|
||||
print ('ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch'.format(len(test__imagenet_loader), imagenet16_config.batch_size))
|
||||
imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config"
|
||||
imagenet16_config = load_config(imagenet16_config_path, None, None)
|
||||
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets(
|
||||
"ImageNet16-120", str(torch_dir / "cifar.python" / "ImageNet16"), -1
|
||||
)
|
||||
print(
|
||||
"original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||
len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num
|
||||
)
|
||||
)
|
||||
imagenet_splits = load_config(root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", None, None)
|
||||
assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [
|
||||
0,
|
||||
4,
|
||||
5,
|
||||
10,
|
||||
11,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
17,
|
||||
20,
|
||||
]
|
||||
train_imagenet_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_ImageNet16_120,
|
||||
batch_size=imagenet16_config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_imagenet_loader = torch.utils.data.DataLoader(
|
||||
VALID_ImageNet16_120,
|
||||
batch_size=imagenet16_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
test__imagenet_loader = torch.utils.data.DataLoader(
|
||||
VALID_ImageNet16_120,
|
||||
batch_size=imagenet16_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print(
|
||||
"ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch".format(
|
||||
len(train_imagenet_loader), imagenet16_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch".format(
|
||||
len(valid_imagenet_loader), imagenet16_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch".format(
|
||||
len(test__imagenet_loader), imagenet16_config.batch_size
|
||||
)
|
||||
)
|
||||
|
||||
# 'cifar10', 'cifar100', 'ImageNet16-120'
|
||||
loaders = {'cifar10@trainval': trainval_cifar10_loader,
|
||||
'cifar10@train' : train_cifar10_loader,
|
||||
'cifar10@valid' : valid_cifar10_loader,
|
||||
'cifar10@test' : test__cifar10_loader,
|
||||
'cifar100@train' : train_cifar100_loader,
|
||||
'cifar100@valid' : valid_cifar100_loader,
|
||||
'cifar100@test' : test__cifar100_loader,
|
||||
'ImageNet16-120@train': train_imagenet_loader,
|
||||
'ImageNet16-120@valid': valid_imagenet_loader,
|
||||
'ImageNet16-120@test' : test__imagenet_loader}
|
||||
return loaders
|
||||
# 'cifar10', 'cifar100', 'ImageNet16-120'
|
||||
loaders = {
|
||||
"cifar10@trainval": trainval_cifar10_loader,
|
||||
"cifar10@train": train_cifar10_loader,
|
||||
"cifar10@valid": valid_cifar10_loader,
|
||||
"cifar10@test": test__cifar10_loader,
|
||||
"cifar100@train": train_cifar100_loader,
|
||||
"cifar100@valid": valid_cifar100_loader,
|
||||
"cifar100@test": test__cifar100_loader,
|
||||
"ImageNet16-120@train": train_imagenet_loader,
|
||||
"ImageNet16-120@valid": valid_imagenet_loader,
|
||||
"ImageNet16-120@test": test__imagenet_loader,
|
||||
}
|
||||
return loaders
|
||||
|
@@ -8,197 +8,201 @@ from torch.optim import Optimizer
|
||||
|
||||
|
||||
class _LRScheduler(object):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs):
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__))
|
||||
self.optimizer = optimizer
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("initial_lr", group["lr"])
|
||||
self.base_lrs = list(map(lambda group: group["initial_lr"], optimizer.param_groups))
|
||||
self.max_epochs = epochs
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.current_epoch = 0
|
||||
self.current_iter = 0
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs):
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError('{:} is not an Optimizer'.format(type(optimizer).__name__))
|
||||
self.optimizer = optimizer
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault('initial_lr', group['lr'])
|
||||
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
|
||||
self.max_epochs = epochs
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.current_epoch = 0
|
||||
self.current_iter = 0
|
||||
def extra_repr(self):
|
||||
return ""
|
||||
|
||||
def extra_repr(self):
|
||||
return ''
|
||||
def __repr__(self):
|
||||
return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
) + ", {:})".format(
|
||||
self.extra_repr()
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}'.format(name=self.__class__.__name__, **self.__dict__)
|
||||
+ ', {:})'.format(self.extra_repr()))
|
||||
def state_dict(self):
|
||||
return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
|
||||
|
||||
def state_dict(self):
|
||||
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
def get_lr(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_lr(self):
|
||||
raise NotImplementedError
|
||||
def get_min_info(self):
|
||||
lrs = self.get_lr()
|
||||
return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format(
|
||||
min(lrs), max(lrs), self.current_epoch, self.current_iter
|
||||
)
|
||||
|
||||
def get_min_info(self):
|
||||
lrs = self.get_lr()
|
||||
return '#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#'.format(min(lrs), max(lrs), self.current_epoch, self.current_iter)
|
||||
|
||||
def get_min_lr(self):
|
||||
return min( self.get_lr() )
|
||||
|
||||
def update(self, cur_epoch, cur_iter):
|
||||
if cur_epoch is not None:
|
||||
assert isinstance(cur_epoch, int) and cur_epoch>=0, 'invalid cur-epoch : {:}'.format(cur_epoch)
|
||||
self.current_epoch = cur_epoch
|
||||
if cur_iter is not None:
|
||||
assert isinstance(cur_iter, float) and cur_iter>=0, 'invalid cur-iter : {:}'.format(cur_iter)
|
||||
self.current_iter = cur_iter
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group['lr'] = lr
|
||||
def get_min_lr(self):
|
||||
return min(self.get_lr())
|
||||
|
||||
def update(self, cur_epoch, cur_iter):
|
||||
if cur_epoch is not None:
|
||||
assert isinstance(cur_epoch, int) and cur_epoch >= 0, "invalid cur-epoch : {:}".format(cur_epoch)
|
||||
self.current_epoch = cur_epoch
|
||||
if cur_iter is not None:
|
||||
assert isinstance(cur_iter, float) and cur_iter >= 0, "invalid cur-iter : {:}".format(cur_iter)
|
||||
self.current_iter = cur_iter
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
class CosineAnnealingLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, T-max={:}, eta-min={:}'.format('cosine', self.T_max, self.eta_min)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
#if last_epoch < self.T_max:
|
||||
#if last_epoch < self.max_epochs:
|
||||
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
|
||||
#else:
|
||||
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
|
||||
elif self.current_epoch >= self.max_epochs:
|
||||
lr = self.eta_min
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
def extra_repr(self):
|
||||
return "type={:}, T-max={:}, eta-min={:}".format("cosine", self.T_max, self.eta_min)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
# if last_epoch < self.T_max:
|
||||
# if last_epoch < self.max_epochs:
|
||||
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
|
||||
# else:
|
||||
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
|
||||
elif self.current_epoch >= self.max_epochs:
|
||||
lr = self.eta_min
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class MultiStepLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
|
||||
assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(len(milestones), len(gammas))
|
||||
self.milestones = milestones
|
||||
self.gammas = gammas
|
||||
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
|
||||
assert len(milestones) == len(gammas), 'invalid {:} vs {:}'.format(len(milestones), len(gammas))
|
||||
self.milestones = milestones
|
||||
self.gammas = gammas
|
||||
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
def extra_repr(self):
|
||||
return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format(
|
||||
"multistep", self.milestones, self.gammas, self.base_lrs
|
||||
)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, milestones={:}, gammas={:}, base-lrs={:}'.format('multistep', self.milestones, self.gammas, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
idx = bisect_right(self.milestones, last_epoch)
|
||||
lr = base_lr
|
||||
for x in self.gammas[:idx]: lr *= x
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
idx = bisect_right(self.milestones, last_epoch)
|
||||
lr = base_lr
|
||||
for x in self.gammas[:idx]:
|
||||
lr *= x
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class ExponentialLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
|
||||
self.gamma = gamma
|
||||
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
|
||||
self.gamma = gamma
|
||||
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
def extra_repr(self):
|
||||
return "type={:}, gamma={:}, base-lrs={:}".format("exponential", self.gamma, self.base_lrs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, gamma={:}, base-lrs={:}'.format('exponential', self.gamma, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
|
||||
lr = base_lr * (self.gamma ** last_epoch)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
||||
lr = base_lr * (self.gamma ** last_epoch)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class LinearLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
|
||||
self.max_LR = max_LR
|
||||
self.min_LR = min_LR
|
||||
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
|
||||
self.max_LR = max_LR
|
||||
self.min_LR = min_LR
|
||||
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, max_LR={:}, min_LR={:}, base-lrs={:}'.format('LinearLR', self.max_LR, self.min_LR, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
|
||||
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
|
||||
lr = base_lr * (1-ratio)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
def extra_repr(self):
|
||||
return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format(
|
||||
"LinearLR", self.max_LR, self.min_LR, self.base_lrs
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
||||
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
|
||||
lr = base_lr * (1 - ratio)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
def __init__(self, num_classes, epsilon):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def __init__(self, num_classes, epsilon):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
|
||||
def get_optim_scheduler(parameters, config):
|
||||
assert hasattr(config, 'optim') and hasattr(config, 'scheduler') and hasattr(config, 'criterion'), 'config must have optim / scheduler / criterion keys instead of {:}'.format(config)
|
||||
if config.optim == 'SGD':
|
||||
optim = torch.optim.SGD(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov)
|
||||
elif config.optim == 'RMSprop':
|
||||
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
|
||||
else:
|
||||
raise ValueError('invalid optim : {:}'.format(config.optim))
|
||||
assert (
|
||||
hasattr(config, "optim") and hasattr(config, "scheduler") and hasattr(config, "criterion")
|
||||
), "config must have optim / scheduler / criterion keys instead of {:}".format(config)
|
||||
if config.optim == "SGD":
|
||||
optim = torch.optim.SGD(
|
||||
parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov
|
||||
)
|
||||
elif config.optim == "RMSprop":
|
||||
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
|
||||
else:
|
||||
raise ValueError("invalid optim : {:}".format(config.optim))
|
||||
|
||||
if config.scheduler == 'cos':
|
||||
T_max = getattr(config, 'T_max', config.epochs)
|
||||
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
|
||||
elif config.scheduler == 'multistep':
|
||||
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
|
||||
elif config.scheduler == 'exponential':
|
||||
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
|
||||
elif config.scheduler == 'linear':
|
||||
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
|
||||
else:
|
||||
raise ValueError('invalid scheduler : {:}'.format(config.scheduler))
|
||||
if config.scheduler == "cos":
|
||||
T_max = getattr(config, "T_max", config.epochs)
|
||||
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
|
||||
elif config.scheduler == "multistep":
|
||||
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
|
||||
elif config.scheduler == "exponential":
|
||||
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
|
||||
elif config.scheduler == "linear":
|
||||
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
|
||||
else:
|
||||
raise ValueError("invalid scheduler : {:}".format(config.scheduler))
|
||||
|
||||
if config.criterion == 'Softmax':
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
elif config.criterion == 'SmoothSoftmax':
|
||||
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
|
||||
else:
|
||||
raise ValueError('invalid criterion : {:}'.format(config.criterion))
|
||||
return optim, scheduler, criterion
|
||||
if config.criterion == "Softmax":
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
elif config.criterion == "SmoothSoftmax":
|
||||
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
|
||||
else:
|
||||
raise ValueError("invalid criterion : {:}".format(config.criterion))
|
||||
return optim, scheduler, criterion
|
||||
|
@@ -7,11 +7,12 @@ from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.log import set_log_basic_config
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
def update_gpu(config, gpu):
|
||||
config = config.copy()
|
||||
if "task" in config and "GPU" in config["task"]["model"]:
|
||||
if "task" in config and "moodel" in config["task"] and "GPU" in config["task"]["model"]:
|
||||
config["task"]["model"]["GPU"] = gpu
|
||||
elif "model" in config and "GPU" in config["model"]:
|
||||
config["model"]["GPU"] = gpu
|
||||
@@ -29,11 +30,6 @@ def update_market(config, market):
|
||||
|
||||
def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
|
||||
|
||||
# model initiaiton
|
||||
print("")
|
||||
print("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri))
|
||||
print("dataset={:}".format(dataset))
|
||||
|
||||
model = init_instance_by_config(task_config["model"])
|
||||
|
||||
# start exp
|
||||
@@ -41,6 +37,10 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
|
||||
|
||||
log_file = R.get_recorder().root_uri / "{:}.log".format(experiment_name)
|
||||
set_log_basic_config(log_file)
|
||||
logger = get_module_logger("q.run_exp")
|
||||
logger.info("task_config={:}".format(task_config))
|
||||
logger.info("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri))
|
||||
logger.info("dataset={:}".format(dataset))
|
||||
|
||||
# train model
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
|
@@ -3,124 +3,170 @@
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
from models import change_key
|
||||
from utils import obtain_accuracy
|
||||
from models import change_key
|
||||
|
||||
|
||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||
expected_flop = torch.mean( expected_flop )
|
||||
expected_flop = torch.mean(expected_flop)
|
||||
|
||||
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
|
||||
loss = - torch.log( expected_flop )
|
||||
#elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
|
||||
elif flop_cur > flop_need: # Too Large FLOP
|
||||
loss = torch.log( expected_flop )
|
||||
else: # Required FLOP
|
||||
loss = None
|
||||
if loss is None: return 0, 0
|
||||
else : return loss, loss.item()
|
||||
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
|
||||
loss = -torch.log(expected_flop)
|
||||
# elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
|
||||
elif flop_cur > flop_need: # Too Large FLOP
|
||||
loss = torch.log(expected_flop)
|
||||
else: # Required FLOP
|
||||
loss = None
|
||||
if loss is None:
|
||||
return 0, 0
|
||||
else:
|
||||
return loss, loss.item()
|
||||
|
||||
|
||||
def search_train(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant']
|
||||
def search_train(
|
||||
search_loader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
base_optimizer,
|
||||
arch_optimizer,
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = (
|
||||
extra_info["epoch-str"],
|
||||
extra_info["FLOP-exp"],
|
||||
extra_info["FLOP-weight"],
|
||||
extra_info["FLOP-tolerant"],
|
||||
)
|
||||
|
||||
network.train()
|
||||
logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight))
|
||||
end = time.time()
|
||||
network.apply( change_key('search_mode', 'search') )
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
base_optimizer.zero_grad()
|
||||
logits, expected_flop = network(base_inputs)
|
||||
#network.apply( change_key('search_mode', 'basic') )
|
||||
#features, logits = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
base_optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
top1.update (prec1.item(), base_inputs.size(0))
|
||||
top5.update (prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop('genotype', None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
|
||||
arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
network.train()
|
||||
logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight))
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step+1) == len(search_loader):
|
||||
Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5)
|
||||
Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses)
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr)
|
||||
#Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
|
||||
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
|
||||
#print(network.module.get_arch_info())
|
||||
#print(network.module.width_attentions[0])
|
||||
#print(network.module.width_attentions[1])
|
||||
network.apply(change_key("search_mode", "search"))
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg))
|
||||
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
|
||||
# update the weights
|
||||
base_optimizer.zero_grad()
|
||||
logits, expected_flop = network(base_inputs)
|
||||
# network.apply( change_key('search_mode', 'basic') )
|
||||
# features, logits = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
base_optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
top1.update(prec1.item(), base_inputs.size(0))
|
||||
top5.update(prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop("genotype", None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
|
||||
arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step + 1) == len(search_loader):
|
||||
Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=base_losses, top1=top1, top5=top5
|
||||
)
|
||||
Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format(
|
||||
aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses
|
||||
)
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr)
|
||||
# Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
|
||||
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
|
||||
# print(network.module.get_arch_info())
|
||||
# print(network.module.width_attentions[0])
|
||||
# print(network.module.width_attentions[1])
|
||||
|
||||
logger.log(
|
||||
" **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format(
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
baseloss=base_losses.avg,
|
||||
archloss=arch_losses.avg,
|
||||
)
|
||||
)
|
||||
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
|
||||
|
||||
|
||||
def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
|
||||
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
data_time, batch_time, losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
|
||||
network.eval()
|
||||
network.apply( change_key('search_mode', 'search') )
|
||||
end = time.time()
|
||||
#logger.log('Starting evaluating {:}'.format(epoch_info))
|
||||
with torch.no_grad():
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
network.eval()
|
||||
network.apply(change_key("search_mode", "search"))
|
||||
end = time.time()
|
||||
# logger.log('Starting evaluating {:}'.format(epoch_info))
|
||||
with torch.no_grad():
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
logits, expected_flop = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
logits, expected_flop = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i+1) == len(xloader):
|
||||
Sstr = '**VALID** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
|
||||
Istr = 'Size={:}'.format(list(inputs.size()))
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
|
||||
if i % print_freq == 0 or (i + 1) == len(xloader):
|
||||
Sstr = "**VALID** " + time_string() + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=losses, top1=top1, top5=top5
|
||||
)
|
||||
Istr = "Size={:}".format(list(inputs.size()))
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
|
||||
|
||||
logger.log(' **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
|
||||
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
logger.log(
|
||||
" **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||
top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
|
||||
)
|
||||
)
|
||||
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
|
@@ -3,85 +3,118 @@
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
from models import change_key
|
||||
from utils import obtain_accuracy
|
||||
from models import change_key
|
||||
|
||||
|
||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||
expected_flop = torch.mean( expected_flop )
|
||||
expected_flop = torch.mean(expected_flop)
|
||||
|
||||
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
|
||||
loss = - torch.log( expected_flop )
|
||||
#elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
|
||||
elif flop_cur > flop_need: # Too Large FLOP
|
||||
loss = torch.log( expected_flop )
|
||||
else: # Required FLOP
|
||||
loss = None
|
||||
if loss is None: return 0, 0
|
||||
else : return loss, loss.item()
|
||||
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
|
||||
loss = -torch.log(expected_flop)
|
||||
# elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
|
||||
elif flop_cur > flop_need: # Too Large FLOP
|
||||
loss = torch.log(expected_flop)
|
||||
else: # Required FLOP
|
||||
loss = None
|
||||
if loss is None:
|
||||
return 0, 0
|
||||
else:
|
||||
return loss, loss.item()
|
||||
|
||||
|
||||
def search_train_v2(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant']
|
||||
def search_train_v2(
|
||||
search_loader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
base_optimizer,
|
||||
arch_optimizer,
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = (
|
||||
extra_info["epoch-str"],
|
||||
extra_info["FLOP-exp"],
|
||||
extra_info["FLOP-weight"],
|
||||
extra_info["FLOP-tolerant"],
|
||||
)
|
||||
|
||||
network.train()
|
||||
logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight))
|
||||
end = time.time()
|
||||
network.apply( change_key('search_mode', 'search') )
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
base_optimizer.zero_grad()
|
||||
logits, expected_flop = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
base_optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
top1.update (prec1.item(), base_inputs.size(0))
|
||||
top5.update (prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop('genotype', None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
|
||||
arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
network.train()
|
||||
logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight))
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step+1) == len(search_loader):
|
||||
Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5)
|
||||
Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses)
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr)
|
||||
#num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0
|
||||
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6))
|
||||
#Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
|
||||
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
|
||||
#print(network.module.get_arch_info())
|
||||
#print(network.module.width_attentions[0])
|
||||
#print(network.module.width_attentions[1])
|
||||
network.apply(change_key("search_mode", "search"))
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg))
|
||||
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
|
||||
# update the weights
|
||||
base_optimizer.zero_grad()
|
||||
logits, expected_flop = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
base_optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
top1.update(prec1.item(), base_inputs.size(0))
|
||||
top5.update(prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop("genotype", None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
|
||||
arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step + 1) == len(search_loader):
|
||||
Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=base_losses, top1=top1, top5=top5
|
||||
)
|
||||
Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format(
|
||||
aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses
|
||||
)
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr)
|
||||
# num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0
|
||||
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6))
|
||||
# Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
|
||||
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
|
||||
# print(network.module.get_arch_info())
|
||||
# print(network.module.width_attentions[0])
|
||||
# print(network.module.width_attentions[1])
|
||||
|
||||
logger.log(
|
||||
" **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format(
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
baseloss=base_losses.avg,
|
||||
archloss=arch_losses.avg,
|
||||
)
|
||||
)
|
||||
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
|
||||
|
@@ -3,92 +3,143 @@
|
||||
#####################################################
|
||||
import os, sys, time, torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# our modules
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
from utils import obtain_accuracy
|
||||
|
||||
|
||||
def simple_KD_train(xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
|
||||
loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger)
|
||||
return loss, acc1, acc5
|
||||
def simple_KD_train(
|
||||
xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger
|
||||
):
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader,
|
||||
teacher,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
"train",
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger):
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, None, None, 'valid', optim_config, extra_info, print_freq, logger)
|
||||
return loss, acc1, acc5
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader, teacher, network, criterion, None, None, "valid", optim_config, extra_info, print_freq, logger
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def loss_KD_fn(criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature):
|
||||
basic_loss = criterion(student_logits, targets) * (1. - alpha)
|
||||
log_student= F.log_softmax(student_logits / temperature, dim=1)
|
||||
sof_teacher= F.softmax (teacher_logits / temperature, dim=1)
|
||||
KD_loss = F.kl_div(log_student, sof_teacher, reduction='batchmean') * (alpha * temperature * temperature)
|
||||
return basic_loss + KD_loss
|
||||
def loss_KD_fn(
|
||||
criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature
|
||||
):
|
||||
basic_loss = criterion(student_logits, targets) * (1.0 - alpha)
|
||||
log_student = F.log_softmax(student_logits / temperature, dim=1)
|
||||
sof_teacher = F.softmax(teacher_logits / temperature, dim=1)
|
||||
KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (alpha * temperature * temperature)
|
||||
return basic_loss + KD_loss
|
||||
|
||||
|
||||
def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
|
||||
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
Ttop1, Ttop5 = AverageMeter(), AverageMeter()
|
||||
if mode == 'train':
|
||||
network.train()
|
||||
elif mode == 'valid':
|
||||
network.eval()
|
||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||
teacher.eval()
|
||||
|
||||
logger.log('[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, config.KD_alpha, config.KD_temperature))
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mode == 'train': optimizer.zero_grad()
|
||||
|
||||
student_f, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits))
|
||||
logits, logits_aux = logits
|
||||
data_time, batch_time, losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
Ttop1, Ttop5 = AverageMeter(), AverageMeter()
|
||||
if mode == "train":
|
||||
network.train()
|
||||
elif mode == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
with torch.no_grad():
|
||||
teacher_f, teacher_logits = teacher(inputs)
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
teacher.eval()
|
||||
|
||||
loss = loss_KD_fn(criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature)
|
||||
if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary * loss_aux
|
||||
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (sprec1.item(), inputs.size(0))
|
||||
top5.update (sprec5.item(), inputs.size(0))
|
||||
# teacher
|
||||
tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5))
|
||||
Ttop1.update (tprec1.item(), inputs.size(0))
|
||||
Ttop5.update (tprec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
logger.log(
|
||||
"[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format(
|
||||
mode, config.auxiliary if hasattr(config, "auxiliary") else -1, config.KD_alpha, config.KD_temperature
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == "train":
|
||||
scheduler.update(None, 1.0 * i / len(xloader))
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if i % print_freq == 0 or (i+1) == len(xloader):
|
||||
Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
|
||||
if scheduler is not None:
|
||||
Sstr += ' {:}'.format(scheduler.get_min_info())
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
|
||||
Lstr+= ' Teacher : acc@1={:.2f}, acc@5={:.2f}'.format(Ttop1.avg, Ttop5.avg)
|
||||
Istr = 'Size={:}'.format(list(inputs.size()))
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
|
||||
logger.log(' **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}'.format(mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg))
|
||||
logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
student_f, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits))
|
||||
logits, logits_aux = logits
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
with torch.no_grad():
|
||||
teacher_f, teacher_logits = teacher(inputs)
|
||||
|
||||
loss = loss_KD_fn(
|
||||
criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature
|
||||
)
|
||||
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary * loss_aux
|
||||
|
||||
if mode == "train":
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(sprec1.item(), inputs.size(0))
|
||||
top5.update(sprec5.item(), inputs.size(0))
|
||||
# teacher
|
||||
tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5))
|
||||
Ttop1.update(tprec1.item(), inputs.size(0))
|
||||
Ttop5.update(tprec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i + 1) == len(xloader):
|
||||
Sstr = (
|
||||
" {:5s} ".format(mode.upper())
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
||||
)
|
||||
if scheduler is not None:
|
||||
Sstr += " {:}".format(scheduler.get_min_info())
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=losses, top1=top1, top5=top5
|
||||
)
|
||||
Lstr += " Teacher : acc@1={:.2f}, acc@5={:.2f}".format(Ttop1.avg, Ttop5.avg)
|
||||
Istr = "Size={:}".format(list(inputs.size()))
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
|
||||
|
||||
logger.log(
|
||||
" **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}".format(
|
||||
mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg
|
||||
)
|
||||
)
|
||||
logger.log(
|
||||
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||
mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
|
||||
)
|
||||
)
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
|
@@ -3,62 +3,71 @@
|
||||
##################################################
|
||||
import os, sys, torch, random, PIL, copy, numpy as np
|
||||
from os import path as osp
|
||||
from shutil import copyfile
|
||||
from shutil import copyfile
|
||||
|
||||
|
||||
def prepare_seed(rand_seed):
|
||||
random.seed(rand_seed)
|
||||
np.random.seed(rand_seed)
|
||||
torch.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed_all(rand_seed)
|
||||
random.seed(rand_seed)
|
||||
np.random.seed(rand_seed)
|
||||
torch.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed_all(rand_seed)
|
||||
|
||||
|
||||
def prepare_logger(xargs):
|
||||
args = copy.deepcopy( xargs )
|
||||
from log_utils import Logger
|
||||
logger = Logger(args.save_dir, args.rand_seed)
|
||||
logger.log('Main Function with logger : {:}'.format(logger))
|
||||
logger.log('Arguments : -------------------------------')
|
||||
for name, value in args._get_kwargs():
|
||||
logger.log('{:16} : {:}'.format(name, value))
|
||||
logger.log("Python Version : {:}".format(sys.version.replace('\n', ' ')))
|
||||
logger.log("Pillow Version : {:}".format(PIL.__version__))
|
||||
logger.log("PyTorch Version : {:}".format(torch.__version__))
|
||||
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
|
||||
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
|
||||
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
|
||||
logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None'))
|
||||
return logger
|
||||
args = copy.deepcopy(xargs)
|
||||
from log_utils import Logger
|
||||
|
||||
logger = Logger(args.save_dir, args.rand_seed)
|
||||
logger.log("Main Function with logger : {:}".format(logger))
|
||||
logger.log("Arguments : -------------------------------")
|
||||
for name, value in args._get_kwargs():
|
||||
logger.log("{:16} : {:}".format(name, value))
|
||||
logger.log("Python Version : {:}".format(sys.version.replace("\n", " ")))
|
||||
logger.log("Pillow Version : {:}".format(PIL.__version__))
|
||||
logger.log("PyTorch Version : {:}".format(torch.__version__))
|
||||
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
|
||||
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
|
||||
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
|
||||
logger.log(
|
||||
"CUDA_VISIBLE_DEVICES : {:}".format(
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "None"
|
||||
)
|
||||
)
|
||||
return logger
|
||||
|
||||
|
||||
def get_machine_info():
|
||||
info = "Python Version : {:}".format(sys.version.replace('\n', ' '))
|
||||
info+= "\nPillow Version : {:}".format(PIL.__version__)
|
||||
info+= "\nPyTorch Version : {:}".format(torch.__version__)
|
||||
info+= "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
|
||||
info+= "\nCUDA available : {:}".format(torch.cuda.is_available())
|
||||
info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||
info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES'])
|
||||
else:
|
||||
info+= "\nDoes not set CUDA_VISIBLE_DEVICES"
|
||||
return info
|
||||
info = "Python Version : {:}".format(sys.version.replace("\n", " "))
|
||||
info += "\nPillow Version : {:}".format(PIL.__version__)
|
||||
info += "\nPyTorch Version : {:}".format(torch.__version__)
|
||||
info += "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
|
||||
info += "\nCUDA available : {:}".format(torch.cuda.is_available())
|
||||
info += "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
info += "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ["CUDA_VISIBLE_DEVICES"])
|
||||
else:
|
||||
info += "\nDoes not set CUDA_VISIBLE_DEVICES"
|
||||
return info
|
||||
|
||||
|
||||
def save_checkpoint(state, filename, logger):
|
||||
if osp.isfile(filename):
|
||||
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename))
|
||||
os.remove(filename)
|
||||
torch.save(state, filename)
|
||||
assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename)
|
||||
if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename))
|
||||
return filename
|
||||
if osp.isfile(filename):
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("Find {:} exist, delete is at first before saving".format(filename))
|
||||
os.remove(filename)
|
||||
torch.save(state, filename)
|
||||
assert osp.isfile(filename), "save filename : {:} failed, which is not found.".format(filename)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("save checkpoint into {:}".format(filename))
|
||||
return filename
|
||||
|
||||
|
||||
def copy_checkpoint(src, dst, logger):
|
||||
if osp.isfile(dst):
|
||||
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst))
|
||||
os.remove(dst)
|
||||
copyfile(src, dst)
|
||||
if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst))
|
||||
if osp.isfile(dst):
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("Find {:} exist, delete is at first before saving".format(dst))
|
||||
os.remove(dst)
|
||||
copyfile(src, dst)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("copy the file from {:} into {:}".format(src, dst))
|
||||
|
Reference in New Issue
Block a user