Add more algorithms
This commit is contained in:
22
lib/procedures/__init__.py
Normal file
22
lib/procedures/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .starts import prepare_seed, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint
|
||||
from .optimizers import get_optim_scheduler
|
||||
|
||||
def get_procedures(procedure):
|
||||
from .basic_main import basic_train, basic_valid
|
||||
from .search_main import search_train, search_valid
|
||||
from .search_main_v2 import search_train_v2
|
||||
from .simple_KD_main import simple_KD_train, simple_KD_valid
|
||||
|
||||
train_funcs = {'basic' : basic_train, \
|
||||
'search': search_train,'Simple-KD': simple_KD_train, \
|
||||
'search-v2': search_train_v2}
|
||||
valid_funcs = {'basic' : basic_valid, \
|
||||
'search': search_valid,'Simple-KD': simple_KD_valid, \
|
||||
'search-v2': search_valid}
|
||||
|
||||
train_func = train_funcs[procedure]
|
||||
valid_func = valid_funcs[procedure]
|
||||
return train_func, valid_func
|
75
lib/procedures/basic_main.py
Normal file
75
lib/procedures/basic_main.py
Normal file
@@ -0,0 +1,75 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
|
||||
|
||||
def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
|
||||
loss, acc1, acc5 = procedure(xloader, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger):
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(xloader, network, criterion, None, None, 'valid', None, extra_info, print_freq, logger)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
|
||||
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == 'train':
|
||||
network.train()
|
||||
elif mode == 'valid':
|
||||
network.eval()
|
||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
#logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
|
||||
logger.log('[{:5s}] config :: auxiliary={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1))
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mode == 'train': optimizer.zero_grad()
|
||||
|
||||
features, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits))
|
||||
logits, logits_aux = logits
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
loss = criterion(logits, targets)
|
||||
if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary * loss_aux
|
||||
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i+1) == len(xloader):
|
||||
Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
|
||||
if scheduler is not None:
|
||||
Sstr += ' {:}'.format(scheduler.get_min_info())
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
|
||||
Istr = 'Size={:}'.format(list(inputs.size()))
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
|
||||
|
||||
logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
|
||||
return losses.avg, top1.avg, top5.avg
|
201
lib/procedures/optimizers.py
Normal file
201
lib/procedures/optimizers.py
Normal file
@@ -0,0 +1,201 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
from bisect import bisect_right
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class _LRScheduler(object):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs):
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError('{:} is not an Optimizer'.format(type(optimizer).__name__))
|
||||
self.optimizer = optimizer
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault('initial_lr', group['lr'])
|
||||
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
|
||||
self.max_epochs = epochs
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.current_epoch = 0
|
||||
self.current_iter = 0
|
||||
|
||||
def extra_repr(self):
|
||||
return ''
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}'.format(name=self.__class__.__name__, **self.__dict__)
|
||||
+ ', {:})'.format(self.extra_repr()))
|
||||
|
||||
def state_dict(self):
|
||||
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_lr(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_min_info(self):
|
||||
lrs = self.get_lr()
|
||||
return '#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#'.format(min(lrs), max(lrs), self.current_epoch, self.current_iter)
|
||||
|
||||
def get_min_lr(self):
|
||||
return min( self.get_lr() )
|
||||
|
||||
def update(self, cur_epoch, cur_iter):
|
||||
if cur_epoch is not None:
|
||||
assert isinstance(cur_epoch, int) and cur_epoch>=0, 'invalid cur-epoch : {:}'.format(cur_epoch)
|
||||
self.current_epoch = cur_epoch
|
||||
if cur_iter is not None:
|
||||
assert isinstance(cur_iter, float) and cur_iter>=0, 'invalid cur-iter : {:}'.format(cur_iter)
|
||||
self.current_iter = cur_iter
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group['lr'] = lr
|
||||
|
||||
|
||||
|
||||
class CosineAnnealingLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, T-max={:}, eta-min={:}'.format('cosine', self.T_max, self.eta_min)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
if last_epoch < self.T_max:
|
||||
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
|
||||
else:
|
||||
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
|
||||
class MultiStepLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
|
||||
assert len(milestones) == len(gammas), 'invalid {:} vs {:}'.format(len(milestones), len(gammas))
|
||||
self.milestones = milestones
|
||||
self.gammas = gammas
|
||||
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, milestones={:}, gammas={:}, base-lrs={:}'.format('multistep', self.milestones, self.gammas, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
idx = bisect_right(self.milestones, last_epoch)
|
||||
lr = base_lr
|
||||
for x in self.gammas[:idx]: lr *= x
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
class ExponentialLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
|
||||
self.gamma = gamma
|
||||
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, gamma={:}, base-lrs={:}'.format('exponential', self.gamma, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
|
||||
lr = base_lr * (self.gamma ** last_epoch)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
class LinearLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
|
||||
self.max_LR = max_LR
|
||||
self.min_LR = min_LR
|
||||
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, max_LR={:}, min_LR={:}, base-lrs={:}'.format('LinearLR', self.max_LR, self.min_LR, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
|
||||
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
|
||||
lr = base_lr * (1-ratio)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, epsilon):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
def get_optim_scheduler(parameters, config):
|
||||
assert hasattr(config, 'optim') and hasattr(config, 'scheduler') and hasattr(config, 'criterion'), 'config must have optim / scheduler / criterion keys instead of {:}'.format(config)
|
||||
if config.optim == 'SGD':
|
||||
optim = torch.optim.SGD(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov)
|
||||
elif config.optim == 'RMSprop':
|
||||
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
|
||||
else:
|
||||
raise ValueError('invalid optim : {:}'.format(config.optim))
|
||||
|
||||
if config.scheduler == 'cos':
|
||||
T_max = getattr(config, 'T_max', config.epochs)
|
||||
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
|
||||
elif config.scheduler == 'multistep':
|
||||
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
|
||||
elif config.scheduler == 'exponential':
|
||||
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
|
||||
elif config.scheduler == 'linear':
|
||||
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
|
||||
else:
|
||||
raise ValueError('invalid scheduler : {:}'.format(config.scheduler))
|
||||
|
||||
if config.criterion == 'Softmax':
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
elif config.criterion == 'SmoothSoftmax':
|
||||
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
|
||||
else:
|
||||
raise ValueError('invalid criterion : {:}'.format(config.criterion))
|
||||
return optim, scheduler, criterion
|
126
lib/procedures/search_main.py
Normal file
126
lib/procedures/search_main.py
Normal file
@@ -0,0 +1,126 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
from models import change_key
|
||||
|
||||
|
||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||
expected_flop = torch.mean( expected_flop )
|
||||
|
||||
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
|
||||
loss = - torch.log( expected_flop )
|
||||
#elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
|
||||
elif flop_cur > flop_need: # Too Large FLOP
|
||||
loss = torch.log( expected_flop )
|
||||
else: # Required FLOP
|
||||
loss = None
|
||||
if loss is None: return 0, 0
|
||||
else : return loss, loss.item()
|
||||
|
||||
|
||||
def search_train(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant']
|
||||
|
||||
network.train()
|
||||
logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight))
|
||||
end = time.time()
|
||||
network.apply( change_key('search_mode', 'search') )
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
base_optimizer.zero_grad()
|
||||
logits, expected_flop = network(base_inputs)
|
||||
#network.apply( change_key('search_mode', 'basic') )
|
||||
#features, logits = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
base_optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
top1.update (prec1.item(), base_inputs.size(0))
|
||||
top5.update (prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop('genotype', None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
|
||||
arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step+1) == len(search_loader):
|
||||
Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5)
|
||||
Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses)
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr)
|
||||
#Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
|
||||
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
|
||||
#print(network.module.get_arch_info())
|
||||
#print(network.module.width_attentions[0])
|
||||
#print(network.module.width_attentions[1])
|
||||
|
||||
logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg))
|
||||
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
|
||||
|
||||
|
||||
|
||||
def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
|
||||
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
|
||||
network.eval()
|
||||
network.apply( change_key('search_mode', 'search') )
|
||||
end = time.time()
|
||||
#logger.log('Starting evaluating {:}'.format(epoch_info))
|
||||
with torch.no_grad():
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
logits, expected_flop = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i+1) == len(xloader):
|
||||
Sstr = '**VALID** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
|
||||
Istr = 'Size={:}'.format(list(inputs.size()))
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
|
||||
|
||||
logger.log(' **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
|
||||
|
||||
return losses.avg, top1.avg, top5.avg
|
87
lib/procedures/search_main_v2.py
Normal file
87
lib/procedures/search_main_v2.py
Normal file
@@ -0,0 +1,87 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
from models import change_key
|
||||
|
||||
|
||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||
expected_flop = torch.mean( expected_flop )
|
||||
|
||||
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
|
||||
loss = - torch.log( expected_flop )
|
||||
#elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
|
||||
elif flop_cur > flop_need: # Too Large FLOP
|
||||
loss = torch.log( expected_flop )
|
||||
else: # Required FLOP
|
||||
loss = None
|
||||
if loss is None: return 0, 0
|
||||
else : return loss, loss.item()
|
||||
|
||||
|
||||
def search_train_v2(search_loader, network, criterion, scheduler, base_optimizer, arch_optimizer, optim_config, extra_info, print_freq, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = extra_info['epoch-str'], extra_info['FLOP-exp'], extra_info['FLOP-weight'], extra_info['FLOP-tolerant']
|
||||
|
||||
network.train()
|
||||
logger.log('[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(epoch_str, flop_need, flop_weight))
|
||||
end = time.time()
|
||||
network.apply( change_key('search_mode', 'search') )
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
base_optimizer.zero_grad()
|
||||
logits, expected_flop = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
base_optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
top1.update (prec1.item(), base_inputs.size(0))
|
||||
top5.update (prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop('genotype', None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
|
||||
arch_cls_losses.update (acls_loss.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step+1) == len(search_loader):
|
||||
Sstr = '**TRAIN** ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(search_loader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=base_losses, top1=top1, top5=top5)
|
||||
Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses)
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr)
|
||||
#num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0
|
||||
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6))
|
||||
#Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
|
||||
#logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
|
||||
#print(network.module.get_arch_info())
|
||||
#print(network.module.width_attentions[0])
|
||||
#print(network.module.width_attentions[1])
|
||||
|
||||
logger.log(' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, baseloss=base_losses.avg, archloss=arch_losses.avg))
|
||||
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
|
94
lib/procedures/simple_KD_main.py
Normal file
94
lib/procedures/simple_KD_main.py
Normal file
@@ -0,0 +1,94 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
import torch.nn.functional as F
|
||||
# our modules
|
||||
from log_utils import AverageMeter, time_string
|
||||
from utils import obtain_accuracy
|
||||
|
||||
|
||||
def simple_KD_train(xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
|
||||
loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, scheduler, optimizer, 'train', optim_config, extra_info, print_freq, logger)
|
||||
return loss, acc1, acc5
|
||||
|
||||
def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger):
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(xloader, teacher, network, criterion, None, None, 'valid', optim_config, extra_info, print_freq, logger)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def loss_KD_fn(criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature):
|
||||
basic_loss = criterion(student_logits, targets) * (1. - alpha)
|
||||
log_student= F.log_softmax(student_logits / temperature, dim=1)
|
||||
sof_teacher= F.softmax (teacher_logits / temperature, dim=1)
|
||||
KD_loss = F.kl_div(log_student, sof_teacher, reduction='batchmean') * (alpha * temperature * temperature)
|
||||
return basic_loss + KD_loss
|
||||
|
||||
|
||||
def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
|
||||
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
Ttop1, Ttop5 = AverageMeter(), AverageMeter()
|
||||
if mode == 'train':
|
||||
network.train()
|
||||
elif mode == 'valid':
|
||||
network.eval()
|
||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||
teacher.eval()
|
||||
|
||||
logger.log('[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, config.KD_alpha, config.KD_temperature))
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mode == 'train': optimizer.zero_grad()
|
||||
|
||||
student_f, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, 'logits must has {:} items instead of {:}'.format(2, len(logits))
|
||||
logits, logits_aux = logits
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
with torch.no_grad():
|
||||
teacher_f, teacher_logits = teacher(inputs)
|
||||
|
||||
loss = loss_KD_fn(criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature)
|
||||
if config is not None and hasattr(config, 'auxiliary') and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary * loss_aux
|
||||
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (sprec1.item(), inputs.size(0))
|
||||
top5.update (sprec5.item(), inputs.size(0))
|
||||
# teacher
|
||||
tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5))
|
||||
Ttop1.update (tprec1.item(), inputs.size(0))
|
||||
Ttop5.update (tprec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i+1) == len(xloader):
|
||||
Sstr = ' {:5s} '.format(mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
|
||||
if scheduler is not None:
|
||||
Sstr += ' {:}'.format(scheduler.get_min_info())
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
|
||||
Lstr+= ' Teacher : acc@1={:.2f}, acc@5={:.2f}'.format(Ttop1.avg, Ttop5.avg)
|
||||
Istr = 'Size={:}'.format(list(inputs.size()))
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)
|
||||
|
||||
logger.log(' **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}'.format(mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg))
|
||||
logger.log(' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(mode=mode.upper(), top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg))
|
||||
return losses.avg, top1.avg, top5.avg
|
67
lib/procedures/starts.py
Normal file
67
lib/procedures/starts.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
import os, sys, time, torch, random, PIL, copy, numpy as np
|
||||
from os import path as osp
|
||||
from shutil import copyfile
|
||||
|
||||
|
||||
def prepare_seed(rand_seed):
|
||||
random.seed(rand_seed)
|
||||
np.random.seed(rand_seed)
|
||||
torch.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed_all(rand_seed)
|
||||
|
||||
|
||||
def prepare_logger(xargs):
|
||||
args = copy.deepcopy( xargs )
|
||||
from log_utils import Logger
|
||||
logger = Logger(args.save_dir, args.rand_seed)
|
||||
logger.log('Main Function with logger : {:}'.format(logger))
|
||||
logger.log('Arguments : -------------------------------')
|
||||
for name, value in args._get_kwargs():
|
||||
logger.log('{:16} : {:}'.format(name, value))
|
||||
logger.log("Python Version : {:}".format(sys.version.replace('\n', ' ')))
|
||||
logger.log("Pillow Version : {:}".format(PIL.__version__))
|
||||
logger.log("PyTorch Version : {:}".format(torch.__version__))
|
||||
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
|
||||
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
|
||||
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
|
||||
logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None'))
|
||||
return logger
|
||||
|
||||
|
||||
def get_machine_info():
|
||||
info = "Python Version : {:}".format(sys.version.replace('\n', ' '))
|
||||
info+= "\nPillow Version : {:}".format(PIL.__version__)
|
||||
info+= "\nPyTorch Version : {:}".format(torch.__version__)
|
||||
info+= "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
|
||||
info+= "\nCUDA available : {:}".format(torch.cuda.is_available())
|
||||
info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||
info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES'])
|
||||
else:
|
||||
info+= "\nDoes not set CUDA_VISIBLE_DEVICES"
|
||||
return info
|
||||
|
||||
|
||||
def save_checkpoint(state, filename, logger):
|
||||
if osp.isfile(filename):
|
||||
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename))
|
||||
os.remove(filename)
|
||||
torch.save(state, filename)
|
||||
assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename)
|
||||
if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename))
|
||||
return filename
|
||||
|
||||
|
||||
def copy_checkpoint(src, dst, logger):
|
||||
if osp.isfile(dst):
|
||||
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst))
|
||||
os.remove(dst)
|
||||
copyfile(src, dst)
|
||||
if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst))
|
Reference in New Issue
Block a user