Update NATS-Bench (sss version 1.2)
This commit is contained in:
187
exps/NATS-algos/bohb.py
Normal file
187
exps/NATS-algos/bohb.py
Normal file
@@ -0,0 +1,187 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
###################################################################
|
||||
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale #
|
||||
# required to install hpbandster ##################################
|
||||
# pip install hpbandster ##################################
|
||||
###################################################################
|
||||
# OMP_NUM_THREADS=4 python exps/NATS-algos/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
|
||||
# OMP_NUM_THREADS=4 python exps/NATS-algos/bohb.py --search_space sss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
|
||||
###################################################################
|
||||
import os, sys, time, random, argparse, collections
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import torch
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from nats_bench import create
|
||||
from models import CellStructure, get_search_spaces
|
||||
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
|
||||
import ConfigSpace
|
||||
from hpbandster.optimizers.bohb import BOHB
|
||||
import hpbandster.core.nameserver as hpns
|
||||
from hpbandster.core.worker import Worker
|
||||
|
||||
|
||||
def get_topology_config_space(search_space, max_nodes=4):
|
||||
cs = ConfigSpace.ConfigurationSpace()
|
||||
#edge2index = {}
|
||||
for i in range(1, max_nodes):
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space))
|
||||
return cs
|
||||
|
||||
|
||||
def get_size_config_space(search_space):
|
||||
cs = ConfigSpace.ConfigurationSpace()
|
||||
for ilayer in range(search_space['numbers']):
|
||||
node_str = 'layer-{:}'.format(ilayer)
|
||||
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space['candidates']))
|
||||
return cs
|
||||
|
||||
|
||||
def config2topology_func(max_nodes=4):
|
||||
def config2structure(config):
|
||||
genotypes = []
|
||||
for i in range(1, max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
op_name = config[node_str]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
return CellStructure( genotypes )
|
||||
return config2structure
|
||||
|
||||
|
||||
def config2size_func(search_space):
|
||||
def config2structure(config):
|
||||
channels = []
|
||||
for ilayer in range(search_space['numbers']):
|
||||
node_str = 'layer-{:}'.format(ilayer)
|
||||
channels.append(str(config[node_str]))
|
||||
return ':'.join(channels)
|
||||
return config2structure
|
||||
|
||||
|
||||
class MyWorker(Worker):
|
||||
|
||||
def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.convert_func = convert_func
|
||||
self._dataset = dataset
|
||||
self._api = api
|
||||
self.total_times = []
|
||||
self.trajectory = []
|
||||
|
||||
def compute(self, config, budget, **kwargs):
|
||||
arch = self.convert_func( config )
|
||||
accuracy, latency, time_cost, total_time = self._api.simulate_train_eval(arch, self._dataset, iepoch=int(budget)-1, hp='12')
|
||||
self.trajectory.append((accuracy, arch))
|
||||
self.total_times.append(total_time)
|
||||
return ({'loss': 100 - accuracy,
|
||||
'info': self._api.query_index_by_arch(arch)})
|
||||
|
||||
|
||||
def main(xargs, api):
|
||||
torch.set_num_threads(4)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
logger.log('{:} use api : {:}'.format(time_string(), api))
|
||||
api.reset_time()
|
||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||
if xargs.search_space == 'tss':
|
||||
cs = get_topology_config_space(search_space)
|
||||
config2structure = config2topology_func()
|
||||
else:
|
||||
cs = get_size_config_space(search_space)
|
||||
config2structure = config2size_func(search_space)
|
||||
|
||||
hb_run_id = '0'
|
||||
|
||||
NS = hpns.NameServer(run_id=hb_run_id, host='localhost', port=0)
|
||||
ns_host, ns_port = NS.start()
|
||||
num_workers = 1
|
||||
|
||||
workers = []
|
||||
for i in range(num_workers):
|
||||
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataset=xargs.dataset, api=api, run_id=hb_run_id, id=i)
|
||||
w.run(background=True)
|
||||
workers.append(w)
|
||||
|
||||
start_time = time.time()
|
||||
bohb = BOHB(configspace=cs, run_id=hb_run_id,
|
||||
eta=3, min_budget=1, max_budget=12,
|
||||
nameserver=ns_host,
|
||||
nameserver_port=ns_port,
|
||||
num_samples=xargs.num_samples,
|
||||
random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor,
|
||||
ping_interval=10, min_bandwidth=xargs.min_bandwidth)
|
||||
|
||||
results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
|
||||
|
||||
bohb.shutdown(shutdown_workers=True)
|
||||
NS.shutdown()
|
||||
|
||||
# print('There are {:} runs.'.format(len(results.get_all_runs())))
|
||||
# workers[0].total_times
|
||||
# workers[0].trajectory
|
||||
current_best_index = []
|
||||
for idx in range(len(workers[0].trajectory)):
|
||||
trajectory = workers[0].trajectory[:idx+1]
|
||||
arch = max(trajectory, key=lambda x: x[0])[1]
|
||||
current_best_index.append(api.query_index_by_arch(arch))
|
||||
|
||||
best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1]
|
||||
logger.log('Best found configuration: {:} within {:.3f} s'.format(best_arch, workers[0].total_times[-1]))
|
||||
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
|
||||
logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
logger.close()
|
||||
|
||||
return logger.log_dir, current_best_index, workers[0].total_times
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale")
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
# general arg
|
||||
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
|
||||
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
|
||||
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
|
||||
# BOHB
|
||||
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
|
||||
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
|
||||
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
|
||||
parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations')
|
||||
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
|
||||
parser.add_argument('--n_iters', default=300, type=int, nargs='?', help='number of iterations for optimization method')
|
||||
# log
|
||||
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
api = create(None, args.search_space, fast_mode=True, verbose=False)
|
||||
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'BOHB')
|
||||
print('save-dir : {:}'.format(args.save_dir))
|
||||
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_info = None, collections.OrderedDict()
|
||||
for i in range(args.loops_if_rand):
|
||||
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
save_dir, all_archs, all_total_times = main(args, api)
|
||||
all_info[i] = {'all_archs': all_archs,
|
||||
'all_total_times': all_total_times}
|
||||
save_path = save_dir / 'results.pth'
|
||||
print('save into {:}'.format(save_path))
|
||||
torch.save(all_info, save_path)
|
||||
else:
|
||||
main(args, api)
|
91
exps/NATS-algos/random_wo_share.py
Normal file
91
exps/NATS-algos/random_wo_share.py
Normal file
@@ -0,0 +1,91 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
##############################################################################
|
||||
# Random Search for Hyper-Parameter Optimization, JMLR 2012 ##################
|
||||
##############################################################################
|
||||
# python ./exps/NATS-algos/random_wo_share.py --dataset cifar10 --search_space tss
|
||||
# python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space tss
|
||||
# python ./exps/NATS-algos/random_wo_share.py --dataset ImageNet16-120 --search_space tss
|
||||
##############################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np, collections
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_search_spaces
|
||||
from nats_bench import create
|
||||
from regularized_ea import random_topology_func, random_size_func
|
||||
|
||||
|
||||
def main(xargs, api):
|
||||
torch.set_num_threads(4)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
logger.log('{:} use api : {:}'.format(time_string(), api))
|
||||
api.reset_time()
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||
if xargs.search_space == 'tss':
|
||||
random_arch = random_topology_func(search_space)
|
||||
else:
|
||||
random_arch = random_size_func(search_space)
|
||||
|
||||
best_arch, best_acc, total_time_cost, history = None, -1, [], []
|
||||
current_best_index = []
|
||||
while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget:
|
||||
arch = random_arch()
|
||||
accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12')
|
||||
total_time_cost.append(total_cost)
|
||||
history.append(arch)
|
||||
if best_arch is None or best_acc < accuracy:
|
||||
best_acc, best_arch = accuracy, arch
|
||||
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
|
||||
current_best_index.append(api.query_index_by_arch(best_arch))
|
||||
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.'.format(time_string(), best_arch, best_acc, len(history), total_time_cost[-1]))
|
||||
|
||||
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
|
||||
logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
logger.close()
|
||||
return logger.log_dir, current_best_index, total_time_cost
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("Random NAS")
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
|
||||
|
||||
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
|
||||
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
|
||||
# log
|
||||
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
api = create(None, args.search_space, fast_mode=True, verbose=False)
|
||||
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'RANDOM')
|
||||
print('save-dir : {:}'.format(args.save_dir))
|
||||
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_info = None, collections.OrderedDict()
|
||||
for i in range(args.loops_if_rand):
|
||||
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
save_dir, all_archs, all_total_times = main(args, api)
|
||||
all_info[i] = {'all_archs': all_archs,
|
||||
'all_total_times': all_total_times}
|
||||
save_path = save_dir / 'results.pth'
|
||||
print('save into {:}'.format(save_path))
|
||||
torch.save(all_info, save_path)
|
||||
else:
|
||||
main(args, api)
|
219
exps/NATS-algos/regularized_ea.py
Normal file
219
exps/NATS-algos/regularized_ea.py
Normal file
@@ -0,0 +1,219 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
##################################################################
|
||||
# Regularized Evolution for Image Classifier Architecture Search #
|
||||
##################################################################
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
##################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np, collections
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import CellStructure, get_search_spaces
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
class Model(object):
|
||||
|
||||
def __init__(self):
|
||||
self.arch = None
|
||||
self.accuracy = None
|
||||
|
||||
def __str__(self):
|
||||
"""Prints a readable version of this bitstring."""
|
||||
return '{:}'.format(self.arch)
|
||||
|
||||
|
||||
def random_topology_func(op_names, max_nodes=4):
|
||||
# Return a random architecture
|
||||
def random_architecture():
|
||||
genotypes = []
|
||||
for i in range(1, max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
op_name = random.choice( op_names )
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
return CellStructure( genotypes )
|
||||
return random_architecture
|
||||
|
||||
|
||||
def random_size_func(info):
|
||||
# Return a random architecture
|
||||
def random_architecture():
|
||||
channels = []
|
||||
for i in range(info['numbers']):
|
||||
channels.append(
|
||||
str(random.choice(info['candidates'])))
|
||||
return ':'.join(channels)
|
||||
return random_architecture
|
||||
|
||||
|
||||
def mutate_topology_func(op_names):
|
||||
"""Computes the architecture for a child of the given parent architecture.
|
||||
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
|
||||
"""
|
||||
def mutate_topology_func(parent_arch):
|
||||
child_arch = deepcopy( parent_arch )
|
||||
node_id = random.randint(0, len(child_arch.nodes)-1)
|
||||
node_info = list( child_arch.nodes[node_id] )
|
||||
snode_id = random.randint(0, len(node_info)-1)
|
||||
xop = random.choice( op_names )
|
||||
while xop == node_info[snode_id][0]:
|
||||
xop = random.choice( op_names )
|
||||
node_info[snode_id] = (xop, node_info[snode_id][1])
|
||||
child_arch.nodes[node_id] = tuple( node_info )
|
||||
return child_arch
|
||||
return mutate_topology_func
|
||||
|
||||
|
||||
def mutate_size_func(info):
|
||||
"""Computes the architecture for a child of the given parent architecture.
|
||||
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
|
||||
"""
|
||||
def mutate_size_func(parent_arch):
|
||||
child_arch = deepcopy(parent_arch)
|
||||
child_arch = child_arch.split(':')
|
||||
index = random.randint(0, len(child_arch)-1)
|
||||
child_arch[index] = str(random.choice(info['candidates']))
|
||||
return ':'.join(child_arch)
|
||||
return mutate_size_func
|
||||
|
||||
|
||||
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, dataset):
|
||||
"""Algorithm for regularized evolution (i.e. aging evolution).
|
||||
|
||||
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
|
||||
Classifier Architecture Search".
|
||||
|
||||
Args:
|
||||
cycles: the number of cycles the algorithm should run for.
|
||||
population_size: the number of individuals to keep in the population.
|
||||
sample_size: the number of individuals that should participate in each tournament.
|
||||
time_budget: the upper bound of searching cost
|
||||
|
||||
Returns:
|
||||
history: a list of `Model` instances, representing all the models computed
|
||||
during the evolution experiment.
|
||||
"""
|
||||
population = collections.deque()
|
||||
api.reset_time()
|
||||
history, total_time_cost = [], [] # Not used by the algorithm, only used to report results.
|
||||
current_best_index = []
|
||||
# Initialize the population with random models.
|
||||
while len(population) < population_size:
|
||||
model = Model()
|
||||
model.arch = random_arch()
|
||||
model.accuracy, _, _, total_cost = api.simulate_train_eval(model.arch, dataset, hp='12')
|
||||
# Append the info
|
||||
population.append(model)
|
||||
history.append((model.accuracy, model.arch))
|
||||
total_time_cost.append(total_cost)
|
||||
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
|
||||
|
||||
# Carry out evolution in cycles. Each cycle produces a model and removes another.
|
||||
while total_time_cost[-1] < time_budget:
|
||||
# Sample randomly chosen models from the current population.
|
||||
start_time, sample = time.time(), []
|
||||
while len(sample) < sample_size:
|
||||
# Inefficient, but written this way for clarity. In the case of neural
|
||||
# nets, the efficiency of this line is irrelevant because training neural
|
||||
# nets is the rate-determining step.
|
||||
candidate = random.choice(list(population))
|
||||
sample.append(candidate)
|
||||
|
||||
# The parent is the best model in the sample.
|
||||
parent = max(sample, key=lambda i: i.accuracy)
|
||||
|
||||
# Create the child model and store it.
|
||||
child = Model()
|
||||
child.arch = mutate_arch(parent.arch)
|
||||
child.accuracy, _, _, total_cost = api.simulate_train_eval(child.arch, dataset, hp='12')
|
||||
# Append the info
|
||||
population.append(child)
|
||||
history.append((child.accuracy, child.arch))
|
||||
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
|
||||
total_time_cost.append(total_cost)
|
||||
|
||||
# Remove the oldest model.
|
||||
population.popleft()
|
||||
return history, current_best_index, total_time_cost
|
||||
|
||||
|
||||
def main(xargs, api):
|
||||
torch.set_num_threads(4)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||
if xargs.search_space == 'tss':
|
||||
random_arch = random_topology_func(search_space)
|
||||
mutate_arch = mutate_topology_func(search_space)
|
||||
else:
|
||||
random_arch = random_size_func(search_space)
|
||||
mutate_arch = mutate_size_func(search_space)
|
||||
|
||||
x_start_time = time.time()
|
||||
logger.log('{:} use api : {:}'.format(time_string(), api))
|
||||
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
|
||||
history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, api, xargs.dataset)
|
||||
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time))
|
||||
best_arch = max(history, key=lambda x: x[0])[1]
|
||||
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
|
||||
|
||||
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
|
||||
logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
logger.close()
|
||||
return logger.log_dir, current_best_index, total_times
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
|
||||
# channels and number-of-cells
|
||||
parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.')
|
||||
parser.add_argument('--ea_population', type=int, help='The population size in EA.')
|
||||
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
|
||||
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
|
||||
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
|
||||
# log
|
||||
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
api = create(None, args.search_space, fast_mode=True, verbose=False)
|
||||
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'R-EA-SS{:}'.format(args.ea_sample_size))
|
||||
print('save-dir : {:}'.format(args.save_dir))
|
||||
print('xargs : {:}'.format(args))
|
||||
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_info = None, collections.OrderedDict()
|
||||
for i in range(args.loops_if_rand):
|
||||
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
save_dir, all_archs, all_total_times = main(args, api)
|
||||
all_info[i] = {'all_archs': all_archs,
|
||||
'all_total_times': all_total_times}
|
||||
save_path = save_dir / 'results.pth'
|
||||
print('save into {:}'.format(save_path))
|
||||
torch.save(all_info, save_path)
|
||||
else:
|
||||
main(args, api)
|
212
exps/NATS-algos/reinforce.py
Normal file
212
exps/NATS-algos/reinforce.py
Normal file
@@ -0,0 +1,212 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
#####################################################################################################
|
||||
# modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py #
|
||||
#####################################################################################################
|
||||
# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.01
|
||||
# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.01
|
||||
# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.01
|
||||
# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.01
|
||||
# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01
|
||||
# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.01
|
||||
#####################################################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np, collections
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Categorical
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import CellStructure, get_search_spaces
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
class PolicyTopology(nn.Module):
|
||||
|
||||
def __init__(self, search_space, max_nodes=4):
|
||||
super(PolicyTopology, self).__init__()
|
||||
self.max_nodes = max_nodes
|
||||
self.search_space = deepcopy(search_space)
|
||||
self.edge2index = {}
|
||||
for i in range(1, max_nodes):
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
self.edge2index[ node_str ] = len(self.edge2index)
|
||||
self.arch_parameters = nn.Parameter(1e-3*torch.randn(len(self.edge2index), len(search_space)))
|
||||
|
||||
def generate_arch(self, actions):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
op_name = self.search_space[ actions[ self.edge2index[ node_str ] ] ]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
return CellStructure( genotypes )
|
||||
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[ self.edge2index[node_str] ]
|
||||
op_name = self.search_space[ weights.argmax().item() ]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
return CellStructure( genotypes )
|
||||
|
||||
def forward(self):
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
return alphas
|
||||
|
||||
|
||||
class PolicySize(nn.Module):
|
||||
|
||||
def __init__(self, search_space):
|
||||
super(PolicySize, self).__init__()
|
||||
self.candidates = search_space['candidates']
|
||||
self.numbers = search_space['numbers']
|
||||
self.arch_parameters = nn.Parameter(1e-3*torch.randn(self.numbers, len(self.candidates)))
|
||||
|
||||
def generate_arch(self, actions):
|
||||
channels = [str(self.candidates[i]) for i in actions]
|
||||
return ':'.join(channels)
|
||||
|
||||
def genotype(self):
|
||||
channels = []
|
||||
for i in range(self.numbers):
|
||||
index = self.arch_parameters[i].argmax().item()
|
||||
channels.append(str(self.candidates[index]))
|
||||
return ':'.join(channels)
|
||||
|
||||
def forward(self):
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
return alphas
|
||||
|
||||
|
||||
class ExponentialMovingAverage(object):
|
||||
"""Class that maintains an exponential moving average."""
|
||||
|
||||
def __init__(self, momentum):
|
||||
self._numerator = 0
|
||||
self._denominator = 0
|
||||
self._momentum = momentum
|
||||
|
||||
def update(self, value):
|
||||
self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
|
||||
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
|
||||
|
||||
def value(self):
|
||||
"""Return the current value of the moving average"""
|
||||
return self._numerator / self._denominator
|
||||
|
||||
|
||||
def select_action(policy):
|
||||
probs = policy()
|
||||
m = Categorical(probs)
|
||||
action = m.sample()
|
||||
# policy.saved_log_probs.append(m.log_prob(action))
|
||||
return m.log_prob(action), action.cpu().tolist()
|
||||
|
||||
|
||||
def main(xargs, api):
|
||||
torch.set_num_threads(4)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||
if xargs.search_space == 'tss':
|
||||
policy = PolicyTopology(search_space)
|
||||
else:
|
||||
policy = PolicySize(search_space)
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
|
||||
#optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate)
|
||||
eps = np.finfo(np.float32).eps.item()
|
||||
baseline = ExponentialMovingAverage(xargs.EMA_momentum)
|
||||
logger.log('policy : {:}'.format(policy))
|
||||
logger.log('optimizer : {:}'.format(optimizer))
|
||||
logger.log('eps : {:}'.format(eps))
|
||||
|
||||
# nas dataset load
|
||||
logger.log('{:} use api : {:}'.format(time_string(), api))
|
||||
api.reset_time()
|
||||
|
||||
# REINFORCE
|
||||
x_start_time = time.time()
|
||||
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
|
||||
total_steps, total_costs, trace = 0, [], []
|
||||
current_best_index = []
|
||||
while len(total_costs) == 0 or total_costs[-1] < xargs.time_budget:
|
||||
start_time = time.time()
|
||||
log_prob, action = select_action( policy )
|
||||
arch = policy.generate_arch( action )
|
||||
reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12')
|
||||
trace.append((reward, arch))
|
||||
total_costs.append(current_total_cost)
|
||||
|
||||
baseline.update(reward)
|
||||
# calculate loss
|
||||
policy_loss = ( -log_prob * (reward - baseline.value()) ).sum()
|
||||
optimizer.zero_grad()
|
||||
policy_loss.backward()
|
||||
optimizer.step()
|
||||
# accumulate time
|
||||
total_steps += 1
|
||||
logger.log('step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(total_steps, baseline.value(), policy_loss.item(), policy.genotype()))
|
||||
# to analyze
|
||||
current_best_index.append(api.query_index_by_arch(max(trace, key=lambda x: x[0])[1]))
|
||||
# best_arch = policy.genotype() # first version
|
||||
best_arch = max(trace, key=lambda x: x[0])[1]
|
||||
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs[-1], time.time()-x_start_time))
|
||||
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
|
||||
logger.log('{:}'.format(info))
|
||||
logger.log('-'*100)
|
||||
logger.close()
|
||||
|
||||
return logger.log_dir, current_best_index, total_costs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("The REINFORCE Algorithm")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
|
||||
parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.')
|
||||
parser.add_argument('--EMA_momentum', type=float, default=0.9, help='The momentum value for EMA.')
|
||||
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
|
||||
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
|
||||
# log
|
||||
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
api = create(None, args.search_space, verbose=False)
|
||||
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, 'REINFORCE-{:}'.format(args.learning_rate))
|
||||
print('save-dir : {:}'.format(args.save_dir))
|
||||
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_info = None, collections.OrderedDict()
|
||||
for i in range(args.loops_if_rand):
|
||||
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
save_dir, all_archs, all_total_times = main(args, api)
|
||||
all_info[i] = {'all_archs': all_archs,
|
||||
'all_total_times': all_total_times}
|
||||
save_path = save_dir / 'results.pth'
|
||||
print('save into {:}'.format(save_path))
|
||||
torch.save(all_info, save_path)
|
||||
else:
|
||||
main(args, api)
|
47
exps/NATS-algos/run-all.sh
Normal file
47
exps/NATS-algos/run-all.sh
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
# bash ./exps/NATS-algos/run-all.sh mul
|
||||
# bash ./exps/NATS-algos/run-all.sh ws
|
||||
set -e
|
||||
echo script name: $0
|
||||
echo $# arguments
|
||||
if [ "$#" -ne 1 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
echo "Need 1 parameters for type of algorithms."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
datasets="cifar10 cifar100 ImageNet16-120"
|
||||
alg_type=$1
|
||||
|
||||
if [ "$alg_type" == "mul" ]; then
|
||||
search_spaces="tss sss"
|
||||
|
||||
for dataset in ${datasets}
|
||||
do
|
||||
for search_space in ${search_spaces}
|
||||
do
|
||||
python ./exps/NATS-algos/reinforce.py --dataset ${dataset} --search_space ${search_space} --learning_rate 0.01
|
||||
python ./exps/NATS-algos/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --ea_cycles 200 --ea_population 10 --ea_sample_size 3
|
||||
python ./exps/NATS-algos/random_wo_share.py --dataset ${dataset} --search_space ${search_space}
|
||||
python ./exps/NATS-algos/bohb.py --dataset ${dataset} --search_space ${search_space} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
|
||||
done
|
||||
done
|
||||
|
||||
python exps/experimental/vis-bench-algos.py --search_space tss
|
||||
python exps/experimental/vis-bench-algos.py --search_space sss
|
||||
else
|
||||
seeds="777 888 999"
|
||||
algos="darts-v1 darts-v2 gdas setn random enas"
|
||||
epoch=200
|
||||
for seed in ${seeds}
|
||||
do
|
||||
for alg in ${algos}
|
||||
do
|
||||
python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
525
exps/NATS-algos/search-cell.py
Normal file
525
exps/NATS-algos/search-cell.py
Normal file
@@ -0,0 +1,525 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
######################################################################################
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --drop_path_rate 0.3
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1
|
||||
####
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v2 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v2
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v2
|
||||
####
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo gdas
|
||||
####
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo setn --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn
|
||||
####
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo random --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random
|
||||
####
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
######################################################################################
|
||||
import os, sys, time, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import count_parameters_in_MB, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
# The following three functions are used for DARTS-V2
|
||||
def _concat(xs):
|
||||
return torch.cat([x.view(-1) for x in xs])
|
||||
|
||||
|
||||
def _hessian_vector_product(vector, network, criterion, base_inputs, base_targets, r=1e-2):
|
||||
R = r / _concat(vector).norm()
|
||||
for p, v in zip(network.weights, vector):
|
||||
p.data.add_(R, v)
|
||||
_, logits = network(base_inputs)
|
||||
loss = criterion(logits, base_targets)
|
||||
grads_p = torch.autograd.grad(loss, network.alphas)
|
||||
|
||||
for p, v in zip(network.weights, vector):
|
||||
p.data.sub_(2*R, v)
|
||||
_, logits = network(base_inputs)
|
||||
loss = criterion(logits, base_targets)
|
||||
grads_n = torch.autograd.grad(loss, network.alphas)
|
||||
|
||||
for p, v in zip(network.weights, vector):
|
||||
p.data.add_(R, v)
|
||||
return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]
|
||||
|
||||
|
||||
def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets):
|
||||
# _compute_unrolled_model
|
||||
_, logits = network(base_inputs)
|
||||
loss = criterion(logits, base_targets)
|
||||
LR, WD, momentum = w_optimizer.param_groups[0]['lr'], w_optimizer.param_groups[0]['weight_decay'], w_optimizer.param_groups[0]['momentum']
|
||||
with torch.no_grad():
|
||||
theta = _concat(network.weights)
|
||||
try:
|
||||
moment = _concat(w_optimizer.state[v]['momentum_buffer'] for v in network.weights)
|
||||
moment = moment.mul_(momentum)
|
||||
except:
|
||||
moment = torch.zeros_like(theta)
|
||||
dtheta = _concat(torch.autograd.grad(loss, network.weights)) + WD*theta
|
||||
params = theta.sub(LR, moment+dtheta)
|
||||
unrolled_model = deepcopy(network)
|
||||
model_dict = unrolled_model.state_dict()
|
||||
new_params, offset = {}, 0
|
||||
for k, v in network.named_parameters():
|
||||
if 'arch_parameters' in k: continue
|
||||
v_length = np.prod(v.size())
|
||||
new_params[k] = params[offset: offset+v_length].view(v.size())
|
||||
offset += v_length
|
||||
model_dict.update(new_params)
|
||||
unrolled_model.load_state_dict(model_dict)
|
||||
|
||||
unrolled_model.zero_grad()
|
||||
_, unrolled_logits = unrolled_model(arch_inputs)
|
||||
unrolled_loss = criterion(unrolled_logits, arch_targets)
|
||||
unrolled_loss.backward()
|
||||
|
||||
dalpha = unrolled_model.arch_parameters.grad
|
||||
vector = [v.grad.data for v in unrolled_model.weights]
|
||||
[implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets)
|
||||
|
||||
dalpha.data.sub_(LR, implicit_grads.data)
|
||||
|
||||
if network.arch_parameters.grad is None:
|
||||
network.arch_parameters.grad = deepcopy( dalpha )
|
||||
else:
|
||||
network.arch_parameters.grad.data.copy_( dalpha.data )
|
||||
return unrolled_loss.detach(), unrolled_logits.detach()
|
||||
|
||||
|
||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, algo, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
end = time.time()
|
||||
network.train()
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
|
||||
scheduler.update(None, 1.0 * step / len(xloader))
|
||||
base_inputs = base_inputs.cuda(non_blocking=True)
|
||||
arch_inputs = arch_inputs.cuda(non_blocking=True)
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# Update the weights
|
||||
if algo == 'setn':
|
||||
sampled_arch = network.dync_genotype(True)
|
||||
network.set_cal_mode('dynamic', sampled_arch)
|
||||
elif algo == 'gdas':
|
||||
network.set_cal_mode('gdas', None)
|
||||
elif algo.startswith('darts'):
|
||||
network.set_cal_mode('joint', None)
|
||||
elif algo == 'random':
|
||||
network.set_cal_mode('urs', None)
|
||||
elif algo == 'enas':
|
||||
with torch.no_grad():
|
||||
network.controller.eval()
|
||||
_, _, sampled_arch = network.controller()
|
||||
network.set_cal_mode('dynamic', sampled_arch)
|
||||
else:
|
||||
raise ValueError('Invalid algo name : {:}'.format(algo))
|
||||
|
||||
network.zero_grad()
|
||||
_, logits = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
w_optimizer.step()
|
||||
# record
|
||||
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
base_top1.update (base_prec1.item(), base_inputs.size(0))
|
||||
base_top5.update (base_prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture-weight
|
||||
if algo == 'setn':
|
||||
network.set_cal_mode('joint')
|
||||
elif algo == 'gdas':
|
||||
network.set_cal_mode('gdas', None)
|
||||
elif algo.startswith('darts'):
|
||||
network.set_cal_mode('joint', None)
|
||||
elif algo == 'random':
|
||||
network.set_cal_mode('urs', None)
|
||||
elif algo != 'enas':
|
||||
raise ValueError('Invalid algo name : {:}'.format(algo))
|
||||
network.zero_grad()
|
||||
if algo == 'darts-v2':
|
||||
arch_loss, logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets)
|
||||
a_optimizer.step()
|
||||
elif algo == 'random' or algo == 'enas':
|
||||
with torch.no_grad():
|
||||
_, logits = network(arch_inputs)
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
else:
|
||||
_, logits = network(arch_inputs)
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
arch_loss.backward()
|
||||
a_optimizer.step()
|
||||
# record
|
||||
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % print_freq == 0 or step + 1 == len(xloader):
|
||||
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
|
||||
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
|
||||
return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||
|
||||
|
||||
def train_controller(xloader, network, criterion, optimizer, prev_baseline, epoch_str, print_freq, logger):
|
||||
# config. (containing some necessary arg)
|
||||
# baseline: The baseline score (i.e. average val_acc) from the previous epoch
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time()
|
||||
|
||||
controller_num_aggregate = 20
|
||||
controller_train_steps = 50
|
||||
controller_bl_dec = 0.99
|
||||
controller_entropy_weight = 0.0001
|
||||
|
||||
network.eval()
|
||||
network.controller.train()
|
||||
network.controller.zero_grad()
|
||||
loader_iter = iter(xloader)
|
||||
for step in range(controller_train_steps * controller_num_aggregate):
|
||||
try:
|
||||
inputs, targets = next(loader_iter)
|
||||
except:
|
||||
loader_iter = iter(xloader)
|
||||
inputs, targets = next(loader_iter)
|
||||
inputs = inputs.cuda(non_blocking=True)
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - xend)
|
||||
|
||||
log_prob, entropy, sampled_arch = network.controller()
|
||||
with torch.no_grad():
|
||||
network.set_cal_mode('dynamic', sampled_arch)
|
||||
_, logits = network(inputs)
|
||||
val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
val_top1 = val_top1.view(-1) / 100
|
||||
reward = val_top1 + controller_entropy_weight * entropy
|
||||
if prev_baseline is None:
|
||||
baseline = val_top1
|
||||
else:
|
||||
baseline = prev_baseline - (1 - controller_bl_dec) * (prev_baseline - reward)
|
||||
|
||||
loss = -1 * log_prob * (reward - baseline)
|
||||
|
||||
# account
|
||||
RewardMeter.update(reward.item())
|
||||
BaselineMeter.update(baseline.item())
|
||||
ValAccMeter.update(val_top1.item()*100)
|
||||
LossMeter.update(loss.item())
|
||||
EntropyMeter.update(entropy.item())
|
||||
|
||||
# Average gradient over controller_num_aggregate samples
|
||||
loss = loss / controller_num_aggregate
|
||||
loss.backward(retain_graph=True)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - xend)
|
||||
xend = time.time()
|
||||
if (step+1) % controller_num_aggregate == 0:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(network.controller.parameters(), 5.0)
|
||||
GradnormMeter.update(grad_norm)
|
||||
optimizer.step()
|
||||
network.controller.zero_grad()
|
||||
|
||||
if step % print_freq == 0:
|
||||
Sstr = '*Train-Controller* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, controller_train_steps * controller_num_aggregate)
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter)
|
||||
Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val, EntropyMeter.avg)
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr)
|
||||
|
||||
return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg
|
||||
|
||||
|
||||
def get_best_arch(xloader, network, n_samples, algo):
|
||||
with torch.no_grad():
|
||||
network.eval()
|
||||
if algo == 'random':
|
||||
archs, valid_accs = network.return_topK(n_samples, True), []
|
||||
elif algo == 'setn':
|
||||
archs, valid_accs = network.return_topK(n_samples, False), []
|
||||
elif algo.startswith('darts') or algo == 'gdas':
|
||||
arch = network.genotype
|
||||
archs, valid_accs = [arch], []
|
||||
elif algo == 'enas':
|
||||
archs, valid_accs = [], []
|
||||
for _ in range(n_samples):
|
||||
_, _, sampled_arch = network.controller()
|
||||
archs.append(sampled_arch)
|
||||
else:
|
||||
raise ValueError('Invalid algorithm name : {:}'.format(algo))
|
||||
loader_iter = iter(xloader)
|
||||
for i, sampled_arch in enumerate(archs):
|
||||
network.set_cal_mode('dynamic', sampled_arch)
|
||||
try:
|
||||
inputs, targets = next(loader_iter)
|
||||
except:
|
||||
loader_iter = iter(xloader)
|
||||
inputs, targets = next(loader_iter)
|
||||
_, logits = network(inputs.cuda(non_blocking=True))
|
||||
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
|
||||
valid_accs.append(val_top1.item())
|
||||
best_idx = np.argmax(valid_accs)
|
||||
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
|
||||
return best_arch, best_valid_acc
|
||||
|
||||
|
||||
def valid_func(xloader, network, criterion, algo, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
network.eval()
|
||||
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# prediction
|
||||
_, logits = network(arch_inputs.cuda(non_blocking=True))
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
# record
|
||||
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||
|
||||
|
||||
def main(xargs):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads( xargs.workers )
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
if xargs.overwite_epochs is None:
|
||||
extra_info = {'class_num': class_num, 'xshape': xshape}
|
||||
else:
|
||||
extra_info = {'class_num': class_num, 'xshape': xshape, 'epochs': xargs.overwite_epochs}
|
||||
config = load_config(xargs.config_path, extra_info, logger)
|
||||
search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', (config.batch_size, config.test_batch_size), xargs.workers)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||
|
||||
|
||||
model_config = dict2config(
|
||||
dict(name='generic', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num,
|
||||
space=search_space, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None)
|
||||
logger.log('search space : {:}'.format(search_space))
|
||||
logger.log('model config : {:}'.format(model_config))
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
search_model.set_algo(xargs.algo)
|
||||
logger.log('{:}'.format(search_model))
|
||||
|
||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config)
|
||||
a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, eps=xargs.arch_eps)
|
||||
logger.log('w-optimizer : {:}'.format(w_optimizer))
|
||||
logger.log('a-optimizer : {:}'.format(a_optimizer))
|
||||
logger.log('w-scheduler : {:}'.format(w_scheduler))
|
||||
logger.log('criterion : {:}'.format(criterion))
|
||||
params = count_parameters_in_MB(search_model)
|
||||
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
|
||||
logger.log('search-space : {:}'.format(search_space))
|
||||
if bool(xargs.use_api):
|
||||
api = create(None, 'topology', fast_mode=True, verbose=False)
|
||||
else:
|
||||
api = None
|
||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||
|
||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||
network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU
|
||||
|
||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||
|
||||
if last_info.exists(): # automatically resume from previous checkpoint
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
|
||||
last_info = torch.load(last_info)
|
||||
start_epoch = last_info['epoch']
|
||||
checkpoint = torch.load(last_info['last_checkpoint'])
|
||||
genotypes = checkpoint['genotypes']
|
||||
baseline = checkpoint['baseline']
|
||||
valid_accuracies = checkpoint['valid_accuracies']
|
||||
search_model.load_state_dict( checkpoint['search_model'] )
|
||||
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
|
||||
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
|
||||
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: network.return_topK(1, True)[0]}
|
||||
baseline = None
|
||||
|
||||
# start training
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
w_scheduler.update(epoch, 0.0)
|
||||
need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.val * (total_epoch-epoch), True))
|
||||
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
|
||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
|
||||
|
||||
network.set_drop_path(float(epoch+1) / total_epoch, xargs.drop_path_rate)
|
||||
if xargs.algo == 'gdas':
|
||||
network.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) )
|
||||
logger.log('[RESET tau as : {:} and drop_path as {:}]'.format(network.tau, network.drop_path))
|
||||
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
|
||||
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger)
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
||||
logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
|
||||
if xargs.algo == 'enas':
|
||||
ctl_loss, ctl_acc, baseline, ctl_reward \
|
||||
= train_controller(valid_loader, network, criterion, a_optimizer, baseline, epoch_str, xargs.print_freq, logger)
|
||||
logger.log('[{:}] controller : loss={:}, acc={:}, baseline={:}, reward={:}'.format(epoch_str, ctl_loss, ctl_acc, baseline, ctl_reward))
|
||||
|
||||
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo)
|
||||
if xargs.algo == 'setn' or xargs.algo == 'enas':
|
||||
network.set_cal_mode('dynamic', genotype)
|
||||
elif xargs.algo == 'gdas':
|
||||
network.set_cal_mode('gdas', None)
|
||||
elif xargs.algo.startswith('darts'):
|
||||
network.set_cal_mode('joint', None)
|
||||
elif xargs.algo == 'random':
|
||||
network.set_cal_mode('urs', None)
|
||||
else:
|
||||
raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo))
|
||||
logger.log('[{:}] - [get_best_arch] : {:} -> {:}'.format(epoch_str, genotype, temp_accuracy))
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion, xargs.algo, logger)
|
||||
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype))
|
||||
valid_accuracies[epoch] = valid_a_top1
|
||||
|
||||
genotypes[epoch] = genotype
|
||||
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
|
||||
# save checkpoint
|
||||
save_path = save_checkpoint({'epoch' : epoch + 1,
|
||||
'args' : deepcopy(xargs),
|
||||
'baseline' : baseline,
|
||||
'search_model': search_model.state_dict(),
|
||||
'w_optimizer' : w_optimizer.state_dict(),
|
||||
'a_optimizer' : a_optimizer.state_dict(),
|
||||
'w_scheduler' : w_scheduler.state_dict(),
|
||||
'genotypes' : genotypes,
|
||||
'valid_accuracies' : valid_accuracies},
|
||||
model_base_path, logger)
|
||||
last_info = save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'args' : deepcopy(args),
|
||||
'last_checkpoint': save_path,
|
||||
}, logger.path('info'), logger)
|
||||
with torch.no_grad():
|
||||
logger.log('{:}'.format(search_model.show_alphas()))
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
# the final post procedure : count the time
|
||||
start_time = time.time()
|
||||
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo)
|
||||
if xargs.algo == 'setn' or xargs.algo == 'enas':
|
||||
network.set_cal_mode('dynamic', genotype)
|
||||
elif xargs.algo == 'gdas':
|
||||
network.set_cal_mode('gdas', None)
|
||||
elif xargs.algo.startswith('darts'):
|
||||
network.set_cal_mode('joint', None)
|
||||
elif xargs.algo == 'random':
|
||||
network.set_cal_mode('urs', None)
|
||||
else:
|
||||
raise ValueError('Invalid algorithm name : {:}'.format(xargs.algo))
|
||||
search_time.update(time.time() - start_time)
|
||||
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion, xargs.algo, logger)
|
||||
logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1))
|
||||
|
||||
logger.log('\n' + '-'*100)
|
||||
# check the performance from the architecture dataset
|
||||
logger.log('[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(xargs.algo, total_epoch, search_time.sum, genotype))
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '200') ))
|
||||
logger.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.")
|
||||
parser.add_argument('--data_path' , type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
parser.add_argument('--search_space', type=str, default='tss', choices=['tss'], help='The search space name.')
|
||||
parser.add_argument('--algo' , type=str, choices=['darts-v1', 'darts-v2', 'gdas', 'setn', 'random', 'enas'], help='The search space name.')
|
||||
parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
|
||||
# FOR GDAS
|
||||
parser.add_argument('--tau_min', type=float, default=0.1, help='The minimum tau for Gumbel Softmax.')
|
||||
parser.add_argument('--tau_max', type=float, default=10, help='The maximum tau for Gumbel Softmax.')
|
||||
# channels and number-of-cells
|
||||
parser.add_argument('--max_nodes' , type=int, default=4, help='The maximum number of nodes.')
|
||||
parser.add_argument('--channel' , type=int, default=16, help='The number of channels.')
|
||||
parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.')
|
||||
#
|
||||
parser.add_argument('--eval_candidate_num', type=int, default=100, help='The number of selected architectures to evaluate.')
|
||||
#
|
||||
parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
|
||||
parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.')
|
||||
parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.')
|
||||
parser.add_argument('--overwite_epochs', type=int, help='The number of epochs to overwrite that value in config files.')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
parser.add_argument('--arch_eps' , type=float, default=1e-8, help='weight decay for arch encoding')
|
||||
parser.add_argument('--drop_path_rate' , type=float, help='The drop path rate.')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--print_freq', type=int, default=200, help='print frequency (default: 200)')
|
||||
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
if args.overwite_epochs is None:
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
|
||||
args.dataset,
|
||||
'{:}-affine{:}_BN{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.drop_path_rate))
|
||||
else:
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
|
||||
args.dataset,
|
||||
'{:}-affine{:}_BN{:}-E{:}-{:}'.format(args.algo, args.affine, args.track_running_stats, args.overwite_epochs, args.drop_path_rate))
|
||||
|
||||
main(args)
|
299
exps/NATS-algos/search-size.py
Normal file
299
exps/NATS-algos/search-size.py
Normal file
@@ -0,0 +1,299 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
######################################################################################
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
|
||||
####
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo fbv2 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo fbv2 --rand_seed 777
|
||||
####
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777 --use_api 0
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tunas --arch_weight_decay 0 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tunas --arch_weight_decay 0 --rand_seed 777
|
||||
######################################################################################
|
||||
import os, sys, time, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from datasets import get_datasets, get_nas_search_loaders
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import count_parameters_in_MB, obtain_accuracy
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
# Ad-hoc for TuNAS
|
||||
class ExponentialMovingAverage(object):
|
||||
"""Class that maintains an exponential moving average."""
|
||||
|
||||
def __init__(self, momentum):
|
||||
self._numerator = 0
|
||||
self._denominator = 0
|
||||
self._momentum = momentum
|
||||
|
||||
def update(self, value):
|
||||
self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
|
||||
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
"""Return the current value of the moving average"""
|
||||
return self._numerator / self._denominator
|
||||
|
||||
RL_BASELINE_EMA = ExponentialMovingAverage(0.95)
|
||||
|
||||
|
||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, algo, epoch_str, print_freq, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
end = time.time()
|
||||
network.train()
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
|
||||
scheduler.update(None, 1.0 * step / len(xloader))
|
||||
base_inputs = base_inputs.cuda(non_blocking=True)
|
||||
arch_inputs = arch_inputs.cuda(non_blocking=True)
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# Update the weights
|
||||
network.zero_grad()
|
||||
_, logits, _ = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
w_optimizer.step()
|
||||
# record
|
||||
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
base_top1.update (base_prec1.item(), base_inputs.size(0))
|
||||
base_top5.update (base_prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture-weight
|
||||
network.zero_grad()
|
||||
_, logits, log_probs = network(arch_inputs)
|
||||
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||
if algo == 'tunas':
|
||||
with torch.no_grad():
|
||||
RL_BASELINE_EMA.update(arch_prec1.item())
|
||||
rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
|
||||
rl_log_prob = sum(log_probs)
|
||||
arch_loss = - rl_advantage * rl_log_prob
|
||||
elif algo == 'tas' or algo == 'fbv2':
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
else:
|
||||
raise ValueError('invalid algorightm name: {:}'.format(algo))
|
||||
arch_loss.backward()
|
||||
a_optimizer.step()
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % print_freq == 0 or step + 1 == len(xloader):
|
||||
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
|
||||
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
|
||||
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
|
||||
return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||
|
||||
|
||||
def valid_func(xloader, network, criterion, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
network.eval()
|
||||
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# prediction
|
||||
_, logits, _ = network(arch_inputs.cuda(non_blocking=True))
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
# record
|
||||
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
|
||||
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||
|
||||
|
||||
def main(xargs):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads( xargs.workers )
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
if xargs.overwite_epochs is None:
|
||||
extra_info = {'class_num': class_num, 'xshape': xshape}
|
||||
else:
|
||||
extra_info = {'class_num': class_num, 'xshape': xshape, 'epochs': xargs.overwite_epochs}
|
||||
config = load_config(xargs.config_path, extra_info, logger)
|
||||
search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', (config.batch_size, config.test_batch_size), xargs.workers)
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||
|
||||
model_config = dict2config(
|
||||
dict(name='generic', super_type='search-shape', candidate_Cs=search_space['candidates'], max_num_Cs=search_space['numbers'], num_classes=class_num,
|
||||
genotype=args.genotype, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None)
|
||||
logger.log('search space : {:}'.format(search_space))
|
||||
logger.log('model config : {:}'.format(model_config))
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
search_model.set_algo(xargs.algo)
|
||||
logger.log('{:}'.format(search_model))
|
||||
|
||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config)
|
||||
a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, eps=xargs.arch_eps)
|
||||
logger.log('w-optimizer : {:}'.format(w_optimizer))
|
||||
logger.log('a-optimizer : {:}'.format(a_optimizer))
|
||||
logger.log('w-scheduler : {:}'.format(w_scheduler))
|
||||
logger.log('criterion : {:}'.format(criterion))
|
||||
params = count_parameters_in_MB(search_model)
|
||||
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
|
||||
logger.log('search-space : {:}'.format(search_space))
|
||||
if bool(xargs.use_api):
|
||||
api = create(None, 'size', fast_mode=True, verbose=False)
|
||||
else:
|
||||
api = None
|
||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||
|
||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||
network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU
|
||||
|
||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||
|
||||
if last_info.exists(): # automatically resume from previous checkpoint
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
|
||||
last_info = torch.load(last_info)
|
||||
start_epoch = last_info['epoch']
|
||||
checkpoint = torch.load(last_info['last_checkpoint'])
|
||||
genotypes = checkpoint['genotypes']
|
||||
valid_accuracies = checkpoint['valid_accuracies']
|
||||
search_model.load_state_dict( checkpoint['search_model'] )
|
||||
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
|
||||
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
|
||||
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: network.random}
|
||||
|
||||
# start training
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
w_scheduler.update(epoch, 0.0)
|
||||
need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.val * (total_epoch-epoch), True))
|
||||
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
|
||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
|
||||
|
||||
if xargs.algo == 'fbv2' or xargs.algo == 'tas':
|
||||
network.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) )
|
||||
logger.log('[RESET tau as : {:}]'.format(network.tau))
|
||||
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
|
||||
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, xargs.algo, epoch_str, xargs.print_freq, logger)
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
|
||||
logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
|
||||
|
||||
genotype = network.genotype
|
||||
logger.log('[{:}] - [get_best_arch] : {:}'.format(epoch_str, genotype))
|
||||
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion, logger)
|
||||
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype))
|
||||
valid_accuracies[epoch] = valid_a_top1
|
||||
|
||||
genotypes[epoch] = genotype
|
||||
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
|
||||
# save checkpoint
|
||||
save_path = save_checkpoint({'epoch' : epoch + 1,
|
||||
'args' : deepcopy(xargs),
|
||||
'search_model': search_model.state_dict(),
|
||||
'w_optimizer' : w_optimizer.state_dict(),
|
||||
'a_optimizer' : a_optimizer.state_dict(),
|
||||
'w_scheduler' : w_scheduler.state_dict(),
|
||||
'genotypes' : genotypes,
|
||||
'valid_accuracies' : valid_accuracies},
|
||||
model_base_path, logger)
|
||||
last_info = save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'args' : deepcopy(args),
|
||||
'last_checkpoint': save_path,
|
||||
}, logger.path('info'), logger)
|
||||
with torch.no_grad():
|
||||
logger.log('{:}'.format(search_model.show_alphas()))
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '90')))
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
# the final post procedure : count the time
|
||||
start_time = time.time()
|
||||
genotype = network.genotype
|
||||
search_time.update(time.time() - start_time)
|
||||
|
||||
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, logger)
|
||||
logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1))
|
||||
|
||||
logger.log('\n' + '-'*100)
|
||||
# check the performance from the architecture dataset
|
||||
logger.log('[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(xargs.algo, total_epoch, search_time.sum, genotype))
|
||||
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '90') ))
|
||||
logger.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.")
|
||||
parser.add_argument('--data_path' , type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
parser.add_argument('--search_space', type=str, default='sss', choices=['sss'], help='The search space name.')
|
||||
parser.add_argument('--algo' , type=str, choices=['tas', 'fbv2', 'tunas'], help='The search space name.')
|
||||
parser.add_argument('--genotype' , type=str, default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.')
|
||||
parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
|
||||
# FOR GDAS
|
||||
parser.add_argument('--tau_min', type=float, default=0.1, help='The minimum tau for Gumbel Softmax.')
|
||||
parser.add_argument('--tau_max', type=float, default=10, help='The maximum tau for Gumbel Softmax.')
|
||||
#
|
||||
parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
|
||||
parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.')
|
||||
parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.')
|
||||
parser.add_argument('--overwite_epochs', type=int, help='The number of epochs to overwrite that value in config files.')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
parser.add_argument('--arch_eps' , type=float, default=1e-8, help='weight decay for arch encoding')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--print_freq', type=int, default=200, help='print frequency (default: 200)')
|
||||
parser.add_argument('--rand_seed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
|
||||
dirname = '{:}-affine{:}_BN{:}-AWD{:}'.format(args.algo, args.affine, args.track_running_stats, args.arch_weight_decay)
|
||||
if args.overwite_epochs is not None:
|
||||
dirname = dirname + '-E{:}'.format(args.overwite_epochs)
|
||||
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, dirname)
|
||||
|
||||
main(args)
|
Reference in New Issue
Block a user