Update NATS-Bench (sss version 1.2)

This commit is contained in:
D-X-Y
2020-08-30 08:04:52 +00:00
parent 469a207945
commit 5f151d1970
15 changed files with 317 additions and 229 deletions

187
exps/NATS-algos/bohb.py Normal file
View File

@@ -0,0 +1,187 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
###################################################################
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale #
# required to install hpbandster ##################################
# pip install hpbandster ##################################
###################################################################
# OMP_NUM_THREADS=4 python exps/NATS-algos/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
# OMP_NUM_THREADS=4 python exps/NATS-algos/bohb.py --search_space sss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
###################################################################
import os, sys, time, random, argparse, collections
from copy import deepcopy
from pathlib import Path
import torch
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger
from log_utils import AverageMeter, time_string, convert_secs2time
from nats_bench import create
from models import CellStructure, get_search_spaces
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
import ConfigSpace
from hpbandster.optimizers.bohb import BOHB
import hpbandster.core.nameserver as hpns
from hpbandster.core.worker import Worker
def get_topology_config_space(search_space, max_nodes=4):
cs = ConfigSpace.ConfigurationSpace()
#edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space))
return cs
def get_size_config_space(search_space):
cs = ConfigSpace.ConfigurationSpace()
for ilayer in range(search_space['numbers']):
node_str = 'layer-{:}'.format(ilayer)
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space['candidates']))
return cs
def config2topology_func(max_nodes=4):
def config2structure(config):
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = config[node_str]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
return config2structure
def config2size_func(search_space):
def config2structure(config):
channels = []
for ilayer in range(search_space['numbers']):
node_str = 'layer-{:}'.format(ilayer)
channels.append(str(config[node_str]))
return ':'.join(channels)
return config2structure
class MyWorker(Worker):
def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs):
super().__init__(*args, **kwargs)
self.convert_func = convert_func
self._dataset = dataset
self._api = api
self.total_times = []
self.trajectory = []
def compute(self, config, budget, **kwargs):
arch = self.convert_func( config )
accuracy, latency, time_cost, total_time = self._api.simulate_train_eval(arch, self._dataset, iepoch=int(budget)-1, hp='12')
self.trajectory.append((accuracy, arch))
self.total_times.append(total_time)
return ({'loss': 100 - accuracy,
'info': self._api.query_index_by_arch(arch)})
def main(xargs, api):
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
logger.log('{:} use api : {:}'.format(time_string(), api))
api.reset_time()
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
if xargs.search_space == 'tss':
cs = get_topology_config_space(search_space)
config2structure = config2topology_func()
else:
cs = get_size_config_space(search_space)
config2structure = config2size_func(search_space)
hb_run_id = '0'
NS = hpns.NameServer(run_id=hb_run_id, host='localhost', port=0)
ns_host, ns_port = NS.start()
num_workers = 1
workers = []
for i in range(num_workers):
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataset=xargs.dataset, api=api, run_id=hb_run_id, id=i)
w.run(background=True)
workers.append(w)
start_time = time.time()
bohb = BOHB(configspace=cs, run_id=hb_run_id,
eta=3, min_budget=1, max_budget=12,
nameserver=ns_host,
nameserver_port=ns_port,
num_samples=xargs.num_samples,
random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor,
ping_interval=10, min_bandwidth=xargs.min_bandwidth)
results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
bohb.shutdown(shutdown_workers=True)
NS.shutdown()
# print('There are {:} runs.'.format(len(results.get_all_runs())))
# workers[0].total_times
# workers[0].trajectory
current_best_index = []
for idx in range(len(workers[0].trajectory)):
trajectory = workers[0].trajectory[:idx+1]
arch = max(trajectory, key=lambda x: x[0])[1]
current_best_index.append(api.query_index_by_arch(arch))
best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1]
logger.log('Best found configuration: {:} within {:.3f} s'.format(best_arch, workers[0].total_times[-1]))
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
return logger.log_dir, current_best_index, workers[0].total_times
if __name__ == '__main__':
parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale")
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# general arg
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# BOHB
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations')
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
parser.add_argument('--n_iters', default=300, type=int, nargs='?', help='number of iterations for optimization method')
# log
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
api = create(None, args.search_space, fast_mode=True, verbose=False)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'BOHB')
print('save-dir : {:}'.format(args.save_dir))
if args.rand_seed < 0:
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {'all_archs': all_archs,
'all_total_times': all_total_times}
save_path = save_dir / 'results.pth'
print('save into {:}'.format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)

