Reformulate via black

This commit is contained in:
D-X-Y
2021-03-17 09:25:58 +00:00
parent a9093e41e1
commit f98edea22a
59 changed files with 12289 additions and 8918 deletions

View File

@@ -12,14 +12,17 @@ import os, sys, time, random, argparse, collections
from copy import deepcopy
from pathlib import Path
import torch
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger
from log_utils import AverageMeter, time_string, convert_secs2time
from nats_bench import create
from models import CellStructure, get_search_spaces
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger
from log_utils import AverageMeter, time_string, convert_secs2time
from nats_bench import create
from models import CellStructure, get_search_spaces
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
import ConfigSpace
from hpbandster.optimizers.bohb import BOHB
@@ -28,161 +31,193 @@ from hpbandster.core.worker import Worker
def get_topology_config_space(search_space, max_nodes=4):
cs = ConfigSpace.ConfigurationSpace()
#edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space))
return cs
cs = ConfigSpace.ConfigurationSpace()
# edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space))
return cs
def get_size_config_space(search_space):
cs = ConfigSpace.ConfigurationSpace()
for ilayer in range(search_space['numbers']):
node_str = 'layer-{:}'.format(ilayer)
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space['candidates']))
return cs
cs = ConfigSpace.ConfigurationSpace()
for ilayer in range(search_space["numbers"]):
node_str = "layer-{:}".format(ilayer)
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space["candidates"]))
return cs
def config2topology_func(max_nodes=4):
def config2structure(config):
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = config[node_str]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
return config2structure
def config2structure(config):
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_name = config[node_str]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return CellStructure(genotypes)
return config2structure
def config2size_func(search_space):
def config2structure(config):
channels = []
for ilayer in range(search_space['numbers']):
node_str = 'layer-{:}'.format(ilayer)
channels.append(str(config[node_str]))
return ':'.join(channels)
return config2structure
def config2structure(config):
channels = []
for ilayer in range(search_space["numbers"]):
node_str = "layer-{:}".format(ilayer)
channels.append(str(config[node_str]))
return ":".join(channels)
return config2structure
class MyWorker(Worker):
def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs):
super().__init__(*args, **kwargs)
self.convert_func = convert_func
self._dataset = dataset
self._api = api
self.total_times = []
self.trajectory = []
def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs):
super().__init__(*args, **kwargs)
self.convert_func = convert_func
self._dataset = dataset
self._api = api
self.total_times = []
self.trajectory = []
def compute(self, config, budget, **kwargs):
arch = self.convert_func( config )
accuracy, latency, time_cost, total_time = self._api.simulate_train_eval(arch, self._dataset, iepoch=int(budget)-1, hp='12')
self.trajectory.append((accuracy, arch))
self.total_times.append(total_time)
return ({'loss': 100 - accuracy,
'info': self._api.query_index_by_arch(arch)})
def compute(self, config, budget, **kwargs):
arch = self.convert_func(config)
accuracy, latency, time_cost, total_time = self._api.simulate_train_eval(
arch, self._dataset, iepoch=int(budget) - 1, hp="12"
)
self.trajectory.append((accuracy, arch))
self.total_times.append(total_time)
return {"loss": 100 - accuracy, "info": self._api.query_index_by_arch(arch)}
def main(xargs, api):
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
logger.log('{:} use api : {:}'.format(time_string(), api))
api.reset_time()
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
if xargs.search_space == 'tss':
cs = get_topology_config_space(search_space)
config2structure = config2topology_func()
else:
cs = get_size_config_space(search_space)
config2structure = config2size_func(search_space)
hb_run_id = '0'
logger.log("{:} use api : {:}".format(time_string(), api))
api.reset_time()
search_space = get_search_spaces(xargs.search_space, "nats-bench")
if xargs.search_space == "tss":
cs = get_topology_config_space(search_space)
config2structure = config2topology_func()
else:
cs = get_size_config_space(search_space)
config2structure = config2size_func(search_space)
NS = hpns.NameServer(run_id=hb_run_id, host='localhost', port=0)
ns_host, ns_port = NS.start()
num_workers = 1
hb_run_id = "0"
workers = []
for i in range(num_workers):
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataset=xargs.dataset, api=api, run_id=hb_run_id, id=i)
w.run(background=True)
workers.append(w)
NS = hpns.NameServer(run_id=hb_run_id, host="localhost", port=0)
ns_host, ns_port = NS.start()
num_workers = 1
start_time = time.time()
bohb = BOHB(configspace=cs, run_id=hb_run_id,
eta=3, min_budget=1, max_budget=12,
nameserver=ns_host,
nameserver_port=ns_port,
num_samples=xargs.num_samples,
random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor,
ping_interval=10, min_bandwidth=xargs.min_bandwidth)
results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
workers = []
for i in range(num_workers):
w = MyWorker(
nameserver=ns_host,
nameserver_port=ns_port,
convert_func=config2structure,
dataset=xargs.dataset,
api=api,
run_id=hb_run_id,
id=i,
)
w.run(background=True)
workers.append(w)
bohb.shutdown(shutdown_workers=True)
NS.shutdown()
start_time = time.time()
bohb = BOHB(
configspace=cs,
run_id=hb_run_id,
eta=3,
min_budget=1,
max_budget=12,
nameserver=ns_host,
nameserver_port=ns_port,
num_samples=xargs.num_samples,
random_fraction=xargs.random_fraction,
bandwidth_factor=xargs.bandwidth_factor,
ping_interval=10,
min_bandwidth=xargs.min_bandwidth,
)
# print('There are {:} runs.'.format(len(results.get_all_runs())))
# workers[0].total_times
# workers[0].trajectory
current_best_index = []
for idx in range(len(workers[0].trajectory)):
trajectory = workers[0].trajectory[:idx+1]
arch = max(trajectory, key=lambda x: x[0])[1]
current_best_index.append(api.query_index_by_arch(arch))
best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1]
logger.log('Best found configuration: {:} within {:.3f} s'.format(best_arch, workers[0].total_times[-1]))
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
return logger.log_dir, current_best_index, workers[0].total_times
bohb.shutdown(shutdown_workers=True)
NS.shutdown()
# print('There are {:} runs.'.format(len(results.get_all_runs())))
# workers[0].total_times
# workers[0].trajectory
current_best_index = []
for idx in range(len(workers[0].trajectory)):
trajectory = workers[0].trajectory[: idx + 1]
arch = max(trajectory, key=lambda x: x[0])[1]
current_best_index.append(api.query_index_by_arch(arch))
best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1]
logger.log("Best found configuration: {:} within {:.3f} s".format(best_arch, workers[0].total_times[-1]))
info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90")
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.close()
return logger.log_dir, current_best_index, workers[0].total_times
if __name__ == '__main__':
parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale")
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# general arg
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# BOHB
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations')
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
parser.add_argument('--n_iters', default=300, type=int, nargs='?', help='number of iterations for optimization method')
# log
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
api = create(None, args.search_space, fast_mode=False, verbose=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# general arg
parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.")
parser.add_argument(
"--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)."
)
parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.")
# BOHB
parser.add_argument(
"--strategy", default="sampling", type=str, nargs="?", help="optimization strategy for the acquisition function"
)
parser.add_argument("--min_bandwidth", default=0.3, type=float, nargs="?", help="minimum bandwidth for KDE")
parser.add_argument(
"--num_samples", default=64, type=int, nargs="?", help="number of samples for the acquisition function"
)
parser.add_argument(
"--random_fraction", default=0.33, type=float, nargs="?", help="fraction of random configurations"
)
parser.add_argument("--bandwidth_factor", default=3, type=int, nargs="?", help="factor multiplied to the bandwidth")
parser.add_argument(
"--n_iters", default=300, type=int, nargs="?", help="number of iterations for optimization method"
)
# log
parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.")
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
'{:}-T{:}'.format(args.dataset, args.time_budget), 'BOHB')
print('save-dir : {:}'.format(args.save_dir))
api = create(None, args.search_space, fast_mode=False, verbose=False)
if args.rand_seed < 0:
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {'all_archs': all_archs,
'all_total_times': all_total_times}
save_path = save_dir / 'results.pth'
print('save into {:}'.format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)
args.save_dir = os.path.join(
"{:}-{:}".format(args.save_dir, args.search_space), "{:}-T{:}".format(args.dataset, args.time_budget), "BOHB"
)
print("save-dir : {:}".format(args.save_dir))
if args.rand_seed < 0:
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times}
save_path = save_dir / "results.pth"
print("save into {:}".format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)

