init
This commit is contained in:
310
exps-cnn/acc_search_v2.py
Normal file
310
exps-cnn/acc_search_v2.py
Normal file
@@ -0,0 +1,310 @@
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from utils import AverageMeter, time_string, convert_secs2time
|
||||
from utils import print_log, obtain_accuracy
|
||||
from utils import Cutout, count_parameters_in_MB
|
||||
from nas import Network, NetworkACC2, NetworkV3, NetworkV4, NetworkV5, NetworkFACC1
|
||||
from nas import return_alphas_str
|
||||
from train_utils import main_procedure
|
||||
from scheduler import load_config
|
||||
|
||||
Networks = {'base': Network, 'acc2': NetworkACC2, 'facc1': NetworkFACC1, 'NetworkV3': NetworkV3, 'NetworkV4': NetworkV4, 'NetworkV5': NetworkV5}
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser("cifar")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100'], help='Choose between Cifar10/100 and ImageNet.')
|
||||
parser.add_argument('--arch', type=str, choices=Networks.keys(), help='Choose networks.')
|
||||
parser.add_argument('--batch_size', type=int, help='the batch size')
|
||||
parser.add_argument('--learning_rate_max', type=float, help='initial learning rate')
|
||||
parser.add_argument('--learning_rate_min', type=float, help='minimum learning rate')
|
||||
parser.add_argument('--tau_max', type=float, help='initial tau')
|
||||
parser.add_argument('--tau_min', type=float, help='minimum tau')
|
||||
parser.add_argument('--momentum', type=float, help='momentum')
|
||||
parser.add_argument('--weight_decay', type=float, help='weight decay')
|
||||
parser.add_argument('--epochs', type=int, help='num of training epochs')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
#
|
||||
parser.add_argument('--init_channels', type=int, help='num of init channels')
|
||||
parser.add_argument('--layers', type=int, help='total number of layers')
|
||||
#
|
||||
parser.add_argument('--cutout', type=int, help='cutout length, negative means no cutout')
|
||||
parser.add_argument('--grad_clip', type=float, help='gradient clipping')
|
||||
parser.add_argument('--model_config', type=str , help='the model configuration')
|
||||
|
||||
# resume
|
||||
parser.add_argument('--resume', type=str , help='the resume path')
|
||||
parser.add_argument('--only_base',action='store_true', default=False, help='only train the searched model')
|
||||
# split data
|
||||
parser.add_argument('--validate', action='store_true', default=False, help='split train-data int train/val or not')
|
||||
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_path', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--manualSeed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert torch.cuda.is_available(), 'torch.cuda is not available'
|
||||
|
||||
if args.manualSeed is None:
|
||||
args.manualSeed = random.randint(1, 10000)
|
||||
random.seed(args.manualSeed)
|
||||
cudnn.benchmark = True
|
||||
cudnn.enabled = True
|
||||
torch.manual_seed(args.manualSeed)
|
||||
torch.cuda.manual_seed_all(args.manualSeed)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Init logger
|
||||
args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed))
|
||||
if not os.path.isdir(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w')
|
||||
print_log('save path : {}'.format(args.save_path), log)
|
||||
state = {k: v for k, v in args._get_kwargs()}
|
||||
print_log(state, log)
|
||||
print_log("Random Seed: {}".format(args.manualSeed), log)
|
||||
print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
|
||||
print_log("Torch version : {}".format(torch.__version__), log)
|
||||
print_log("CUDA version : {}".format(torch.version.cuda), log)
|
||||
print_log("cuDNN version : {}".format(cudnn.version()), log)
|
||||
print_log("Num of GPUs : {}".format(torch.cuda.device_count()), log)
|
||||
args.dataset = args.dataset.lower()
|
||||
|
||||
# Mean + Std
|
||||
if args.dataset == 'cifar10':
|
||||
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||
elif args.dataset == 'cifar100':
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(args.dataset))
|
||||
# Data Argumentation
|
||||
if args.dataset == 'cifar10' or args.dataset == 'cifar100':
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std)]
|
||||
if args.cutout > 0 : lists += [Cutout(args.cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(args.dataset))
|
||||
# Datasets
|
||||
if args.dataset == 'cifar10':
|
||||
train_data = dset.CIFAR10(args.data_path, train= True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform , download=True)
|
||||
num_classes = 10
|
||||
elif args.dataset == 'cifar100':
|
||||
train_data = dset.CIFAR100(args.data_path, train= True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform , download=True)
|
||||
num_classes = 100
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(args.dataset))
|
||||
# Data Loader
|
||||
if args.validate:
|
||||
indices = list(range(len(train_data)))
|
||||
split = int(args.train_portion * len(indices))
|
||||
random.shuffle(indices)
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
|
||||
pin_memory=True, num_workers=args.workers)
|
||||
test_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
|
||||
pin_memory=True, num_workers=args.workers)
|
||||
else:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
|
||||
test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
|
||||
|
||||
# network and criterion
|
||||
criterion = torch.nn.CrossEntropyLoss().cuda()
|
||||
basemodel = Networks[args.arch](args.init_channels, num_classes, args.layers)
|
||||
model = torch.nn.DataParallel(basemodel).cuda()
|
||||
print_log("Parameter size = {:.3f} MB".format(count_parameters_in_MB(basemodel.base_parameters())), log)
|
||||
print_log("Train-transformation : {:}\nTest--transformation : {:}".format(train_transform, test_transform), log)
|
||||
|
||||
# optimizer and LR-scheduler
|
||||
base_optimizer = torch.optim.SGD (basemodel.base_parameters(), args.learning_rate_max, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
#base_optimizer = torch.optim.Adam(basemodel.base_parameters(), lr=args.learning_rate_max, betas=(0.5, 0.999), weight_decay=args.weight_decay)
|
||||
base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(base_optimizer, float(args.epochs), eta_min=args.learning_rate_min)
|
||||
arch_optimizer = torch.optim.Adam(basemodel.arch_parameters(), lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
|
||||
|
||||
# snapshot
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-search.pth')
|
||||
if args.resume is not None and os.path.isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
base_scheduler.load_state_dict( checkpoint['base_scheduler'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
print_log('Load resume from {:} with start-epoch = {:}'.format(args.resume, start_epoch), log)
|
||||
elif os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
base_scheduler.load_state_dict( checkpoint['base_scheduler'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log)
|
||||
else:
|
||||
start_epoch, genotypes = 0, {}
|
||||
print_log('Train model-search from scratch.', log)
|
||||
|
||||
config = load_config(args.model_config)
|
||||
|
||||
if args.only_base:
|
||||
print_log('---- Only Train the Searched Model ----', log)
|
||||
main_procedure(config, args.dataset, args.data_path, args, basemodel.genotype(), 36, 20, log)
|
||||
return
|
||||
|
||||
# Main loop
|
||||
start_time, epoch_time, total_train_time = time.time(), AverageMeter(), 0
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
base_scheduler.step()
|
||||
|
||||
basemodel.set_tau( args.tau_max - epoch*1.0/args.epochs*(args.tau_max-args.tau_min) )
|
||||
#if epoch + 2 == args.epochs:
|
||||
# torch.cuda.empty_cache()
|
||||
# basemodel.set_gumbel(False)
|
||||
|
||||
need_time = convert_secs2time(epoch_time.val * (args.epochs-epoch), True)
|
||||
print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f} ~ {:6.4f}] [Batch={:d}], tau={:}'.format(time_string(), epoch, args.epochs, need_time, min(base_scheduler.get_lr()), max(base_scheduler.get_lr()), args.batch_size, basemodel.get_tau()), log)
|
||||
|
||||
genotype = basemodel.genotype()
|
||||
print_log('genotype = {:}'.format(genotype), log)
|
||||
|
||||
print_log('{:03d}/{:03d} alphas :\n{:}'.format(epoch, args.epochs, return_alphas_str(basemodel)), log)
|
||||
|
||||
# training
|
||||
train_acc1, train_acc5, train_obj, train_time \
|
||||
= train(train_loader, test_loader, model, criterion, base_optimizer, arch_optimizer, epoch, log)
|
||||
total_train_time += train_time
|
||||
# validation
|
||||
valid_acc1, valid_acc5, valid_obj = infer(test_loader, model, criterion, epoch, log)
|
||||
print_log('{:03d}/{:03d}, Train-Accuracy = {:.2f}, Test-Accuracy = {:.2f}'.format(epoch, args.epochs, train_acc1, valid_acc1), log)
|
||||
# save genotype
|
||||
genotypes[epoch] = basemodel.genotype()
|
||||
# save checkpoint
|
||||
torch.save({'epoch' : epoch + 1,
|
||||
'args' : deepcopy(args),
|
||||
'state_dict': basemodel.state_dict(),
|
||||
'genotypes' : genotypes,
|
||||
'base_optimizer' : base_optimizer.state_dict(),
|
||||
'arch_optimizer' : arch_optimizer.state_dict(),
|
||||
'base_scheduler' : base_scheduler.state_dict()},
|
||||
checkpoint_path)
|
||||
print_log('----> Save into {:}'.format(checkpoint_path), log)
|
||||
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
print_log('Finish with training time = {:}'.format( convert_secs2time(total_train_time, True) ), log)
|
||||
|
||||
# clear GPU cache
|
||||
torch.cuda.empty_cache()
|
||||
main_procedure(config, args.dataset, args.data_path, args, basemodel.genotype(), 36, 20, log)
|
||||
log.close()
|
||||
|
||||
|
||||
def train(train_queue, valid_queue, model, criterion, base_optimizer, arch_optimizer, epoch, log):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
model.train()
|
||||
|
||||
valid_iter = iter(valid_queue)
|
||||
end = time.time()
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
|
||||
#inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# get a random minibatch from the search queue with replacement
|
||||
try:
|
||||
input_search, target_search = next(valid_iter)
|
||||
except:
|
||||
valid_iter = iter(valid_queue)
|
||||
input_search, target_search = next(valid_iter)
|
||||
|
||||
target_search = target_search.cuda(non_blocking=True)
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
output_search = model(input_search)
|
||||
arch_loss = criterion(output_search, target_search)
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# update the parameters
|
||||
base_optimizer.zero_grad()
|
||||
logits = model(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.module.base_parameters(), args.grad_clip)
|
||||
base_optimizer.step()
|
||||
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
objs.update(loss.item() , batch)
|
||||
top1.update(prec1.item(), batch)
|
||||
top5.update(prec5.item(), batch)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(train_queue):
|
||||
Sstr = ' TRAIN-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(train_queue))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg, batch_time.sum
|
||||
|
||||
|
||||
def infer(valid_queue, model, criterion, epoch, log):
|
||||
objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
logits = model(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
objs.update(loss.item() , batch)
|
||||
top1.update(prec1.item(), batch)
|
||||
top5.update(prec5.item(), batch)
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(valid_queue):
|
||||
Sstr = ' VALID-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(valid_queue))
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Lstr, log)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
397
exps-cnn/acc_search_v3.py
Normal file
397
exps-cnn/acc_search_v3.py
Normal file
@@ -0,0 +1,397 @@
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from utils import AverageMeter, time_string, convert_secs2time
|
||||
from utils import print_log, obtain_accuracy
|
||||
from utils import Cutout, count_parameters_in_MB
|
||||
from nas import Network, NetworkACC2, NetworkV3, NetworkV4, NetworkV5, NetworkFACC1
|
||||
from nas import return_alphas_str
|
||||
from train_utils import main_procedure
|
||||
from scheduler import load_config
|
||||
|
||||
Networks = {'base': Network, 'acc2': NetworkACC2, 'facc1': NetworkFACC1, 'NetworkV3': NetworkV3, 'NetworkV4': NetworkV4, 'NetworkV5': NetworkV5}
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser("cifar")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100'], help='Choose between Cifar10/100 and ImageNet.')
|
||||
parser.add_argument('--arch', type=str, choices=Networks.keys(), help='Choose networks.')
|
||||
parser.add_argument('--batch_size', type=int, help='the batch size')
|
||||
parser.add_argument('--learning_rate_max', type=float, help='initial learning rate')
|
||||
parser.add_argument('--learning_rate_min', type=float, help='minimum learning rate')
|
||||
parser.add_argument('--tau_max', type=float, help='initial tau')
|
||||
parser.add_argument('--tau_min', type=float, help='minimum tau')
|
||||
parser.add_argument('--momentum', type=float, help='momentum')
|
||||
parser.add_argument('--weight_decay', type=float, help='weight decay')
|
||||
parser.add_argument('--epochs', type=int, help='num of training epochs')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
#
|
||||
parser.add_argument('--init_channels', type=int, help='num of init channels')
|
||||
parser.add_argument('--layers', type=int, help='total number of layers')
|
||||
#
|
||||
parser.add_argument('--cutout', type=int, help='cutout length, negative means no cutout')
|
||||
parser.add_argument('--grad_clip', type=float, help='gradient clipping')
|
||||
parser.add_argument('--model_config', type=str , help='the model configuration')
|
||||
|
||||
# resume
|
||||
parser.add_argument('--resume', type=str , help='the resume path')
|
||||
parser.add_argument('--only_base',action='store_true', default=False, help='only train the searched model')
|
||||
# split data
|
||||
parser.add_argument('--validate', action='store_true', default=False, help='split train-data int train/val or not')
|
||||
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_path', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--manualSeed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert torch.cuda.is_available(), 'torch.cuda is not available'
|
||||
|
||||
if args.manualSeed is None:
|
||||
args.manualSeed = random.randint(1, 10000)
|
||||
random.seed(args.manualSeed)
|
||||
cudnn.benchmark = True
|
||||
cudnn.enabled = True
|
||||
torch.manual_seed(args.manualSeed)
|
||||
torch.cuda.manual_seed_all(args.manualSeed)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Init logger
|
||||
args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed))
|
||||
if not os.path.isdir(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w')
|
||||
print_log('save path : {}'.format(args.save_path), log)
|
||||
state = {k: v for k, v in args._get_kwargs()}
|
||||
print_log(state, log)
|
||||
print_log("Random Seed: {}".format(args.manualSeed), log)
|
||||
print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
|
||||
print_log("Torch version : {}".format(torch.__version__), log)
|
||||
print_log("CUDA version : {}".format(torch.version.cuda), log)
|
||||
print_log("cuDNN version : {}".format(cudnn.version()), log)
|
||||
print_log("Num of GPUs : {}".format(torch.cuda.device_count()), log)
|
||||
args.dataset = args.dataset.lower()
|
||||
|
||||
# Mean + Std
|
||||
if args.dataset == 'cifar10':
|
||||
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||
elif args.dataset == 'cifar100':
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(args.dataset))
|
||||
# Data Argumentation
|
||||
if args.dataset == 'cifar10' or args.dataset == 'cifar100':
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std)]
|
||||
if args.cutout > 0 : lists += [Cutout(args.cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(args.dataset))
|
||||
# Datasets
|
||||
if args.dataset == 'cifar10':
|
||||
train_data = dset.CIFAR10(args.data_path, train= True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform , download=True)
|
||||
num_classes = 10
|
||||
elif args.dataset == 'cifar100':
|
||||
train_data = dset.CIFAR100(args.data_path, train= True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform , download=True)
|
||||
num_classes = 100
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(args.dataset))
|
||||
# Data Loader
|
||||
if args.validate:
|
||||
indices = list(range(len(train_data)))
|
||||
split = int(args.train_portion * len(indices))
|
||||
random.shuffle(indices)
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
|
||||
pin_memory=True, num_workers=args.workers)
|
||||
test_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
|
||||
pin_memory=True, num_workers=args.workers)
|
||||
else:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
|
||||
test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
|
||||
|
||||
# network and criterion
|
||||
criterion = torch.nn.CrossEntropyLoss().cuda()
|
||||
basemodel = Networks[args.arch](args.init_channels, num_classes, args.layers)
|
||||
model = torch.nn.DataParallel(basemodel).cuda()
|
||||
print_log("Parameter size = {:.3f} MB".format(count_parameters_in_MB(basemodel.base_parameters())), log)
|
||||
print_log("Train-transformation : {:}\nTest--transformation : {:}".format(train_transform, test_transform), log)
|
||||
|
||||
# optimizer and LR-scheduler
|
||||
base_optimizer = torch.optim.SGD (basemodel.base_parameters(), args.learning_rate_max, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
#base_optimizer = torch.optim.Adam(basemodel.base_parameters(), lr=args.learning_rate_max, betas=(0.5, 0.999), weight_decay=args.weight_decay)
|
||||
base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(base_optimizer, float(args.epochs), eta_min=args.learning_rate_min)
|
||||
arch_optimizer = torch.optim.Adam(basemodel.arch_parameters(), lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
|
||||
|
||||
# snapshot
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-search.pth')
|
||||
if args.resume is not None and os.path.isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
base_scheduler.load_state_dict( checkpoint['base_scheduler'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
print_log('Load resume from {:} with start-epoch = {:}'.format(args.resume, start_epoch), log)
|
||||
elif os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
base_scheduler.load_state_dict( checkpoint['base_scheduler'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log)
|
||||
else:
|
||||
start_epoch, genotypes = 0, {}
|
||||
print_log('Train model-search from scratch.', log)
|
||||
|
||||
config = load_config(args.model_config)
|
||||
|
||||
if args.only_base:
|
||||
print_log('---- Only Train the Searched Model ----', log)
|
||||
main_procedure(config, args.dataset, args.data_path, args, basemodel.genotype(), 36, 20, log)
|
||||
return
|
||||
|
||||
# Main loop
|
||||
start_time, epoch_time, total_train_time = time.time(), AverageMeter(), 0
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
base_scheduler.step()
|
||||
|
||||
basemodel.set_tau( args.tau_max - epoch*1.0/args.epochs*(args.tau_max-args.tau_min) )
|
||||
#if epoch + 1 == args.epochs:
|
||||
# torch.cuda.empty_cache()
|
||||
# basemodel.set_gumbel(False)
|
||||
|
||||
need_time = convert_secs2time(epoch_time.val * (args.epochs-epoch), True)
|
||||
print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f} ~ {:6.4f}] [Batch={:d}], tau={:}'.format(time_string(), epoch, args.epochs, need_time, min(base_scheduler.get_lr()), max(base_scheduler.get_lr()), args.batch_size, basemodel.get_tau()), log)
|
||||
|
||||
genotype = basemodel.genotype()
|
||||
print_log('genotype = {:}'.format(genotype), log)
|
||||
|
||||
print_log('{:03d}/{:03d} alphas :\n{:}'.format(epoch, args.epochs, return_alphas_str(basemodel)), log)
|
||||
|
||||
# training
|
||||
if epoch + 1 == args.epochs:
|
||||
train_acc1, train_acc5, train_obj, train_time \
|
||||
= train_joint(train_loader, test_loader, model, criterion, base_optimizer, arch_optimizer, epoch, log)
|
||||
total_train_time += train_time
|
||||
else:
|
||||
train_acc1, train_acc5, train_obj, train_time \
|
||||
= train_base(train_loader, None, model, criterion, base_optimizer, None, epoch, log)
|
||||
total_train_time += train_time
|
||||
Arch__acc1, Arch__acc5, Arch__obj, train_time \
|
||||
= train_arch(None , test_loader, model, criterion, None, arch_optimizer, epoch, log)
|
||||
total_train_time += train_time
|
||||
# validation
|
||||
valid_acc1, valid_acc5, valid_obj = infer(test_loader, model, criterion, epoch, log)
|
||||
print_log('{:03d}/{:03d}, Train-Accuracy = {:.2f}, Arch-Accuracy = {:.2f}, Test-Accuracy = {:.2f}'.format(epoch, args.epochs, train_acc1, Arch__acc1, valid_acc1), log)
|
||||
|
||||
# save genotype
|
||||
genotypes[epoch] = basemodel.genotype()
|
||||
# save checkpoint
|
||||
torch.save({'epoch' : epoch + 1,
|
||||
'args' : deepcopy(args),
|
||||
'state_dict': basemodel.state_dict(),
|
||||
'genotypes' : genotypes,
|
||||
'base_optimizer' : base_optimizer.state_dict(),
|
||||
'arch_optimizer' : arch_optimizer.state_dict(),
|
||||
'base_scheduler' : base_scheduler.state_dict()},
|
||||
checkpoint_path)
|
||||
print_log('----> Save into {:}'.format(checkpoint_path), log)
|
||||
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
print_log('Finish with training time = {:}'.format( convert_secs2time(total_train_time, True) ), log)
|
||||
|
||||
# clear GPU cache
|
||||
torch.cuda.empty_cache()
|
||||
main_procedure(config, args.dataset, args.data_path, args, basemodel.genotype(), 36, 20, log)
|
||||
log.close()
|
||||
|
||||
|
||||
def train_base(train_queue, _, model, criterion, base_optimizer, __, epoch, log):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
model.train()
|
||||
|
||||
end = time.time()
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
|
||||
#inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the parameters
|
||||
base_optimizer.zero_grad()
|
||||
logits = model(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.module.base_parameters(), args.grad_clip)
|
||||
base_optimizer.step()
|
||||
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
objs.update(loss.item() , batch)
|
||||
top1.update(prec1.item(), batch)
|
||||
top5.update(prec5.item(), batch)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(train_queue):
|
||||
Sstr = ' TRAIN-BASE ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(train_queue))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg, batch_time.sum
|
||||
|
||||
|
||||
def train_arch(_, valid_queue, model, criterion, __, arch_optimizer, epoch, log):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
model.train()
|
||||
|
||||
end = time.time()
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
|
||||
#inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
outputs = model(inputs)
|
||||
arch_loss = criterion(outputs, targets)
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
prec1, prec5 = obtain_accuracy(outputs.data, targets.data, topk=(1, 5))
|
||||
objs.update(arch_loss.item() , batch)
|
||||
top1.update(prec1.item(), batch)
|
||||
top5.update(prec5.item(), batch)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(valid_queue):
|
||||
Sstr = ' TRAIN-ARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(valid_queue))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg, batch_time.sum
|
||||
|
||||
|
||||
def train_joint(train_queue, valid_queue, model, criterion, base_optimizer, arch_optimizer, epoch, log):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
model.train()
|
||||
|
||||
valid_iter = iter(valid_queue)
|
||||
end = time.time()
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
|
||||
#inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# get a random minibatch from the search queue with replacement
|
||||
try:
|
||||
input_search, target_search = next(valid_iter)
|
||||
except:
|
||||
valid_iter = iter(valid_queue)
|
||||
input_search, target_search = next(valid_iter)
|
||||
|
||||
target_search = target_search.cuda(non_blocking=True)
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
output_search = model(input_search)
|
||||
arch_loss = criterion(output_search, target_search)
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# update the parameters
|
||||
base_optimizer.zero_grad()
|
||||
logits = model(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.module.base_parameters(), args.grad_clip)
|
||||
base_optimizer.step()
|
||||
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
objs.update(loss.item() , batch)
|
||||
top1.update(prec1.item(), batch)
|
||||
top5.update(prec5.item(), batch)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(train_queue):
|
||||
Sstr = ' TRAIN-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(train_queue))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg, batch_time.sum
|
||||
|
||||
|
||||
def infer(valid_queue, model, criterion, epoch, log):
|
||||
objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
logits = model(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
objs.update(loss.item() , batch)
|
||||
top1.update(prec1.item(), batch)
|
||||
top5.update(prec5.item(), batch)
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(valid_queue):
|
||||
Sstr = ' VALID-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(valid_queue))
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Lstr, log)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
94
exps-cnn/cvpr-vis.py
Normal file
94
exps-cnn/cvpr-vis.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# python ./exps-nas/cvpr-vis.py --save_dir ./snapshots/NAS-VIS/
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from nas import DMS_V1, DMS_F1
|
||||
from nas_rnn import DARTS_V2, GDAS
|
||||
from graphviz import Digraph
|
||||
|
||||
parser = argparse.ArgumentParser("Visualize the Networks")
|
||||
parser.add_argument('--save_dir', type=str, help='The directory to save the network plot.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def plot_cnn(genotype, filename):
|
||||
g = Digraph(
|
||||
format='pdf',
|
||||
edge_attr=dict(fontsize='20', fontname="times"),
|
||||
node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
|
||||
engine='dot')
|
||||
g.body.extend(['rankdir=LR'])
|
||||
|
||||
g.node("c_{k-2}", fillcolor='darkseagreen2')
|
||||
g.node("c_{k-1}", fillcolor='darkseagreen2')
|
||||
assert len(genotype) % 2 == 0, '{:}'.format(genotype)
|
||||
steps = len(genotype) // 2
|
||||
|
||||
for i in range(steps):
|
||||
g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
for i in range(steps):
|
||||
for k in [2*i, 2*i + 1]:
|
||||
op, j, weight = genotype[k]
|
||||
if j == 0:
|
||||
u = "c_{k-2}"
|
||||
elif j == 1:
|
||||
u = "c_{k-1}"
|
||||
else:
|
||||
u = str(j-2)
|
||||
v = str(i)
|
||||
g.edge(u, v, label=op, fillcolor="gray")
|
||||
|
||||
g.node("c_{k}", fillcolor='palegoldenrod')
|
||||
for i in range(steps):
|
||||
g.edge(str(i), "c_{k}", fillcolor="gray")
|
||||
|
||||
g.render(filename, view=False)
|
||||
|
||||
def plot_rnn(genotype, filename):
|
||||
g = Digraph(
|
||||
format='pdf',
|
||||
edge_attr=dict(fontsize='20', fontname="times"),
|
||||
node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
|
||||
engine='dot')
|
||||
g.body.extend(['rankdir=LR'])
|
||||
|
||||
g.node("x_{t}", fillcolor='darkseagreen2')
|
||||
g.node("h_{t-1}", fillcolor='darkseagreen2')
|
||||
g.node("0", fillcolor='lightblue')
|
||||
g.edge("x_{t}", "0", fillcolor="gray")
|
||||
g.edge("h_{t-1}", "0", fillcolor="gray")
|
||||
steps = len(genotype)
|
||||
|
||||
for i in range(1, steps + 1):
|
||||
g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
for i, (op, j) in enumerate(genotype):
|
||||
g.edge(str(j), str(i + 1), label=op, fillcolor="gray")
|
||||
|
||||
g.node("h_{t}", fillcolor='palegoldenrod')
|
||||
for i in range(1, steps + 1):
|
||||
g.edge(str(i), "h_{t}", fillcolor="gray")
|
||||
|
||||
g.render(filename, view=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
save_dir = Path(args.save_dir)
|
||||
|
||||
save_path = str(save_dir / 'DMS_V1-normal')
|
||||
plot_cnn(DMS_V1.normal, save_path)
|
||||
save_path = str(save_dir / 'DMS_V1-reduce')
|
||||
plot_cnn(DMS_V1.reduce, save_path)
|
||||
save_path = str(save_dir / 'DMS_F1-normal')
|
||||
plot_cnn(DMS_F1.normal, save_path)
|
||||
|
||||
save_path = str(save_dir / 'DARTS-V2-RNN')
|
||||
plot_rnn(DARTS_V2.recurrent, save_path)
|
||||
|
||||
save_path = str(save_dir / 'GDAS-V1-RNN')
|
||||
plot_rnn(GDAS.recurrent, save_path)
|
312
exps-cnn/meta_search.py
Normal file
312
exps-cnn/meta_search.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from datasets import TieredImageNet, MetaBatchSampler
|
||||
from utils import AverageMeter, time_string, convert_secs2time
|
||||
from utils import print_log, obtain_accuracy
|
||||
from utils import Cutout, count_parameters_in_MB
|
||||
from meta_nas import return_alphas_str, MetaNetwork
|
||||
from train_utils import main_procedure
|
||||
from scheduler import load_config
|
||||
|
||||
Networks = {'meta': MetaNetwork}
|
||||
|
||||
parser = argparse.ArgumentParser("cifar")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--arch', type=str, choices=Networks.keys(), help='Choose networks.')
|
||||
parser.add_argument('--n_way', type=int, help='N-WAY.')
|
||||
parser.add_argument('--k_shot', type=int, help='K-SHOT.')
|
||||
# Learning Parameters
|
||||
parser.add_argument('--learning_rate_max', type=float, help='initial learning rate')
|
||||
parser.add_argument('--learning_rate_min', type=float, help='minimum learning rate')
|
||||
parser.add_argument('--momentum', type=float, help='momentum')
|
||||
parser.add_argument('--weight_decay', type=float, help='weight decay')
|
||||
parser.add_argument('--epochs', type=int, help='num of training epochs')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
#
|
||||
parser.add_argument('--init_channels', type=int, help='num of init channels')
|
||||
parser.add_argument('--layers', type=int, help='total number of layers')
|
||||
#
|
||||
parser.add_argument('--cutout', type=int, help='cutout length, negative means no cutout')
|
||||
parser.add_argument('--grad_clip', type=float, help='gradient clipping')
|
||||
parser.add_argument('--model_config', type=str , help='the model configuration')
|
||||
|
||||
# resume
|
||||
parser.add_argument('--resume', type=str , help='the resume path')
|
||||
parser.add_argument('--only_base',action='store_true', default=False, help='only train the searched model')
|
||||
# split data
|
||||
parser.add_argument('--validate', action='store_true', default=False, help='split train-data int train/val or not')
|
||||
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_path', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--manualSeed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert torch.cuda.is_available(), 'torch.cuda is not available'
|
||||
|
||||
if args.manualSeed is None:
|
||||
args.manualSeed = random.randint(1, 10000)
|
||||
random.seed(args.manualSeed)
|
||||
cudnn.benchmark = True
|
||||
cudnn.enabled = True
|
||||
torch.manual_seed(args.manualSeed)
|
||||
torch.cuda.manual_seed_all(args.manualSeed)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Init logger
|
||||
args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed))
|
||||
if not os.path.isdir(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w')
|
||||
print_log('save path : {}'.format(args.save_path), log)
|
||||
state = {k: v for k, v in args._get_kwargs()}
|
||||
print_log(state, log)
|
||||
print_log("Random Seed: {}".format(args.manualSeed), log)
|
||||
print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
|
||||
print_log("Torch version : {}".format(torch.__version__), log)
|
||||
print_log("CUDA version : {}".format(torch.version.cuda), log)
|
||||
print_log("cuDNN version : {}".format(cudnn.version()), log)
|
||||
print_log("Num of GPUs : {}".format(torch.cuda.device_count()), log)
|
||||
|
||||
# Mean + Std
|
||||
means, stds = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
# Data Argumentation
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(),
|
||||
transforms.Normalize(means, stds)]
|
||||
if args.cutout > 0 : lists += [Cutout(args.cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(means, stds)])
|
||||
|
||||
train_data = TieredImageNet(args.data_path, 'train', train_transform)
|
||||
test_data = TieredImageNet(args.data_path, 'val' , test_transform )
|
||||
|
||||
train_sampler = MetaBatchSampler(train_data.labels, args.n_way, args.k_shot * 2, len(train_data) // (args.n_way*args.k_shot))
|
||||
test_sampler = MetaBatchSampler( test_data.labels, args.n_way, args.k_shot * 2, len( test_data) // (args.n_way*args.k_shot))
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_sampler=train_sampler)
|
||||
test_loader = torch.utils.data.DataLoader( test_data, batch_sampler= test_sampler)
|
||||
|
||||
# network
|
||||
basemodel = Networks[args.arch](args.init_channels, args.layers, head='imagenet')
|
||||
model = torch.nn.DataParallel(basemodel).cuda()
|
||||
print_log("Parameter size = {:.3f} MB".format(count_parameters_in_MB(basemodel.base_parameters())), log)
|
||||
print_log("Train-transformation : {:}\nTest--transformation : {:}".format(train_transform, test_transform), log)
|
||||
|
||||
# optimizer and LR-scheduler
|
||||
#base_optimizer = torch.optim.SGD (basemodel.base_parameters(), args.learning_rate_max, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
base_optimizer = torch.optim.Adam(basemodel.base_parameters(), lr=args.learning_rate_max, betas=(0.5, 0.999), weight_decay=args.weight_decay)
|
||||
base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(base_optimizer, float(args.epochs), eta_min=args.learning_rate_min)
|
||||
arch_optimizer = torch.optim.Adam(basemodel.arch_parameters(), lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
|
||||
|
||||
# snapshot
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-meta-search.pth')
|
||||
if args.resume is not None and os.path.isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
base_scheduler.load_state_dict( checkpoint['base_scheduler'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
print_log('Load resume from {:} with start-epoch = {:}'.format(args.resume, start_epoch), log)
|
||||
elif os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
base_scheduler.load_state_dict( checkpoint['base_scheduler'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log)
|
||||
else:
|
||||
start_epoch, genotypes = 0, {}
|
||||
print_log('Train model-search from scratch.', log)
|
||||
|
||||
config = load_config(args.model_config)
|
||||
|
||||
if args.only_base:
|
||||
print_log('---- Only Train the Searched Model ----', log)
|
||||
CIFAR_DATA_DIR = os.environ['TORCH_HOME'] + '/cifar.python'
|
||||
main_procedure(config, 'cifar10', CIFAR_DATA_DIR, args, basemodel.genotype(), 36, 20, log)
|
||||
return
|
||||
|
||||
# Main loop
|
||||
start_time, epoch_time, total_train_time = time.time(), AverageMeter(), 0
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
base_scheduler.step()
|
||||
|
||||
need_time = convert_secs2time(epoch_time.val * (args.epochs-epoch), True)
|
||||
print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f} ~ {:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, min(base_scheduler.get_lr()), max(base_scheduler.get_lr())), log)
|
||||
|
||||
genotype = basemodel.genotype()
|
||||
print_log('genotype = {:}'.format(genotype), log)
|
||||
print_log('{:03d}/{:03d} alphas :\n{:}'.format(epoch, args.epochs, return_alphas_str(basemodel)), log)
|
||||
|
||||
# training
|
||||
train_acc1, train_obj, train_time \
|
||||
= train(train_loader, test_loader, model, args.n_way, base_optimizer, arch_optimizer, epoch, log)
|
||||
total_train_time += train_time
|
||||
# validation
|
||||
valid_acc1, valid_obj = infer(test_loader, model, epoch, args.n_way, log)
|
||||
|
||||
print_log('META -> {:}-way {:}-shot : {:03d}/{:03d} : Train Acc : {:.2f}, Test Acc : {:.2f}'.format(args.n_way, args.k_shot, epoch, args.epochs, train_acc1, valid_acc1), log)
|
||||
# save genotype
|
||||
genotypes[epoch] = basemodel.genotype()
|
||||
|
||||
# save checkpoint
|
||||
torch.save({'epoch' : epoch + 1,
|
||||
'args' : deepcopy(args),
|
||||
'state_dict': basemodel.state_dict(),
|
||||
'genotypes' : genotypes,
|
||||
'base_optimizer' : base_optimizer.state_dict(),
|
||||
'arch_optimizer' : arch_optimizer.state_dict(),
|
||||
'base_scheduler' : base_scheduler.state_dict()},
|
||||
checkpoint_path)
|
||||
print_log('----> Save into {:}'.format(checkpoint_path), log)
|
||||
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
print_log('Finish with training time = {:}'.format( convert_secs2time(total_train_time, True) ), log)
|
||||
|
||||
# clear GPU cache
|
||||
CIFAR_DATA_DIR = os.environ['TORCH_HOME'] + '/cifar.python'
|
||||
print_log('test for CIFAR-10', log)
|
||||
torch.cuda.empty_cache()
|
||||
main_procedure(config, 'cifar10' , CIFAR_DATA_DIR, args, basemodel.genotype(), 36, 20, log)
|
||||
print_log('test for CIFAR-100', log)
|
||||
torch.cuda.empty_cache()
|
||||
main_procedure(config, 'cifar100', CIFAR_DATA_DIR, args, basemodel.genotype(), 36, 20, log)
|
||||
log.close()
|
||||
|
||||
|
||||
|
||||
def euclidean_dist(A, B):
|
||||
na, da = A.size()
|
||||
nb, db = B.size()
|
||||
assert da == db, 'invalid feature dim : {:} vs. {:}'.format(da, db)
|
||||
X, Y = A.view(na, 1, da), B.view(1, nb, db)
|
||||
return torch.pow(X-Y, 2).sum(2)
|
||||
|
||||
|
||||
|
||||
def get_loss(features, targets, n_way):
|
||||
classes = torch.unique(targets)
|
||||
shot = features.size(0) // n_way // 2
|
||||
|
||||
support_index, query_index, labels = [], [], []
|
||||
for idx, cls in enumerate( classes.tolist() ):
|
||||
indexs = (targets == cls).nonzero().view(-1).tolist()
|
||||
support_index.append(indexs[:shot])
|
||||
query_index += indexs[shot:]
|
||||
labels += [idx] * shot
|
||||
query_features = features[query_index, :]
|
||||
support_features = features[support_index, :]
|
||||
support_features = torch.mean(support_features, dim=1)
|
||||
|
||||
labels = torch.LongTensor(labels).cuda(non_blocking=True)
|
||||
logits = -euclidean_dist(query_features, support_features)
|
||||
loss = F.cross_entropy(logits, labels)
|
||||
accuracy = obtain_accuracy(logits.data, labels.data, topk=(1,))[0]
|
||||
return loss, accuracy
|
||||
|
||||
|
||||
|
||||
def train(train_queue, valid_queue, model, n_way, base_optimizer, arch_optimizer, epoch, log):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
objs, accuracies = AverageMeter(), AverageMeter()
|
||||
model.train()
|
||||
|
||||
valid_iter = iter(valid_queue)
|
||||
end = time.time()
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
|
||||
#inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
|
||||
#targets = targets.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# get a random minibatch from the search queue with replacement
|
||||
try:
|
||||
input_search, target_search = next(valid_iter)
|
||||
except:
|
||||
valid_iter = iter(valid_queue)
|
||||
input_search, target_search = next(valid_iter)
|
||||
|
||||
#target_search = target_search.cuda(non_blocking=True)
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
feature_search = model(input_search)
|
||||
arch_loss, arch_accuracy = get_loss(feature_search, target_search, n_way)
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# update the parameters
|
||||
base_optimizer.zero_grad()
|
||||
feature_model = model(inputs)
|
||||
model_loss, model_accuracy = get_loss(feature_model, targets, n_way)
|
||||
|
||||
model_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.module.base_parameters(), args.grad_clip)
|
||||
base_optimizer.step()
|
||||
|
||||
objs.update(model_loss.item() , batch)
|
||||
accuracies.update(model_accuracy.item(), batch)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(train_queue):
|
||||
Sstr = ' TRAIN-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(train_queue))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f})'.format(loss=objs, top1=accuracies)
|
||||
Istr = 'I : {:}'.format( list(inputs.size()) )
|
||||
print_log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr, log)
|
||||
|
||||
return accuracies.avg, objs.avg, batch_time.sum
|
||||
|
||||
|
||||
|
||||
def infer(valid_queue, model, epoch, n_way, log):
|
||||
objs, accuracies = AverageMeter(), AverageMeter()
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
#targets = targets.cuda(non_blocking=True)
|
||||
|
||||
features = model(inputs)
|
||||
loss, accuracy = get_loss(features, targets, n_way)
|
||||
|
||||
objs.update(loss.item() , batch)
|
||||
accuracies.update(accuracy.item(), batch)
|
||||
|
||||
if step % (args.print_freq*4) == 0 or (step+1) == len(valid_queue):
|
||||
Sstr = ' VALID-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(valid_queue))
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f})'.format(loss=objs, top1=accuracies)
|
||||
print_log(Sstr + ' ' + Lstr, log)
|
||||
|
||||
return accuracies.avg, objs.avg
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
96
exps-cnn/train_base.py
Normal file
96
exps-cnn/train_base.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from utils import AverageMeter, time_string, convert_secs2time
|
||||
from utils import print_log, obtain_accuracy
|
||||
from utils import Cutout, count_parameters_in_MB
|
||||
from nas import DARTS_V1, DARTS_V2, NASNet, PNASNet, AmoebaNet, ENASNet
|
||||
from nas import DMS_V1, DMS_F1, GDAS_CC
|
||||
from meta_nas import META_V1, META_V2
|
||||
from train_utils import main_procedure
|
||||
from train_utils_imagenet import main_procedure_imagenet
|
||||
from scheduler import load_config
|
||||
|
||||
models = {'DARTS_V1': DARTS_V1,
|
||||
'DARTS_V2': DARTS_V2,
|
||||
'NASNet' : NASNet,
|
||||
'PNASNet' : PNASNet,
|
||||
'ENASNet' : ENASNet,
|
||||
'DMS_V1' : DMS_V1,
|
||||
'DMS_F1' : DMS_F1,
|
||||
'GDAS_CC' : GDAS_CC,
|
||||
'META_V1' : META_V1,
|
||||
'META_V2' : META_V2,
|
||||
'AmoebaNet' : AmoebaNet}
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser("cifar")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['imagenet', 'cifar10', 'cifar100'], help='Choose between Cifar10/100 and ImageNet.')
|
||||
parser.add_argument('--arch', type=str, choices=models.keys(), help='the searched model.')
|
||||
#
|
||||
parser.add_argument('--grad_clip', type=float, help='gradient clipping')
|
||||
parser.add_argument('--model_config', type=str , help='the model configuration')
|
||||
parser.add_argument('--init_channels', type=int , help='the initial number of channels')
|
||||
parser.add_argument('--layers', type=int , help='the number of layers.')
|
||||
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_path', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--manualSeed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert torch.cuda.is_available(), 'torch.cuda is not available'
|
||||
|
||||
if args.manualSeed is None:
|
||||
args.manualSeed = random.randint(1, 10000)
|
||||
random.seed(args.manualSeed)
|
||||
cudnn.benchmark = True
|
||||
cudnn.enabled = True
|
||||
torch.manual_seed(args.manualSeed)
|
||||
torch.cuda.manual_seed_all(args.manualSeed)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Init logger
|
||||
args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed))
|
||||
if not os.path.isdir(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w')
|
||||
print_log('save path : {}'.format(args.save_path), log)
|
||||
state = {k: v for k, v in args._get_kwargs()}
|
||||
print_log(state, log)
|
||||
print_log("Random Seed: {}".format(args.manualSeed), log)
|
||||
print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
|
||||
print_log("Torch version : {}".format(torch.__version__), log)
|
||||
print_log("CUDA version : {}".format(torch.version.cuda), log)
|
||||
print_log("cuDNN version : {}".format(cudnn.version()), log)
|
||||
print_log("Num of GPUs : {}".format(torch.cuda.device_count()), log)
|
||||
args.dataset = args.dataset.lower()
|
||||
|
||||
config = load_config(args.model_config)
|
||||
genotype = models[args.arch]
|
||||
print_log('configuration : {:}'.format(config), log)
|
||||
print_log('genotype : {:}'.format(genotype), log)
|
||||
# clear GPU cache
|
||||
torch.cuda.empty_cache()
|
||||
if args.dataset == 'imagenet':
|
||||
main_procedure_imagenet(config, args.data_path, args, genotype, args.init_channels, args.layers, log)
|
||||
else:
|
||||
main_procedure(config, args.dataset, args.data_path, args, genotype, args.init_channels, args.layers, log)
|
||||
log.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
312
exps-cnn/train_search.py
Normal file
312
exps-cnn/train_search.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from utils import AverageMeter, time_string, convert_secs2time
|
||||
from utils import print_log, obtain_accuracy
|
||||
from utils import Cutout, count_parameters_in_MB
|
||||
from datasets import TieredImageNet
|
||||
from nas import return_alphas_str, Network, NetworkV1, NetworkF1
|
||||
from train_utils import main_procedure
|
||||
from scheduler import load_config
|
||||
|
||||
Networks = {'base': Network, 'share': NetworkV1, 'fix': NetworkF1}
|
||||
|
||||
parser = argparse.ArgumentParser("CNN")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'tiered'], help='Choose between Cifar10/100 and TieredImageNet.')
|
||||
parser.add_argument('--arch', type=str, choices=Networks.keys(), help='Choose networks.')
|
||||
parser.add_argument('--batch_size', type=int, help='the batch size')
|
||||
parser.add_argument('--learning_rate_max', type=float, help='initial learning rate')
|
||||
parser.add_argument('--learning_rate_min', type=float, help='minimum learning rate')
|
||||
parser.add_argument('--momentum', type=float, help='momentum')
|
||||
parser.add_argument('--weight_decay', type=float, help='weight decay')
|
||||
parser.add_argument('--epochs', type=int, help='num of training epochs')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
#
|
||||
parser.add_argument('--init_channels', type=int, help='num of init channels')
|
||||
parser.add_argument('--layers', type=int, help='total number of layers')
|
||||
#
|
||||
parser.add_argument('--cutout', type=int, help='cutout length, negative means no cutout')
|
||||
parser.add_argument('--grad_clip', type=float, help='gradient clipping')
|
||||
parser.add_argument('--model_config', type=str , help='the model configuration')
|
||||
|
||||
# resume
|
||||
parser.add_argument('--resume', type=str , help='the resume path')
|
||||
parser.add_argument('--only_base',action='store_true', default=False, help='only train the searched model')
|
||||
# split data
|
||||
parser.add_argument('--validate', action='store_true', default=False, help='split train-data int train/val or not')
|
||||
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_path', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--manualSeed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert torch.cuda.is_available(), 'torch.cuda is not available'
|
||||
|
||||
if args.manualSeed is None:
|
||||
args.manualSeed = random.randint(1, 10000)
|
||||
random.seed(args.manualSeed)
|
||||
cudnn.benchmark = True
|
||||
cudnn.enabled = True
|
||||
torch.manual_seed(args.manualSeed)
|
||||
torch.cuda.manual_seed_all(args.manualSeed)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Init logger
|
||||
args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed))
|
||||
if not os.path.isdir(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w')
|
||||
print_log('save path : {}'.format(args.save_path), log)
|
||||
state = {k: v for k, v in args._get_kwargs()}
|
||||
print_log(state, log)
|
||||
print_log("Random Seed: {}".format(args.manualSeed), log)
|
||||
print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
|
||||
print_log("Torch version : {}".format(torch.__version__), log)
|
||||
print_log("CUDA version : {}".format(torch.version.cuda), log)
|
||||
print_log("cuDNN version : {}".format(cudnn.version()), log)
|
||||
print_log("Num of GPUs : {}".format(torch.cuda.device_count()), log)
|
||||
args.dataset = args.dataset.lower()
|
||||
|
||||
# Mean + Std
|
||||
if args.dataset == 'cifar10':
|
||||
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||
elif args.dataset == 'cifar100':
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
elif args.dataset == 'tiered':
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(args.dataset))
|
||||
# Data Argumentation
|
||||
if args.dataset == 'cifar10' or args.dataset == 'cifar100':
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std)]
|
||||
if args.cutout > 0 : lists += [Cutout(args.cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
elif args.dataset == 'tiered':
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
|
||||
if args.cutout > 0 : lists += [Cutout(args.cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(args.dataset))
|
||||
# Datasets
|
||||
if args.dataset == 'cifar10':
|
||||
train_data = dset.CIFAR10(args.data_path, train= True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform , download=True)
|
||||
num_classes, head = 10, 'cifar'
|
||||
elif args.dataset == 'cifar100':
|
||||
train_data = dset.CIFAR100(args.data_path, train= True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform , download=True)
|
||||
num_classes, head = 100, 'cifar'
|
||||
elif args.dataset == 'tiered':
|
||||
train_data = TieredImageNet(args.data_path, 'train-val', train_transform)
|
||||
test_data = None
|
||||
num_classes, head = train_data.n_classes, 'imagenet'
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(args.dataset))
|
||||
# Data Loader
|
||||
if args.validate:
|
||||
indices = list(range(len(train_data)))
|
||||
split = int(args.train_portion * len(indices))
|
||||
random.shuffle(indices)
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
|
||||
pin_memory=True, num_workers=args.workers)
|
||||
test_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
|
||||
pin_memory=True, num_workers=args.workers)
|
||||
else:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
|
||||
test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
|
||||
|
||||
# network and criterion
|
||||
criterion = torch.nn.CrossEntropyLoss().cuda()
|
||||
basemodel = Networks[args.arch](args.init_channels, num_classes, args.layers, head=head)
|
||||
model = torch.nn.DataParallel(basemodel).cuda()
|
||||
print_log("Network : {:}".format(model), log)
|
||||
print_log("Parameter size = {:.3f} MB".format(count_parameters_in_MB(basemodel.base_parameters())), log)
|
||||
print_log("Train-transformation : {:}\nTest--transformation : {:}\nClass number : {:}".format(train_transform, test_transform, num_classes), log)
|
||||
|
||||
# optimizer and LR-scheduler
|
||||
base_optimizer = torch.optim.SGD (basemodel.base_parameters(), args.learning_rate_max, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(base_optimizer, float(args.epochs), eta_min=args.learning_rate_min)
|
||||
arch_optimizer = torch.optim.Adam(basemodel.arch_parameters(), lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
|
||||
|
||||
# snapshot
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-search.pth')
|
||||
if args.resume is not None and os.path.isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
base_scheduler.load_state_dict( checkpoint['base_scheduler'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
print_log('Load resume from {:} with start-epoch = {:}'.format(args.resume, start_epoch), log)
|
||||
elif os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
base_scheduler.load_state_dict( checkpoint['base_scheduler'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log)
|
||||
else:
|
||||
start_epoch, genotypes = 0, {}
|
||||
print_log('Train model-search from scratch.', log)
|
||||
|
||||
config = load_config(args.model_config)
|
||||
|
||||
if args.only_base:
|
||||
print_log('---- Only Train the Searched Model ----', log)
|
||||
main_procedure(config, args.dataset, args.data_path, args, basemodel.genotype(), 36, 20, log)
|
||||
return
|
||||
|
||||
# Main loop
|
||||
start_time, epoch_time, total_train_time = time.time(), AverageMeter(), 0
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
base_scheduler.step()
|
||||
|
||||
need_time = convert_secs2time(epoch_time.val * (args.epochs-epoch), True)
|
||||
print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f} ~ {:6.4f}] [Batch={:d}]'.format(time_string(), epoch, args.epochs, need_time, min(base_scheduler.get_lr()), max(base_scheduler.get_lr()), args.batch_size), log)
|
||||
|
||||
genotype = basemodel.genotype()
|
||||
print_log('genotype = {:}'.format(genotype), log)
|
||||
|
||||
print_log('{:03d}/{:03d} alphas :\n{:}'.format(epoch, args.epochs, return_alphas_str(basemodel)), log)
|
||||
|
||||
# training
|
||||
train_acc1, train_acc5, train_obj, train_time \
|
||||
= train(train_loader, test_loader, model, criterion, base_optimizer, arch_optimizer, epoch, log)
|
||||
total_train_time += train_time
|
||||
# validation
|
||||
valid_acc1, valid_acc5, valid_obj = infer(test_loader, model, criterion, epoch, log)
|
||||
print_log('Base-Search : {:03d}/{:03d} : Train-Acc={:.3f}, Test-Acc={:.3f}'.format(epoch, args.epochs, train_acc1, valid_acc1), log)
|
||||
# save genotype
|
||||
genotypes[epoch] = basemodel.genotype()
|
||||
# save checkpoint
|
||||
torch.save({'epoch' : epoch + 1,
|
||||
'args' : deepcopy(args),
|
||||
'state_dict': basemodel.state_dict(),
|
||||
'genotypes' : genotypes,
|
||||
'base_optimizer' : base_optimizer.state_dict(),
|
||||
'arch_optimizer' : arch_optimizer.state_dict(),
|
||||
'base_scheduler' : base_scheduler.state_dict()},
|
||||
checkpoint_path)
|
||||
print_log('----> Save into {:}'.format(checkpoint_path), log)
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
print_log('Finish with training time = {:}'.format( convert_secs2time(total_train_time, True) ), log)
|
||||
|
||||
# clear GPU cache
|
||||
torch.cuda.empty_cache()
|
||||
main_procedure(config, 'cifar10', os.environ['TORCH_HOME'] + '/cifar.python', args, basemodel.genotype(), 36, 20, log)
|
||||
log.close()
|
||||
|
||||
|
||||
def train(train_queue, valid_queue, model, criterion, base_optimizer, arch_optimizer, epoch, log):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
model.train()
|
||||
|
||||
valid_iter = iter(valid_queue)
|
||||
end = time.time()
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
|
||||
#inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# get a random minibatch from the search queue with replacement
|
||||
try:
|
||||
input_search, target_search = next(valid_iter)
|
||||
except:
|
||||
valid_iter = iter(valid_queue)
|
||||
input_search, target_search = next(valid_iter)
|
||||
|
||||
target_search = target_search.cuda(non_blocking=True)
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
output_search = model(input_search)
|
||||
arch_loss = criterion(output_search, target_search)
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# update the parameters
|
||||
base_optimizer.zero_grad()
|
||||
logits = model(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.module.base_parameters(), args.grad_clip)
|
||||
base_optimizer.step()
|
||||
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
objs.update(loss.item() , batch)
|
||||
top1.update(prec1.item(), batch)
|
||||
top5.update(prec5.item(), batch)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(train_queue):
|
||||
Sstr = ' TRAIN-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(train_queue))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg, batch_time.sum
|
||||
|
||||
|
||||
def infer(valid_queue, model, criterion, epoch, log):
|
||||
objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
logits = model(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
objs.update(loss.item() , batch)
|
||||
top1.update(prec1.item(), batch)
|
||||
top5.update(prec5.item(), batch)
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(valid_queue):
|
||||
Sstr = ' VALID-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(valid_queue))
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Lstr, log)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
184
exps-cnn/train_utils.py
Normal file
184
exps-cnn/train_utils.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import os, sys, time
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
|
||||
from utils import print_log, obtain_accuracy, AverageMeter
|
||||
from utils import time_string, convert_secs2time
|
||||
from utils import count_parameters_in_MB
|
||||
from utils import Cutout
|
||||
from nas import NetworkCIFAR as Network
|
||||
|
||||
def obtain_best(accuracies):
|
||||
if len(accuracies) == 0: return (0, 0)
|
||||
tops = [value for key, value in accuracies.items()]
|
||||
s2b = sorted( tops )
|
||||
return s2b[-1]
|
||||
|
||||
def main_procedure(config, dataset, data_path, args, genotype, init_channels, layers, log):
|
||||
|
||||
# Mean + Std
|
||||
if dataset == 'cifar10':
|
||||
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||
elif dataset == 'cifar100':
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(dataset))
|
||||
# Dataset Transformation
|
||||
if dataset == 'cifar10' or dataset == 'cifar100':
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std)]
|
||||
if config.cutout > 0 : lists += [Cutout(config.cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(dataset))
|
||||
# Dataset Defination
|
||||
if dataset == 'cifar10':
|
||||
train_data = dset.CIFAR10(data_path, train= True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR10(data_path, train=False, transform=test_transform , download=True)
|
||||
class_num = 10
|
||||
elif dataset == 'cifar100':
|
||||
train_data = dset.CIFAR100(data_path, train= True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(data_path, train=False, transform=test_transform , download=True)
|
||||
class_num = 100
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(dataset))
|
||||
|
||||
|
||||
print_log('-------------------------------------- main-procedure', log)
|
||||
print_log('config : {:}'.format(config), log)
|
||||
print_log('genotype : {:}'.format(genotype), log)
|
||||
print_log('init_channels : {:}'.format(init_channels), log)
|
||||
print_log('layers : {:}'.format(layers), log)
|
||||
print_log('class_num : {:}'.format(class_num), log)
|
||||
basemodel = Network(init_channels, class_num, layers, config.auxiliary, genotype)
|
||||
model = torch.nn.DataParallel(basemodel).cuda()
|
||||
|
||||
total_param, aux_param = count_parameters_in_MB(basemodel), count_parameters_in_MB(basemodel.auxiliary_param())
|
||||
print_log('Network =>\n{:}'.format(basemodel), log)
|
||||
print_log('Parameters : {:} - {:} = {:.3f} MB'.format(total_param, aux_param, total_param - aux_param), log)
|
||||
print_log('config : {:}'.format(config), log)
|
||||
print_log('genotype : {:}'.format(genotype), log)
|
||||
print_log('args : {:}'.format(args), log)
|
||||
print_log('Train-Dataset : {:}'.format(train_data), log)
|
||||
print_log('Train-Trans : {:}'.format(train_transform), log)
|
||||
print_log('Test--Dataset : {:}'.format(test_data ), log)
|
||||
print_log('Test--Trans : {:}'.format(test_transform ), log)
|
||||
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True,
|
||||
num_workers=args.workers, pin_memory=True)
|
||||
test_loader = torch.utils.data.DataLoader(test_data , batch_size=config.batch_size, shuffle=False,
|
||||
num_workers=args.workers, pin_memory=True)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss().cuda()
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay)
|
||||
#optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay, nestero=True)
|
||||
if config.type == 'cosine':
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(config.epochs))
|
||||
else:
|
||||
raise ValueError('Can not find the schedular type : {:}'.format(config.type))
|
||||
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-{:}-model.pth'.format(dataset))
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load( checkpoint_path )
|
||||
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
accuracies = checkpoint['accuracies']
|
||||
print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log)
|
||||
else:
|
||||
start_epoch, accuracies = 0, {}
|
||||
print_log('Train model from scratch without pre-trained model or snapshot', log)
|
||||
|
||||
|
||||
# Main loop
|
||||
start_time, epoch_time = time.time(), AverageMeter()
|
||||
for epoch in range(start_epoch, config.epochs):
|
||||
scheduler.step()
|
||||
|
||||
need_time = convert_secs2time(epoch_time.val * (config.epochs-epoch), True)
|
||||
print_log("\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} LR={:6.4f} ~ {:6.4f}, Batch={:d}".format(time_string(), epoch, config.epochs, need_time, min(scheduler.get_lr()), max(scheduler.get_lr()), config.batch_size), log)
|
||||
|
||||
basemodel.update_drop_path(config.drop_path_prob * epoch / config.epochs)
|
||||
|
||||
train_acc1, train_acc5, train_los = _train(train_loader, model, criterion, optimizer, 'train', epoch, config, args.print_freq, log)
|
||||
|
||||
with torch.no_grad():
|
||||
valid_acc1, valid_acc5, valid_los = _train(test_loader, model, criterion, optimizer, 'test', epoch, config, args.print_freq, log)
|
||||
accuracies[epoch] = (valid_acc1, valid_acc5)
|
||||
|
||||
torch.save({'epoch' : epoch + 1,
|
||||
'args' : deepcopy(args),
|
||||
'state_dict': basemodel.state_dict(),
|
||||
'optimizer' : optimizer.state_dict(),
|
||||
'scheduler' : scheduler.state_dict(),
|
||||
'accuracies': accuracies},
|
||||
checkpoint_path)
|
||||
best_acc = obtain_best( accuracies )
|
||||
print_log('----> Best Accuracy : Acc@1={:.2f}, Acc@5={:.2f}, Error@1={:.2f}, Error@5={:.2f}'.format(best_acc[0], best_acc[1], 100-best_acc[0], 100-best_acc[1]), log)
|
||||
print_log('----> Save into {:}'.format(checkpoint_path), log)
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
def _train(xloader, model, criterion, optimizer, mode, epoch, config, print_freq, log):
|
||||
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == 'train':
|
||||
model.train()
|
||||
elif mode == 'test':
|
||||
model.eval()
|
||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mode == 'train': optimizer.zero_grad()
|
||||
|
||||
if config.auxiliary and model.training:
|
||||
logits, logits_aux = model(inputs)
|
||||
else:
|
||||
logits = model(inputs)
|
||||
|
||||
loss = criterion(logits, targets)
|
||||
if config.auxiliary and model.training:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary_weight * loss_aux
|
||||
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
if config.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
|
||||
optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i+1) == len(xloader):
|
||||
Sstr = ' {:5s}'.format(mode) + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, i, len(xloader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log)
|
||||
|
||||
print_log ('{TIME:} **{mode:}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(TIME=time_string(), mode=mode, top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg), log)
|
||||
return top1.avg, top5.avg, losses.avg
|
207
exps-cnn/train_utils_imagenet.py
Normal file
207
exps-cnn/train_utils_imagenet.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import os, sys, time
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
|
||||
from utils import print_log, obtain_accuracy, AverageMeter
|
||||
from utils import time_string, convert_secs2time
|
||||
from utils import count_parameters_in_MB
|
||||
from utils import print_FLOPs
|
||||
from utils import Cutout
|
||||
from nas import NetworkImageNet as Network
|
||||
|
||||
|
||||
def obtain_best(accuracies):
|
||||
if len(accuracies) == 0: return (0, 0)
|
||||
tops = [value for key, value in accuracies.items()]
|
||||
s2b = sorted( tops )
|
||||
return s2b[-1]
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, epsilon):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
|
||||
def main_procedure_imagenet(config, data_path, args, genotype, init_channels, layers, log):
|
||||
|
||||
# training data and testing data
|
||||
traindir = os.path.join(data_path, 'train')
|
||||
validdir = os.path.join(data_path, 'val')
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_data = dset.ImageFolder(
|
||||
traindir,
|
||||
transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.2),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
valid_data = dset.ImageFolder(
|
||||
validdir,
|
||||
transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=config.batch_size, shuffle= True, pin_memory=True, num_workers=args.workers)
|
||||
|
||||
valid_queue = torch.utils.data.DataLoader(
|
||||
valid_data, batch_size=config.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers)
|
||||
|
||||
class_num = 1000
|
||||
|
||||
|
||||
print_log('-------------------------------------- main-procedure', log)
|
||||
print_log('config : {:}'.format(config), log)
|
||||
print_log('genotype : {:}'.format(genotype), log)
|
||||
print_log('init_channels : {:}'.format(init_channels), log)
|
||||
print_log('layers : {:}'.format(layers), log)
|
||||
print_log('class_num : {:}'.format(class_num), log)
|
||||
basemodel = Network(init_channels, class_num, layers, config.auxiliary, genotype)
|
||||
model = torch.nn.DataParallel(basemodel).cuda()
|
||||
|
||||
total_param, aux_param = count_parameters_in_MB(basemodel), count_parameters_in_MB(basemodel.auxiliary_param())
|
||||
print_log('Network =>\n{:}'.format(basemodel), log)
|
||||
#print_FLOPs(basemodel, (1,3,224,224), [print_log, log])
|
||||
print_log('Parameters : {:} - {:} = {:.3f} MB'.format(total_param, aux_param, total_param - aux_param), log)
|
||||
print_log('config : {:}'.format(config), log)
|
||||
print_log('genotype : {:}'.format(genotype), log)
|
||||
print_log('Train-Dataset : {:}'.format(train_data), log)
|
||||
print_log('Valid--Dataset : {:}'.format(valid_data), log)
|
||||
print_log('Args : {:}'.format(args), log)
|
||||
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss().cuda()
|
||||
criterion_smooth = CrossEntropyLabelSmooth(class_num, config.label_smooth).cuda()
|
||||
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay)
|
||||
#optimizer = torch.optim.SGD(model.parameters(), config.LR, momentum=config.momentum, weight_decay=config.decay, nestero=True)
|
||||
if config.type == 'cosine':
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(config.epochs))
|
||||
elif config.type == 'steplr':
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config.decay_period, gamma=config.gamma)
|
||||
else:
|
||||
raise ValueError('Can not find the schedular type : {:}'.format(config.type))
|
||||
|
||||
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-imagenet-model.pth')
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load( checkpoint_path )
|
||||
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
accuracies = checkpoint['accuracies']
|
||||
print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log)
|
||||
else:
|
||||
start_epoch, accuracies = 0, {}
|
||||
print_log('Train model from scratch without pre-trained model or snapshot', log)
|
||||
|
||||
|
||||
# Main loop
|
||||
start_time, epoch_time = time.time(), AverageMeter()
|
||||
for epoch in range(start_epoch, config.epochs):
|
||||
scheduler.step()
|
||||
|
||||
need_time = convert_secs2time(epoch_time.val * (config.epochs-epoch), True)
|
||||
print_log("\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} LR={:6.4f} ~ {:6.4f}, Batch={:d}".format(time_string(), epoch, config.epochs, need_time, min(scheduler.get_lr()), max(scheduler.get_lr()), config.batch_size), log)
|
||||
|
||||
basemodel.update_drop_path(config.drop_path_prob * epoch / config.epochs)
|
||||
|
||||
train_acc1, train_acc5, train_los = _train(train_queue, model, criterion_smooth, optimizer, 'train', epoch, config, args.print_freq, log)
|
||||
|
||||
with torch.no_grad():
|
||||
valid_acc1, valid_acc5, valid_los = _train(valid_queue, model, criterion, None, 'test' , epoch, config, args.print_freq, log)
|
||||
accuracies[epoch] = (valid_acc1, valid_acc5)
|
||||
|
||||
torch.save({'epoch' : epoch + 1,
|
||||
'args' : deepcopy(args),
|
||||
'state_dict': basemodel.state_dict(),
|
||||
'optimizer' : optimizer.state_dict(),
|
||||
'scheduler' : scheduler.state_dict(),
|
||||
'accuracies': accuracies},
|
||||
checkpoint_path)
|
||||
best_acc = obtain_best( accuracies )
|
||||
print_log('----> Best Accuracy : Acc@1={:.2f}, Acc@5={:.2f}, Error@1={:.2f}, Error@5={:.2f}'.format(best_acc[0], best_acc[1], 100-best_acc[0], 100-best_acc[1]), log)
|
||||
print_log('----> Save into {:}'.format(checkpoint_path), log)
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
def _train(xloader, model, criterion, optimizer, mode, epoch, config, print_freq, log):
|
||||
data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == 'train':
|
||||
model.train()
|
||||
elif mode == 'test':
|
||||
model.eval()
|
||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mode == 'train': optimizer.zero_grad()
|
||||
|
||||
if config.auxiliary and model.training:
|
||||
logits, logits_aux = model(inputs)
|
||||
else:
|
||||
logits = model(inputs)
|
||||
|
||||
loss = criterion(logits, targets)
|
||||
if config.auxiliary and model.training:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary_weight * loss_aux
|
||||
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
if config.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
|
||||
optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update (prec1.item(), inputs.size(0))
|
||||
top5.update (prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i+1) == len(xloader):
|
||||
Sstr = ' {:5s}'.format(mode) + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, i, len(xloader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=losses, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log)
|
||||
|
||||
print_log ('{TIME:} **{mode:}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'.format(TIME=time_string(), mode=mode, top1=top1, top5=top5, error1=100-top1.avg, error5=100-top5.avg, loss=losses.avg), log)
|
||||
return top1.avg, top5.avg, losses.avg
|
69
exps-cnn/vis-arch.py
Normal file
69
exps-cnn/vis-arch.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from graphviz import Digraph
|
||||
|
||||
parser = argparse.ArgumentParser("Visualize the Networks")
|
||||
parser.add_argument('--checkpoint', type=str, help='The path to the checkpoint.')
|
||||
parser.add_argument('--save_dir', type=str, help='The directory to save the network plot.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def plot(genotype, filename):
|
||||
g = Digraph(
|
||||
format='pdf',
|
||||
edge_attr=dict(fontsize='20', fontname="times"),
|
||||
node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
|
||||
engine='dot')
|
||||
g.body.extend(['rankdir=LR'])
|
||||
|
||||
g.node("c_{k-2}", fillcolor='darkseagreen2')
|
||||
g.node("c_{k-1}", fillcolor='darkseagreen2')
|
||||
assert len(genotype) % 2 == 0
|
||||
steps = len(genotype) // 2
|
||||
|
||||
for i in range(steps):
|
||||
g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
for i in range(steps):
|
||||
for k in [2*i, 2*i + 1]:
|
||||
op, j, weight = genotype[k]
|
||||
if j == 0:
|
||||
u = "c_{k-2}"
|
||||
elif j == 1:
|
||||
u = "c_{k-1}"
|
||||
else:
|
||||
u = str(j-2)
|
||||
v = str(i)
|
||||
g.edge(u, v, label=op, fillcolor="gray")
|
||||
|
||||
g.node("c_{k}", fillcolor='palegoldenrod')
|
||||
for i in range(steps):
|
||||
g.edge(str(i), "c_{k}", fillcolor="gray")
|
||||
|
||||
g.render(filename, view=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
checkpoint = args.checkpoint
|
||||
assert os.path.isfile(checkpoint), 'Invalid path for checkpoint : {:}'.format(checkpoint)
|
||||
checkpoint = torch.load( checkpoint, map_location='cpu' )
|
||||
genotypes = checkpoint['genotypes']
|
||||
save_dir = Path(args.save_dir)
|
||||
subs = ['normal', 'reduce']
|
||||
for sub in subs:
|
||||
if not (save_dir / sub).exists():
|
||||
(save_dir / sub).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for key, network in genotypes.items():
|
||||
save_path = str(save_dir / 'normal' / 'epoch-{:03d}'.format( int(key) ))
|
||||
print('save into {:}'.format(save_path))
|
||||
plot(network.normal, save_path)
|
||||
|
||||
save_path = str(save_dir / 'reduce' / 'epoch-{:03d}'.format( int(key) ))
|
||||
print('save into {:}'.format(save_path))
|
||||
plot(network.reduce, save_path)
|
Reference in New Issue
Block a user