View File

@@ -0,0 +1,91 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##############################################################################
# Random Search for Hyper-Parameter Optimization, JMLR 2012 ##################
##############################################################################
# python ./exps/NATS-algos/random_wo_share.py --dataset cifar10 --search_space tss
# python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space tss
# python ./exps/NATS-algos/random_wo_share.py --dataset ImageNet16-120 --search_space tss
##############################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_search_spaces
from nats_bench import create
from regularized_ea import random_topology_func, random_size_func
def main(xargs, api):
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
logger.log('{:} use api : {:}'.format(time_string(), api))
api.reset_time()
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
if xargs.search_space == 'tss':
random_arch = random_topology_func(search_space)
else:
random_arch = random_size_func(search_space)
best_arch, best_acc, total_time_cost, history = None, -1, [], []
current_best_index = []
while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget:
arch = random_arch()
accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12')
total_time_cost.append(total_cost)
history.append(arch)
if best_arch is None or best_acc < accuracy:
best_acc, best_arch = accuracy, arch
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
current_best_index.append(api.query_index_by_arch(best_arch))
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.'.format(time_string(), best_arch, best_acc, len(history), total_time_cost[-1]))
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
return logger.log_dir, current_best_index, total_time_cost
if __name__ == '__main__':
parser = argparse.ArgumentParser("Random NAS")
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
api = create(None, args.search_space, fast_mode=True, verbose=False)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'RANDOM')
print('save-dir : {:}'.format(args.save_dir))
if args.rand_seed < 0:
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {'all_archs': all_archs,
'all_total_times': all_total_times}
save_path = save_dir / 'results.pth'
print('save into {:}'.format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)

View File

@@ -0,0 +1,219 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##################################################################
# Regularized Evolution for Image Classifier Architecture Search #
##################################################################
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
##################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import CellStructure, get_search_spaces
from nats_bench import create
class Model(object):
def __init__(self):
self.arch = None
self.accuracy = None
def __str__(self):
"""Prints a readable version of this bitstring."""
return '{:}'.format(self.arch)
def random_topology_func(op_names, max_nodes=4):
# Return a random architecture
def random_architecture():
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = random.choice( op_names )
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
return random_architecture
def random_size_func(info):
# Return a random architecture
def random_architecture():
channels = []
for i in range(info['numbers']):
channels.append(
str(random.choice(info['candidates'])))
return ':'.join(channels)
return random_architecture
def mutate_topology_func(op_names):
"""Computes the architecture for a child of the given parent architecture.
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
"""
def mutate_topology_func(parent_arch):
child_arch = deepcopy( parent_arch )
node_id = random.randint(0, len(child_arch.nodes)-1)
node_info = list( child_arch.nodes[node_id] )
snode_id = random.randint(0, len(node_info)-1)
xop = random.choice( op_names )
while xop == node_info[snode_id][0]:
xop = random.choice( op_names )
node_info[snode_id] = (xop, node_info[snode_id][1])
child_arch.nodes[node_id] = tuple( node_info )
return child_arch
return mutate_topology_func
def mutate_size_func(info):
"""Computes the architecture for a child of the given parent architecture.
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
"""
def mutate_size_func(parent_arch):
child_arch = deepcopy(parent_arch)
child_arch = child_arch.split(':')
index = random.randint(0, len(child_arch)-1)
child_arch[index] = str(random.choice(info['candidates']))
return ':'.join(child_arch)
return mutate_size_func
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, dataset):
"""Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
Classifier Architecture Search".
Args:
cycles: the number of cycles the algorithm should run for.
population_size: the number of individuals to keep in the population.
sample_size: the number of individuals that should participate in each tournament.
time_budget: the upper bound of searching cost
Returns:
history: a list of `Model` instances, representing all the models computed
during the evolution experiment.
"""
population = collections.deque()
api.reset_time()
history, total_time_cost = [], [] # Not used by the algorithm, only used to report results.
current_best_index = []
# Initialize the population with random models.
while len(population) < population_size:
model = Model()
model.arch = random_arch()
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12')
# Append the info
population.append(model)
history.append((model.accuracy, model.arch))
total_time_cost.append(total_cost)
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
# Carry out evolution in cycles. Each cycle produces a model and removes another.
while total_time_cost[-1] < time_budget:
# Sample randomly chosen models from the current population.
start_time, sample = time.time(), []
while len(sample) < sample_size:
# Inefficient, but written this way for clarity. In the case of neural
# nets, the efficiency of this line is irrelevant because training neural
# nets is the rate-determining step.
candidate = random.choice(list(population))
sample.append(candidate)
# The parent is the best model in the sample.
parent = max(sample, key=lambda i: i.accuracy)
# Create the child model and store it.
child = Model()
child.arch = mutate_arch(parent.arch)
child.accuracy, _, _, total_cost = api.simulate_train_eval(child.arch, dataset, hp='12')
# Append the info
population.append(child)
history.append((child.accuracy, child.arch))
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
total_time_cost.append(total_cost)
# Remove the oldest model.
population.popleft()
return history, current_best_index, total_time_cost
def main(xargs, api):
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
if xargs.search_space == 'tss':
random_arch = random_topology_func(search_space)
mutate_arch = mutate_topology_func(search_space)
else:
random_arch = random_size_func(search_space)
mutate_arch = mutate_size_func(search_space)
x_start_time = time.time()
logger.log('{:} use api : {:}'.format(time_string(), api))
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, api, xargs.dataset)
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time))
best_arch = max(history, key=lambda x: x[0])[1]
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
return logger.log_dir, current_best_index, total_times
if __name__ == '__main__':
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
# channels and number-of-cells
parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.')
parser.add_argument('--ea_population', type=int, help='The population size in EA.')
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
api = create(None, args.search_space, fast_mode=True, verbose=False)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'R-EA-SS{:}'.format(args.ea_sample_size))
print('save-dir : {:}'.format(args.save_dir))
print('xargs : {:}'.format(args))
if args.rand_seed < 0:
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {'all_archs': all_archs,
'all_total_times': all_total_times}
save_path = save_dir / 'results.pth'
print('save into {:}'.format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)