View File

@@ -13,80 +13,93 @@ from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_search_spaces
from nats_bench import create
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_search_spaces
from nats_bench import create
from regularized_ea import random_topology_func, random_size_func
def main(xargs, api):
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
logger.log('{:} use api : {:}'.format(time_string(), api))
api.reset_time()
logger.log("{:} use api : {:}".format(time_string(), api))
api.reset_time()
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
if xargs.search_space == 'tss':
random_arch = random_topology_func(search_space)
else:
random_arch = random_size_func(search_space)
search_space = get_search_spaces(xargs.search_space, "nats-bench")
if xargs.search_space == "tss":
random_arch = random_topology_func(search_space)
else:
random_arch = random_size_func(search_space)
best_arch, best_acc, total_time_cost, history = None, -1, [], []
current_best_index = []
while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget:
arch = random_arch()
accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12')
total_time_cost.append(total_cost)
history.append(arch)
if best_arch is None or best_acc < accuracy:
best_acc, best_arch = accuracy, arch
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
current_best_index.append(api.query_index_by_arch(best_arch))
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.'.format(time_string(), best_arch, best_acc, len(history), total_time_cost[-1]))
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
return logger.log_dir, current_best_index, total_time_cost
best_arch, best_acc, total_time_cost, history = None, -1, [], []
current_best_index = []
while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget:
arch = random_arch()
accuracy, _, _, total_cost = api.simulate_train_eval(arch, xargs.dataset, hp="12")
total_time_cost.append(total_cost)
history.append(arch)
if best_arch is None or best_acc < accuracy:
best_acc, best_arch = accuracy, arch
logger.log("[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy))
current_best_index.append(api.query_index_by_arch(best_arch))
logger.log(
"{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.".format(
time_string(), best_arch, best_acc, len(history), total_time_cost[-1]
)
)
info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90")
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.close()
return logger.log_dir, current_best_index, total_time_cost
if __name__ == '__main__':
parser = argparse.ArgumentParser("Random NAS")
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
if __name__ == "__main__":
parser = argparse.ArgumentParser("Random NAS")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.")
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
api = create(None, args.search_space, fast_mode=True, verbose=False)
parser.add_argument(
"--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)."
)
parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.")
# log
parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.")
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
'{:}-T{:}'.format(args.dataset, args.time_budget), 'RANDOM')
print('save-dir : {:}'.format(args.save_dir))
api = create(None, args.search_space, fast_mode=True, verbose=False)
if args.rand_seed < 0:
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {'all_archs': all_archs,
'all_total_times': all_total_times}
save_path = save_dir / 'results.pth'
print('save into {:}'.format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)
args.save_dir = os.path.join(
"{:}-{:}".format(args.save_dir, args.search_space), "{:}-T{:}".format(args.dataset, args.time_budget), "RANDOM"
)
print("save-dir : {:}".format(args.save_dir))
if args.rand_seed < 0:
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times}
save_path = save_dir / "results.pth"
print("save into {:}".format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)

View File

