Reformulate via black

This commit is contained in:
D-X-Y
2021-03-17 09:25:58 +00:00
parent a9093e41e1
commit f98edea22a
59 changed files with 12289 additions and 8918 deletions

View File

@@ -11,14 +11,17 @@ import os, sys, time, random, argparse
from copy import deepcopy
from pathlib import Path
import torch
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger
from log_utils import AverageMeter, time_string, convert_secs2time
from nas_201_api import NASBench201API as API
from models import CellStructure, get_search_spaces
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger
from log_utils import AverageMeter, time_string, convert_secs2time
from nas_201_api import NASBench201API as API
from models import CellStructure, get_search_spaces
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
import ConfigSpace
from hpbandster.optimizers.bohb import BOHB
@@ -27,209 +30,258 @@ from hpbandster.core.worker import Worker
def get_configuration_space(max_nodes, search_space):
cs = ConfigSpace.ConfigurationSpace()
#edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space))
return cs
cs = ConfigSpace.ConfigurationSpace()
# edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
cs.add_hyperparameter(ConfigSpace.CategoricalHyperparameter(node_str, search_space))
return cs
def config2structure_func(max_nodes):
def config2structure(config):
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = config[node_str]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
return config2structure
def config2structure(config):
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_name = config[node_str]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return CellStructure(genotypes)
return config2structure
class MyWorker(Worker):
def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs):
super().__init__(*args, **kwargs)
self.convert_func = convert_func
self._dataname = dataname
self._nas_bench = nas_bench
self.time_budget = time_budget
self.seen_archs = []
self.sim_cost_time = 0
self.real_cost_time = 0
self.is_end = False
def __init__(self, *args, convert_func=None, dataname=None, nas_bench=None, time_budget=None, **kwargs):
super().__init__(*args, **kwargs)
self.convert_func = convert_func
self._dataname = dataname
self._nas_bench = nas_bench
self.time_budget = time_budget
self.seen_archs = []
self.sim_cost_time = 0
self.real_cost_time = 0
self.is_end = False
def get_the_best(self):
assert len(self.seen_archs) > 0
best_index, best_acc = -1, None
for arch_index in self.seen_archs:
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp="200", is_random=True)
vacc = info["valid-accuracy"]
if best_acc is None or best_acc < vacc:
best_acc = vacc
best_index = arch_index
assert best_index != -1
return best_index
def get_the_best(self):
assert len(self.seen_archs) > 0
best_index, best_acc = -1, None
for arch_index in self.seen_archs:
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True)
vacc = info['valid-accuracy']
if best_acc is None or best_acc < vacc:
best_acc = vacc
best_index = arch_index
assert best_index != -1
return best_index
def compute(self, config, budget, **kwargs):
start_time = time.time()
structure = self.convert_func( config )
arch_index = self._nas_bench.query_index_by_arch( structure )
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp='200', is_random=True)
cur_time = info['train-all-time'] + info['valid-per-time']
cur_vacc = info['valid-accuracy']
self.real_cost_time += (time.time() - start_time)
if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end:
self.sim_cost_time += cur_time
self.seen_archs.append( arch_index )
return ({'loss': 100 - float(cur_vacc),
'info': {'seen-arch' : len(self.seen_archs),
'sim-test-time' : self.sim_cost_time,
'current-arch' : arch_index}
})
else:
self.is_end = True
return ({'loss': 100,
'info': {'seen-arch' : len(self.seen_archs),
'sim-test-time' : self.sim_cost_time,
'current-arch' : None}
})
def compute(self, config, budget, **kwargs):
start_time = time.time()
structure = self.convert_func(config)
arch_index = self._nas_bench.query_index_by_arch(structure)
info = self._nas_bench.get_more_info(arch_index, self._dataname, None, hp="200", is_random=True)
cur_time = info["train-all-time"] + info["valid-per-time"]
cur_vacc = info["valid-accuracy"]
self.real_cost_time += time.time() - start_time
if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end:
self.sim_cost_time += cur_time
self.seen_archs.append(arch_index)
return {
"loss": 100 - float(cur_vacc),
"info": {
"seen-arch": len(self.seen_archs),
"sim-test-time": self.sim_cost_time,
"current-arch": arch_index,
},
}
else:
self.is_end = True
return {
"loss": 100,
"info": {"seen-arch": len(self.seen_archs), "sim-test-time": self.sim_cost_time, "current-arch": None},
}
def main(xargs, nas_bench):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
if xargs.dataset == 'cifar10':
dataname = 'cifar10-valid'
else:
dataname = xargs.dataset
if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
config_path = 'configs/nas-benchmark/algos/R-EA.config'
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
else:
config_path = 'configs/nas-benchmark/algos/R-EA.config'
config = load_config(config_path, None, logger)
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
extra_info = {'config': config, 'train_loader': None, 'valid_loader': None}
if xargs.dataset == "cifar10":
dataname = "cifar10-valid"
else:
dataname = xargs.dataset
if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = "configs/nas-benchmark/cifar-split.txt"
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log("Load split file from {:}".format(split_Fpath))
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
num_workers=xargs.workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
num_workers=xargs.workers,
pin_memory=True,
)
logger.log(
"||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(train_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader}
else:
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, None, logger)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {"config": config, "train_loader": None, "valid_loader": None}
# nas dataset load
assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset)
search_space = get_search_spaces('cell', xargs.search_space_name)
cs = get_configuration_space(xargs.max_nodes, search_space)
# nas dataset load
assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset)
search_space = get_search_spaces("cell", xargs.search_space_name)
cs = get_configuration_space(xargs.max_nodes, search_space)
config2structure = config2structure_func(xargs.max_nodes)
hb_run_id = '0'
config2structure = config2structure_func(xargs.max_nodes)
hb_run_id = "0"
NS = hpns.NameServer(run_id=hb_run_id, host='localhost', port=0)
ns_host, ns_port = NS.start()
num_workers = 1
NS = hpns.NameServer(run_id=hb_run_id, host="localhost", port=0)
ns_host, ns_port = NS.start()
num_workers = 1
#nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
#logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
workers = []
for i in range(num_workers):
w = MyWorker(nameserver=ns_host, nameserver_port=ns_port, convert_func=config2structure, dataname=dataname, nas_bench=nas_bench, time_budget=xargs.time_budget, run_id=hb_run_id, id=i)
w.run(background=True)
workers.append(w)
start_time = time.time()
bohb = BOHB(configspace=cs,
run_id=hb_run_id,
eta=3, min_budget=12, max_budget=200,
# nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
# logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
workers = []
for i in range(num_workers):
w = MyWorker(
nameserver=ns_host,
nameserver_port=ns_port,
num_samples=xargs.num_samples,
random_fraction=xargs.random_fraction, bandwidth_factor=xargs.bandwidth_factor,
ping_interval=10, min_bandwidth=xargs.min_bandwidth)
results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
convert_func=config2structure,
dataname=dataname,
nas_bench=nas_bench,
time_budget=xargs.time_budget,
run_id=hb_run_id,
id=i,
)
w.run(background=True)
workers.append(w)
bohb.shutdown(shutdown_workers=True)
NS.shutdown()
start_time = time.time()
bohb = BOHB(
configspace=cs,
run_id=hb_run_id,
eta=3,
min_budget=12,
max_budget=200,
nameserver=ns_host,
nameserver_port=ns_port,
num_samples=xargs.num_samples,
random_fraction=xargs.random_fraction,
bandwidth_factor=xargs.bandwidth_factor,
ping_interval=10,
min_bandwidth=xargs.min_bandwidth,
)
real_cost_time = time.time() - start_time
results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
id2config = results.get_id2config_mapping()
incumbent = results.get_incumbent_id()
logger.log('Best found configuration: {:} within {:.3f} s'.format(id2config[incumbent]['config'], real_cost_time))
best_arch = config2structure( id2config[incumbent]['config'] )
bohb.shutdown(shutdown_workers=True)
NS.shutdown()
info = nas_bench.query_by_arch(best_arch, '200')
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
real_cost_time = time.time() - start_time
logger.log('workers : {:.1f}s with {:} archs'.format(workers[0].time_budget, len(workers[0].seen_archs)))
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch( best_arch ), real_cost_time
id2config = results.get_id2config_mapping()
incumbent = results.get_incumbent_id()
logger.log("Best found configuration: {:} within {:.3f} s".format(id2config[incumbent]["config"], real_cost_time))
best_arch = config2structure(id2config[incumbent]["config"])
info = nas_bench.query_by_arch(best_arch, "200")
if info is None:
logger.log("Did not find this architecture : {:}.".format(best_arch))
else:
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.log("workers : {:.1f}s with {:} archs".format(workers[0].time_budget, len(workers[0].seen_archs)))
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch(best_arch), real_cost_time
if __name__ == '__main__':
parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
# BOHB
parser.add_argument('--strategy', default="sampling", type=str, nargs='?', help='optimization strategy for the acquisition function')
parser.add_argument('--min_bandwidth', default=.3, type=float, nargs='?', help='minimum bandwidth for KDE')
parser.add_argument('--num_samples', default=64, type=int, nargs='?', help='number of samples for the acquisition function')
parser.add_argument('--random_fraction', default=.33, type=float, nargs='?', help='fraction of random configurations')
parser.add_argument('--bandwidth_factor', default=3, type=int, nargs='?', help='factor multiplied to the bandwidth')
parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
#if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num, all_times = None, [], 500, []
for i in range(num):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000)
save_dir, index, ctime = main(args, nas_bench)
all_indexes.append( index )
all_times.append( ctime )
print ('\n average time : {:.3f} s'.format(sum(all_times)/len(all_times)))
torch.save(all_indexes, save_dir / 'results.pth')
else:
main(args, nas_bench)
if __name__ == "__main__":
parser = argparse.ArgumentParser("BOHB: Robust and Efficient Hyperparameter Optimization at Scale")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.")
parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).")
# BOHB
parser.add_argument(
"--strategy", default="sampling", type=str, nargs="?", help="optimization strategy for the acquisition function"
)
parser.add_argument("--min_bandwidth", default=0.3, type=float, nargs="?", help="minimum bandwidth for KDE")
parser.add_argument(
"--num_samples", default=64, type=int, nargs="?", help="number of samples for the acquisition function"
)
parser.add_argument(
"--random_fraction", default=0.33, type=float, nargs="?", help="fraction of random configurations"
)
parser.add_argument("--bandwidth_factor", default=3, type=int, nargs="?", help="factor multiplied to the bandwidth")
parser.add_argument(
"--n_iters", default=100, type=int, nargs="?", help="number of iterations for optimization method"
)
# log
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.")
parser.add_argument(
"--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)."
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
# if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset))
nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num, all_times = None, [], 500, []
for i in range(num):
print("{:} : {:03d}/{:03d}".format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000)
save_dir, index, ctime = main(args, nas_bench)
all_indexes.append(index)
all_times.append(ctime)
print("\n average time : {:.3f} s".format(sum(all_times) / len(all_times)))
torch.save(all_indexes, save_dir / "results.pth")
else:
main(args, nas_bench)

View File

@@ -7,232 +7,330 @@ import sys, time, random, argparse
from copy import deepcopy
import torch
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger, gradient_clip):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.train()
end = time.time()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
w_optimizer.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
if gradient_clip > 0: torch.nn.utils.clip_grad_norm_(network.parameters(), gradient_clip)
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update (base_prec1.item(), base_inputs.size(0))
base_top5.update (base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
a_optimizer.zero_grad()
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
arch_loss.backward()
a_optimizer.step()
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
def search_func(
xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger, gradient_clip
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.train()
end = time.time()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
return base_losses.avg, base_top1.avg, base_top5.avg
# update the weights
w_optimizer.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
if gradient_clip > 0:
torch.nn.utils.clip_grad_norm_(network.parameters(), gradient_clip)
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update(base_prec1.item(), base_inputs.size(0))
base_top5.update(base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
a_optimizer.zero_grad()
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
arch_loss.backward()
a_optimizer.step()
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=base_losses, top1=base_top1, top5=base_top5
)
Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=arch_losses, top1=arch_top1, top5=arch_top5
)
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr)
return base_losses.avg, base_top1.avg, base_top5.avg
def valid_func(xloader, network, criterion):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
#config_path = 'configs/nas-benchmark/algos/DARTS.config'
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
# config_path = 'configs/nas-benchmark/algos/DARTS.config'
config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger)
search_loader, _, valid_loader = get_nas_search_loaders(
train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers
)
logger.log(
"||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(search_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
search_space = get_search_spaces('cell', xargs.search_space_name)
if xargs.model_config is None:
model_config = dict2config({'name': 'DARTS-V1', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
else:
model_config = load_config(xargs.model_config, {'num_classes': class_num, 'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
search_model = get_cell_based_tiny_net(model_config)
logger.log('search-model :\n{:}'.format(search_model))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
logger.log('w-optimizer : {:}'.format(w_optimizer))
logger.log('a-optimizer : {:}'.format(a_optimizer))
logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion))
flop, param = get_model_infos(search_model, xshape)
#logger.log('{:}'.format(search_model))
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))
search_space = get_search_spaces("cell", xargs.search_space_name)
if xargs.model_config is None:
model_config = dict2config(
{
"name": "DARTS-V1",
"C": xargs.channel,
"N": xargs.num_cells,
"max_nodes": xargs.max_nodes,
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
else:
model_config = load_config(
xargs.model_config,
{
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
search_model = get_cell_based_tiny_net(model_config)
logger.log("search-model :\n{:}".format(search_model))
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
a_optimizer = torch.optim.Adam(
search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay
)
logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("a-optimizer : {:}".format(a_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion : {:}".format(criterion))
flop, param = get_model_infos(search_model, xshape)
# logger.log('{:}'.format(search_model))
logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log("{:} create API = {:} done".format(time_string(), api))
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info['epoch']
checkpoint = torch.load(last_info['last_checkpoint'])
genotypes = checkpoint['genotypes']
valid_accuracies = checkpoint['valid_accuracies']
search_model.load_state_dict( checkpoint['search_model'] )
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}
last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best")
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
# start training
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info["epoch"]
checkpoint = torch.load(last_info["last_checkpoint"])
genotypes = checkpoint["genotypes"]
valid_accuracies = checkpoint["valid_accuracies"]
search_model.load_state_dict(checkpoint["search_model"])
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
a_optimizer.load_state_dict(checkpoint["a_optimizer"])
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)
)
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: search_model.genotype()}
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger, xargs.gradient_clip)
search_time.update(time.time() - start_time)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies['best']:
valid_accuracies['best'] = valid_a_top1
genotypes['best'] = search_model.genotype()
find_best = True
else: find_best = False
# start training
start_time, search_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
AverageMeter(),
config.epochs + config.warmup,
)
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min(w_scheduler.get_lr())))
genotypes[epoch] = search_model.genotype()
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint({'epoch' : epoch + 1,
'args' : deepcopy(xargs),
'search_model': search_model.state_dict(),
'w_optimizer' : w_optimizer.state_dict(),
'a_optimizer' : a_optimizer.state_dict(),
'w_scheduler' : w_scheduler.state_dict(),
'genotypes' : genotypes,
'valid_accuracies' : valid_accuracies},
model_base_path, logger)
last_info = save_checkpoint({
'epoch': epoch + 1,
'args' : deepcopy(args),
'last_checkpoint': save_path,
}, logger.path('info'), logger)
if find_best:
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad():
#logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
logger.log('{:}'.format(search_model.show_alphas()))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
search_w_loss, search_w_top1, search_w_top5 = search_func(
search_loader,
network,
criterion,
w_scheduler,
w_optimizer,
a_optimizer,
epoch_str,
xargs.print_freq,
logger,
xargs.gradient_clip,
)
search_time.update(time.time() - start_time)
logger.log(
"[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
)
)
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log(
"[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
epoch_str, valid_a_loss, valid_a_top1, valid_a_top5
)
)
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies["best"]:
valid_accuracies["best"] = valid_a_top1
genotypes["best"] = search_model.genotype()
find_best = True
else:
find_best = False
logger.log('\n' + '-'*100)
logger.log('DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1], '200')))
logger.close()
genotypes[epoch] = search_model.genotype()
logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(xargs),
"search_model": search_model.state_dict(),
"w_optimizer": w_optimizer.state_dict(),
"a_optimizer": a_optimizer.state_dict(),
"w_scheduler": w_scheduler.state_dict(),
"genotypes": genotypes,
"valid_accuracies": valid_accuracies,
},
model_base_path,
logger,
)
last_info = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(args),
"last_checkpoint": save_path,
},
logger.path("info"),
logger,
)
if find_best:
logger.log(
"<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format(
epoch_str, valid_a_top1
)
)
copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad():
# logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
logger.log("{:}".format(search_model.show_alphas()))
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.log("\n" + "-" * 100)
logger.log(
"DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
total_epoch, search_time.sum, genotypes[total_epoch - 1]
)
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1], "200")))
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser("DARTS first order")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--config_path', type=str, help='The config path.')
parser.add_argument('--model_config', type=str, help='The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.')
parser.add_argument('--gradient_clip', type=float, default=5, help='')
# architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
main(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser("DARTS first order")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.")
parser.add_argument(
"--track_running_stats",
type=int,
choices=[0, 1],
help="Whether use track_running_stats or not in the BN layer.",
)
parser.add_argument("--config_path", type=str, help="The config path.")
parser.add_argument(
"--model_config",
type=str,
help="The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.",
)
parser.add_argument("--gradient_clip", type=float, default=5, help="")
# architecture leraning rate
parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding")
parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding")
# log
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.")
parser.add_argument(
"--arch_nas_dataset", type=str, help="The path to load the architecture dataset (nas-benchmark)."
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
main(args)

View File

@@ -9,290 +9,381 @@ from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
def _concat(xs):
return torch.cat([x.view(-1) for x in xs])
return torch.cat([x.view(-1) for x in xs])
def _hessian_vector_product(vector, network, criterion, base_inputs, base_targets, r=1e-2):
R = r / _concat(vector).norm()
for p, v in zip(network.module.get_weights(), vector):
p.data.add_(R, v)
_, logits = network(base_inputs)
loss = criterion(logits, base_targets)
grads_p = torch.autograd.grad(loss, network.module.get_alphas())
R = r / _concat(vector).norm()
for p, v in zip(network.module.get_weights(), vector):
p.data.add_(R, v)
_, logits = network(base_inputs)
loss = criterion(logits, base_targets)
grads_p = torch.autograd.grad(loss, network.module.get_alphas())
for p, v in zip(network.module.get_weights(), vector):
p.data.sub_(2*R, v)
_, logits = network(base_inputs)
loss = criterion(logits, base_targets)
grads_n = torch.autograd.grad(loss, network.module.get_alphas())
for p, v in zip(network.module.get_weights(), vector):
p.data.sub_(2 * R, v)
_, logits = network(base_inputs)
loss = criterion(logits, base_targets)
grads_n = torch.autograd.grad(loss, network.module.get_alphas())
for p, v in zip(network.module.get_weights(), vector):
p.data.add_(R, v)
return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]
for p, v in zip(network.module.get_weights(), vector):
p.data.add_(R, v)
return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)]
def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets):
# _compute_unrolled_model
_, logits = network(base_inputs)
loss = criterion(logits, base_targets)
LR, WD, momentum = w_optimizer.param_groups[0]['lr'], w_optimizer.param_groups[0]['weight_decay'], w_optimizer.param_groups[0]['momentum']
with torch.no_grad():
theta = _concat(network.module.get_weights())
try:
moment = _concat(w_optimizer.state[v]['momentum_buffer'] for v in network.module.get_weights())
moment = moment.mul_(momentum)
except:
moment = torch.zeros_like(theta)
dtheta = _concat(torch.autograd.grad(loss, network.module.get_weights())) + WD*theta
params = theta.sub(LR, moment+dtheta)
unrolled_model = deepcopy(network)
model_dict = unrolled_model.state_dict()
new_params, offset = {}, 0
for k, v in network.named_parameters():
if 'arch_parameters' in k: continue
v_length = np.prod(v.size())
new_params[k] = params[offset: offset+v_length].view(v.size())
offset += v_length
model_dict.update(new_params)
unrolled_model.load_state_dict(model_dict)
# _compute_unrolled_model
_, logits = network(base_inputs)
loss = criterion(logits, base_targets)
LR, WD, momentum = (
w_optimizer.param_groups[0]["lr"],
w_optimizer.param_groups[0]["weight_decay"],
w_optimizer.param_groups[0]["momentum"],
)
with torch.no_grad():
theta = _concat(network.module.get_weights())
try:
moment = _concat(w_optimizer.state[v]["momentum_buffer"] for v in network.module.get_weights())
moment = moment.mul_(momentum)
except:
moment = torch.zeros_like(theta)
dtheta = _concat(torch.autograd.grad(loss, network.module.get_weights())) + WD * theta
params = theta.sub(LR, moment + dtheta)
unrolled_model = deepcopy(network)
model_dict = unrolled_model.state_dict()
new_params, offset = {}, 0
for k, v in network.named_parameters():
if "arch_parameters" in k:
continue
v_length = np.prod(v.size())
new_params[k] = params[offset : offset + v_length].view(v.size())
offset += v_length
model_dict.update(new_params)
unrolled_model.load_state_dict(model_dict)
unrolled_model.zero_grad()
_, unrolled_logits = unrolled_model(arch_inputs)
unrolled_loss = criterion(unrolled_logits, arch_targets)
unrolled_loss.backward()
unrolled_model.zero_grad()
_, unrolled_logits = unrolled_model(arch_inputs)
unrolled_loss = criterion(unrolled_logits, arch_targets)
unrolled_loss.backward()
dalpha = unrolled_model.module.arch_parameters.grad
vector = [v.grad.data for v in unrolled_model.module.get_weights()]
[implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets)
dalpha.data.sub_(LR, implicit_grads.data)
dalpha = unrolled_model.module.arch_parameters.grad
vector = [v.grad.data for v in unrolled_model.module.get_weights()]
[implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets)
dalpha.data.sub_(LR, implicit_grads.data)
if network.module.arch_parameters.grad is None:
network.module.arch_parameters.grad = deepcopy(dalpha)
else:
network.module.arch_parameters.grad.data.copy_(dalpha.data)
return unrolled_loss.detach(), unrolled_logits.detach()
if network.module.arch_parameters.grad is None:
network.module.arch_parameters.grad = deepcopy( dalpha )
else:
network.module.arch_parameters.grad.data.copy_( dalpha.data )
return unrolled_loss.detach(), unrolled_logits.detach()
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.train()
end = time.time()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the architecture-weight
a_optimizer.zero_grad()
arch_loss, arch_logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets)
a_optimizer.step()
# record
arch_prec1, arch_prec5 = obtain_accuracy(arch_logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# update the weights
w_optimizer.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update (base_prec1.item(), base_inputs.size(0))
base_top5.update (base_prec5.item(), base_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.train()
end = time.time()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
return base_losses.avg, base_top1.avg, base_top5.avg
# update the architecture-weight
a_optimizer.zero_grad()
arch_loss, arch_logits = backward_step_unrolled(
network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets
)
a_optimizer.step()
# record
arch_prec1, arch_prec5 = obtain_accuracy(arch_logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# update the weights
w_optimizer.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update(base_prec1.item(), base_inputs.size(0))
base_top5.update(base_prec5.item(), base_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=base_losses, top1=base_top1, top5=base_top5
)
Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=arch_losses, top1=arch_top1, top5=arch_top5
)
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr)
return base_losses.avg, base_top1.avg, base_top5.avg
def valid_func(xloader, network, criterion):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger)
search_loader, _, valid_loader = get_nas_search_loaders(
train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers
)
logger.log(
"||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(search_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'DARTS-V2', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
search_model = get_cell_based_tiny_net(model_config)
logger.log('search-model :\n{:}'.format(search_model))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
logger.log('w-optimizer : {:}'.format(w_optimizer))
logger.log('a-optimizer : {:}'.format(a_optimizer))
logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion))
flop, param = get_model_infos(search_model, xshape)
#logger.log('{:}'.format(search_model))
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))
search_space = get_search_spaces("cell", xargs.search_space_name)
model_config = dict2config(
{
"name": "DARTS-V2",
"C": xargs.channel,
"N": xargs.num_cells,
"max_nodes": xargs.max_nodes,
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
search_model = get_cell_based_tiny_net(model_config)
logger.log("search-model :\n{:}".format(search_model))
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
a_optimizer = torch.optim.Adam(
search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay
)
logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("a-optimizer : {:}".format(a_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion : {:}".format(criterion))
flop, param = get_model_infos(search_model, xshape)
# logger.log('{:}'.format(search_model))
logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log("{:} create API = {:} done".format(time_string(), api))
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info['epoch']
checkpoint = torch.load(last_info['last_checkpoint'])
genotypes = checkpoint['genotypes']
valid_accuracies = checkpoint['valid_accuracies']
search_model.load_state_dict( checkpoint['search_model'] )
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}
last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best")
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
# start training
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
min_LR = min(w_scheduler.get_lr())
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min_LR))
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info["epoch"]
checkpoint = torch.load(last_info["last_checkpoint"])
genotypes = checkpoint["genotypes"]
valid_accuracies = checkpoint["valid_accuracies"]
search_model.load_state_dict(checkpoint["search_model"])
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
a_optimizer.load_state_dict(checkpoint["a_optimizer"])
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)
)
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: search_model.genotype()}
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
search_time.update(time.time() - start_time)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies['best']:
valid_accuracies['best'] = valid_a_top1
genotypes['best'] = search_model.genotype()
find_best = True
else: find_best = False
# start training
start_time, search_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
AverageMeter(),
config.epochs + config.warmup,
)
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
min_LR = min(w_scheduler.get_lr())
logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min_LR))
genotypes[epoch] = search_model.genotype()
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint({'epoch' : epoch + 1,
'args' : deepcopy(xargs),
'search_model': search_model.state_dict(),
'w_optimizer' : w_optimizer.state_dict(),
'a_optimizer' : a_optimizer.state_dict(),
'w_scheduler' : w_scheduler.state_dict(),
'genotypes' : genotypes,
'valid_accuracies' : valid_accuracies},
model_base_path, logger)
last_info = save_checkpoint({
'epoch': epoch + 1,
'args' : deepcopy(args),
'last_checkpoint': save_path,
}, logger.path('info'), logger)
if find_best:
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad():
logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
search_w_loss, search_w_top1, search_w_top5 = search_func(
search_loader,
network,
criterion,
w_scheduler,
w_optimizer,
a_optimizer,
epoch_str,
xargs.print_freq,
logger,
)
search_time.update(time.time() - start_time)
logger.log(
"[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
)
)
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log(
"[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
epoch_str, valid_a_loss, valid_a_top1, valid_a_top5
)
)
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies["best"]:
valid_accuracies["best"] = valid_a_top1
genotypes["best"] = search_model.genotype()
find_best = True
else:
find_best = False
logger.log('\n' + '-'*100)
# check the performance from the architecture dataset
logger.log('DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1]), '200'))
logger.close()
genotypes[epoch] = search_model.genotype()
logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(xargs),
"search_model": search_model.state_dict(),
"w_optimizer": w_optimizer.state_dict(),
"a_optimizer": a_optimizer.state_dict(),
"w_scheduler": w_scheduler.state_dict(),
"genotypes": genotypes,
"valid_accuracies": valid_accuracies,
},
model_base_path,
logger,
)
last_info = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(args),
"last_checkpoint": save_path,
},
logger.path("info"),
logger,
)
if find_best:
logger.log(
"<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format(
epoch_str, valid_a_top1
)
)
copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad():
logger.log(
"arch-parameters :\n{:}".format(nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu())
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.log("\n" + "-" * 100)
# check the performance from the architecture dataset
logger.log(
"DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
total_epoch, search_time.sum, genotypes[total_epoch - 1]
)
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1]), "200"))
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser("DARTS Second Order")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--config_path', type=str, help='The config path.')
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
# architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
main(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser("DARTS Second Order")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--config_path", type=str, help="The config path.")
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.")
parser.add_argument(
"--track_running_stats",
type=int,
choices=[0, 1],
help="Whether use track_running_stats or not in the BN layer.",
)
# architecture leraning rate
parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding")
parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding")
# log
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.")
parser.add_argument(
"--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)."
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
main(args)

View File

@@ -9,336 +9,447 @@ from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler, optimizer, epoch_str, print_freq, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
losses, top1s, top5s, xend = AverageMeter(), AverageMeter(), AverageMeter(), time.time()
shared_cnn.train()
controller.eval()
data_time, batch_time = AverageMeter(), AverageMeter()
losses, top1s, top5s, xend = AverageMeter(), AverageMeter(), AverageMeter(), time.time()
for step, (inputs, targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
targets = targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - xend)
with torch.no_grad():
_, _, sampled_arch = controller()
shared_cnn.train()
controller.eval()
optimizer.zero_grad()
shared_cnn.module.update_arch(sampled_arch)
_, logits = shared_cnn(inputs)
loss = criterion(logits, targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5)
optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1s.update (prec1.item(), inputs.size(0))
top5s.update (prec5.item(), inputs.size(0))
for step, (inputs, targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
targets = targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - xend)
# measure elapsed time
batch_time.update(time.time() - xend)
xend = time.time()
with torch.no_grad():
_, _, sampled_arch = controller()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = '*Train-Shared-CNN* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=losses, top1=top1s, top5=top5s)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr)
return losses.avg, top1s.avg, top5s.avg
optimizer.zero_grad()
shared_cnn.module.update_arch(sampled_arch)
_, logits = shared_cnn(inputs)
loss = criterion(logits, targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5)
optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1s.update(prec1.item(), inputs.size(0))
top5s.update(prec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - xend)
xend = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = "*Train-Shared-CNN* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=losses, top1=top1s, top5=top5s
)
logger.log(Sstr + " " + Tstr + " " + Wstr)
return losses.avg, top1s.avg, top5s.avg
def train_controller(xloader, shared_cnn, controller, criterion, optimizer, config, epoch_str, print_freq, logger):
# config. (containing some necessary arg)
# baseline: The baseline score (i.e. average val_acc) from the previous epoch
data_time, batch_time = AverageMeter(), AverageMeter()
GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time()
shared_cnn.eval()
controller.train()
controller.zero_grad()
#for step, (inputs, targets) in enumerate(xloader):
loader_iter = iter(xloader)
for step in range(config.ctl_train_steps * config.ctl_num_aggre):
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
targets = targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - xend)
log_prob, entropy, sampled_arch = controller()
with torch.no_grad():
shared_cnn.module.update_arch(sampled_arch)
_, logits = shared_cnn(inputs)
val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
val_top1 = val_top1.view(-1) / 100
reward = val_top1 + config.ctl_entropy_w * entropy
if config.baseline is None:
baseline = val_top1
else:
baseline = config.baseline - (1 - config.ctl_bl_dec) * (config.baseline - reward)
loss = -1 * log_prob * (reward - baseline)
# account
RewardMeter.update(reward.item())
BaselineMeter.update(baseline.item())
ValAccMeter.update(val_top1.item()*100)
LossMeter.update(loss.item())
EntropyMeter.update(entropy.item())
# Average gradient over controller_num_aggregate samples
loss = loss / config.ctl_num_aggre
loss.backward(retain_graph=True)
# config. (containing some necessary arg)
# baseline: The baseline score (i.e. average val_acc) from the previous epoch
data_time, batch_time = AverageMeter(), AverageMeter()
GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = (
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
time.time(),
)
# measure elapsed time
batch_time.update(time.time() - xend)
xend = time.time()
if (step+1) % config.ctl_num_aggre == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), 5.0)
GradnormMeter.update(grad_norm)
optimizer.step()
controller.zero_grad()
shared_cnn.eval()
controller.train()
controller.zero_grad()
# for step, (inputs, targets) in enumerate(xloader):
loader_iter = iter(xloader)
for step in range(config.ctl_train_steps * config.ctl_num_aggre):
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
targets = targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - xend)
if step % print_freq == 0:
Sstr = '*Train-Controller* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, config.ctl_train_steps * config.ctl_num_aggre)
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter)
Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val, EntropyMeter.avg)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr)
log_prob, entropy, sampled_arch = controller()
with torch.no_grad():
shared_cnn.module.update_arch(sampled_arch)
_, logits = shared_cnn(inputs)
val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
val_top1 = val_top1.view(-1) / 100
reward = val_top1 + config.ctl_entropy_w * entropy
if config.baseline is None:
baseline = val_top1
else:
baseline = config.baseline - (1 - config.ctl_bl_dec) * (config.baseline - reward)
return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg, baseline.item()
loss = -1 * log_prob * (reward - baseline)
# account
RewardMeter.update(reward.item())
BaselineMeter.update(baseline.item())
ValAccMeter.update(val_top1.item() * 100)
LossMeter.update(loss.item())
EntropyMeter.update(entropy.item())
# Average gradient over controller_num_aggregate samples
loss = loss / config.ctl_num_aggre
loss.backward(retain_graph=True)
# measure elapsed time
batch_time.update(time.time() - xend)
xend = time.time()
if (step + 1) % config.ctl_num_aggre == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), 5.0)
GradnormMeter.update(grad_norm)
optimizer.step()
controller.zero_grad()
if step % print_freq == 0:
Sstr = (
"*Train-Controller* "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, config.ctl_train_steps * config.ctl_num_aggre)
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})".format(
loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter
)
Estr = "Entropy={:.4f} ({:.4f})".format(EntropyMeter.val, EntropyMeter.avg)
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Estr)
return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg, baseline.item()
def get_best_arch(controller, shared_cnn, xloader, n_samples=10):
with torch.no_grad():
controller.eval()
shared_cnn.eval()
archs, valid_accs = [], []
loader_iter = iter(xloader)
for i in range(n_samples):
try:
inputs, targets = next(loader_iter)
except:
with torch.no_grad():
controller.eval()
shared_cnn.eval()
archs, valid_accs = [], []
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
for i in range(n_samples):
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
_, _, sampled_arch = controller()
arch = shared_cnn.module.update_arch(sampled_arch)
_, logits = shared_cnn(inputs)
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
_, _, sampled_arch = controller()
arch = shared_cnn.module.update_arch(sampled_arch)
_, logits = shared_cnn(inputs)
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
archs.append( arch )
valid_accs.append( val_top1.item() )
archs.append(arch)
valid_accs.append(val_top1.item())
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
return best_arch, best_valid_acc
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
return best_arch, best_valid_acc
def valid_func(xloader, network, criterion):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
logger.log('use config from : {:}'.format(xargs.config_path))
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
_, train_loader, valid_loader = get_nas_search_loaders(train_data, test_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
# since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader
valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform)
if hasattr(valid_loader.dataset, 'transforms'):
valid_loader.dataset.transforms = deepcopy(train_loader.dataset.transforms)
# data loader
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
train_data, test_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
logger.log("use config from : {:}".format(xargs.config_path))
config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger)
_, train_loader, valid_loader = get_nas_search_loaders(
train_data, test_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers
)
# since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader
valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform)
if hasattr(valid_loader.dataset, "transforms"):
valid_loader.dataset.transforms = deepcopy(train_loader.dataset.transforms)
# data loader
logger.log(
"||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(train_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'ENAS', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
shared_cnn = get_cell_based_tiny_net(model_config)
controller = shared_cnn.create_controller()
w_optimizer, w_scheduler, criterion = get_optim_scheduler(shared_cnn.parameters(), config)
a_optimizer = torch.optim.Adam(controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps)
logger.log('w-optimizer : {:}'.format(w_optimizer))
logger.log('a-optimizer : {:}'.format(a_optimizer))
logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion))
#flop, param = get_model_infos(shared_cnn, xshape)
#logger.log('{:}'.format(shared_cnn))
#logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
logger.log('search-space : {:}'.format(search_space))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))
shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda()
search_space = get_search_spaces("cell", xargs.search_space_name)
model_config = dict2config(
{
"name": "ENAS",
"C": xargs.channel,
"N": xargs.num_cells,
"max_nodes": xargs.max_nodes,
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
shared_cnn = get_cell_based_tiny_net(model_config)
controller = shared_cnn.create_controller()
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
w_optimizer, w_scheduler, criterion = get_optim_scheduler(shared_cnn.parameters(), config)
a_optimizer = torch.optim.Adam(
controller.parameters(), lr=config.controller_lr, betas=config.controller_betas, eps=config.controller_eps
)
logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("a-optimizer : {:}".format(a_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion : {:}".format(criterion))
# flop, param = get_model_infos(shared_cnn, xshape)
# logger.log('{:}'.format(shared_cnn))
# logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
logger.log("search-space : {:}".format(search_space))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log("{:} create API = {:} done".format(time_string(), api))
shared_cnn, controller, criterion = torch.nn.DataParallel(shared_cnn).cuda(), controller.cuda(), criterion.cuda()
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info['epoch']
checkpoint = torch.load(last_info['last_checkpoint'])
genotypes = checkpoint['genotypes']
baseline = checkpoint['baseline']
valid_accuracies = checkpoint['valid_accuracies']
shared_cnn.load_state_dict( checkpoint['shared_cnn'] )
controller.load_state_dict( checkpoint['controller'] )
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes, baseline = 0, {'best': -1}, {}, None
last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best")
# start training
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr()), baseline))
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info["epoch"]
checkpoint = torch.load(last_info["last_checkpoint"])
genotypes = checkpoint["genotypes"]
baseline = checkpoint["baseline"]
valid_accuracies = checkpoint["valid_accuracies"]
shared_cnn.load_state_dict(checkpoint["shared_cnn"])
controller.load_state_dict(checkpoint["controller"])
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
a_optimizer.load_state_dict(checkpoint["a_optimizer"])
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)
)
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes, baseline = 0, {"best": -1}, {}, None
cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(train_loader, shared_cnn, controller, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
logger.log('[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, cnn_loss, cnn_top1, cnn_top5))
ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline \
= train_controller(valid_loader, shared_cnn, controller, criterion, a_optimizer, \
dict2config({'baseline': baseline,
'ctl_train_steps': xargs.controller_train_steps, 'ctl_num_aggre': xargs.controller_num_aggregate,
'ctl_entropy_w': xargs.controller_entropy_weight,
'ctl_bl_dec' : xargs.controller_bl_dec}, None), \
epoch_str, xargs.print_freq, logger)
search_time.update(time.time() - start_time)
logger.log('[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s'.format(epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum))
best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader)
shared_cnn.module.update_arch(best_arch)
_, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion)
# start training
start_time, search_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
AverageMeter(),
config.epochs + config.warmup,
)
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
logger.log(
"\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}".format(
epoch_str, need_time, min(w_scheduler.get_lr()), baseline
)
)
genotypes[epoch] = best_arch
# check the best accuracy
valid_accuracies[epoch] = best_valid_acc
if best_valid_acc > valid_accuracies['best']:
valid_accuracies['best'] = best_valid_acc
genotypes['best'] = best_arch
find_best = True
else: find_best = False
cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(
train_loader,
shared_cnn,
controller,
criterion,
w_scheduler,
w_optimizer,
epoch_str,
xargs.print_freq,
logger,
)
logger.log(
"[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
epoch_str, cnn_loss, cnn_top1, cnn_top5
)
)
ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline = train_controller(
valid_loader,
shared_cnn,
controller,
criterion,
a_optimizer,
dict2config(
{
"baseline": baseline,
"ctl_train_steps": xargs.controller_train_steps,
"ctl_num_aggre": xargs.controller_num_aggregate,
"ctl_entropy_w": xargs.controller_entropy_weight,
"ctl_bl_dec": xargs.controller_bl_dec,
},
None,
),
epoch_str,
xargs.print_freq,
logger,
)
search_time.update(time.time() - start_time)
logger.log(
"[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s".format(
epoch_str, ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline, search_time.sum
)
)
best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader)
shared_cnn.module.update_arch(best_arch)
_, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion)
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint({'epoch' : epoch + 1,
'args' : deepcopy(xargs),
'baseline' : baseline,
'shared_cnn' : shared_cnn.state_dict(),
'controller' : controller.state_dict(),
'w_optimizer' : w_optimizer.state_dict(),
'a_optimizer' : a_optimizer.state_dict(),
'w_scheduler' : w_scheduler.state_dict(),
'genotypes' : genotypes,
'valid_accuracies' : valid_accuracies},
model_base_path, logger)
last_info = save_checkpoint({
'epoch': epoch + 1,
'args' : deepcopy(args),
'last_checkpoint': save_path,
}, logger.path('info'), logger)
if find_best:
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, best_valid_acc))
copy_checkpoint(model_base_path, model_best_path, logger)
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
genotypes[epoch] = best_arch
# check the best accuracy
valid_accuracies[epoch] = best_valid_acc
if best_valid_acc > valid_accuracies["best"]:
valid_accuracies["best"] = best_valid_acc
genotypes["best"] = best_arch
find_best = True
else:
find_best = False
logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(xargs),
"baseline": baseline,
"shared_cnn": shared_cnn.state_dict(),
"controller": controller.state_dict(),
"w_optimizer": w_optimizer.state_dict(),
"a_optimizer": a_optimizer.state_dict(),
"w_scheduler": w_scheduler.state_dict(),
"genotypes": genotypes,
"valid_accuracies": valid_accuracies,
},
model_base_path,
logger,
)
last_info = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(args),
"last_checkpoint": save_path,
},
logger.path("info"),
logger,
)
if find_best:
logger.log(
"<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format(
epoch_str, best_valid_acc
)
)
copy_checkpoint(model_base_path, model_best_path, logger)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.log("\n" + "-" * 100)
logger.log("During searching, the best architecture is {:}".format(genotypes["best"]))
logger.log("Its accuracy is {:.2f}%".format(valid_accuracies["best"]))
logger.log("Randomly select {:} architectures and select the best.".format(xargs.controller_num_samples))
start_time = time.time()
logger.log('\n' + '-'*100)
logger.log('During searching, the best architecture is {:}'.format(genotypes['best']))
logger.log('Its accuracy is {:.2f}%'.format(valid_accuracies['best']))
logger.log('Randomly select {:} architectures and select the best.'.format(xargs.controller_num_samples))
start_time = time.time()
final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples)
search_time.update(time.time() - start_time)
shared_cnn.module.update_arch(final_arch)
final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion)
logger.log('The Selected Final Architecture : {:}'.format(final_arch))
logger.log('Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%'.format(final_loss, final_top1, final_top5))
logger.log('ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, final_arch))
if api is not None: logger.log('{:}'.format( api.query_by_arch(final_arch) ))
logger.close()
final_arch, _ = get_best_arch(controller, shared_cnn, valid_loader, xargs.controller_num_samples)
search_time.update(time.time() - start_time)
shared_cnn.module.update_arch(final_arch)
final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion)
logger.log("The Selected Final Architecture : {:}".format(final_arch))
logger.log("Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%".format(final_loss, final_top1, final_top5))
logger.log(
"ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(total_epoch, search_time.sum, final_arch)
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(final_arch)))
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser("ENAS")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--config_path', type=str, help='The config file to train ENAS.')
parser.add_argument('--controller_train_steps', type=int, help='.')
parser.add_argument('--controller_num_aggregate', type=int, help='.')
parser.add_argument('--controller_entropy_weight', type=float, help='The weight for the entropy of the controller.')
parser.add_argument('--controller_bl_dec' , type=float, help='.')
parser.add_argument('--controller_num_samples' , type=int, help='.')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
main(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser("ENAS")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument(
"--track_running_stats",
type=int,
choices=[0, 1],
help="Whether use track_running_stats or not in the BN layer.",
)
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.")
parser.add_argument("--config_path", type=str, help="The config file to train ENAS.")
parser.add_argument("--controller_train_steps", type=int, help=".")
parser.add_argument("--controller_num_aggregate", type=int, help=".")
parser.add_argument("--controller_entropy_weight", type=float, help="The weight for the entropy of the controller.")
parser.add_argument("--controller_bl_dec", type=float, help=".")
parser.add_argument("--controller_num_samples", type=int, help=".")
# log
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.")
parser.add_argument(
"--arch_nas_dataset", type=str, help="The path to load the architecture dataset (nas-benchmark)."
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
main(args)

View File

@@ -7,211 +7,308 @@ import sys, time, random, argparse
from copy import deepcopy
import torch
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.train()
end = time.time()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
w_optimizer.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update (base_prec1.item(), base_inputs.size(0))
base_top5.update (base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
a_optimizer.zero_grad()
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
arch_loss.backward()
a_optimizer.step()
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.train()
end = time.time()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
# update the weights
w_optimizer.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update(base_prec1.item(), base_inputs.size(0))
base_top5.update(base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
a_optimizer.zero_grad()
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
arch_loss.backward()
a_optimizer.step()
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=base_losses, top1=base_top1, top5=base_top5
)
Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=arch_losses, top1=arch_top1, top5=arch_top5
)
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr)
return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
#config_path = 'configs/nas-benchmark/algos/GDAS.config'
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', config.batch_size, xargs.workers)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
# config_path = 'configs/nas-benchmark/algos/GDAS.config'
config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger)
search_loader, _, valid_loader = get_nas_search_loaders(
train_data, valid_data, xargs.dataset, "configs/nas-benchmark/", config.batch_size, xargs.workers
)
logger.log(
"||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(search_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
search_space = get_search_spaces('cell', xargs.search_space_name)
if xargs.model_config is None:
model_config = dict2config({'name': 'GDAS', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
else:
model_config = load_config(xargs.model_config, {'num_classes': class_num, 'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
search_model = get_cell_based_tiny_net(model_config)
logger.log('search-model :\n{:}'.format(search_model))
logger.log('model-config : {:}'.format(model_config))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
logger.log('w-optimizer : {:}'.format(w_optimizer))
logger.log('a-optimizer : {:}'.format(a_optimizer))
logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion))
flop, param = get_model_infos(search_model, xshape)
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
logger.log('search-space [{:} ops] : {:}'.format(len(search_space), search_space))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))
search_space = get_search_spaces("cell", xargs.search_space_name)
if xargs.model_config is None:
model_config = dict2config(
{
"name": "GDAS",
"C": xargs.channel,
"N": xargs.num_cells,
"max_nodes": xargs.max_nodes,
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
else:
model_config = load_config(
xargs.model_config,
{
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
search_model = get_cell_based_tiny_net(model_config)
logger.log("search-model :\n{:}".format(search_model))
logger.log("model-config : {:}".format(model_config))
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
a_optimizer = torch.optim.Adam(
search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay
)
logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("a-optimizer : {:}".format(a_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion : {:}".format(criterion))
flop, param = get_model_infos(search_model, xshape)
logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param))
logger.log("search-space [{:} ops] : {:}".format(len(search_space), search_space))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log("{:} create API = {:} done".format(time_string(), api))
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info['epoch']
checkpoint = torch.load(last_info['last_checkpoint'])
genotypes = checkpoint['genotypes']
valid_accuracies = checkpoint['valid_accuracies']
search_model.load_state_dict( checkpoint['search_model'] )
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}
last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best")
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
# start training
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
search_model.set_tau( xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (total_epoch-1) )
logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format(epoch_str, need_time, search_model.get_tau(), min(w_scheduler.get_lr())))
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info["epoch"]
checkpoint = torch.load(last_info["last_checkpoint"])
genotypes = checkpoint["genotypes"]
valid_accuracies = checkpoint["valid_accuracies"]
search_model.load_state_dict(checkpoint["search_model"])
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
a_optimizer.load_state_dict(checkpoint["a_optimizer"])
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)
)
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: search_model.genotype()}
search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
search_time.update(time.time() - start_time)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss , valid_a_top1 , valid_a_top5 ))
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies['best']:
valid_accuracies['best'] = valid_a_top1
genotypes['best'] = search_model.genotype()
find_best = True
else: find_best = False
# start training
start_time, search_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
AverageMeter(),
config.epochs + config.warmup,
)
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
search_model.set_tau(xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1))
logger.log(
"\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}".format(
epoch_str, need_time, search_model.get_tau(), min(w_scheduler.get_lr())
)
)
genotypes[epoch] = search_model.genotype()
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint({'epoch' : epoch + 1,
'args' : deepcopy(xargs),
'search_model': search_model.state_dict(),
'w_optimizer' : w_optimizer.state_dict(),
'a_optimizer' : a_optimizer.state_dict(),
'w_scheduler' : w_scheduler.state_dict(),
'genotypes' : genotypes,
'valid_accuracies' : valid_accuracies},
model_base_path, logger)
last_info = save_checkpoint({
'epoch': epoch + 1,
'args' : deepcopy(args),
'last_checkpoint': save_path,
}, logger.path('info'), logger)
if find_best:
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad():
logger.log('{:}'.format(search_model.show_alphas()))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
search_w_loss, search_w_top1, search_w_top5, valid_a_loss, valid_a_top1, valid_a_top5 = search_func(
search_loader,
network,
criterion,
w_scheduler,
w_optimizer,
a_optimizer,
epoch_str,
xargs.print_freq,
logger,
)
search_time.update(time.time() - start_time)
logger.log(
"[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
)
)
logger.log(
"[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
epoch_str, valid_a_loss, valid_a_top1, valid_a_top5
)
)
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies["best"]:
valid_accuracies["best"] = valid_a_top1
genotypes["best"] = search_model.genotype()
find_best = True
else:
find_best = False
logger.log('\n' + '-'*100)
# check the performance from the architecture dataset
logger.log('GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotypes[total_epoch-1]))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch-1], '200')))
logger.close()
genotypes[epoch] = search_model.genotype()
logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(xargs),
"search_model": search_model.state_dict(),
"w_optimizer": w_optimizer.state_dict(),
"a_optimizer": a_optimizer.state_dict(),
"w_scheduler": w_scheduler.state_dict(),
"genotypes": genotypes,
"valid_accuracies": valid_accuracies,
},
model_base_path,
logger,
)
last_info = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(args),
"last_checkpoint": save_path,
},
logger.path("info"),
logger,
)
if find_best:
logger.log(
"<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format(
epoch_str, valid_a_top1
)
)
copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad():
logger.log("{:}".format(search_model.show_alphas()))
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.log("\n" + "-" * 100)
# check the performance from the architecture dataset
logger.log(
"GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
total_epoch, search_time.sum, genotypes[total_epoch - 1]
)
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1], "200")))
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser("GDAS")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--config_path', type=str, help='The path of the configuration.')
parser.add_argument('--model_config', type=str, help='The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.')
# architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
parser.add_argument('--tau_min', type=float, help='The minimum tau for Gumbel')
parser.add_argument('--tau_max', type=float, help='The maximum tau for Gumbel')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
main(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser("GDAS")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.")
parser.add_argument(
"--track_running_stats",
type=int,
choices=[0, 1],
help="Whether use track_running_stats or not in the BN layer.",
)
parser.add_argument("--config_path", type=str, help="The path of the configuration.")
parser.add_argument(
"--model_config",
type=str,
help="The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.",
)
# architecture leraning rate
parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding")
parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding")
parser.add_argument("--tau_min", type=float, help="The minimum tau for Gumbel")
parser.add_argument("--tau_max", type=float, help="The maximum tau for Gumbel")
# log
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.")
parser.add_argument(
"--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)."
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
main(args)

View File

@@ -9,229 +9,306 @@ from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.train()
end = time.time()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
network.module.random_genotype( True )
w_optimizer.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
nn.utils.clip_grad_norm_(network.parameters(), 5)
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update (base_prec1.item(), base_inputs.size(0))
base_top5.update (base_prec5.item(), base_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.train()
end = time.time()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr)
return base_losses.avg, base_top1.avg, base_top5.avg
# update the weights
network.module.random_genotype(True)
w_optimizer.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
nn.utils.clip_grad_norm_(network.parameters(), 5)
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update(base_prec1.item(), base_inputs.size(0))
base_top5.update(base_prec5.item(), base_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=base_losses, top1=base_top1, top5=base_top5
)
logger.log(Sstr + " " + Tstr + " " + Wstr)
return base_losses.avg, base_top1.avg, base_top5.avg
def valid_func(xloader, network, criterion):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
network.module.random_genotype( True )
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
network.module.random_genotype(True)
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def search_find_best(xloader, network, n_samples):
with torch.no_grad():
network.eval()
archs, valid_accs = [], []
#print ('obtain the top-{:} architectures'.format(n_samples))
loader_iter = iter(xloader)
for i in range(n_samples):
arch = network.module.random_genotype( True )
try:
inputs, targets = next(loader_iter)
except:
with torch.no_grad():
network.eval()
archs, valid_accs = [], []
# print ('obtain the top-{:} architectures'.format(n_samples))
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
for i in range(n_samples):
arch = network.module.random_genotype(True)
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
_, logits = network(inputs)
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
_, logits = network(inputs)
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
archs.append( arch )
valid_accs.append( val_top1.item() )
archs.append(arch)
valid_accs.append(val_top1.item())
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
return best_arch, best_valid_acc
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
return best_arch, best_valid_acc
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
(config.batch_size, config.test_batch_size), xargs.workers)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger)
search_loader, _, valid_loader = get_nas_search_loaders(
train_data,
valid_data,
xargs.dataset,
"configs/nas-benchmark/",
(config.batch_size, config.test_batch_size),
xargs.workers,
)
logger.log(
"||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(search_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
search_space = get_search_spaces('cell', xargs.search_space_name)
model_config = dict2config({'name': 'RANDOM', 'C': xargs.channel, 'N': xargs.num_cells,
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
'space' : search_space,
'affine' : False, 'track_running_stats': bool(xargs.track_running_stats)}, None)
search_model = get_cell_based_tiny_net(model_config)
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config)
logger.log('w-optimizer : {:}'.format(w_optimizer))
logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion))
if xargs.arch_nas_dataset is None: api = None
else : api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))
search_space = get_search_spaces("cell", xargs.search_space_name)
model_config = dict2config(
{
"name": "RANDOM",
"C": xargs.channel,
"N": xargs.num_cells,
"max_nodes": xargs.max_nodes,
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
search_model = get_cell_based_tiny_net(model_config)
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.parameters(), config)
logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion : {:}".format(criterion))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log("{:} create API = {:} done".format(time_string(), api))
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info['epoch']
checkpoint = torch.load(last_info['last_checkpoint'])
genotypes = checkpoint['genotypes']
valid_accuracies = checkpoint['valid_accuracies']
search_model.load_state_dict( checkpoint['search_model'] )
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best")
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
# start training
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info["epoch"]
checkpoint = torch.load(last_info["last_checkpoint"])
genotypes = checkpoint["genotypes"]
valid_accuracies = checkpoint["valid_accuracies"]
search_model.load_state_dict(checkpoint["search_model"])
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)
)
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {}
# selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
search_w_loss, search_w_top1, search_w_top5 = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
search_time.update(time.time() - start_time)
logger.log('[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num)
logger.log('[{:}] find-the-best : {:}, accuracy@1={:.2f}%'.format(epoch_str, cur_arch, cur_valid_acc))
genotypes[epoch] = cur_arch
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies['best']:
valid_accuracies['best'] = valid_a_top1
find_best = True
else: find_best = False
# start training
start_time, search_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
AverageMeter(),
config.epochs + config.warmup,
)
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min(w_scheduler.get_lr())))
# save checkpoint
save_path = save_checkpoint({'epoch' : epoch + 1,
'args' : deepcopy(xargs),
'search_model': search_model.state_dict(),
'w_optimizer' : w_optimizer.state_dict(),
'w_scheduler' : w_scheduler.state_dict(),
'genotypes' : genotypes,
'valid_accuracies' : valid_accuracies},
model_base_path, logger)
last_info = save_checkpoint({
'epoch': epoch + 1,
'args' : deepcopy(args),
'last_checkpoint': save_path,
}, logger.path('info'), logger)
if find_best:
logger.log('<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'.format(epoch_str, valid_a_top1))
copy_checkpoint(model_base_path, model_best_path, logger)
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
# selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
search_w_loss, search_w_top1, search_w_top5 = search_func(
search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger
)
search_time.update(time.time() - start_time)
logger.log(
"[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
)
)
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log(
"[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
epoch_str, valid_a_loss, valid_a_top1, valid_a_top5
)
)
cur_arch, cur_valid_acc = search_find_best(valid_loader, network, xargs.select_num)
logger.log("[{:}] find-the-best : {:}, accuracy@1={:.2f}%".format(epoch_str, cur_arch, cur_valid_acc))
genotypes[epoch] = cur_arch
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies["best"]:
valid_accuracies["best"] = valid_a_top1
find_best = True
else:
find_best = False
# save checkpoint
save_path = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(xargs),
"search_model": search_model.state_dict(),
"w_optimizer": w_optimizer.state_dict(),
"w_scheduler": w_scheduler.state_dict(),
"genotypes": genotypes,
"valid_accuracies": valid_accuracies,
},
model_base_path,
logger,
)
last_info = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(args),
"last_checkpoint": save_path,
},
logger.path("info"),
logger,
)
if find_best:
logger.log(
"<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format(
epoch_str, valid_a_top1
)
)
copy_checkpoint(model_base_path, model_best_path, logger)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.log("\n" + "-" * 200)
logger.log("Pre-searching costs {:.1f} s".format(search_time.sum))
start_time = time.time()
logger.log('\n' + '-'*200)
logger.log('Pre-searching costs {:.1f} s'.format(search_time.sum))
start_time = time.time()
best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
search_time.update(time.time() - start_time)
logger.log('RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.'.format(best_arch, best_acc, search_time.sum))
if api is not None: logger.log('{:}'.format(api.query_by_arch(best_arch, '200')))
logger.close()
best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
search_time.update(time.time() - start_time)
logger.log(
"RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.".format(
best_arch, best_acc, search_time.sum
)
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(best_arch, "200")))
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser("Random search for NAS.")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--config_path', type=str, help='The path to the configuration.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.')
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
main(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Random search for NAS.")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--config_path", type=str, help="The path to the configuration.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.")
parser.add_argument("--select_num", type=int, help="The number of selected architectures to evaluate.")
parser.add_argument(
"--track_running_stats",
type=int,
choices=[0, 1],
help="Whether use track_running_stats or not in the BN layer.",
)
# log
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.")
parser.add_argument(
"--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)."
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
main(args)

View File

@@ -7,113 +7,145 @@ from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_search_spaces
from nas_201_api import NASBench201API as API
from R_EA import train_and_eval, random_architecture_func
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_search_spaces
from nas_201_api import NASBench201API as API
from R_EA import train_and_eval, random_architecture_func
def main(xargs, nas_bench):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
if xargs.dataset == 'cifar10':
dataname = 'cifar10-valid'
else:
dataname = xargs.dataset
if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
config_path = 'configs/nas-benchmark/algos/R-EA.config'
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
else:
config_path = 'configs/nas-benchmark/algos/R-EA.config'
config = load_config(config_path, None, logger)
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
extra_info = {'config': config, 'train_loader': None, 'valid_loader': None}
search_space = get_search_spaces('cell', xargs.search_space_name)
random_arch = random_architecture_func(xargs.max_nodes, search_space)
#x =random_arch() ; y = mutate_arch(x)
x_start_time = time.time()
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
best_arch, best_acc, total_time_cost, history = None, -1, 0, []
#for idx in range(xargs.random_num):
while total_time_cost < xargs.time_budget:
arch = random_arch()
accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
if total_time_cost + cost_time > xargs.time_budget: break
else: total_time_cost += cost_time
history.append(arch)
if best_arch is None or best_acc < accuracy:
best_acc, best_arch = accuracy, arch
logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy))
logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time))
info = nas_bench.query_by_arch(best_arch, '200')
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
if xargs.dataset == "cifar10":
dataname = "cifar10-valid"
else:
dataname = xargs.dataset
if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = "configs/nas-benchmark/cifar-split.txt"
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log("Load split file from {:}".format(split_Fpath))
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
num_workers=xargs.workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
num_workers=xargs.workers,
pin_memory=True,
)
logger.log(
"||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(train_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader}
else:
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, None, logger)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {"config": config, "train_loader": None, "valid_loader": None}
search_space = get_search_spaces("cell", xargs.search_space_name)
random_arch = random_architecture_func(xargs.max_nodes, search_space)
# x =random_arch() ; y = mutate_arch(x)
x_start_time = time.time()
logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench))
best_arch, best_acc, total_time_cost, history = None, -1, 0, []
# for idx in range(xargs.random_num):
while total_time_cost < xargs.time_budget:
arch = random_arch()
accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
if total_time_cost + cost_time > xargs.time_budget:
break
else:
total_time_cost += cost_time
history.append(arch)
if best_arch is None or best_acc < accuracy:
best_acc, best_arch = accuracy, arch
logger.log("[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy))
logger.log(
"{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).".format(
time_string(), best_arch, best_acc, len(history), total_time_cost, time.time() - x_start_time
)
)
info = nas_bench.query_by_arch(best_arch, "200")
if info is None:
logger.log("Did not find this architecture : {:}.".format(best_arch))
else:
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch(best_arch)
if __name__ == '__main__':
parser = argparse.ArgumentParser("Random NAS")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
#parser.add_argument('--random_num', type=int, help='The number of random selected architectures.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
#if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num = None, [], 500
for i in range(num):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000)
save_dir, index = main(args, nas_bench)
all_indexes.append( index )
torch.save(all_indexes, save_dir / 'results.pth')
else:
main(args, nas_bench)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Random NAS")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.")
# parser.add_argument('--random_num', type=int, help='The number of random selected architectures.')
parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).")
# log
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.")
parser.add_argument(
"--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)."
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
# if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset))
nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num = None, [], 500
for i in range(num):
print("{:} : {:03d}/{:03d}".format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000)
save_dir, index = main(args, nas_bench)
all_indexes.append(index)
torch.save(all_indexes, save_dir / "results.pth")
else:
main(args, nas_bench)

View File

@@ -9,260 +9,324 @@ from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from nas_201_api import NASBench201API as API
from models import CellStructure, get_search_spaces
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from nas_201_api import NASBench201API as API
from models import CellStructure, get_search_spaces
class Model(object):
def __init__(self):
self.arch = None
self.accuracy = None
def __str__(self):
"""Prints a readable version of this bitstring."""
return "{:}".format(self.arch)
def __init__(self):
self.arch = None
self.accuracy = None
def __str__(self):
"""Prints a readable version of this bitstring."""
return '{:}'.format(self.arch)
# This function is to mimic the training and evaluatinig procedure for a single architecture `arch`.
# The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch.
# For use_012_epoch_training = True, the architecture is trained for 12 epochs, with LR being decaded from 0.1 to 0.
# In this case, the LR schedular is converged.
# For use_012_epoch_training = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure.
#
def train_and_eval(arch, nas_bench, extra_info, dataname='cifar10-valid', use_012_epoch_training=True):
#
def train_and_eval(arch, nas_bench, extra_info, dataname="cifar10-valid", use_012_epoch_training=True):
if use_012_epoch_training and nas_bench is not None:
arch_index = nas_bench.query_index_by_arch( arch )
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
info = nas_bench.get_more_info(arch_index, dataname, iepoch=None, hp='12', is_random=True)
valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
elif not use_012_epoch_training and nas_bench is not None:
# Please contact me if you want to use the following logic, because it has some potential issues.
# Please use `use_012_epoch_training=False` for cifar10 only.
# It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details)
arch_index, nepoch = nas_bench.query_index_by_arch( arch ), 25
assert arch_index >= 0, 'can not find this arch : {:}'.format(arch)
xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', iepoch=None, hp='12')
xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', hp='200')
info = nas_bench.get_more_info(arch_index, dataname, nepoch, hp='200', is_random=True) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
cost = nas_bench.get_cost_info(arch_index, dataname, hp='200')
# The following codes are used to estimate the time cost.
# When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
# When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared.
nums = {'ImageNet16-120-train': 151700, 'ImageNet16-120-valid': 3000,
'cifar10-valid-train' : 25000, 'cifar10-valid-valid' : 25000,
'cifar100-train' : 50000, 'cifar100-valid' : 5000}
estimated_train_cost = xoinfo['train-per-time'] / nums['cifar10-valid-train'] * nums['{:}-train'.format(dataname)] / xocost['latency'] * cost['latency'] * nepoch
estimated_valid_cost = xoinfo['valid-per-time'] / nums['cifar10-valid-valid'] * nums['{:}-valid'.format(dataname)] / xocost['latency'] * cost['latency']
try:
valid_acc, time_cost = info['valid-accuracy'], estimated_train_cost + estimated_valid_cost
except:
valid_acc, time_cost = info['valtest-accuracy'], estimated_train_cost + estimated_valid_cost
else:
# train a model from scratch.
raise ValueError('NOT IMPLEMENT YET')
return valid_acc, time_cost
if use_012_epoch_training and nas_bench is not None:
arch_index = nas_bench.query_index_by_arch(arch)
assert arch_index >= 0, "can not find this arch : {:}".format(arch)
info = nas_bench.get_more_info(arch_index, dataname, iepoch=None, hp="12", is_random=True)
valid_acc, time_cost = info["valid-accuracy"], info["train-all-time"] + info["valid-per-time"]
# _, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
elif not use_012_epoch_training and nas_bench is not None:
# Please contact me if you want to use the following logic, because it has some potential issues.
# Please use `use_012_epoch_training=False` for cifar10 only.
# It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details)
arch_index, nepoch = nas_bench.query_index_by_arch(arch), 25
assert arch_index >= 0, "can not find this arch : {:}".format(arch)
xoinfo = nas_bench.get_more_info(arch_index, "cifar10-valid", iepoch=None, hp="12")
xocost = nas_bench.get_cost_info(arch_index, "cifar10-valid", hp="200")
info = nas_bench.get_more_info(
arch_index, dataname, nepoch, hp="200", is_random=True
) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
cost = nas_bench.get_cost_info(arch_index, dataname, hp="200")
# The following codes are used to estimate the time cost.
# When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
# When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared.
nums = {
"ImageNet16-120-train": 151700,
"ImageNet16-120-valid": 3000,
"cifar10-valid-train": 25000,
"cifar10-valid-valid": 25000,
"cifar100-train": 50000,
"cifar100-valid": 5000,
}
estimated_train_cost = (
xoinfo["train-per-time"]
/ nums["cifar10-valid-train"]
* nums["{:}-train".format(dataname)]
/ xocost["latency"]
* cost["latency"]
* nepoch
)
estimated_valid_cost = (
xoinfo["valid-per-time"]
/ nums["cifar10-valid-valid"]
* nums["{:}-valid".format(dataname)]
/ xocost["latency"]
* cost["latency"]
)
try:
valid_acc, time_cost = info["valid-accuracy"], estimated_train_cost + estimated_valid_cost
except:
valid_acc, time_cost = info["valtest-accuracy"], estimated_train_cost + estimated_valid_cost
else:
# train a model from scratch.
raise ValueError("NOT IMPLEMENT YET")
return valid_acc, time_cost
def random_architecture_func(max_nodes, op_names):
# return a random architecture
def random_architecture():
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = random.choice( op_names )
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
return random_architecture
# return a random architecture
def random_architecture():
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_name = random.choice(op_names)
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return CellStructure(genotypes)
return random_architecture
def mutate_arch_func(op_names):
"""Computes the architecture for a child of the given parent architecture.
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
"""
def mutate_arch_func(parent_arch):
child_arch = deepcopy( parent_arch )
node_id = random.randint(0, len(child_arch.nodes)-1)
node_info = list( child_arch.nodes[node_id] )
snode_id = random.randint(0, len(node_info)-1)
xop = random.choice( op_names )
while xop == node_info[snode_id][0]:
xop = random.choice( op_names )
node_info[snode_id] = (xop, node_info[snode_id][1])
child_arch.nodes[node_id] = tuple( node_info )
return child_arch
return mutate_arch_func
"""Computes the architecture for a child of the given parent architecture.
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
"""
def mutate_arch_func(parent_arch):
child_arch = deepcopy(parent_arch)
node_id = random.randint(0, len(child_arch.nodes) - 1)
node_info = list(child_arch.nodes[node_id])
snode_id = random.randint(0, len(node_info) - 1)
xop = random.choice(op_names)
while xop == node_info[snode_id][0]:
xop = random.choice(op_names)
node_info[snode_id] = (xop, node_info[snode_id][1])
child_arch.nodes[node_id] = tuple(node_info)
return child_arch
return mutate_arch_func
def regularized_evolution(cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info, dataname):
"""Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
Classifier Architecture Search".
Args:
cycles: the number of cycles the algorithm should run for.
population_size: the number of individuals to keep in the population.
sample_size: the number of individuals that should participate in each tournament.
time_budget: the upper bound of searching cost
def regularized_evolution(
cycles, population_size, sample_size, time_budget, random_arch, mutate_arch, nas_bench, extra_info, dataname
):
"""Algorithm for regularized evolution (i.e. aging evolution).
Returns:
history: a list of `Model` instances, representing all the models computed
during the evolution experiment.
"""
population = collections.deque()
history, total_time_cost = [], 0 # Not used by the algorithm, only used to report results.
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
Classifier Architecture Search".
# Initialize the population with random models.
while len(population) < population_size:
model = Model()
model.arch = random_arch()
model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info, dataname)
population.append(model)
history.append(model)
total_time_cost += time_cost
Args:
cycles: the number of cycles the algorithm should run for.
population_size: the number of individuals to keep in the population.
sample_size: the number of individuals that should participate in each tournament.
time_budget: the upper bound of searching cost
# Carry out evolution in cycles. Each cycle produces a model and removes
# another.
#while len(history) < cycles:
while total_time_cost < time_budget:
# Sample randomly chosen models from the current population.
start_time, sample = time.time(), []
while len(sample) < sample_size:
# Inefficient, but written this way for clarity. In the case of neural
# nets, the efficiency of this line is irrelevant because training neural
# nets is the rate-determining step.
candidate = random.choice(list(population))
sample.append(candidate)
Returns:
history: a list of `Model` instances, representing all the models computed
during the evolution experiment.
"""
population = collections.deque()
history, total_time_cost = [], 0 # Not used by the algorithm, only used to report results.
# The parent is the best model in the sample.
parent = max(sample, key=lambda i: i.accuracy)
# Initialize the population with random models.
while len(population) < population_size:
model = Model()
model.arch = random_arch()
model.accuracy, time_cost = train_and_eval(model.arch, nas_bench, extra_info, dataname)
population.append(model)
history.append(model)
total_time_cost += time_cost
# Create the child model and store it.
child = Model()
child.arch = mutate_arch(parent.arch)
total_time_cost += time.time() - start_time
child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info, dataname)
if total_time_cost + time_cost > time_budget: # return
return history, total_time_cost
else:
total_time_cost += time_cost
population.append(child)
history.append(child)
# Carry out evolution in cycles. Each cycle produces a model and removes
# another.
# while len(history) < cycles:
while total_time_cost < time_budget:
# Sample randomly chosen models from the current population.
start_time, sample = time.time(), []
while len(sample) < sample_size:
# Inefficient, but written this way for clarity. In the case of neural
# nets, the efficiency of this line is irrelevant because training neural
# nets is the rate-determining step.
candidate = random.choice(list(population))
sample.append(candidate)
# Remove the oldest model.
population.popleft()
return history, total_time_cost
# The parent is the best model in the sample.
parent = max(sample, key=lambda i: i.accuracy)
# Create the child model and store it.
child = Model()
child.arch = mutate_arch(parent.arch)
total_time_cost += time.time() - start_time
child.accuracy, time_cost = train_and_eval(child.arch, nas_bench, extra_info, dataname)
if total_time_cost + time_cost > time_budget: # return
return history, total_time_cost
else:
total_time_cost += time_cost
population.append(child)
history.append(child)
# Remove the oldest model.
population.popleft()
return history, total_time_cost
def main(xargs, nas_bench):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
if xargs.dataset == 'cifar10':
dataname = 'cifar10-valid'
else:
dataname = xargs.dataset
if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
config_path = 'configs/nas-benchmark/algos/R-EA.config'
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
else:
config_path = 'configs/nas-benchmark/algos/R-EA.config'
config = load_config(config_path, None, logger)
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
extra_info = {'config': config, 'train_loader': None, 'valid_loader': None}
if xargs.dataset == "cifar10":
dataname = "cifar10-valid"
else:
dataname = xargs.dataset
if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = "configs/nas-benchmark/cifar-split.txt"
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log("Load split file from {:}".format(split_Fpath))
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
num_workers=xargs.workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
num_workers=xargs.workers,
pin_memory=True,
)
logger.log(
"||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(train_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader}
else:
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, None, logger)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {"config": config, "train_loader": None, "valid_loader": None}
search_space = get_search_spaces('cell', xargs.search_space_name)
random_arch = random_architecture_func(xargs.max_nodes, search_space)
mutate_arch = mutate_arch_func(search_space)
#x =random_arch() ; y = mutate_arch(x)
x_start_time = time.time()
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
logger.log('-'*30 + ' start searching with the time budget of {:} s'.format(xargs.time_budget))
history, total_cost = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, xargs.time_budget, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info, dataname)
logger.log('{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).'.format(time_string(), len(history), total_cost, time.time()-x_start_time))
best_arch = max(history, key=lambda i: i.accuracy)
best_arch = best_arch.arch
logger.log('{:} best arch is {:}'.format(time_string(), best_arch))
info = nas_bench.query_by_arch(best_arch, '200')
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
search_space = get_search_spaces("cell", xargs.search_space_name)
random_arch = random_architecture_func(xargs.max_nodes, search_space)
mutate_arch = mutate_arch_func(search_space)
# x =random_arch() ; y = mutate_arch(x)
x_start_time = time.time()
logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench))
logger.log("-" * 30 + " start searching with the time budget of {:} s".format(xargs.time_budget))
history, total_cost = regularized_evolution(
xargs.ea_cycles,
xargs.ea_population,
xargs.ea_sample_size,
xargs.time_budget,
random_arch,
mutate_arch,
nas_bench if args.ea_fast_by_api else None,
extra_info,
dataname,
)
logger.log(
"{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).".format(
time_string(), len(history), total_cost, time.time() - x_start_time
)
)
best_arch = max(history, key=lambda i: i.accuracy)
best_arch = best_arch.arch
logger.log("{:} best arch is {:}".format(time_string(), best_arch))
info = nas_bench.query_by_arch(best_arch, "200")
if info is None:
logger.log("Did not find this architecture : {:}.".format(best_arch))
else:
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch(best_arch)
if __name__ == '__main__':
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--ea_cycles', type=int, help='The number of cycles in EA.')
parser.add_argument('--ea_population', type=int, help='The population size in EA.')
parser.add_argument('--ea_sample_size', type=int, help='The sample size in EA.')
parser.add_argument('--ea_fast_by_api', type=int, help='Use our API to speed up the experiments or not.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
#if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
args.ea_fast_by_api = args.ea_fast_by_api > 0
if __name__ == "__main__":
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.")
parser.add_argument("--ea_cycles", type=int, help="The number of cycles in EA.")
parser.add_argument("--ea_population", type=int, help="The population size in EA.")
parser.add_argument("--ea_sample_size", type=int, help="The sample size in EA.")
parser.add_argument("--ea_fast_by_api", type=int, help="Use our API to speed up the experiments or not.")
parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).")
# log
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.")
parser.add_argument(
"--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)."
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
# if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
args.ea_fast_by_api = args.ea_fast_by_api > 0
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num = None, [], 500
for i in range(num):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000)
save_dir, index = main(args, nas_bench)
all_indexes.append( index )
torch.save(all_indexes, save_dir / 'results.pth')
else:
main(args, nas_bench)
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset))
nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num = None, [], 500
for i in range(num):
print("{:} : {:03d}/{:03d}".format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000)
save_dir, index = main(args, nas_bench)
all_indexes.append(index)
torch.save(all_indexes, save_dir / "results.pth")
else:
main(args, nas_bench)