View File

@@ -0,0 +1,212 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
#####################################################################################################
# modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py #
#####################################################################################################
# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.01
#####################################################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
from copy import deepcopy
from pathlib import Path
import torch
import torch.nn as nn
from torch.distributions import Categorical
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import CellStructure, get_search_spaces
from nats_bench import create
class PolicyTopology(nn.Module):
def __init__(self, search_space, max_nodes=4):
super(PolicyTopology, self).__init__()
self.max_nodes = max_nodes
self.search_space = deepcopy(search_space)
self.edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
self.edge2index[ node_str ] = len(self.edge2index)
self.arch_parameters = nn.Parameter(1e-3*torch.randn(len(self.edge2index), len(search_space)))
def generate_arch(self, actions):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = self.search_space[ actions[ self.edge2index[ node_str ] ] ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ]
op_name = self.search_space[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
def forward(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
return alphas
class PolicySize(nn.Module):
def __init__(self, search_space):
super(PolicySize, self).__init__()
self.candidates = search_space['candidates']
self.numbers = search_space['numbers']
self.arch_parameters = nn.Parameter(1e-3*torch.randn(self.numbers, len(self.candidates)))
def generate_arch(self, actions):
channels = [str(self.candidates[i]) for i in actions]
return ':'.join(channels)
def genotype(self):
channels = []
for i in range(self.numbers):
index = self.arch_parameters[i].argmax().item()
channels.append(str(self.candidates[index]))
return ':'.join(channels)
def forward(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
return alphas
class ExponentialMovingAverage(object):
"""Class that maintains an exponential moving average."""
def __init__(self, momentum):
self._numerator = 0
self._denominator = 0
self._momentum = momentum
def update(self, value):
self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
def value(self):
"""Return the current value of the moving average"""
return self._numerator / self._denominator
def select_action(policy):
probs = policy()
m = Categorical(probs)
action = m.sample()
# policy.saved_log_probs.append(m.log_prob(action))
return m.log_prob(action), action.cpu().tolist()
def main(xargs, api):
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
if xargs.search_space == 'tss':
policy = PolicyTopology(search_space)
else:
policy = PolicySize(search_space)
optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
#optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate)
eps = np.finfo(np.float32).eps.item()
baseline = ExponentialMovingAverage(xargs.EMA_momentum)
logger.log('policy : {:}'.format(policy))
logger.log('optimizer : {:}'.format(optimizer))
logger.log('eps : {:}'.format(eps))
# nas dataset load
logger.log('{:} use api : {:}'.format(time_string(), api))
api.reset_time()
# REINFORCE
x_start_time = time.time()
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
total_steps, total_costs, trace = 0, [], []
current_best_index = []
while len(total_costs) == 0 or total_costs[-1] < xargs.time_budget:
start_time = time.time()
log_prob, action = select_action( policy )
arch = policy.generate_arch( action )
reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12')
trace.append((reward, arch))
total_costs.append(current_total_cost)
baseline.update(reward)
# calculate loss
policy_loss = ( -log_prob * (reward - baseline.value()) ).sum()
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
# accumulate time
total_steps += 1
logger.log('step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(total_steps, baseline.value(), policy_loss.item(), policy.genotype()))
# to analyze
current_best_index.append(api.query_index_by_arch(max(trace, key=lambda x: x[0])[1]))
# best_arch = policy.genotype() # first version
best_arch = max(trace, key=lambda x: x[0])[1]
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs[-1], time.time()-x_start_time))
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
return logger.log_dir, current_best_index, total_costs
if __name__ == '__main__':
parser = argparse.ArgumentParser("The REINFORCE Algorithm")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.')
parser.add_argument('--EMA_momentum', type=float, default=0.9, help='The momentum value for EMA.')
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
api = create(None, args.search_space, verbose=False)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'REINFORCE-{:}'.format(args.learning_rate))
print('save-dir : {:}'.format(args.save_dir))
if args.rand_seed < 0:
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {'all_archs': all_archs,
'all_total_times': all_total_times}
save_path = save_dir / 'results.pth'
print('save into {:}'.format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)

View File

@@ -0,0 +1,47 @@
#!/bin/bash
# bash ./exps/NATS-algos/run-all.sh mul
# bash ./exps/NATS-algos/run-all.sh ws
set -e
echo script name: $0
echo $# arguments
if [ "$#" -ne 1 ] ;then
echo "Input illegal number of parameters " $#
echo "Need 1 parameters for type of algorithms."
exit 1
fi
datasets="cifar10 cifar100 ImageNet16-120"
alg_type=$1
if [ "$alg_type" == "mul" ]; then
search_spaces="tss sss"
for dataset in ${datasets}
do
for search_space in ${search_spaces}
do
python ./exps/NATS-algos/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.01
python ./exps/NATS-algos/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3
python ./exps/NATS-algos/random_wo_share.py --dataset ${dataset} --search_space ${search_space}
python ./exps/NATS-algos/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
done
done
python exps/experimental/vis-bench-algos.py --search_space tss
python exps/experimental/vis-bench-algos.py --search_space sss
else
seeds="777 888 999"
algos="darts-v1 darts-v2 gdas setn random enas"
epoch=200
for seed in ${seeds}
do
for alg in ${algos}
do
python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
done
done
fi

View File

@@ -0,0 +1,525 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
######################################################################################
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 777
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --drop_path_rate 0.3
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1
####
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v2 --rand_seed 777
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v2
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v2
####
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas --rand_seed 777
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo gdas
####
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo setn --rand_seed 777
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn
####
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo random --rand_seed 777
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random
####
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
######################################################################################
import os, sys, time, random, argparse
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import count_parameters_in_MB, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nats_bench import create
# The following three functions are used for DARTS-V2
def _concat(xs):
return torch.cat([x.view(-1) for x in xs])
def _hessian_vector_product(vector, network, criterion, base_inputs, base_targets, r=1e-2):
R = r / _concat(vector).norm()
for p, v in zip(network.weights, vector):
p.data.add_(R, v)
_, logits = network(base_inputs)
loss = criterion(logits, base_targets)
grads_p = torch.autograd.grad(loss, network.alphas)
for p, v in zip(network.weights, vector):
p.data.sub_(2*R, v)
_, logits = network(base_inputs)
loss = criterion(logits, base_targets)
grads_n = torch.autograd.grad(loss, network.alphas)
for p, v in zip(network.weights, vector):
p.data.add_(R, v)
return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]
def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets):
# _compute_unrolled_model
_, logits = network(base_inputs)
loss = criterion(logits, base_targets)
LR, WD, momentum = w_optimizer.param_groups[0]['lr'], w_optimizer.param_groups[0]['weight_decay'], w_optimizer.param_groups[0]['momentum']
with torch.no_grad():
theta = _concat(network.weights)
try:
moment = _concat(w_optimizer.state[v]['momentum_buffer'] for v in network.weights)
moment = moment.mul_(momentum)
except:
moment = torch.zeros_like(theta)
dtheta = _concat(torch.autograd.grad(loss, network.weights)) + WD*theta
params = theta.sub(LR, moment+dtheta)
unrolled_model = deepcopy(network)
model_dict = unrolled_model.state_dict()
new_params, offset = {}, 0
for k, v in network.named_parameters():
if 'arch_parameters' in k: continue
v_length = np.prod(v.size())
new_params[k] = params[offset: offset+v_length].view(v.size())
offset += v_length
model_dict.update(new_params)
unrolled_model.load_state_dict(model_dict)
unrolled_model.zero_grad()
_, unrolled_logits = unrolled_model(arch_inputs)
unrolled_loss = criterion(unrolled_logits, arch_targets)
unrolled_loss.backward()
dalpha = unrolled_model.arch_parameters.grad
vector = [v.grad.data for v in unrolled_model.weights]
[implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets)
dalpha.data.sub_(LR, implicit_grads.data)
if network.arch_parameters.grad is None:
network.arch_parameters.grad = deepcopy( dalpha )
else:
network.arch_parameters.grad.data.copy_( dalpha.data )
return unrolled_loss.detach(), unrolled_logits.detach()
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, algo, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
network.train()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_inputs = base_inputs.cuda(non_blocking=True)
arch_inputs = arch_inputs.cuda(non_blocking=True)
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# Update the weights
if algo == 'setn':
sampled_arch = network.dync_genotype(True)
network.set_cal_mode('dynamic', sampled_arch)
elif algo == 'gdas':
network.set_cal_mode('gdas', None)
elif algo.startswith('darts'):
network.set_cal_mode('joint', None)
elif algo == 'random':
network.set_cal_mode('urs', None)
elif algo == 'enas':
with torch.no_grad():
network.controller.eval()
_, _, sampled_arch = network.controller()
network.set_cal_mode('dynamic', sampled_arch)
else:
raise ValueError('Invalid algo name : {:}'.format(algo))
network.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update (base_prec1.item(), base_inputs.size(0))
base_top5.update (base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
if algo == 'setn':
network.set_cal_mode('joint')
elif algo == 'gdas':
network.set_cal_mode('gdas', None)
elif algo.startswith('darts'):
network.set_cal_mode('joint', None)
elif algo == 'random':
network.set_cal_mode('urs', None)
elif algo != 'enas':
raise ValueError('Invalid algo name : {:}'.format(algo))
network.zero_grad()
if algo == 'darts-v2':
arch_loss, logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets)
a_optimizer.step()
elif algo == 'random' or algo == 'enas':
with torch.no_grad():
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
else:
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
arch_loss.backward()
a_optimizer.step()
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
def train_controller(xloader, network, criterion, optimizer, prev_baseline, epoch_str, print_freq, logger):
# config. (containing some necessary arg)
# baseline: The baseline score (i.e. average val_acc) from the previous epoch
data_time, batch_time = AverageMeter(), AverageMeter()
GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time()
controller_num_aggregate = 20
controller_train_steps = 50
controller_bl_dec = 0.99
controller_entropy_weight = 0.0001
network.eval()
network.controller.train()
network.controller.zero_grad()
loader_iter = iter(xloader)
for step in range(controller_train_steps * controller_num_aggregate):
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
inputs = inputs.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - xend)
log_prob, entropy, sampled_arch = network.controller()
with torch.no_grad():
network.set_cal_mode('dynamic', sampled_arch)
_, logits = network(inputs)
val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
val_top1 = val_top1.view(-1) / 100
reward = val_top1 + controller_entropy_weight * entropy
if prev_baseline is None:
baseline = val_top1
else:
baseline = prev_baseline - (1 - controller_bl_dec) * (prev_baseline - reward)
loss = -1 * log_prob * (reward - baseline)
# account
RewardMeter.update(reward.item())
BaselineMeter.update(baseline.item())
ValAccMeter.update(val_top1.item()*100)
LossMeter.update(loss.item())
EntropyMeter.update(entropy.item())
# Average gradient over controller_num_aggregate samples
loss = loss / controller_num_aggregate
loss.backward(retain_graph=True)
# measure elapsed time
batch_time.update(time.time() - xend)
xend = time.time()
if (step+1) % controller_num_aggregate == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(network.controller.parameters(), 5.0)
GradnormMeter.update(grad_norm)
optimizer.step()
network.controller.zero_grad()
if step % print_freq == 0:
Sstr = '*Train-Controller* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, controller_train_steps * controller_num_aggregate)
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter)
Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val, EntropyMeter.avg)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr)
return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg
def get_best_arch(xloader, network, n_samples, algo):
with torch.no_grad():
network.eval()
if algo == 'random':
archs, valid_accs = network.return_topK(n_samples, True), []
elif algo == 'setn':
archs, valid_accs = network.return_topK(n_samples, False), []
elif algo.startswith('darts') or algo == 'gdas':
arch = network.genotype
archs, valid_accs = [arch], []
elif algo == 'enas':
archs, valid_accs = [], []
for _ in range(n_samples):
_, _, sampled_arch = network.controller()
archs.append(sampled_arch)
else:
raise ValueError('Invalid algorithm name : {:}'.format(algo))
loader_iter = iter(xloader)
for i, sampled_arch in enumerate(archs):
network.set_cal_mode('dynamic', sampled_arch)
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
_, logits = network(inputs.cuda(non_blocking=True))
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
valid_accs.append(val_top1.item())
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
return best_arch, best_valid_acc
def valid_func(xloader, network, criterion, algo, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
with torch.no_grad():
network.eval()
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs.cuda(non_blocking=True))
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
if xargs.overwite_epochs is None:
extra_info = {'class_num': class_num, 'xshape': xshape}
else:
extra_info = {'class_num': class_num, 'xshape': xshape, 'epochs': xargs.overwite_epochs}
config = load_config(xargs.config_path, extra_info, logger)
search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', (config.batch_size, config.test_batch_size), xargs.workers)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
model_config = dict2config(
dict(name='generic', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num,
space=search_space, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None)
logger.log('search space : {:}'.format(search_space))
logger.log('model config : {:}'.format(model_config))
search_model = get_cell_based_tiny_net(model_config)
search_model.set_algo(xargs.algo)
logger.log('{:}'.format(search_model))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config)
a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, eps=xargs.arch_eps)
logger.log('w-optimizer : {:}'.format(w_optimizer))
logger.log('a-optimizer : {:}'.format(a_optimizer))
logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion))
params = count_parameters_in_MB(search_model)
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
logger.log('search-space : {:}'.format(search_space))
if bool(xargs.use_api):
api = create(None, 'topology', fast_mode=True, verbose=False)
else:
api = None
logger.log('{:} create API = {:} done'.format(time_string(), api))
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info['epoch']
checkpoint = torch.load(last_info['last_checkpoint'])
genotypes = checkpoint['genotypes']
baseline = checkpoint['baseline']
valid_accuracies = checkpoint['valid_accuracies']
search_model.load_state_dict( checkpoint['search_model'] )
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: network.return_topK(1, True)[0]}
baseline = None
# start training
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.val * (total_epoch-epoch), True))
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
network.set_drop_path(float(epoch+1) / total_epoch, xargs.drop_path_rate)
if xargs.algo == 'gdas':
network.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) )
logger.log('[RESET tau as : {:} and drop_path as {:}]'.format(network.tau, network.drop_path))
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger)
search_time.update(time.time() - start_time)
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
if xargs.algo == 'enas':
ctl_loss, ctl_acc, baseline, ctl_reward \
= train_controller(valid_loader, network, criterion, a_optimizer, baseline, epoch_str, xargs.print_freq, logger)
logger.log('[{:}] controller : loss={:}, acc={:}, baseline={:}, reward={:}'.format(epoch_str, ctl_loss, ctl_acc, baseline, ctl_reward))
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo)
if xargs.algo == 'setn' or xargs.algo == 'enas':
network.set_cal_mode('dynamic', genotype)
elif xargs.algo == 'gdas':
network.set_cal_mode('gdas', None)
elif xargs.algo.startswith('darts'):
network.set_cal_mode('joint', None)
elif xargs.algo == 'random':
network.set_cal_mode('urs', None)
else:
raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo))
logger.log('[{:}] - [get_best_arch] : {:} -> {:}'.format(epoch_str, genotype, temp_accuracy))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion, xargs.algo, logger)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype))
valid_accuracies[epoch] = valid_a_top1
genotypes[epoch] = genotype
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint({'epoch' : epoch + 1,
'args' : deepcopy(xargs),
'baseline' : baseline,
'search_model': search_model.state_dict(),
'w_optimizer' : w_optimizer.state_dict(),
'a_optimizer' : a_optimizer.state_dict(),
'w_scheduler' : w_scheduler.state_dict(),
'genotypes' : genotypes,
'valid_accuracies' : valid_accuracies},
model_base_path, logger)
last_info = save_checkpoint({
'epoch': epoch + 1,
'args' : deepcopy(args),
'last_checkpoint': save_path,
}, logger.path('info'), logger)
with torch.no_grad():
logger.log('{:}'.format(search_model.show_alphas()))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
# the final post procedure : count the time
start_time = time.time()
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo)
if xargs.algo == 'setn' or xargs.algo == 'enas':
network.set_cal_mode('dynamic', genotype)
elif xargs.algo == 'gdas':
network.set_cal_mode('gdas', None)
elif xargs.algo.startswith('darts'):
network.set_cal_mode('joint', None)
elif xargs.algo == 'random':
network.set_cal_mode('urs', None)
else:
raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo))
search_time.update(time.time() - start_time)
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion, xargs.algo, logger)
logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1))
logger.log('\n' + '-'*100)
# check the performance from the architecture dataset
logger.log('[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(xargs.algo, total_epoch, search_time.sum, genotype))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '200') ))
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.")
parser.add_argument('--data_path' , type=str, help='Path to dataset')
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, default='tss', choices=['tss'], help='The search space name.')
parser.add_argument('--algo' , type=str, choices=['darts-v1', 'darts-v2', 'gdas', 'setn', 'random', 'enas'], help='The search space name.')
parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
# FOR GDAS
parser.add_argument('--tau_min', type=float, default=0.1, help='The minimum tau for Gumbel Softmax.')
parser.add_argument('--tau_max', type=float, default=10, help='The maximum tau for Gumbel Softmax.')
# channels and number-of-cells
parser.add_argument('--max_nodes' , type=int, default=4, help='The maximum number of nodes.')
parser.add_argument('--channel' , type=int, default=16, help='The number of channels.')
parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.')
#
parser.add_argument('--eval_candidate_num', type=int, default=100, help='The number of selected architectures to evaluate.')
#
parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.')
parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.')
parser.add_argument('--overwite_epochs', type=int, help='The number of epochs to overwrite that value in config files.')
# architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding')
parser.add_argument('--arch_eps' , type=float, default=1e-8, help='weight decay for arch encoding')
parser.add_argument('--drop_path_rate' , type=float, help='The drop path rate.')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--print_freq', type=int, default=200, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
if args.overwite_epochs is None:
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
args.dataset,
'{:}-affine{:}_BN{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate))
else:
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
args.dataset,
'{:}-affine{:}_BN{:}-E{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.overwite_epochs, args.drop_path_rate))
main(args)