@@ -17,214 +17,242 @@ from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import CellStructure, get_search_spaces
from nats_bench import create
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import CellStructure, get_search_spaces
from nats_bench import create
class Model(object):
def __init__(self):
self.arch = None
self.accuracy = None
def __str__(self):
"""Prints a readable version of this bitstring."""
return "{:}".format(self.arch)
def __init__(self):
self.arch = None
self.accuracy = None
def __str__(self):
"""Prints a readable version of this bitstring."""
return '{:}'.format(self.arch)
def random_topology_func(op_names, max_nodes=4):
# Return a random architecture
def random_architecture():
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = random.choice( op_names )
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
return random_architecture
# Return a random architecture
def random_architecture():
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_name = random.choice(op_names)
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return CellStructure(genotypes)
return random_architecture
def random_size_func(info):
# Return a random architecture
def random_architecture():
channels = []
for i in range(info['numbers']):
channels.append(
str(random.choice(info['candidates'])))
return ':'.join(channels)
return random_architecture
# Return a random architecture
def random_architecture():
channels = []
for i in range(info["numbers"]):
channels.append(str(random.choice(info["candidates"])))
return ":".join(channels)
return random_architecture
def mutate_topology_func(op_names):
"""Computes the architecture for a child of the given parent architecture.
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
"""
def mutate_topology_func(parent_arch):
child_arch = deepcopy( parent_arch )
node_id = random.randint(0, len(child_arch.nodes)-1)
node_info = list( child_arch.nodes[node_id] )
snode_id = random.randint(0, len(node_info)-1)
xop = random.choice( op_names )
while xop == node_info[snode_id][0]:
xop = random.choice( op_names )
node_info[snode_id] = (xop, node_info[snode_id][1])
child_arch.nodes[node_id] = tuple( node_info )
return child_arch
return mutate_topology_func
"""Computes the architecture for a child of the given parent architecture.
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
"""
def mutate_topology_func(parent_arch):
child_arch = deepcopy(parent_arch)
node_id = random.randint(0, len(child_arch.nodes) - 1)
node_info = list(child_arch.nodes[node_id])
snode_id = random.randint(0, len(node_info) - 1)
xop = random.choice(op_names)
while xop == node_info[snode_id][0]:
xop = random.choice(op_names)
node_info[snode_id] = (xop, node_info[snode_id][1])
child_arch.nodes[node_id] = tuple(node_info)
return child_arch
return mutate_topology_func
def mutate_size_func(info):
"""Computes the architecture for a child of the given parent architecture.
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
"""
def mutate_size_func(parent_arch):
child_arch = deepcopy(parent_arch)
child_arch = child_arch.split(':')
index = random.randint(0, len(child_arch)-1)
child_arch[index] = str(random.choice(info['candidates']))
return ':'.join(child_arch)
return mutate_size_func
"""Computes the architecture for a child of the given parent architecture.
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
"""
def mutate_size_func(parent_arch):
child_arch = deepcopy(parent_arch)
child_arch = child_arch.split(":")
index = random.randint(0, len(child_arch) - 1)
child_arch[index] = str(random.choice(info["candidates"]))
return ":".join(child_arch)
return mutate_size_func
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, use_proxy, dataset):
"""Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
Classifier Architecture Search".
Args:
cycles: the number of cycles the algorithm should run for.
population_size: the number of individuals to keep in the population.
sample_size: the number of individuals that should participate in each tournament.
time_budget: the upper bound of searching cost
def regularized_evolution(
cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, api, use_proxy, dataset
):
"""Algorithm for regularized evolution (i.e. aging evolution).
Returns:
history: a list of `Model` instances, representing all the models computed
during the evolution experiment.
"""
population = collections.deque()
api.reset_time()
history, total_time_cost = [], [] # Not used by the algorithm, only used to report results.
current_best_index = []
# Initialize the population with random models.
while len(population) < population_size:
model = Model()
model.arch = random_arch()
model.accuracy, _, _, total_cost = api.simulate_train_eval(
model.arch, dataset, hp='12' if use_proxy else api.full_train_epochs)
# Append the info
population.append(model)
history.append((model.accuracy, model.arch))
total_time_cost.append(total_cost)
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
Classifier Architecture Search".
# Carry out evolution in cycles. Each cycle produces a model and removes another.
while total_time_cost[-1] < time_budget:
# Sample randomly chosen models from the current population.
start_time, sample = time.time(), []
while len(sample) < sample_size:
# Inefficient, but written this way for clarity. In the case of neural
# nets, the efficiency of this line is irrelevant because training neural
# nets is the rate-determining step.
candidate = random.choice(list(population))
sample.append(candidate)
Args:
cycles: the number of cycles the algorithm should run for.
population_size: the number of individuals to keep in the population.
sample_size: the number of individuals that should participate in each tournament.
time_budget: the upper bound of searching cost
# The parent is the best model in the sample.
parent = max(sample, key=lambda i: i.accuracy)
Returns:
history: a list of `Model` instances, representing all the models computed
during the evolution experiment.
"""
population = collections.deque()
api.reset_time()
history, total_time_cost = [], [] # Not used by the algorithm, only used to report results.
current_best_index = []
# Initialize the population with random models.
while len(population) < population_size:
model = Model()
model.arch = random_arch()
model.accuracy, _, _, total_cost = api.simulate_train_eval(
model.arch, dataset, hp="12" if use_proxy else api.full_train_epochs
)
# Append the info
population.append(model)
history.append((model.accuracy, model.arch))
total_time_cost.append(total_cost)
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
# Create the child model and store it.
child = Model()
child.arch = mutate_arch(parent.arch)
child.accuracy, _, _, total_cost = api.simulate_train_eval(
child.arch, dataset, hp='12' if use_proxy else api.full_train_epochs)
# Append the info
population.append(child)
history.append((child.accuracy, child.arch))
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
total_time_cost.append(total_cost)
# Carry out evolution in cycles. Each cycle produces a model and removes another.
while total_time_cost[-1] < time_budget:
# Sample randomly chosen models from the current population.
start_time, sample = time.time(), []
while len(sample) < sample_size:
# Inefficient, but written this way for clarity. In the case of neural
# nets, the efficiency of this line is irrelevant because training neural
# nets is the rate-determining step.
candidate = random.choice(list(population))
sample.append(candidate)
# Remove the oldest model.
population.popleft()
return history, current_best_index, total_time_cost
# The parent is the best model in the sample.
parent = max(sample, key=lambda i: i.accuracy)
# Create the child model and store it.
child = Model()
child.arch = mutate_arch(parent.arch)
child.accuracy, _, _, total_cost = api.simulate_train_eval(
child.arch, dataset, hp="12" if use_proxy else api.full_train_epochs
)
# Append the info
population.append(child)
history.append((child.accuracy, child.arch))
current_best_index.append(api.query_index_by_arch(max(history, key=lambda x: x[0])[1]))
total_time_cost.append(total_cost)
# Remove the oldest model.
population.popleft()
return history, current_best_index, total_time_cost
def main(xargs, api):
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
if xargs.search_space == 'tss':
random_arch = random_topology_func(search_space)
mutate_arch = mutate_topology_func(search_space)
else:
random_arch = random_size_func(search_space)
mutate_arch = mutate_size_func(search_space)
search_space = get_search_spaces(xargs.search_space, "nats-bench")
if xargs.search_space == "tss":
random_arch = random_topology_func(search_space)
mutate_arch = mutate_topology_func(search_space)
else:
random_arch = random_size_func(search_space)
mutate_arch = mutate_size_func(search_space)
x_start_time = time.time()
logger.log('{:} use api : {:}'.format(time_string(), api))
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
history, current_best_index, total_times = regularized_evolution(xargs.ea_cycles,
xargs.ea_population,
xargs.ea_sample_size,
xargs.time_budget,
random_arch, mutate_arch, api, xargs.use_proxy > 0, xargs.dataset)
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_times[-1], time.time()-x_start_time))
best_arch = max(history, key=lambda x: x[0])[1]
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
return logger.log_dir, current_best_index, total_times
x_start_time = time.time()
logger.log("{:} use api : {:}".format(time_string(), api))
logger.log("-" * 30 + " start searching with the time budget of {:} s".format(xargs.time_budget))
history, current_best_index, total_times = regularized_evolution(
xargs.ea_cycles,
xargs.ea_population,
xargs.ea_sample_size,
xargs.time_budget,
random_arch,
mutate_arch,
api,
xargs.use_proxy > 0,
xargs.dataset,
)
logger.log(
"{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).".format(
time_string(), len(history), total_times[-1], time.time() - x_start_time
)
)
best_arch = max(history, key=lambda x: x[0])[1]
logger.log("{:} best arch is {:}".format(time_string(), best_arch))
info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90")
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.close()
return logger.log_dir, current_best_index, total_times
if __name__ == '__main__':
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
# hyperparameters for REA
parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.')
parser.add_argument('--ea_population', type=int, help='The population size in EA.')
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--use_proxy', type=int, default=1, help='Whether to use the proxy (H0) task or not.')
#
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.")
# hyperparameters for REA
parser.add_argument("--ea_cycles", type=int, help="The number of cycles in EA.")
parser.add_argument("--ea_population", type=int, help="The population size in EA.")
parser.add_argument("--ea_sample_size", type=int, help="The sample size in EA.")
parser.add_argument(
"--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)."
)
parser.add_argument("--use_proxy", type=int, default=1, help="Whether to use the proxy (H0) task or not.")
#
parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.")
# log
parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.")
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
api = create(None, args.search_space, fast_mode=True, verbose=False)
api = create(None, args.search_space, fast_mode=True, verbose=False)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
'{:}-T{:}{:}'.format(args.dataset, args.time_budget, '' if args.use_proxy > 0 else '-FULL'),
'R-EA-SS{:}'.format(args.ea_sample_size))
print('save-dir : {:}'.format(args.save_dir))
print('xargs : {:}'.format(args))
args.save_dir = os.path.join(
"{:}-{:}".format(args.save_dir, args.search_space),
"{:}-T{:}{:}".format(args.dataset, args.time_budget, "" if args.use_proxy > 0 else "-FULL"),
"R-EA-SS{:}".format(args.ea_sample_size),
)
print("save-dir : {:}".format(args.save_dir))
print("xargs : {:}".format(args))
if args.rand_seed < 0:
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {'all_archs': all_archs,
'all_total_times': all_total_times}
save_path = save_dir / 'results.pth'
print('save into {:}'.format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)
if args.rand_seed < 0:
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times}
save_path = save_dir / "results.pth"
print("save into {:}".format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)