View File

@@ -9,274 +9,363 @@ from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
network.train()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
sampled_arch = network.module.dync_genotype(True)
network.module.set_cal_mode('dynamic', sampled_arch)
#network.module.set_cal_mode( 'urs' )
network.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update (base_prec1.item(), base_inputs.size(0))
base_top5.update (base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
network.module.set_cal_mode( 'joint' )
network.zero_grad()
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
arch_loss.backward()
a_optimizer.step()
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
network.train()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
#print (nn.functional.softmax(network.module.arch_parameters, dim=-1))
#print (network.module.arch_parameters)
return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
# update the weights
sampled_arch = network.module.dync_genotype(True)
network.module.set_cal_mode("dynamic", sampled_arch)
# network.module.set_cal_mode( 'urs' )
network.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update(base_prec1.item(), base_inputs.size(0))
base_top5.update(base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
network.module.set_cal_mode("joint")
network.zero_grad()
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
arch_loss.backward()
a_optimizer.step()
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = "*SEARCH* " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=base_losses, top1=base_top1, top5=base_top5
)
Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=arch_losses, top1=arch_top1, top5=arch_top5
)
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr)
# print (nn.functional.softmax(network.module.arch_parameters, dim=-1))
# print (network.module.arch_parameters)
return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
def get_best_arch(xloader, network, n_samples):
with torch.no_grad():
network.eval()
archs, valid_accs = network.module.return_topK(n_samples), []
#print ('obtain the top-{:} architectures'.format(n_samples))
loader_iter = iter(xloader)
for i, sampled_arch in enumerate(archs):
network.module.set_cal_mode('dynamic', sampled_arch)
try:
inputs, targets = next(loader_iter)
except:
with torch.no_grad():
network.eval()
archs, valid_accs = network.module.return_topK(n_samples), []
# print ('obtain the top-{:} architectures'.format(n_samples))
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
for i, sampled_arch in enumerate(archs):
network.module.set_cal_mode("dynamic", sampled_arch)
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
_, logits = network(inputs)
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
_, logits = network(inputs)
val_top1, val_top5 = obtain_accuracy(logits.cpu().data, targets.data, topk=(1, 5))
valid_accs.append(val_top1.item())
valid_accs.append(val_top1.item())
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
return best_arch, best_valid_acc
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
return best_arch, best_valid_acc
def valid_func(xloader, network, criterion):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
with torch.no_grad():
network.eval()
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update (arch_prec1.item(), arch_inputs.size(0))
arch_top5.update (arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
with torch.no_grad():
network.eval()
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(xargs):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
(config.batch_size, config.test_batch_size), xargs.workers)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
config = load_config(xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger)
search_loader, _, valid_loader = get_nas_search_loaders(
train_data,
valid_data,
xargs.dataset,
"configs/nas-benchmark/",
(config.batch_size, config.test_batch_size),
xargs.workers,
)
logger.log(
"||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(search_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
search_space = get_search_spaces('cell', xargs.search_space_name)
if xargs.model_config is None:
model_config = dict2config(
dict(name='SETN', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num,
space=search_space, affine=False, track_running_stats=bool(xargs.track_running_stats)), None)
else:
model_config = load_config(xargs.model_config, dict(num_classes=class_num, space=search_space, affine=False,
track_running_stats=bool(xargs.track_running_stats)), None)
logger.log('search space : {:}'.format(search_space))
search_model = get_cell_based_tiny_net(model_config)
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
a_optimizer = torch.optim.Adam(search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay)
logger.log('w-optimizer : {:}'.format(w_optimizer))
logger.log('a-optimizer : {:}'.format(a_optimizer))
logger.log('w-scheduler : {:}'.format(w_scheduler))
logger.log('criterion : {:}'.format(criterion))
flop, param = get_model_infos(search_model, xshape)
logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
logger.log('search-space : {:}'.format(search_space))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log('{:} create API = {:} done'.format(time_string(), api))
search_space = get_search_spaces("cell", xargs.search_space_name)
if xargs.model_config is None:
model_config = dict2config(
dict(
name="SETN",
C=xargs.channel,
N=xargs.num_cells,
max_nodes=xargs.max_nodes,
num_classes=class_num,
space=search_space,
affine=False,
track_running_stats=bool(xargs.track_running_stats),
),
None,
)
else:
model_config = load_config(
xargs.model_config,
dict(
num_classes=class_num,
space=search_space,
affine=False,
track_running_stats=bool(xargs.track_running_stats),
),
None,
)
logger.log("search space : {:}".format(search_space))
search_model = get_cell_based_tiny_net(model_config)
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.get_weights(), config)
a_optimizer = torch.optim.Adam(
search_model.get_alphas(), lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay
)
logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("a-optimizer : {:}".format(a_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion : {:}".format(criterion))
flop, param = get_model_infos(search_model, xshape)
logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param))
logger.log("search-space : {:}".format(search_space))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log("{:} create API = {:} done".format(time_string(), api))
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info['epoch']
checkpoint = torch.load(last_info['last_checkpoint'])
genotypes = checkpoint['genotypes']
valid_accuracies = checkpoint['valid_accuracies']
search_model.load_state_dict( checkpoint['search_model'] )
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
w_optimizer.load_state_dict ( checkpoint['w_optimizer'] )
a_optimizer.load_state_dict ( checkpoint['a_optimizer'] )
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
init_genotype, _ = get_best_arch(valid_loader, network, xargs.select_num)
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: init_genotype}
last_info, model_base_path, model_best_path = logger.path("info"), logger.path("model"), logger.path("best")
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
# start training
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.val * (total_epoch-epoch), True) )
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
if last_info.exists(): # automatically resume from previous checkpoint
logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
last_info = torch.load(last_info)
start_epoch = last_info["epoch"]
checkpoint = torch.load(last_info["last_checkpoint"])
genotypes = checkpoint["genotypes"]
valid_accuracies = checkpoint["valid_accuracies"]
search_model.load_state_dict(checkpoint["search_model"])
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
a_optimizer.load_state_dict(checkpoint["a_optimizer"])
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch)
)
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
init_genotype, _ = get_best_arch(valid_loader, network, xargs.select_num)
start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: init_genotype}
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
search_time.update(time.time() - start_time)
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
# start training
start_time, search_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
AverageMeter(),
config.epochs + config.warmup,
)
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
logger.log("\n[Search the {:}-th epoch] {:}, LR={:}".format(epoch_str, need_time, min(w_scheduler.get_lr())))
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
network.module.set_cal_mode('dynamic', genotype)
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype))
#search_model.set_cal_mode('urs')
#valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
#logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
#search_model.set_cal_mode('joint')
#valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
#logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
#search_model.set_cal_mode('select')
#valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
#logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 = search_func(
search_loader,
network,
criterion,
w_scheduler,
w_optimizer,
a_optimizer,
epoch_str,
xargs.print_freq,
logger,
)
search_time.update(time.time() - start_time)
logger.log(
"[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
)
)
logger.log(
"[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
epoch_str, search_a_loss, search_a_top1, search_a_top5
)
)
genotypes[epoch] = genotype
logger.log('<<<--->>> The {:}-th epoch : {:}'.format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint({'epoch' : epoch + 1,
'args' : deepcopy(xargs),
'search_model': search_model.state_dict(),
'w_optimizer' : w_optimizer.state_dict(),
'a_optimizer' : a_optimizer.state_dict(),
'w_scheduler' : w_scheduler.state_dict(),
'genotypes' : genotypes,
'valid_accuracies' : valid_accuracies},
model_base_path, logger)
last_info = save_checkpoint({
'epoch': epoch + 1,
'args' : deepcopy(args),
'last_checkpoint': save_path,
}, logger.path('info'), logger)
with torch.no_grad():
logger.log('{:}'.format(search_model.show_alphas()))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotypes[epoch], '200')))
# measure elapsed time
epoch_time.update(time.time() - start_time)
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
network.module.set_cal_mode("dynamic", genotype)
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log(
"[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format(
epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype
)
)
# search_model.set_cal_mode('urs')
# valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
# logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# search_model.set_cal_mode('joint')
# valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
# logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# search_model.set_cal_mode('select')
# valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
# logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
genotypes[epoch] = genotype
logger.log("<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch]))
# save checkpoint
save_path = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(xargs),
"search_model": search_model.state_dict(),
"w_optimizer": w_optimizer.state_dict(),
"a_optimizer": a_optimizer.state_dict(),
"w_scheduler": w_scheduler.state_dict(),
"genotypes": genotypes,
"valid_accuracies": valid_accuracies,
},
model_base_path,
logger,
)
last_info = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(args),
"last_checkpoint": save_path,
},
logger.path("info"),
logger,
)
with torch.no_grad():
logger.log("{:}".format(search_model.show_alphas()))
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
# the final post procedure : count the time
start_time = time.time()
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
search_time.update(time.time() - start_time)
network.module.set_cal_mode("dynamic", genotype)
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log("Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format(genotype, valid_a_top1))
# the final post procedure : count the time
start_time = time.time()
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
search_time.update(time.time() - start_time)
network.module.set_cal_mode('dynamic', genotype)
valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
logger.log('Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.'.format(genotype, valid_a_top1))
logger.log('\n' + '-'*100)
# check the performance from the architecture dataset
logger.log('SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(total_epoch, search_time.sum, genotype))
if api is not None: logger.log('{:}'.format(api.query_by_arch(genotype, '200') ))
logger.close()
logger.log("\n" + "-" * 100)
# check the performance from the architecture dataset
logger.log("SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(total_epoch, search_time.sum, genotype))
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotype, "200")))
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser("SETN")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--select_num', type=int, help='The number of selected architectures to evaluate.')
parser.add_argument('--track_running_stats',type=int, choices=[0,1],help='Whether use track_running_stats or not in the BN layer.')
parser.add_argument('--config_path', type=str, help='The path of the configuration.')
# architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, help='manual seed')
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
main(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser("SETN")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.")
parser.add_argument("--select_num", type=int, help="The number of selected architectures to evaluate.")
parser.add_argument(
"--track_running_stats",
type=int,
choices=[0, 1],
help="Whether use track_running_stats or not in the BN layer.",
)
parser.add_argument("--config_path", type=str, help="The path of the configuration.")
# architecture leraning rate
parser.add_argument("--arch_learning_rate", type=float, default=3e-4, help="learning rate for arch encoding")
parser.add_argument("--arch_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding")
# log
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.")
parser.add_argument(
"--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)."
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
main(args)

View File

@@ -10,212 +10,245 @@ from pathlib import Path
import torch
import torch.nn as nn
from torch.distributions import Categorical
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, dict2config, configure2str
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from nas_201_api import NASBench201API as API
from models import CellStructure, get_search_spaces
from datasets import get_datasets, SearchDataset
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
from utils import get_model_infos, obtain_accuracy
from log_utils import AverageMeter, time_string, convert_secs2time
from nas_201_api import NASBench201API as API
from models import CellStructure, get_search_spaces
from R_EA import train_and_eval
class Policy(nn.Module):
def __init__(self, max_nodes, search_space):
super(Policy, self).__init__()
self.max_nodes = max_nodes
self.search_space = deepcopy(search_space)
self.edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
self.edge2index[node_str] = len(self.edge2index)
self.arch_parameters = nn.Parameter(1e-3 * torch.randn(len(self.edge2index), len(search_space)))
def __init__(self, max_nodes, search_space):
super(Policy, self).__init__()
self.max_nodes = max_nodes
self.search_space = deepcopy(search_space)
self.edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
self.edge2index[ node_str ] = len(self.edge2index)
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(len(self.edge2index), len(search_space)) )
def generate_arch(self, actions):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_name = self.search_space[actions[self.edge2index[node_str]]]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return CellStructure(genotypes)
def generate_arch(self, actions):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_name = self.search_space[ actions[ self.edge2index[ node_str ] ] ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self.search_space[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return CellStructure(genotypes)
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
with torch.no_grad():
weights = self.arch_parameters[ self.edge2index[node_str] ]
op_name = self.search_space[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return CellStructure( genotypes )
def forward(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
return alphas
def forward(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
return alphas
class ExponentialMovingAverage(object):
"""Class that maintains an exponential moving average."""
"""Class that maintains an exponential moving average."""
def __init__(self, momentum):
self._numerator = 0
self._denominator = 0
self._momentum = momentum
def __init__(self, momentum):
self._numerator = 0
self._denominator = 0
self._momentum = momentum
def update(self, value):
self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
def update(self, value):
self._numerator = self._momentum * self._numerator + (1 - self._momentum) * value
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
def value(self):
"""Return the current value of the moving average"""
return self._numerator / self._denominator
def value(self):
"""Return the current value of the moving average"""
return self._numerator / self._denominator
def select_action(policy):
probs = policy()
m = Categorical(probs)
action = m.sample()
#policy.saved_log_probs.append(m.log_prob(action))
return m.log_prob(action), action.cpu().tolist()
probs = policy()
m = Categorical(probs)
action = m.sample()
# policy.saved_log_probs.append(m.log_prob(action))
return m.log_prob(action), action.cpu().tolist()
def main(xargs, nas_bench):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads( xargs.workers )
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
if xargs.dataset == 'cifar10':
dataname = 'cifar10-valid'
else:
dataname = xargs.dataset
if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log('Load split file from {:}'.format(split_Fpath))
config_path = 'configs/nas-benchmark/algos/R-EA.config'
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader}
else:
config_path = 'configs/nas-benchmark/algos/R-EA.config'
config = load_config(config_path, None, logger)
extra_info = {'config': config, 'train_loader': None, 'valid_loader': None}
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
search_space = get_search_spaces('cell', xargs.search_space_name)
policy = Policy(xargs.max_nodes, search_space)
optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
#optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate)
eps = np.finfo(np.float32).eps.item()
baseline = ExponentialMovingAverage(xargs.EMA_momentum)
logger.log('policy : {:}'.format(policy))
logger.log('optimizer : {:}'.format(optimizer))
logger.log('eps : {:}'.format(eps))
if xargs.dataset == "cifar10":
dataname = "cifar10-valid"
else:
dataname = xargs.dataset
if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
split_Fpath = "configs/nas-benchmark/cifar-split.txt"
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log("Load split file from {:}".format(split_Fpath))
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, {"class_num": class_num, "xshape": xshape}, logger)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
num_workers=xargs.workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
num_workers=xargs.workers,
pin_memory=True,
)
logger.log(
"||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(train_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {"config": config, "train_loader": train_loader, "valid_loader": valid_loader}
else:
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, None, logger)
extra_info = {"config": config, "train_loader": None, "valid_loader": None}
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
# nas dataset load
logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench))
search_space = get_search_spaces("cell", xargs.search_space_name)
policy = Policy(xargs.max_nodes, search_space)
optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
# optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate)
eps = np.finfo(np.float32).eps.item()
baseline = ExponentialMovingAverage(xargs.EMA_momentum)
logger.log("policy : {:}".format(policy))
logger.log("optimizer : {:}".format(optimizer))
logger.log("eps : {:}".format(eps))
# REINFORCE
# attempts = 0
x_start_time = time.time()
logger.log('Will start searching with time budget of {:} s.'.format(xargs.time_budget))
total_steps, total_costs, trace = 0, 0, []
#for istep in range(xargs.RL_steps):
while total_costs < xargs.time_budget:
start_time = time.time()
log_prob, action = select_action( policy )
arch = policy.generate_arch( action )
reward, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
trace.append( (reward, arch) )
# accumulate time
if total_costs + cost_time < xargs.time_budget:
total_costs += cost_time
else: break
# nas dataset load
logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench))
baseline.update(reward)
# calculate loss
policy_loss = ( -log_prob * (reward - baseline.value()) ).sum()
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
# accumulate time
total_costs += time.time() - start_time
total_steps += 1
logger.log('step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}'.format(total_steps, baseline.value(), policy_loss.item(), policy.genotype()))
#logger.log('----> {:}'.format(policy.arch_parameters))
#logger.log('')
# REINFORCE
# attempts = 0
x_start_time = time.time()
logger.log("Will start searching with time budget of {:} s.".format(xargs.time_budget))
total_steps, total_costs, trace = 0, 0, []
# for istep in range(xargs.RL_steps):
while total_costs < xargs.time_budget:
start_time = time.time()
log_prob, action = select_action(policy)
arch = policy.generate_arch(action)
reward, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
trace.append((reward, arch))
# accumulate time
if total_costs + cost_time < xargs.time_budget:
total_costs += cost_time
else:
break
# best_arch = policy.genotype() # first version
best_arch = max(trace, key=lambda x: x[0])[1]
logger.log('REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).'.format(total_steps, total_costs, time.time()-x_start_time))
info = nas_bench.query_by_arch(best_arch, '200')
if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch))
else : logger.log('{:}'.format(info))
logger.log('-'*100)
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch( best_arch )
baseline.update(reward)
# calculate loss
policy_loss = (-log_prob * (reward - baseline.value())).sum()
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
# accumulate time
total_costs += time.time() - start_time
total_steps += 1
logger.log(
"step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}".format(
total_steps, baseline.value(), policy_loss.item(), policy.genotype()
)
)
# logger.log('----> {:}'.format(policy.arch_parameters))
# logger.log('')
# best_arch = policy.genotype() # first version
best_arch = max(trace, key=lambda x: x[0])[1]
logger.log(
"REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).".format(
total_steps, total_costs, time.time() - x_start_time
)
)
info = nas_bench.query_by_arch(best_arch, "200")
if info is None:
logger.log("Did not find this architecture : {:}.".format(best_arch))
else:
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch(best_arch)
if __name__ == '__main__':
parser = argparse.ArgumentParser("The REINFORCE Algorithm")
parser.add_argument('--data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
# channels and number-of-cells
parser.add_argument('--search_space_name', type=str, help='The search space name.')
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
parser.add_argument('--channel', type=int, help='The number of channels.')
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
parser.add_argument('--learning_rate', type=float, help='The learning rate for REINFORCE.')
#parser.add_argument('--RL_steps', type=int, help='The steps for REINFORCE.')
parser.add_argument('--EMA_momentum', type=float, help='The momentum value for EMA.')
parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).')
# log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed')
args = parser.parse_args()
#if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset))
nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num = None, [], 500
for i in range(num):
print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000)
save_dir, index = main(args, nas_bench)
all_indexes.append( index )
torch.save(all_indexes, save_dir / 'results.pth')
else:
main(args, nas_bench)
if __name__ == "__main__":
parser = argparse.ArgumentParser("The REINFORCE Algorithm")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument("--num_cells", type=int, help="The number of cells in one stage.")
parser.add_argument("--learning_rate", type=float, help="The learning rate for REINFORCE.")
# parser.add_argument('--RL_steps', type=int, help='The steps for REINFORCE.')
parser.add_argument("--EMA_momentum", type=float, help="The momentum value for EMA.")
parser.add_argument("--time_budget", type=int, help="The total time cost budge for searching (in seconds).")
# log
parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
parser.add_argument("--save_dir", type=str, help="Folder to save checkpoints and log.")
parser.add_argument(
"--arch_nas_dataset", type=str, help="The path to load the architecture dataset (tiny-nas-benchmark)."
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
# if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print("{:} build NAS-Benchmark-API from {:}".format(time_string(), args.arch_nas_dataset))
nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num = None, [], 500
for i in range(num):
print("{:} : {:03d}/{:03d}".format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000)
save_dir, index = main(args, nas_bench)
all_indexes.append(index)
torch.save(all_indexes, save_dir / "results.pth")
else:
main(args, nas_bench)