View File

@@ -0,0 +1,299 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
######################################################################################
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777
####
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777
######################################################################################
import os, sys, time, random, argparse
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import count_parameters_in_MB, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nats_bench import create
# Ad-hoc for TuNAS
class ExponentialMovingAverage(object):
"""Class that maintains an exponential moving average."""
def __init__(self, momentum):
self._numerator = 0
self._denominator = 0
self._momentum = momentum
def update(self, value):
self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
@property
def value(self):
"""Return the current value of the moving average"""
return self._numerator / self._denominator
RL_BASELINE_EMA = ExponentialMovingAverage(0.95)
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, algo, epoch_str, print_freq, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
network.train()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_inputs = base_inputs.cuda(non_blocking=True)
arch_inputs = arch_inputs.cuda(non_blocking=True)
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# Update the weights
network.zero_grad()
_, logits, _ = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update (base_prec1.item(), base_inputs.size(0))
base_top5.update (base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
network.zero_grad()
_, logits, log_probs = network(arch_inputs)
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
if algo == 'tunas':
with torch.no_grad():
RL_BASELINE_EMA.update(arch_prec1.item())
rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
rl_log_prob = sum(log_probs)
arch_loss = - rl_advantage * rl_log_prob
elif algo == 'tas' or algo == 'fbv2':
arch_loss = criterion(logits, arch_targets)
else:
raise ValueError('invalid algorightm name: {:}'.format(algo))
arch_loss.backward()
a_optimizer.step()
# record
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
def valid_func(xloader, network, criterion, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
with torch.no_grad():
network.eval()
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits, _ = network(arch_inputs.cuda(non_blocking=True))
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
if xargs.overwite_epochs is None:
extra_info = {'class_num': class_num, 'xshape': xshape}
else:
extra_info = {'class_num': class_num, 'xshape': xshape, 'epochs': xargs.overwite_epochs}
config = load_config(xargs.config_path, extra_info, logger)
search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', (config.batch_size, config.test_batch_size), xargs.workers)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
model_config = dict2config(
dict(name='generic', super_type='search-shape', candidate_Cs=search_space['candidates'], max_num_Cs=search_space['numbers'], num_classes=class_num,
genotype=args.genotype, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None)
logger.log('search space : {:}'.format(search_space))
logger.log('model config : {:}'.format(model_config))
search_model = get_cell_based_tiny_net(model_config)
search_model.set_algo(xargs.algo)
logger.log('{:}'.format(search_model))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config)
a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, eps=xargs.arch_eps)
logger.log('w-optimizer : {:}'.format(w_optimizer))
logger.log('a-optimizer : {:}'.format(a_optimizer))
logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion))
params = count_parameters_in_MB(search_model)
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
logger.log('search-space : {:}'.format(search_space))
if bool(xargs.use_api):
api = create(None, 'size', fast_mode=True, verbose=False)
else:
api = None
logger.log('{:} create API = {:} done'.format(time_string(), api))
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info['epoch']
checkpoint = torch.load(last_info['last_checkpoint'])
genotypes = checkpoint['genotypes']
valid_accuracies = checkpoint['valid_accuracies']
search_model.load_state_dict( checkpoint['search_model'] )
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: network.random}
# start training
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.val * (total_epoch-epoch), True))
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
if xargs.algo == 'fbv2' or xargs.algo == 'tas':
network.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) )
logger.log('[RESET tau as : {:}]'.format(network.tau))
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, xargs.algo, epoch_str, xargs.print_freq, logger)
search_time.update(time.time() - start_time)
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
genotype = network.genotype
logger.log('[{:}] - [get_best_arch] : {:}'.format(epoch_str, genotype))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion, logger)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype))
valid_accuracies[epoch] = valid_a_top1
genotypes[epoch] = genotype
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint({'epoch' : epoch + 1,
'args' : deepcopy(xargs),
'search_model': search_model.state_dict(),
'w_optimizer' : w_optimizer.state_dict(),
'a_optimizer' : a_optimizer.state_dict(),
'w_scheduler' : w_scheduler.state_dict(),
'genotypes' : genotypes,
'valid_accuracies' : valid_accuracies},
model_base_path, logger)
last_info = save_checkpoint({
'epoch': epoch + 1,
'args' : deepcopy(args),
'last_checkpoint': save_path,
}, logger.path('info'), logger)
with torch.no_grad():
logger.log('{:}'.format(search_model.show_alphas()))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '90')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
# the final post procedure : count the time
start_time = time.time()
genotype = network.genotype
search_time.update(time.time() - start_time)
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, logger)
logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1))
logger.log('\n' + '-'*100)
# check the performance from the architecture dataset
logger.log('[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(xargs.algo, total_epoch, search_time.sum, genotype))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '90') ))
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.")
parser.add_argument('--data_path' , type=str, help='Path to dataset')
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, default='sss', choices=['sss'], help='The search space name.')
parser.add_argument('--algo' , type=str, choices=['tas', 'fbv2', 'tunas'], help='The search space name.')
parser.add_argument('--genotype' , type=str, default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.')
parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
# FOR GDAS
parser.add_argument('--tau_min', type=float, default=0.1, help='The minimum tau for Gumbel Softmax.')
parser.add_argument('--tau_max', type=float, default=10, help='The maximum tau for Gumbel Softmax.')
#
parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.')
parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.')
parser.add_argument('--overwite_epochs', type=int, help='The number of epochs to overwrite that value in config files.')
# architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding')
parser.add_argument('--arch_eps' , type=float, default=1e-8, help='weight decay for arch encoding')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--print_freq', type=int, default=200, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
dirname = '{:}-affine{:}_BN{:}-AWD{:}'.format(args.algo, args.affine, args.track_running_stats, args.arch_weight_decay)
if args.overwite_epochs is not None:
dirname = dirname + '-E{:}'.format(args.overwite_epochs)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, dirname)
main(args)