View File

@@ -3,12 +3,12 @@
#####################################################################################################
# modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py #
#####################################################################################################
# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01
# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.01
#####################################################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
@@ -17,197 +17,216 @@ from pathlib import Path
import torch
import torch.nn as nn
from torch.distributions import Categorical
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import CellStructure, get_search_spaces
from nats_bench import create
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import CellStructure, get_search_spaces
from nats_bench import create
class PolicyTopology(nn.Module):
def __init__(self, search_space, max_nodes=4):
super(PolicyTopology, self).__init__()
self.max_nodes = max_nodes
self.search_space = deepcopy(search_space)
self.edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
self.edge2index[node_str] = len(self.edge2index)
self.arch_parameters = nn.Parameter(1e-3 * torch.randn(len(self.edge2index), len(search_space)))
def __init__(self, search_space, max_nodes=4):
super(PolicyTopology, self).__init__()
self.max_nodes = max_nodes
self.search_space = deepcopy(search_space)
self.edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
self.edge2index[ node_str ] = len(self.edge2index)
self.arch_parameters = nn.Parameter(1e-3*torch.randn(len(self.edge2index), len(search_space)))
def generate_arch(self, actions):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_name = self.search_space[actions[self.edge2index[node_str]]]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return CellStructure(genotypes)
def generate_arch(self, actions):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = self.search_space[ actions[ self.edge2index[ node_str ] ] ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self.search_space[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return CellStructure(genotypes)
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ]
op_name = self.search_space[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
def forward(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
return alphas
def forward(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
return alphas
class PolicySize(nn.Module):
def __init__(self, search_space):
super(PolicySize, self).__init__()
self.candidates = search_space["candidates"]
self.numbers = search_space["numbers"]
self.arch_parameters = nn.Parameter(1e-3 * torch.randn(self.numbers, len(self.candidates)))
def __init__(self, search_space):
super(PolicySize, self).__init__()
self.candidates = search_space['candidates']
self.numbers = search_space['numbers']
self.arch_parameters = nn.Parameter(1e-3*torch.randn(self.numbers, len(self.candidates)))
def generate_arch(self, actions):
channels = [str(self.candidates[i]) for i in actions]
return ":".join(channels)
def generate_arch(self, actions):
channels = [str(self.candidates[i]) for i in actions]
return ':'.join(channels)
def genotype(self):
channels = []
for i in range(self.numbers):
index = self.arch_parameters[i].argmax().item()
channels.append(str(self.candidates[index]))
return ":".join(channels)
def genotype(self):
channels = []
for i in range(self.numbers):
index = self.arch_parameters[i].argmax().item()
channels.append(str(self.candidates[index]))
return ':'.join(channels)
def forward(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
return alphas
def forward(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
return alphas
class ExponentialMovingAverage(object):
"""Class that maintains an exponential moving average."""
"""Class that maintains an exponential moving average."""
def __init__(self, momentum):
self._numerator = 0
self._denominator = 0
self._momentum = momentum
def __init__(self, momentum):
self._numerator = 0
self._denominator = 0
self._momentum = momentum
def update(self, value):
self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
def update(self, value):
self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
def value(self):
"""Return the current value of the moving average"""
return self._numerator / self._denominator
def value(self):
"""Return the current value of the moving average"""
return self._numerator / self._denominator
def select_action(policy):
probs = policy()
m = Categorical(probs)
action = m.sample()
# policy.saved_log_probs.append(m.log_prob(action))
return m.log_prob(action), action.cpu().tolist()
probs = policy()
m = Categorical(probs)
action = m.sample()
# policy.saved_log_probs.append(m.log_prob(action))
return m.log_prob(action), action.cpu().tolist()
def main(xargs, api):
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
if xargs.search_space == 'tss':
policy = PolicyTopology(search_space)
else:
policy = PolicySize(search_space)
optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
#optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate)
eps = np.finfo(np.float32).eps.item()
baseline = ExponentialMovingAverage(xargs.EMA_momentum)
logger.log('policy : {:}'.format(policy))
logger.log('optimizer : {:}'.format(optimizer))
logger.log('eps : {:}'.format(eps))
torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
# nas dataset load
logger.log('{:} use api : {:}'.format(time_string(), api))
api.reset_time()
search_space = get_search_spaces(xargs.search_space, "nats-bench")
if xargs.search_space == "tss":
policy = PolicyTopology(search_space)
else:
policy = PolicySize(search_space)
optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
# optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate)
eps = np.finfo(np.float32).eps.item()
baseline = ExponentialMovingAverage(xargs.EMA_momentum)
logger.log("policy : {:}".format(policy))
logger.log("optimizer : {:}".format(optimizer))
logger.log("eps : {:}".format(eps))
# REINFORCE
x_start_time = time.time()
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
total_steps, total_costs, trace = 0, [], []
current_best_index = []
while len(total_costs) == 0 or total_costs[-1] < xargs.time_budget:
start_time = time.time()
log_prob, action = select_action( policy )
arch = policy.generate_arch( action )
reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, hp='12')
trace.append((reward, arch))
total_costs.append(current_total_cost)
# nas dataset load
logger.log("{:} use api : {:}".format(time_string(), api))
api.reset_time()
baseline.update(reward)
# calculate loss
policy_loss = ( -log_prob * (reward - baseline.value()) ).sum()
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
# accumulate time
total_steps += 1
logger.log('step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(total_steps, baseline.value(), policy_loss.item(), policy.genotype()))
# to analyze
current_best_index.append(api.query_index_by_arch(max(trace, key=lambda x: x[0])[1]))
# best_arch = policy.genotype() # first version
best_arch = max(trace, key=lambda x: x[0])[1]
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs[-1], time.time()-x_start_time))
info = api.query_info_str_by_arch(best_arch, '200' if xargs.search_space == 'tss' else '90')
logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
# REINFORCE
x_start_time = time.time()
logger.log("Will start searching with time budget of {:} s.".format(xargs.time_budget))
total_steps, total_costs, trace = 0, [], []
current_best_index = []
while len(total_costs) == 0 or total_costs[-1] < xargs.time_budget:
start_time = time.time()
log_prob, action = select_action(policy)
arch = policy.generate_arch(action)
reward, _, _, current_total_cost = api.simulate_train_eval(arch, xargs.dataset, hp="12")
trace.append((reward, arch))
total_costs.append(current_total_cost)
return logger.log_dir, current_best_index, total_costs
baseline.update(reward)
# calculate loss
policy_loss = (-log_prob * (reward - baseline.value())).sum()
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
# accumulate time
total_steps += 1
logger.log(
"step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}".format(
total_steps, baseline.value(), policy_loss.item(), policy.genotype()
)
)
# to analyze
current_best_index.append(api.query_index_by_arch(max(trace, key=lambda x: x[0])[1]))
# best_arch = policy.genotype() # first version
best_arch = max(trace, key=lambda x: x[0])[1]
logger.log(
"REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).".format(
total_steps, total_costs[-1], time.time() - x_start_time
)
)
info = api.query_info_str_by_arch(best_arch, "200" if xargs.search_space == "tss" else "90")
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.close()
return logger.log_dir, current_best_index, total_costs
if __name__ == '__main__':
parser = argparse.ArgumentParser("The REINFORCE Algorithm")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, choices=['tss', 'sss'], help='Choose the search space.')
parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.')
parser.add_argument('--EMA_momentum', type=float, default=0.9, help='The momentum value for EMA.')
parser.add_argument('--time_budget', type=int, default=20000, help='The total time cost budge for searching (in seconds).')
parser.add_argument('--loops_if_rand', type=int, default=500, help='The total runs for evaluation.')
# log
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
if __name__ == "__main__":
parser = argparse.ArgumentParser("The REINFORCE Algorithm")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
parser.add_argument("--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.")
parser.add_argument("--learning_rate", type=float, help="The learning rate for REINFORCE.")
parser.add_argument("--EMA_momentum", type=float, default=0.9, help="The momentum value for EMA.")
parser.add_argument(
"--time_budget", type=int, default=20000, help="The total time cost budge for searching (in seconds)."
)
parser.add_argument("--loops_if_rand", type=int, default=500, help="The total runs for evaluation.")
# log
parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.")
parser.add_argument(
"--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)."
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
api = create(None, args.search_space, fast_mode=True, verbose=False)
api = create(None, args.search_space, fast_mode=True, verbose=False)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space),
'{:}-T{:}'.format(args.dataset, args.time_budget), 'REINFORCE-{:}'.format(args.learning_rate))
print('save-dir : {:}'.format(args.save_dir))
args.save_dir = os.path.join(
"{:}-{:}".format(args.save_dir, args.search_space),
"{:}-T{:}".format(args.dataset, args.time_budget),
"REINFORCE-{:}".format(args.learning_rate),
)
print("save-dir : {:}".format(args.save_dir))
if args.rand_seed < 0:
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {'all_archs': all_archs,
'all_total_times': all_total_times}
save_path = save_dir / 'results.pth'
print('save into {:}'.format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)
if args.rand_seed < 0:
save_dir, all_info = None, collections.OrderedDict()
for i in range(args.loops_if_rand):
print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand))
args.rand_seed = random.randint(1, 100000)
save_dir, all_archs, all_total_times = main(args, api)
all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times}
save_path = save_dir / "results.pth"
print("save into {:}".format(save_path))
torch.save(all_info, save_path)
else:
main(args, api)

File diff suppressed because it is too large Load Diff

View File

@@ -32,294 +32,420 @@ from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import count_parameters_in_MB, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nats_bench import create
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import count_parameters_in_MB, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nats_bench import create
# Ad-hoc for RL algorithms.
class ExponentialMovingAverage(object):
"""Class that maintains an exponential moving average."""
"""Class that maintains an exponential moving average."""
def __init__(self, momentum):
self._numerator = 0
self._denominator = 0
self._momentum = momentum
def __init__(self, momentum):
self._numerator = 0
self._denominator = 0
self._momentum = momentum
def update(self, value):
self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
def update(self, value):
self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
@property
def value(self):
"""Return the current value of the moving average"""
return self._numerator / self._denominator
@property
def value(self):
"""Return the current value of the moving average"""
return self._numerator / self._denominator
RL_BASELINE_EMA = ExponentialMovingAverage(0.95)
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, enable_controller, algo, epoch_str, print_freq, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
network.train()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_inputs = base_inputs.cuda(non_blocking=True)
arch_inputs = arch_inputs.cuda(non_blocking=True)
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# Update the weights
network.zero_grad()
_, logits, _ = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update (base_prec1.item(), base_inputs.size(0))
base_top5.update (base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
network.zero_grad()
a_optimizer.zero_grad()
_, logits, log_probs = network(arch_inputs)
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
if algo == 'mask_rl':
with torch.no_grad():
RL_BASELINE_EMA.update(arch_prec1.item())
rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
rl_log_prob = sum(log_probs)
arch_loss = - rl_advantage * rl_log_prob
elif algo == 'tas' or algo == 'mask_gumbel':
arch_loss = criterion(logits, arch_targets)
else:
raise ValueError('invalid algorightm name: {:}'.format(algo))
if enable_controller:
arch_loss.backward()
a_optimizer.step()
# record
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
def search_func(
xloader,
network,
criterion,
scheduler,
w_optimizer,
a_optimizer,
enable_controller,
algo,
epoch_str,
print_freq,
logger,
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
network.train()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_inputs = base_inputs.cuda(non_blocking=True)
arch_inputs = arch_inputs.cuda(non_blocking=True)
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
# Update the weights
network.zero_grad()
_, logits, _ = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update(base_prec1.item(), base_inputs.size(0))
base_top5.update(base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
network.zero_grad()
a_optimizer.zero_grad()
_, logits, log_probs = network(arch_inputs)
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
if algo == "mask_rl":
with torch.no_grad():
RL_BASELINE_EMA.update(arch_prec1.item())
rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
rl_log_prob = sum(log_probs)
arch_loss = -rl_advantage * rl_log_prob
elif algo == "tas" or algo == "mask_gumbel":
arch_loss = criterion(logits, arch_targets)
else:
raise ValueError("invalid algorightm name: {:}".format(algo))
if enable_controller:
arch_loss.backward()
a_optimizer.step()
# record
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=base_losses, top1=base_top1, top5=base_top5
)
Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=arch_losses, top1=arch_top1, top5=arch_top5
)
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr)
return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
def valid_func(xloader, network, criterion, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
with torch.no_grad():
network.eval()
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits, _ = network(arch_inputs.cuda(non_blocking=True))
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
with torch.no_grad():
network.eval()
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits, _ = network(arch_inputs.cuda(non_blocking=True))
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
if xargs.overwite_epochs is None:
extra_info = {'class_num': class_num, 'xshape': xshape}
else:
extra_info = {'class_num': class_num, 'xshape': xshape, 'epochs': xargs.overwite_epochs}
config = load_config(xargs.config_path, extra_info, logger)
search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', (config.batch_size, config.test_batch_size), xargs.workers)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
model_config = dict2config(
dict(name='generic', super_type='search-shape', candidate_Cs=search_space['candidates'], max_num_Cs=search_space['numbers'], num_classes=class_num,
genotype=args.genotype, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None)
logger.log('search space : {:}'.format(search_space))
logger.log('model config : {:}'.format(model_config))
search_model = get_cell_based_tiny_net(model_config)
search_model.set_algo(xargs.algo)
logger.log('{:}'.format(search_model))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config)
a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, eps=xargs.arch_eps)
logger.log('w-optimizer : {:}'.format(w_optimizer))
logger.log('a-optimizer : {:}'.format(a_optimizer))
logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion))
params = count_parameters_in_MB(search_model)
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
logger.log('search-space : {:}'.format(search_space))
if bool(xargs.use_api):
api = create(None, 'size', fast_mode=True, verbose=False)
else:
api = None
logger.log('{:} create API = {:} done'.format(time_string(), api))
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info['epoch']
checkpoint = torch.load(last_info['last_checkpoint'])
genotypes = checkpoint['genotypes']
valid_accuracies = checkpoint['valid_accuracies']
search_model.load_state_dict( checkpoint['search_model'] )
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: network.random}
# start training
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.val * (total_epoch-epoch), True))
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
if xargs.warmup_ratio is None or xargs.warmup_ratio <= float(epoch) / total_epoch:
enable_controller = True
network.set_warmup_ratio(None)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
if xargs.overwite_epochs is None:
extra_info = {"class_num": class_num, "xshape": xshape}
else:
enable_controller = False
network.set_warmup_ratio(1.0 - float(epoch) / total_epoch / xargs.warmup_ratio)
extra_info = {"class_num": class_num, "xshape": xshape, "epochs": xargs.overwite_epochs}
config = load_config(xargs.config_path, extra_info, logger)
search_loader, train_loader, valid_loader = get_nas_search_loaders(
train_data,
valid_data,
xargs.dataset,
"configs/nas-benchmark/",
(config.batch_size, config.test_batch_size),
xargs.workers,
)
logger.log(
"||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(search_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller))
search_space = get_search_spaces(xargs.search_space, "nats-bench")
if xargs.algo == 'mask_gumbel' or xargs.algo == 'tas':
network.set_tau(xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1))
logger.log('[RESET tau as : {:}]'.format(network.tau))
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
= search_func(search_loader, network, criterion, w_scheduler,
w_optimizer, a_optimizer, enable_controller, xargs.algo, epoch_str, xargs.print_freq, logger)
search_time.update(time.time() - start_time)
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
model_config = dict2config(
dict(
name="generic",
super_type="search-shape",
candidate_Cs=search_space["candidates"],
max_num_Cs=search_space["numbers"],
num_classes=class_num,
genotype=args.genotype,
affine=bool(xargs.affine),
track_running_stats=bool(xargs.track_running_stats),
),
None,
)
logger.log("search space : {:}".format(search_space))
logger.log("model config : {:}".format(model_config))
search_model = get_cell_based_tiny_net(model_config)
search_model.set_algo(xargs.algo)
logger.log("{:}".format(search_model))
genotype = network.genotype
logger.log('[{:}] - [get_best_arch] : {:}'.format(epoch_str, genotype))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion, logger)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype))
valid_accuracies[epoch] = valid_a_top1
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config)
a_optimizer = torch.optim.Adam(
search_model.alphas,
lr=xargs.arch_learning_rate,
betas=(0.5, 0.999),
weight_decay=xargs.arch_weight_decay,
eps=xargs.arch_eps,
)
logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("a-optimizer : {:}".format(a_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion : {:}".format(criterion))
params = count_parameters_in_MB(search_model)
logger.log("The parameters of the search model = {:.2f} MB".format(params))
logger.log("search-space : {:}".format(search_space))
if bool(xargs.use_api):
api = create(None, "size", fast_mode=True, verbose=False)
else:
api = None
logger.log("{:} create API = {:} done".format(time_string(), api))
genotypes[epoch] = genotype
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint({'epoch' : epoch + 1,
'args' : deepcopy(xargs),
'search_model': search_model.state_dict(),
'w_optimizer' : w_optimizer.state_dict(),
'a_optimizer' : a_optimizer.state_dict(),
'w_scheduler' : w_scheduler.state_dict(),
'genotypes' : genotypes,
'valid_accuracies' : valid_accuracies},
model_base_path, logger)
last_info = save_checkpoint({
'epoch': epoch + 1,
'args' : deepcopy(args),
'last_checkpoint': save_path,
}, logger.path('info'), logger)
with torch.no_grad():
logger.log('{:}'.format(search_model.show_alphas()))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '90')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best")
network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU
last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best")
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info["epoch"]
checkpoint = torch.load(last_info["last_checkpoint"])
genotypes = checkpoint["genotypes"]
valid_accuracies = checkpoint["valid_accuracies"]
search_model.load_state_dict(checkpoint["search_model"])
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
a_optimizer.load_state_dict(checkpoint["a_optimizer"])
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)
)
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: network.random}
# start training
start_time, search_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
AverageMeter(),
config.epochs + config.warmup,
)
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
if xargs.warmup_ratio is None or xargs.warmup_ratio <= float(epoch) / total_epoch:
enable_controller = True
network.set_warmup_ratio(None)
else:
enable_controller = False
network.set_warmup_ratio(1.0 - float(epoch) / total_epoch / xargs.warmup_ratio)
logger.log(
"\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}".format(
epoch_str, need_time, min(w_scheduler.get_lr()), network.warmup_ratio, enable_controller
)
)
if xargs.algo == "mask_gumbel" or xargs.algo == "tas":
network.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1))
logger.log("[RESET tau as : {:}]".format(network.tau))
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 = search_func(
search_loader,
network,
criterion,
w_scheduler,
w_optimizer,
a_optimizer,
enable_controller,
xargs.algo,
epoch_str,
xargs.print_freq,
logger,
)
search_time.update(time.time() - start_time)
logger.log(
"[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
)
)
logger.log(
"[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
epoch_str, search_a_loss, search_a_top1, search_a_top5
)
)
genotype = network.genotype
logger.log("[{:}] - [get_best_arch] : {:}".format(epoch_str, genotype))
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, logger)
logger.log(
"[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format(
epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype
)
)
valid_accuracies[epoch] = valid_a_top1
genotypes[epoch] = genotype
logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(xargs),
"search_model": search_model.state_dict(),
"w_optimizer": w_optimizer.state_dict(),
"a_optimizer": a_optimizer.state_dict(),
"w_scheduler": w_scheduler.state_dict(),
"genotypes": genotypes,
"valid_accuracies": valid_accuracies,
},
model_base_path,
logger,
)
last_info = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(args),
"last_checkpoint": save_path,
},
logger.path("info"),
logger,
)
with torch.no_grad():
logger.log("{:}".format(search_model.show_alphas()))
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "90")))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
# the final post procedure : count the time
start_time = time.time()
genotype = network.genotype
search_time.update(time.time() - start_time)
# the final post procedure : count the time
start_time = time.time()
genotype = network.genotype
search_time.update(time.time() - start_time)
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, logger)
logger.log("Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format(genotype, valid_a_top1))
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion, logger)
logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1))
logger.log("\n" + "-" * 100)
# check the performance from the architecture dataset
logger.log(
"[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
xargs.algo, total_epoch, search_time.sum, genotype
)
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotype, "90")))
logger.close()
logger.log('\n' + '-'*100)
# check the performance from the architecture dataset
logger.log('[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(xargs.algo, total_epoch, search_time.sum, genotype))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '90') ))
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.")
parser.add_argument('--data_path' , type=str, help='Path to dataset')
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
parser.add_argument('--search_space', type=str, default='sss', choices=['sss'], help='The search space name.')
parser.add_argument('--algo' , type=str, choices=['tas', 'mask_gumbel', 'mask_rl'], help='The search space name.')
parser.add_argument('--genotype' , type=str, default='|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|', help='The genotype.')
parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
# FOR GDAS
parser.add_argument('--tau_min', type=float, default=0.1, help='The minimum tau for Gumbel Softmax.')
parser.add_argument('--tau_max', type=float, default=10, help='The maximum tau for Gumbel Softmax.')
# FOR ALL
parser.add_argument('--warmup_ratio', type=float, help='The warmup ratio, if None, not use warmup.')
#
parser.add_argument('--track_running_stats',type=int, default=0, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--affine' , type=int, default=0, choices=[0,1],help='Whether use affine=True or False in the BN layer.')
parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.')
parser.add_argument('--overwite_epochs', type=int, help='The number of epochs to overwrite that value in config files.')
# architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding')
parser.add_argument('--arch_eps' , type=float, default=1e-8, help='weight decay for arch encoding')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
parser.add_argument('--print_freq', type=int, default=200, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
dirname = '{:}-affine{:}_BN{:}-AWD{:}-WARM{:}'.format(args.algo, args.affine, args.track_running_stats, args.arch_weight_decay, args.warmup_ratio)
if args.overwite_epochs is not None:
dirname = dirname + '-E{:}'.format(args.overwite_epochs)
args.save_dir = os.path.join('{:}-{:}'.format(args.save_dir, args.search_space), args.dataset, dirname)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
parser.add_argument("--search_space", type=str, default="sss", choices=["sss"], help="The search space name.")
parser.add_argument("--algo", type=str, choices=["tas", "mask_gumbel", "mask_rl"], help="The search space name.")
parser.add_argument(
"--genotype",
type=str,
default="|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|",
help="The genotype.",
)
parser.add_argument(
"--use_api", type=int, default=1, choices=[0, 1], help="Whether use API or not (which will cost much memory)."
)
# FOR GDAS
parser.add_argument("--tau_min", type=float, default=0.1, help="The minimum tau for Gumbel Softmax.")
parser.add_argument("--tau_max", type=float, default=10, help="The maximum tau for Gumbel Softmax.")
# FOR ALL
parser.add_argument("--warmup_ratio", type=float, help="The warmup ratio, if None, not use warmup.")
#
parser.add_argument(
"--track_running_stats",
type=int,
default=0,
choices=[0, 1],
help="Whether use track_running_stats or not in the BN layer.",
)
parser.add_argument(
"--affine", type=int, default=0, choices=[0, 1], help="Whether use affine=True or False in the BN layer."
)
parser.add_argument(
"--config_path",
type=str,
default="./configs/nas-benchmark/algos/weight-sharing.config",
help="The path of configuration.",
)
parser.add_argument(
"--overwite_epochs", type=int, help="The number of epochs to overwrite that value in config files."
)
# architecture leraning rate
parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding")
parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding")
parser.add_argument("--arch_eps", type=float, default=1e-8, help="weight decay for arch encoding")
# log
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
parser.add_argument("--save_dir", type=str, default="./output/search", help="Folder to save checkpoints and log.")
parser.add_argument("--print_freq", type=int, default=200, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
dirname = "{:}-affine{:}_BN{:}-AWD{:}-WARM{:}".format(
args.algo, args.affine, args.track_running_stats, args.arch_weight_decay, args.warmup_ratio
)
if args.overwite_epochs is not None:
dirname = dirname + "-E{:}".format(args.overwite_epochs)
args.save_dir = os.path.join("{:}-{:}".format(args.save_dir, args.search_space), args.dataset, dirname)
main(args)
main(args)