Move algos to NAS-Bench-201-algos

This commit is contained in:
D-X-Y
2021-06-03 01:32:00 -07:00
parent 84462da79e
commit 5d7ccd445d
25 changed files with 33 additions and 33 deletions

View File

@@ -0,0 +1,367 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
###################################################################
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale #
# required to install hpbandster ##################################
# pip install hpbandster ##################################
###################################################################
# bash ./scripts-search/algos/BOHB.sh -1 ##################
###################################################################
import os, sys, time, random, argparse
from copy import deepcopy
from pathlib import Path
import torch
from xautodl.config_utils import load_config
from xautodl.datasets import get_datasets, SearchDataset
from xautodl.procedures import prepare_seed, prepare_logger
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import CellStructure, get_search_spaces
from nas_201_api import NASBench201API as API
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
import ConfigSpace
from hpbandster.optimizers.bohb import BOHB
import hpbandster.core.nameserver as hpns
from hpbandster.core.worker import Worker
def get_configuration_space(max_nodes, search_space):
cs = ConfigSpace.ConfigurationSpace()
# edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
cs.add_hyperparameter(
ConfigSpace.CategoricalHyperparameter(node_str, search_space)
)
return cs
def config2structure_func(max_nodes):
def config2structure(config):
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_name = config[node_str]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return CellStructure(genotypes)
return config2structure
class MyWorker(Worker):
def __init__(
self,
*args,
convert_func=None,
dataname=None,
nas_bench=None,
time_budget=None,
**kwargs
):
super().__init__(*args, **kwargs)
self.convert_func = convert_func
self._dataname = dataname
self._nas_bench = nas_bench
self.time_budget = time_budget
self.seen_archs = []
self.sim_cost_time = 0
self.real_cost_time = 0
self.is_end = False
def get_the_best(self):
assert len(self.seen_archs) > 0
best_index, best_acc = -1, None
for arch_index in self.seen_archs:
info = self._nas_bench.get_more_info(
arch_index, self._dataname, None, hp="200", is_random=True
)
vacc = info["valid-accuracy"]
if best_acc is None or best_acc < vacc:
best_acc = vacc
best_index = arch_index
assert best_index != -1
return best_index
def compute(self, config, budget, **kwargs):
start_time = time.time()
structure = self.convert_func(config)
arch_index = self._nas_bench.query_index_by_arch(structure)
info = self._nas_bench.get_more_info(
arch_index, self._dataname, None, hp="200", is_random=True
)
cur_time = info["train-all-time"] + info["valid-per-time"]
cur_vacc = info["valid-accuracy"]
self.real_cost_time += time.time() - start_time
if self.sim_cost_time + cur_time <= self.time_budget and not self.is_end:
self.sim_cost_time += cur_time
self.seen_archs.append(arch_index)
return {
"loss": 100 - float(cur_vacc),
"info": {
"seen-arch": len(self.seen_archs),
"sim-test-time": self.sim_cost_time,
"current-arch": arch_index,
},
}
else:
self.is_end = True
return {
"loss": 100,
"info": {
"seen-arch": len(self.seen_archs),
"sim-test-time": self.sim_cost_time,
"current-arch": None,
},
}
def main(xargs, nas_bench):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
if xargs.dataset == "cifar10":
dataname = "cifar10-valid"
else:
dataname = xargs.dataset
if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
split_Fpath = "configs/nas-benchmark/cifar-split.txt"
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log("Load split file from {:}".format(split_Fpath))
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(
config_path, {"class_num": class_num, "xshape": xshape}, logger
)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
num_workers=xargs.workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
num_workers=xargs.workers,
pin_memory=True,
)
logger.log(
"||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(train_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {
"config": config,
"train_loader": train_loader,
"valid_loader": valid_loader,
}
else:
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, None, logger)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {"config": config, "train_loader": None, "valid_loader": None}
# nas dataset load
assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset)
search_space = get_search_spaces("cell", xargs.search_space_name)
cs = get_configuration_space(xargs.max_nodes, search_space)
config2structure = config2structure_func(xargs.max_nodes)
hb_run_id = "0"
NS = hpns.NameServer(run_id=hb_run_id, host="localhost", port=0)
ns_host, ns_port = NS.start()
num_workers = 1
# nas_bench = AANASBenchAPI(xargs.arch_nas_dataset)
# logger.log('{:} Create NAS-BENCH-API DONE'.format(time_string()))
workers = []
for i in range(num_workers):
w = MyWorker(
nameserver=ns_host,
nameserver_port=ns_port,
convert_func=config2structure,
dataname=dataname,
nas_bench=nas_bench,
time_budget=xargs.time_budget,
run_id=hb_run_id,
id=i,
)
w.run(background=True)
workers.append(w)
start_time = time.time()
bohb = BOHB(
configspace=cs,
run_id=hb_run_id,
eta=3,
min_budget=12,
max_budget=200,
nameserver=ns_host,
nameserver_port=ns_port,
num_samples=xargs.num_samples,
random_fraction=xargs.random_fraction,
bandwidth_factor=xargs.bandwidth_factor,
ping_interval=10,
min_bandwidth=xargs.min_bandwidth,
)
results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
bohb.shutdown(shutdown_workers=True)
NS.shutdown()
real_cost_time = time.time() - start_time
id2config = results.get_id2config_mapping()
incumbent = results.get_incumbent_id()
logger.log(
"Best found configuration: {:} within {:.3f} s".format(
id2config[incumbent]["config"], real_cost_time
)
)
best_arch = config2structure(id2config[incumbent]["config"])
info = nas_bench.query_by_arch(best_arch, "200")
if info is None:
logger.log("Did not find this architecture : {:}.".format(best_arch))
else:
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.log(
"workers : {:.1f}s with {:} archs".format(
workers[0].time_budget, len(workers[0].seen_archs)
)
)
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch(best_arch), real_cost_time
if __name__ == "__main__":
parser = argparse.ArgumentParser(
"BOHB: Robust and Efficient Hyperparameter Optimization at Scale"
)
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument(
"--num_cells", type=int, help="The number of cells in one stage."
)
parser.add_argument(
"--time_budget",
type=int,
help="The total time cost budge for searching (in seconds).",
)
# BOHB
parser.add_argument(
"--strategy",
default="sampling",
type=str,
nargs="?",
help="optimization strategy for the acquisition function",
)
parser.add_argument(
"--min_bandwidth",
default=0.3,
type=float,
nargs="?",
help="minimum bandwidth for KDE",
)
parser.add_argument(
"--num_samples",
default=64,
type=int,
nargs="?",
help="number of samples for the acquisition function",
)
parser.add_argument(
"--random_fraction",
default=0.33,
type=float,
nargs="?",
help="fraction of random configurations",
)
parser.add_argument(
"--bandwidth_factor",
default=3,
type=int,
nargs="?",
help="factor multiplied to the bandwidth",
)
parser.add_argument(
"--n_iters",
default=100,
type=int,
nargs="?",
help="number of iterations for optimization method",
)
# log
parser.add_argument(
"--workers",
type=int,
default=2,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
parser.add_argument(
"--arch_nas_dataset",
type=str,
help="The path to load the architecture dataset (tiny-nas-benchmark).",
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
# if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print(
"{:} build NAS-Benchmark-API from {:}".format(
time_string(), args.arch_nas_dataset
)
)
nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num, all_times = None, [], 500, []
for i in range(num):
print("{:} : {:03d}/{:03d}".format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000)
save_dir, index, ctime = main(args, nas_bench)
all_indexes.append(index)
all_times.append(ctime)
print("\n average time : {:.3f} s".format(sum(all_times) / len(all_times)))
torch.save(all_indexes, save_dir / "results.pth")
else:
main(args, nas_bench)

View File

@@ -0,0 +1,417 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
########################################################
# DARTS: Differentiable Architecture Search, ICLR 2019 #
########################################################
import sys, time, random, argparse
from copy import deepcopy
import torch
from pathlib import Path
from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, get_nas_search_loaders
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
def search_func(
xloader,
network,
criterion,
scheduler,
w_optimizer,
a_optimizer,
epoch_str,
print_freq,
logger,
gradient_clip,
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.train()
end = time.time()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
xloader
):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
w_optimizer.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
if gradient_clip > 0:
torch.nn.utils.clip_grad_norm_(network.parameters(), gradient_clip)
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(
logits.data, base_targets.data, topk=(1, 5)
)
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update(base_prec1.item(), base_inputs.size(0))
base_top5.update(base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
a_optimizer.zero_grad()
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
arch_loss.backward()
a_optimizer.step()
# record
arch_prec1, arch_prec5 = obtain_accuracy(
logits.data, arch_targets.data, topk=(1, 5)
)
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = (
"*SEARCH* "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=base_losses, top1=base_top1, top5=base_top5
)
Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=arch_losses, top1=arch_top1, top5=arch_top5
)
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr)
return base_losses.avg, base_top1.avg, base_top5.avg
def valid_func(xloader, network, criterion):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(
logits.data, arch_targets.data, topk=(1, 5)
)
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(xargs):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
# config_path = 'configs/nas-benchmark/algos/DARTS.config'
config = load_config(
xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger
)
search_loader, _, valid_loader = get_nas_search_loaders(
train_data,
valid_data,
xargs.dataset,
"configs/nas-benchmark/",
config.batch_size,
xargs.workers,
)
logger.log(
"||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(search_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
search_space = get_search_spaces("cell", xargs.search_space_name)
if xargs.model_config is None:
model_config = dict2config(
{
"name": "DARTS-V1",
"C": xargs.channel,
"N": xargs.num_cells,
"max_nodes": xargs.max_nodes,
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
else:
model_config = load_config(
xargs.model_config,
{
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
search_model = get_cell_based_tiny_net(model_config)
logger.log("search-model :\n{:}".format(search_model))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(
search_model.get_weights(), config
)
a_optimizer = torch.optim.Adam(
search_model.get_alphas(),
lr=xargs.arch_learning_rate,
betas=(0.5, 0.999),
weight_decay=xargs.arch_weight_decay,
)
logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("a-optimizer : {:}".format(a_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion : {:}".format(criterion))
flop, param = get_model_infos(search_model, xshape)
# logger.log('{:}'.format(search_model))
logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log("{:} create API = {:} done".format(time_string(), api))
last_info, model_base_path, model_best_path = (
logger.path("info"),
logger.path("model"),
logger.path("best"),
)
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
if last_info.exists(): # automatically resume from previous checkpoint
logger.log(
"=> loading checkpoint of the last-info '{:}' start".format(last_info)
)
last_info = torch.load(last_info)
start_epoch = last_info["epoch"]
checkpoint = torch.load(last_info["last_checkpoint"])
genotypes = checkpoint["genotypes"]
valid_accuracies = checkpoint["valid_accuracies"]
search_model.load_state_dict(checkpoint["search_model"])
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
a_optimizer.load_state_dict(checkpoint["a_optimizer"])
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(
last_info, start_epoch
)
)
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = (
0,
{"best": -1},
{-1: search_model.genotype()},
)
# start training
start_time, search_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
AverageMeter(),
config.epochs + config.warmup,
)
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.val * (total_epoch - epoch), True)
)
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
logger.log(
"\n[Search the {:}-th epoch] {:}, LR={:}".format(
epoch_str, need_time, min(w_scheduler.get_lr())
)
)
search_w_loss, search_w_top1, search_w_top5 = search_func(
search_loader,
network,
criterion,
w_scheduler,
w_optimizer,
a_optimizer,
epoch_str,
xargs.print_freq,
logger,
xargs.gradient_clip,
)
search_time.update(time.time() - start_time)
logger.log(
"[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
)
)
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
valid_loader, network, criterion
)
logger.log(
"[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
epoch_str, valid_a_loss, valid_a_top1, valid_a_top5
)
)
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies["best"]:
valid_accuracies["best"] = valid_a_top1
genotypes["best"] = search_model.genotype()
find_best = True
else:
find_best = False
genotypes[epoch] = search_model.genotype()
logger.log(
"<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])
)
# save checkpoint
save_path = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(xargs),
"search_model": search_model.state_dict(),
"w_optimizer": w_optimizer.state_dict(),
"a_optimizer": a_optimizer.state_dict(),
"w_scheduler": w_scheduler.state_dict(),
"genotypes": genotypes,
"valid_accuracies": valid_accuracies,
},
model_base_path,
logger,
)
last_info = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(args),
"last_checkpoint": save_path,
},
logger.path("info"),
logger,
)
if find_best:
logger.log(
"<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format(
epoch_str, valid_a_top1
)
)
copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad():
# logger.log('arch-parameters :\n{:}'.format( nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu() ))
logger.log("{:}".format(search_model.show_alphas()))
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.log("\n" + "-" * 100)
logger.log(
"DARTS-V1 : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
total_epoch, search_time.sum, genotypes[total_epoch - 1]
)
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1], "200")))
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("DARTS first order")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument(
"--num_cells", type=int, help="The number of cells in one stage."
)
parser.add_argument(
"--track_running_stats",
type=int,
choices=[0, 1],
help="Whether use track_running_stats or not in the BN layer.",
)
parser.add_argument("--config_path", type=str, help="The config path.")
parser.add_argument(
"--model_config",
type=str,
help="The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.",
)
parser.add_argument("--gradient_clip", type=float, default=5, help="")
# architecture leraning rate
parser.add_argument(
"--arch_learning_rate",
type=float,
default=3e-4,
help="learning rate for arch encoding",
)
parser.add_argument(
"--arch_weight_decay",
type=float,
default=1e-3,
help="weight decay for arch encoding",
)
# log
parser.add_argument(
"--workers",
type=int,
default=2,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
parser.add_argument(
"--arch_nas_dataset",
type=str,
help="The path to load the architecture dataset (nas-benchmark).",
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
main(args)

View File

@@ -0,0 +1,496 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
########################################################
# DARTS: Differentiable Architecture Search, ICLR 2019 #
########################################################
import os, sys, time, glob, random, argparse
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, get_nas_search_loaders
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
def _concat(xs):
return torch.cat([x.view(-1) for x in xs])
def _hessian_vector_product(
vector, network, criterion, base_inputs, base_targets, r=1e-2
):
R = r / _concat(vector).norm()
for p, v in zip(network.module.get_weights(), vector):
p.data.add_(R, v)
_, logits = network(base_inputs)
loss = criterion(logits, base_targets)
grads_p = torch.autograd.grad(loss, network.module.get_alphas())
for p, v in zip(network.module.get_weights(), vector):
p.data.sub_(2 * R, v)
_, logits = network(base_inputs)
loss = criterion(logits, base_targets)
grads_n = torch.autograd.grad(loss, network.module.get_alphas())
for p, v in zip(network.module.get_weights(), vector):
p.data.add_(R, v)
return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)]
def backward_step_unrolled(
network,
criterion,
base_inputs,
base_targets,
w_optimizer,
arch_inputs,
arch_targets,
):
# _compute_unrolled_model
_, logits = network(base_inputs)
loss = criterion(logits, base_targets)
LR, WD, momentum = (
w_optimizer.param_groups[0]["lr"],
w_optimizer.param_groups[0]["weight_decay"],
w_optimizer.param_groups[0]["momentum"],
)
with torch.no_grad():
theta = _concat(network.module.get_weights())
try:
moment = _concat(
w_optimizer.state[v]["momentum_buffer"]
for v in network.module.get_weights()
)
moment = moment.mul_(momentum)
except:
moment = torch.zeros_like(theta)
dtheta = (
_concat(torch.autograd.grad(loss, network.module.get_weights()))
+ WD * theta
)
params = theta.sub(LR, moment + dtheta)
unrolled_model = deepcopy(network)
model_dict = unrolled_model.state_dict()
new_params, offset = {}, 0
for k, v in network.named_parameters():
if "arch_parameters" in k:
continue
v_length = np.prod(v.size())
new_params[k] = params[offset : offset + v_length].view(v.size())
offset += v_length
model_dict.update(new_params)
unrolled_model.load_state_dict(model_dict)
unrolled_model.zero_grad()
_, unrolled_logits = unrolled_model(arch_inputs)
unrolled_loss = criterion(unrolled_logits, arch_targets)
unrolled_loss.backward()
dalpha = unrolled_model.module.arch_parameters.grad
vector = [v.grad.data for v in unrolled_model.module.get_weights()]
[implicit_grads] = _hessian_vector_product(
vector, network, criterion, base_inputs, base_targets
)
dalpha.data.sub_(LR, implicit_grads.data)
if network.module.arch_parameters.grad is None:
network.module.arch_parameters.grad = deepcopy(dalpha)
else:
network.module.arch_parameters.grad.data.copy_(dalpha.data)
return unrolled_loss.detach(), unrolled_logits.detach()
def search_func(
xloader,
network,
criterion,
scheduler,
w_optimizer,
a_optimizer,
epoch_str,
print_freq,
logger,
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.train()
end = time.time()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
xloader
):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the architecture-weight
a_optimizer.zero_grad()
arch_loss, arch_logits = backward_step_unrolled(
network,
criterion,
base_inputs,
base_targets,
w_optimizer,
arch_inputs,
arch_targets,
)
a_optimizer.step()
# record
arch_prec1, arch_prec5 = obtain_accuracy(
arch_logits.data, arch_targets.data, topk=(1, 5)
)
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# update the weights
w_optimizer.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(
logits.data, base_targets.data, topk=(1, 5)
)
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update(base_prec1.item(), base_inputs.size(0))
base_top5.update(base_prec5.item(), base_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = (
"*SEARCH* "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=base_losses, top1=base_top1, top5=base_top5
)
Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=arch_losses, top1=arch_top1, top5=arch_top5
)
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr)
return base_losses.avg, base_top1.avg, base_top5.avg
def valid_func(xloader, network, criterion):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(
logits.data, arch_targets.data, topk=(1, 5)
)
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(xargs):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
config = load_config(
xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger
)
search_loader, _, valid_loader = get_nas_search_loaders(
train_data,
valid_data,
xargs.dataset,
"configs/nas-benchmark/",
config.batch_size,
xargs.workers,
)
logger.log(
"||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(search_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
search_space = get_search_spaces("cell", xargs.search_space_name)
model_config = dict2config(
{
"name": "DARTS-V2",
"C": xargs.channel,
"N": xargs.num_cells,
"max_nodes": xargs.max_nodes,
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
search_model = get_cell_based_tiny_net(model_config)
logger.log("search-model :\n{:}".format(search_model))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(
search_model.get_weights(), config
)
a_optimizer = torch.optim.Adam(
search_model.get_alphas(),
lr=xargs.arch_learning_rate,
betas=(0.5, 0.999),
weight_decay=xargs.arch_weight_decay,
)
logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("a-optimizer : {:}".format(a_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion : {:}".format(criterion))
flop, param = get_model_infos(search_model, xshape)
# logger.log('{:}'.format(search_model))
logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log("{:} create API = {:} done".format(time_string(), api))
last_info, model_base_path, model_best_path = (
logger.path("info"),
logger.path("model"),
logger.path("best"),
)
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
if last_info.exists(): # automatically resume from previous checkpoint
logger.log(
"=> loading checkpoint of the last-info '{:}' start".format(last_info)
)
last_info = torch.load(last_info)
start_epoch = last_info["epoch"]
checkpoint = torch.load(last_info["last_checkpoint"])
genotypes = checkpoint["genotypes"]
valid_accuracies = checkpoint["valid_accuracies"]
search_model.load_state_dict(checkpoint["search_model"])
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
a_optimizer.load_state_dict(checkpoint["a_optimizer"])
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(
last_info, start_epoch
)
)
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = (
0,
{"best": -1},
{-1: search_model.genotype()},
)
# start training
start_time, search_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
AverageMeter(),
config.epochs + config.warmup,
)
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.val * (total_epoch - epoch), True)
)
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
min_LR = min(w_scheduler.get_lr())
logger.log(
"\n[Search the {:}-th epoch] {:}, LR={:}".format(
epoch_str, need_time, min_LR
)
)
search_w_loss, search_w_top1, search_w_top5 = search_func(
search_loader,
network,
criterion,
w_scheduler,
w_optimizer,
a_optimizer,
epoch_str,
xargs.print_freq,
logger,
)
search_time.update(time.time() - start_time)
logger.log(
"[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
)
)
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
valid_loader, network, criterion
)
logger.log(
"[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
epoch_str, valid_a_loss, valid_a_top1, valid_a_top5
)
)
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies["best"]:
valid_accuracies["best"] = valid_a_top1
genotypes["best"] = search_model.genotype()
find_best = True
else:
find_best = False
genotypes[epoch] = search_model.genotype()
logger.log(
"<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])
)
# save checkpoint
save_path = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(xargs),
"search_model": search_model.state_dict(),
"w_optimizer": w_optimizer.state_dict(),
"a_optimizer": a_optimizer.state_dict(),
"w_scheduler": w_scheduler.state_dict(),
"genotypes": genotypes,
"valid_accuracies": valid_accuracies,
},
model_base_path,
logger,
)
last_info = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(args),
"last_checkpoint": save_path,
},
logger.path("info"),
logger,
)
if find_best:
logger.log(
"<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format(
epoch_str, valid_a_top1
)
)
copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad():
logger.log(
"arch-parameters :\n{:}".format(
nn.functional.softmax(search_model.arch_parameters, dim=-1).cpu()
)
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.log("\n" + "-" * 100)
# check the performance from the architecture dataset
logger.log(
"DARTS-V2 : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
total_epoch, search_time.sum, genotypes[total_epoch - 1]
)
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1]), "200"))
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("DARTS Second Order")
parser.add_argument("--data_path", type=str, help="The path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--config_path", type=str, help="The config path.")
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument(
"--num_cells", type=int, help="The number of cells in one stage."
)
parser.add_argument(
"--track_running_stats",
type=int,
choices=[0, 1],
help="Whether use track_running_stats or not in the BN layer.",
)
# architecture leraning rate
parser.add_argument(
"--arch_learning_rate",
type=float,
default=3e-4,
help="learning rate for arch encoding",
)
parser.add_argument(
"--arch_weight_decay",
type=float,
default=1e-3,
help="weight decay for arch encoding",
)
# log
parser.add_argument(
"--workers",
type=int,
default=2,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
parser.add_argument(
"--arch_nas_dataset",
type=str,
help="The path to load the architecture dataset (tiny-nas-benchmark).",
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
main(args)

View File

@@ -0,0 +1,578 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##########################################################################
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
##########################################################################
import os, sys, time, glob, random, argparse
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, get_nas_search_loaders
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
def train_shared_cnn(
xloader,
shared_cnn,
controller,
criterion,
scheduler,
optimizer,
epoch_str,
print_freq,
logger,
):
data_time, batch_time = AverageMeter(), AverageMeter()
losses, top1s, top5s, xend = (
AverageMeter(),
AverageMeter(),
AverageMeter(),
time.time(),
)
shared_cnn.train()
controller.eval()
for step, (inputs, targets) in enumerate(xloader):
scheduler.update(None, 1.0 * step / len(xloader))
targets = targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - xend)
with torch.no_grad():
_, _, sampled_arch = controller()
optimizer.zero_grad()
shared_cnn.module.update_arch(sampled_arch)
_, logits = shared_cnn(inputs)
loss = criterion(logits, targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5)
optimizer.step()
# record
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1s.update(prec1.item(), inputs.size(0))
top5s.update(prec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - xend)
xend = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = (
"*Train-Shared-CNN* "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=losses, top1=top1s, top5=top5s
)
logger.log(Sstr + " " + Tstr + " " + Wstr)
return losses.avg, top1s.avg, top5s.avg
def train_controller(
xloader,
shared_cnn,
controller,
criterion,
optimizer,
config,
epoch_str,
print_freq,
logger,
):
# config. (containing some necessary arg)
# baseline: The baseline score (i.e. average val_acc) from the previous epoch
data_time, batch_time = AverageMeter(), AverageMeter()
(
GradnormMeter,
LossMeter,
ValAccMeter,
EntropyMeter,
BaselineMeter,
RewardMeter,
xend,
) = (
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
time.time(),
)
shared_cnn.eval()
controller.train()
controller.zero_grad()
# for step, (inputs, targets) in enumerate(xloader):
loader_iter = iter(xloader)
for step in range(config.ctl_train_steps * config.ctl_num_aggre):
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
targets = targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - xend)
log_prob, entropy, sampled_arch = controller()
with torch.no_grad():
shared_cnn.module.update_arch(sampled_arch)
_, logits = shared_cnn(inputs)
val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
val_top1 = val_top1.view(-1) / 100
reward = val_top1 + config.ctl_entropy_w * entropy
if config.baseline is None:
baseline = val_top1
else:
baseline = config.baseline - (1 - config.ctl_bl_dec) * (
config.baseline - reward
)
loss = -1 * log_prob * (reward - baseline)
# account
RewardMeter.update(reward.item())
BaselineMeter.update(baseline.item())
ValAccMeter.update(val_top1.item() * 100)
LossMeter.update(loss.item())
EntropyMeter.update(entropy.item())
# Average gradient over controller_num_aggregate samples
loss = loss / config.ctl_num_aggre
loss.backward(retain_graph=True)
# measure elapsed time
batch_time.update(time.time() - xend)
xend = time.time()
if (step + 1) % config.ctl_num_aggre == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), 5.0)
GradnormMeter.update(grad_norm)
optimizer.step()
controller.zero_grad()
if step % print_freq == 0:
Sstr = (
"*Train-Controller* "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(
epoch_str, step, config.ctl_train_steps * config.ctl_num_aggre
)
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})".format(
loss=LossMeter,
top1=ValAccMeter,
reward=RewardMeter,
basel=BaselineMeter,
)
Estr = "Entropy={:.4f} ({:.4f})".format(EntropyMeter.val, EntropyMeter.avg)
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Estr)
return (
LossMeter.avg,
ValAccMeter.avg,
BaselineMeter.avg,
RewardMeter.avg,
baseline.item(),
)
def get_best_arch(controller, shared_cnn, xloader, n_samples=10):
with torch.no_grad():
controller.eval()
shared_cnn.eval()
archs, valid_accs = [], []
loader_iter = iter(xloader)
for i in range(n_samples):
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
_, _, sampled_arch = controller()
arch = shared_cnn.module.update_arch(sampled_arch)
_, logits = shared_cnn(inputs)
val_top1, val_top5 = obtain_accuracy(
logits.cpu().data, targets.data, topk=(1, 5)
)
archs.append(arch)
valid_accs.append(val_top1.item())
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
return best_arch, best_valid_acc
def valid_func(xloader, network, criterion):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(
logits.data, arch_targets.data, topk=(1, 5)
)
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(xargs):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, test_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
logger.log("use config from : {:}".format(xargs.config_path))
config = load_config(
xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger
)
_, train_loader, valid_loader = get_nas_search_loaders(
train_data,
test_data,
xargs.dataset,
"configs/nas-benchmark/",
config.batch_size,
xargs.workers,
)
# since ENAS will train the controller on valid-loader, we need to use train transformation for valid-loader
valid_loader.dataset.transform = deepcopy(train_loader.dataset.transform)
if hasattr(valid_loader.dataset, "transforms"):
valid_loader.dataset.transforms = deepcopy(train_loader.dataset.transforms)
# data loader
logger.log(
"||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(train_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
search_space = get_search_spaces("cell", xargs.search_space_name)
model_config = dict2config(
{
"name": "ENAS",
"C": xargs.channel,
"N": xargs.num_cells,
"max_nodes": xargs.max_nodes,
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
shared_cnn = get_cell_based_tiny_net(model_config)
controller = shared_cnn.create_controller()
w_optimizer, w_scheduler, criterion = get_optim_scheduler(
shared_cnn.parameters(), config
)
a_optimizer = torch.optim.Adam(
controller.parameters(),
lr=config.controller_lr,
betas=config.controller_betas,
eps=config.controller_eps,
)
logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("a-optimizer : {:}".format(a_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion : {:}".format(criterion))
# flop, param = get_model_infos(shared_cnn, xshape)
# logger.log('{:}'.format(shared_cnn))
# logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
logger.log("search-space : {:}".format(search_space))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log("{:} create API = {:} done".format(time_string(), api))
shared_cnn, controller, criterion = (
torch.nn.DataParallel(shared_cnn).cuda(),
controller.cuda(),
criterion.cuda(),
)
last_info, model_base_path, model_best_path = (
logger.path("info"),
logger.path("model"),
logger.path("best"),
)
if last_info.exists(): # automatically resume from previous checkpoint
logger.log(
"=> loading checkpoint of the last-info '{:}' start".format(last_info)
)
last_info = torch.load(last_info)
start_epoch = last_info["epoch"]
checkpoint = torch.load(last_info["last_checkpoint"])
genotypes = checkpoint["genotypes"]
baseline = checkpoint["baseline"]
valid_accuracies = checkpoint["valid_accuracies"]
shared_cnn.load_state_dict(checkpoint["shared_cnn"])
controller.load_state_dict(checkpoint["controller"])
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
a_optimizer.load_state_dict(checkpoint["a_optimizer"])
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(
last_info, start_epoch
)
)
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes, baseline = 0, {"best": -1}, {}, None
# start training
start_time, search_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
AverageMeter(),
config.epochs + config.warmup,
)
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.val * (total_epoch - epoch), True)
)
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
logger.log(
"\n[Search the {:}-th epoch] {:}, LR={:}, baseline={:}".format(
epoch_str, need_time, min(w_scheduler.get_lr()), baseline
)
)
cnn_loss, cnn_top1, cnn_top5 = train_shared_cnn(
train_loader,
shared_cnn,
controller,
criterion,
w_scheduler,
w_optimizer,
epoch_str,
xargs.print_freq,
logger,
)
logger.log(
"[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
epoch_str, cnn_loss, cnn_top1, cnn_top5
)
)
ctl_loss, ctl_acc, ctl_baseline, ctl_reward, baseline = train_controller(
valid_loader,
shared_cnn,
controller,
criterion,
a_optimizer,
dict2config(
{
"baseline": baseline,
"ctl_train_steps": xargs.controller_train_steps,
"ctl_num_aggre": xargs.controller_num_aggregate,
"ctl_entropy_w": xargs.controller_entropy_weight,
"ctl_bl_dec": xargs.controller_bl_dec,
},
None,
),
epoch_str,
xargs.print_freq,
logger,
)
search_time.update(time.time() - start_time)
logger.log(
"[{:}] controller : loss={:.2f}, accuracy={:.2f}%, baseline={:.2f}, reward={:.2f}, current-baseline={:.4f}, time-cost={:.1f} s".format(
epoch_str,
ctl_loss,
ctl_acc,
ctl_baseline,
ctl_reward,
baseline,
search_time.sum,
)
)
best_arch, _ = get_best_arch(controller, shared_cnn, valid_loader)
shared_cnn.module.update_arch(best_arch)
_, best_valid_acc, _ = valid_func(valid_loader, shared_cnn, criterion)
genotypes[epoch] = best_arch
# check the best accuracy
valid_accuracies[epoch] = best_valid_acc
if best_valid_acc > valid_accuracies["best"]:
valid_accuracies["best"] = best_valid_acc
genotypes["best"] = best_arch
find_best = True
else:
find_best = False
logger.log(
"<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])
)
# save checkpoint
save_path = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(xargs),
"baseline": baseline,
"shared_cnn": shared_cnn.state_dict(),
"controller": controller.state_dict(),
"w_optimizer": w_optimizer.state_dict(),
"a_optimizer": a_optimizer.state_dict(),
"w_scheduler": w_scheduler.state_dict(),
"genotypes": genotypes,
"valid_accuracies": valid_accuracies,
},
model_base_path,
logger,
)
last_info = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(args),
"last_checkpoint": save_path,
},
logger.path("info"),
logger,
)
if find_best:
logger.log(
"<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format(
epoch_str, best_valid_acc
)
)
copy_checkpoint(model_base_path, model_best_path, logger)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.log("\n" + "-" * 100)
logger.log(
"During searching, the best architecture is {:}".format(genotypes["best"])
)
logger.log("Its accuracy is {:.2f}%".format(valid_accuracies["best"]))
logger.log(
"Randomly select {:} architectures and select the best.".format(
xargs.controller_num_samples
)
)
start_time = time.time()
final_arch, _ = get_best_arch(
controller, shared_cnn, valid_loader, xargs.controller_num_samples
)
search_time.update(time.time() - start_time)
shared_cnn.module.update_arch(final_arch)
final_loss, final_top1, final_top5 = valid_func(valid_loader, shared_cnn, criterion)
logger.log("The Selected Final Architecture : {:}".format(final_arch))
logger.log(
"Loss={:.3f}, Accuracy@1={:.2f}%, Accuracy@5={:.2f}%".format(
final_loss, final_top1, final_top5
)
)
logger.log(
"ENAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
total_epoch, search_time.sum, final_arch
)
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(final_arch)))
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("ENAS")
parser.add_argument("--data_path", type=str, help="The path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument(
"--track_running_stats",
type=int,
choices=[0, 1],
help="Whether use track_running_stats or not in the BN layer.",
)
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument(
"--num_cells", type=int, help="The number of cells in one stage."
)
parser.add_argument(
"--config_path", type=str, help="The config file to train ENAS."
)
parser.add_argument("--controller_train_steps", type=int, help=".")
parser.add_argument("--controller_num_aggregate", type=int, help=".")
parser.add_argument(
"--controller_entropy_weight",
type=float,
help="The weight for the entropy of the controller.",
)
parser.add_argument("--controller_bl_dec", type=float, help=".")
parser.add_argument("--controller_num_samples", type=int, help=".")
# log
parser.add_argument(
"--workers",
type=int,
default=2,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
parser.add_argument(
"--arch_nas_dataset",
type=str,
help="The path to load the architecture dataset (nas-benchmark).",
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
main(args)

View File

@@ -0,0 +1,404 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import sys, time, random, argparse
from copy import deepcopy
import torch
from xautodl.config_utils import load_config, dict2config
from xautodl.datasets import get_datasets, get_nas_search_loaders
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
def search_func(
xloader,
network,
criterion,
scheduler,
w_optimizer,
a_optimizer,
epoch_str,
print_freq,
logger,
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.train()
end = time.time()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
xloader
):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
w_optimizer.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(
logits.data, base_targets.data, topk=(1, 5)
)
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update(base_prec1.item(), base_inputs.size(0))
base_top5.update(base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
a_optimizer.zero_grad()
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
arch_loss.backward()
a_optimizer.step()
# record
arch_prec1, arch_prec5 = obtain_accuracy(
logits.data, arch_targets.data, topk=(1, 5)
)
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = (
"*SEARCH* "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=base_losses, top1=base_top1, top5=base_top5
)
Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=arch_losses, top1=arch_top1, top5=arch_top5
)
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr)
return (
base_losses.avg,
base_top1.avg,
base_top5.avg,
arch_losses.avg,
arch_top1.avg,
arch_top5.avg,
)
def main(xargs):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
# config_path = 'configs/nas-benchmark/algos/GDAS.config'
config = load_config(
xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger
)
search_loader, _, valid_loader = get_nas_search_loaders(
train_data,
valid_data,
xargs.dataset,
"configs/nas-benchmark/",
config.batch_size,
xargs.workers,
)
logger.log(
"||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(search_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
search_space = get_search_spaces("cell", xargs.search_space_name)
if xargs.model_config is None:
model_config = dict2config(
{
"name": "GDAS",
"C": xargs.channel,
"N": xargs.num_cells,
"max_nodes": xargs.max_nodes,
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
else:
model_config = load_config(
xargs.model_config,
{
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
search_model = get_cell_based_tiny_net(model_config)
logger.log("search-model :\n{:}".format(search_model))
logger.log("model-config : {:}".format(model_config))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(
search_model.get_weights(), config
)
a_optimizer = torch.optim.Adam(
search_model.get_alphas(),
lr=xargs.arch_learning_rate,
betas=(0.5, 0.999),
weight_decay=xargs.arch_weight_decay,
)
logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("a-optimizer : {:}".format(a_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion : {:}".format(criterion))
flop, param = get_model_infos(search_model, xshape)
logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param))
logger.log("search-space [{:} ops] : {:}".format(len(search_space), search_space))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log("{:} create API = {:} done".format(time_string(), api))
last_info, model_base_path, model_best_path = (
logger.path("info"),
logger.path("model"),
logger.path("best"),
)
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
if last_info.exists(): # automatically resume from previous checkpoint
logger.log(
"=> loading checkpoint of the last-info '{:}' start".format(last_info)
)
last_info = torch.load(last_info)
start_epoch = last_info["epoch"]
checkpoint = torch.load(last_info["last_checkpoint"])
genotypes = checkpoint["genotypes"]
valid_accuracies = checkpoint["valid_accuracies"]
search_model.load_state_dict(checkpoint["search_model"])
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
a_optimizer.load_state_dict(checkpoint["a_optimizer"])
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(
last_info, start_epoch
)
)
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = (
0,
{"best": -1},
{-1: search_model.genotype()},
)
# start training
start_time, search_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
AverageMeter(),
config.epochs + config.warmup,
)
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.val * (total_epoch - epoch), True)
)
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
search_model.set_tau(
xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)
)
logger.log(
"\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}".format(
epoch_str, need_time, search_model.get_tau(), min(w_scheduler.get_lr())
)
)
(
search_w_loss,
search_w_top1,
search_w_top5,
valid_a_loss,
valid_a_top1,
valid_a_top5,
) = search_func(
search_loader,
network,
criterion,
w_scheduler,
w_optimizer,
a_optimizer,
epoch_str,
xargs.print_freq,
logger,
)
search_time.update(time.time() - start_time)
logger.log(
"[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
)
)
logger.log(
"[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
epoch_str, valid_a_loss, valid_a_top1, valid_a_top5
)
)
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies["best"]:
valid_accuracies["best"] = valid_a_top1
genotypes["best"] = search_model.genotype()
find_best = True
else:
find_best = False
genotypes[epoch] = search_model.genotype()
logger.log(
"<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])
)
# save checkpoint
save_path = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(xargs),
"search_model": search_model.state_dict(),
"w_optimizer": w_optimizer.state_dict(),
"a_optimizer": a_optimizer.state_dict(),
"w_scheduler": w_scheduler.state_dict(),
"genotypes": genotypes,
"valid_accuracies": valid_accuracies,
},
model_base_path,
logger,
)
last_info = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(args),
"last_checkpoint": save_path,
},
logger.path("info"),
logger,
)
if find_best:
logger.log(
"<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format(
epoch_str, valid_a_top1
)
)
copy_checkpoint(model_base_path, model_best_path, logger)
with torch.no_grad():
logger.log("{:}".format(search_model.show_alphas()))
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.log("\n" + "-" * 100)
# check the performance from the architecture dataset
logger.log(
"GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
total_epoch, search_time.sum, genotypes[total_epoch - 1]
)
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[total_epoch - 1], "200")))
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("GDAS")
parser.add_argument("--data_path", type=str, help="The path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument(
"--num_cells", type=int, help="The number of cells in one stage."
)
parser.add_argument(
"--track_running_stats",
type=int,
choices=[0, 1],
help="Whether use track_running_stats or not in the BN layer.",
)
parser.add_argument(
"--config_path", type=str, help="The path of the configuration."
)
parser.add_argument(
"--model_config",
type=str,
help="The path of the model configuration. When this arg is set, it will cover max_nodes / channels / num_cells.",
)
# architecture leraning rate
parser.add_argument(
"--arch_learning_rate",
type=float,
default=3e-4,
help="learning rate for arch encoding",
)
parser.add_argument(
"--arch_weight_decay",
type=float,
default=1e-3,
help="weight decay for arch encoding",
)
parser.add_argument("--tau_min", type=float, help="The minimum tau for Gumbel")
parser.add_argument("--tau_max", type=float, help="The maximum tau for Gumbel")
# log
parser.add_argument(
"--workers",
type=int,
default=2,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
parser.add_argument(
"--arch_nas_dataset",
type=str,
help="The path to load the architecture dataset (tiny-nas-benchmark).",
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
main(args)

View File

@@ -0,0 +1,382 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##############################################################################
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 #
##############################################################################
import os, sys, time, glob, random, argparse
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, get_nas_search_loaders
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
def search_func(
xloader, network, criterion, scheduler, w_optimizer, epoch_str, print_freq, logger
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.train()
end = time.time()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
xloader
):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
network.module.random_genotype(True)
w_optimizer.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
nn.utils.clip_grad_norm_(network.parameters(), 5)
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(
logits.data, base_targets.data, topk=(1, 5)
)
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update(base_prec1.item(), base_inputs.size(0))
base_top5.update(base_prec5.item(), base_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = (
"*SEARCH* "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=base_losses, top1=base_top1, top5=base_top5
)
logger.log(Sstr + " " + Tstr + " " + Wstr)
return base_losses.avg, base_top1.avg, base_top5.avg
def valid_func(xloader, network, criterion):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
network.eval()
end = time.time()
with torch.no_grad():
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
network.module.random_genotype(True)
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(
logits.data, arch_targets.data, topk=(1, 5)
)
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def search_find_best(xloader, network, n_samples):
with torch.no_grad():
network.eval()
archs, valid_accs = [], []
# print ('obtain the top-{:} architectures'.format(n_samples))
loader_iter = iter(xloader)
for i in range(n_samples):
arch = network.module.random_genotype(True)
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
_, logits = network(inputs)
val_top1, val_top5 = obtain_accuracy(
logits.cpu().data, targets.data, topk=(1, 5)
)
archs.append(arch)
valid_accs.append(val_top1.item())
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
return best_arch, best_valid_acc
def main(xargs):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
config = load_config(
xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger
)
search_loader, _, valid_loader = get_nas_search_loaders(
train_data,
valid_data,
xargs.dataset,
"configs/nas-benchmark/",
(config.batch_size, config.test_batch_size),
xargs.workers,
)
logger.log(
"||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(search_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
search_space = get_search_spaces("cell", xargs.search_space_name)
model_config = dict2config(
{
"name": "RANDOM",
"C": xargs.channel,
"N": xargs.num_cells,
"max_nodes": xargs.max_nodes,
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": bool(xargs.track_running_stats),
},
None,
)
search_model = get_cell_based_tiny_net(model_config)
w_optimizer, w_scheduler, criterion = get_optim_scheduler(
search_model.parameters(), config
)
logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion : {:}".format(criterion))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log("{:} create API = {:} done".format(time_string(), api))
last_info, model_base_path, model_best_path = (
logger.path("info"),
logger.path("model"),
logger.path("best"),
)
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
if last_info.exists(): # automatically resume from previous checkpoint
logger.log(
"=> loading checkpoint of the last-info '{:}' start".format(last_info)
)
last_info = torch.load(last_info)
start_epoch = last_info["epoch"]
checkpoint = torch.load(last_info["last_checkpoint"])
genotypes = checkpoint["genotypes"]
valid_accuracies = checkpoint["valid_accuracies"]
search_model.load_state_dict(checkpoint["search_model"])
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(
last_info, start_epoch
)
)
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {}
# start training
start_time, search_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
AverageMeter(),
config.epochs + config.warmup,
)
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.val * (total_epoch - epoch), True)
)
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
logger.log(
"\n[Search the {:}-th epoch] {:}, LR={:}".format(
epoch_str, need_time, min(w_scheduler.get_lr())
)
)
# selected_arch = search_find_best(valid_loader, network, criterion, xargs.select_num)
search_w_loss, search_w_top1, search_w_top5 = search_func(
search_loader,
network,
criterion,
w_scheduler,
w_optimizer,
epoch_str,
xargs.print_freq,
logger,
)
search_time.update(time.time() - start_time)
logger.log(
"[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
)
)
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
valid_loader, network, criterion
)
logger.log(
"[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
epoch_str, valid_a_loss, valid_a_top1, valid_a_top5
)
)
cur_arch, cur_valid_acc = search_find_best(
valid_loader, network, xargs.select_num
)
logger.log(
"[{:}] find-the-best : {:}, accuracy@1={:.2f}%".format(
epoch_str, cur_arch, cur_valid_acc
)
)
genotypes[epoch] = cur_arch
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
if valid_a_top1 > valid_accuracies["best"]:
valid_accuracies["best"] = valid_a_top1
find_best = True
else:
find_best = False
# save checkpoint
save_path = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(xargs),
"search_model": search_model.state_dict(),
"w_optimizer": w_optimizer.state_dict(),
"w_scheduler": w_scheduler.state_dict(),
"genotypes": genotypes,
"valid_accuracies": valid_accuracies,
},
model_base_path,
logger,
)
last_info = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(args),
"last_checkpoint": save_path,
},
logger.path("info"),
logger,
)
if find_best:
logger.log(
"<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.".format(
epoch_str, valid_a_top1
)
)
copy_checkpoint(model_base_path, model_best_path, logger)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
logger.log("\n" + "-" * 200)
logger.log("Pre-searching costs {:.1f} s".format(search_time.sum))
start_time = time.time()
best_arch, best_acc = search_find_best(valid_loader, network, xargs.select_num)
search_time.update(time.time() - start_time)
logger.log(
"RANDOM-NAS finds the best one : {:} with accuracy={:.2f}%, with {:.1f} s.".format(
best_arch, best_acc, search_time.sum
)
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(best_arch, "200")))
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Random search for NAS.")
parser.add_argument("--data_path", type=str, help="The path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument(
"--config_path", type=str, help="The path to the configuration."
)
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument(
"--num_cells", type=int, help="The number of cells in one stage."
)
parser.add_argument(
"--select_num",
type=int,
help="The number of selected architectures to evaluate.",
)
parser.add_argument(
"--track_running_stats",
type=int,
choices=[0, 1],
help="Whether use track_running_stats or not in the BN layer.",
)
# log
parser.add_argument(
"--workers",
type=int,
default=2,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
parser.add_argument(
"--arch_nas_dataset",
type=str,
help="The path to load the architecture dataset (tiny-nas-benchmark).",
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
main(args)

View File

@@ -0,0 +1,189 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##############################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, SearchDataset
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import get_search_spaces
from nas_201_api import NASBench201API as API
from R_EA import train_and_eval, random_architecture_func
def main(xargs, nas_bench):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
if xargs.dataset == "cifar10":
dataname = "cifar10-valid"
else:
dataname = xargs.dataset
if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
split_Fpath = "configs/nas-benchmark/cifar-split.txt"
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log("Load split file from {:}".format(split_Fpath))
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(
config_path, {"class_num": class_num, "xshape": xshape}, logger
)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
num_workers=xargs.workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
num_workers=xargs.workers,
pin_memory=True,
)
logger.log(
"||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(train_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {
"config": config,
"train_loader": train_loader,
"valid_loader": valid_loader,
}
else:
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, None, logger)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {"config": config, "train_loader": None, "valid_loader": None}
search_space = get_search_spaces("cell", xargs.search_space_name)
random_arch = random_architecture_func(xargs.max_nodes, search_space)
# x =random_arch() ; y = mutate_arch(x)
x_start_time = time.time()
logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench))
best_arch, best_acc, total_time_cost, history = None, -1, 0, []
# for idx in range(xargs.random_num):
while total_time_cost < xargs.time_budget:
arch = random_arch()
accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
if total_time_cost + cost_time > xargs.time_budget:
break
else:
total_time_cost += cost_time
history.append(arch)
if best_arch is None or best_acc < accuracy:
best_acc, best_arch = accuracy, arch
logger.log(
"[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy)
)
logger.log(
"{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).".format(
time_string(),
best_arch,
best_acc,
len(history),
total_time_cost,
time.time() - x_start_time,
)
)
info = nas_bench.query_by_arch(best_arch, "200")
if info is None:
logger.log("Did not find this architecture : {:}.".format(best_arch))
else:
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch(best_arch)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Random NAS")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument(
"--num_cells", type=int, help="The number of cells in one stage."
)
# parser.add_argument('--random_num', type=int, help='The number of random selected architectures.')
parser.add_argument(
"--time_budget",
type=int,
help="The total time cost budge for searching (in seconds).",
)
# log
parser.add_argument(
"--workers",
type=int,
default=2,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
parser.add_argument(
"--arch_nas_dataset",
type=str,
help="The path to load the architecture dataset (tiny-nas-benchmark).",
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
# if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print(
"{:} build NAS-Benchmark-API from {:}".format(
time_string(), args.arch_nas_dataset
)
)
nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num = None, [], 500
for i in range(num):
print("{:} : {:03d}/{:03d}".format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000)
save_dir, index = main(args, nas_bench)
all_indexes.append(index)
torch.save(all_indexes, save_dir / "results.pth")
else:
main(args, nas_bench)

View File

@@ -0,0 +1,7 @@
# NAS Algorithms evaluated in NAS-Bench-201
The Python files in this folder are used to re-produce the results in our NAS-Bench-201 paper.
We have upgraded the codes to be more general and extendable at [NATS-algos](https://github.com/D-X-Y/AutoDL-Projects/tree/main/exps/NATS-algos).
**Notice** On 24 May 2021, the codes in `AutoDL` repo have been re-organized. If you find `module not found` error, please let me know. I will fix them ASAP.

View File

@@ -0,0 +1,399 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##################################################################
# Regularized Evolution for Image Classifier Architecture Search #
##################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, SearchDataset
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import CellStructure, get_search_spaces
from nas_201_api import NASBench201API as API
class Model(object):
def __init__(self):
self.arch = None
self.accuracy = None
def __str__(self):
"""Prints a readable version of this bitstring."""
return "{:}".format(self.arch)
# This function is to mimic the training and evaluatinig procedure for a single architecture `arch`.
# The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch.
# For use_012_epoch_training = True, the architecture is trained for 12 epochs, with LR being decaded from 0.1 to 0.
# In this case, the LR schedular is converged.
# For use_012_epoch_training = False, the architecture is planed to be trained for 200 epochs, but we early stop its procedure.
#
def train_and_eval(
arch, nas_bench, extra_info, dataname="cifar10-valid", use_012_epoch_training=True
):
if use_012_epoch_training and nas_bench is not None:
arch_index = nas_bench.query_index_by_arch(arch)
assert arch_index >= 0, "can not find this arch : {:}".format(arch)
info = nas_bench.get_more_info(
arch_index, dataname, iepoch=None, hp="12", is_random=True
)
valid_acc, time_cost = (
info["valid-accuracy"],
info["train-all-time"] + info["valid-per-time"],
)
# _, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
elif not use_012_epoch_training and nas_bench is not None:
# Please contact me if you want to use the following logic, because it has some potential issues.
# Please use `use_012_epoch_training=False` for cifar10 only.
# It did return values for cifar100 and ImageNet16-120, but it has some potential issues. (Please email me for more details)
arch_index, nepoch = nas_bench.query_index_by_arch(arch), 25
assert arch_index >= 0, "can not find this arch : {:}".format(arch)
xoinfo = nas_bench.get_more_info(
arch_index, "cifar10-valid", iepoch=None, hp="12"
)
xocost = nas_bench.get_cost_info(arch_index, "cifar10-valid", hp="200")
info = nas_bench.get_more_info(
arch_index, dataname, nepoch, hp="200", is_random=True
) # use the validation accuracy after 25 training epochs, which is used in our ICLR submission (not the camera ready).
cost = nas_bench.get_cost_info(arch_index, dataname, hp="200")
# The following codes are used to estimate the time cost.
# When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
# When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared.
nums = {
"ImageNet16-120-train": 151700,
"ImageNet16-120-valid": 3000,
"cifar10-valid-train": 25000,
"cifar10-valid-valid": 25000,
"cifar100-train": 50000,
"cifar100-valid": 5000,
}
estimated_train_cost = (
xoinfo["train-per-time"]
/ nums["cifar10-valid-train"]
* nums["{:}-train".format(dataname)]
/ xocost["latency"]
* cost["latency"]
* nepoch
)
estimated_valid_cost = (
xoinfo["valid-per-time"]
/ nums["cifar10-valid-valid"]
* nums["{:}-valid".format(dataname)]
/ xocost["latency"]
* cost["latency"]
)
try:
valid_acc, time_cost = (
info["valid-accuracy"],
estimated_train_cost + estimated_valid_cost,
)
except:
valid_acc, time_cost = (
info["valtest-accuracy"],
estimated_train_cost + estimated_valid_cost,
)
else:
# train a model from scratch.
raise ValueError("NOT IMPLEMENT YET")
return valid_acc, time_cost
def random_architecture_func(max_nodes, op_names):
# return a random architecture
def random_architecture():
genotypes = []
for i in range(1, max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_name = random.choice(op_names)
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return CellStructure(genotypes)
return random_architecture
def mutate_arch_func(op_names):
"""Computes the architecture for a child of the given parent architecture.
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
"""
def mutate_arch_func(parent_arch):
child_arch = deepcopy(parent_arch)
node_id = random.randint(0, len(child_arch.nodes) - 1)
node_info = list(child_arch.nodes[node_id])
snode_id = random.randint(0, len(node_info) - 1)
xop = random.choice(op_names)
while xop == node_info[snode_id][0]:
xop = random.choice(op_names)
node_info[snode_id] = (xop, node_info[snode_id][1])
child_arch.nodes[node_id] = tuple(node_info)
return child_arch
return mutate_arch_func
def regularized_evolution(
cycles,
population_size,
sample_size,
time_budget,
random_arch,
mutate_arch,
nas_bench,
extra_info,
dataname,
):
"""Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
Classifier Architecture Search".
Args:
cycles: the number of cycles the algorithm should run for.
population_size: the number of individuals to keep in the population.
sample_size: the number of individuals that should participate in each tournament.
time_budget: the upper bound of searching cost
Returns:
history: a list of `Model` instances, representing all the models computed
during the evolution experiment.
"""
population = collections.deque()
history, total_time_cost = (
[],
0,
) # Not used by the algorithm, only used to report results.
# Initialize the population with random models.
while len(population) < population_size:
model = Model()
model.arch = random_arch()
model.accuracy, time_cost = train_and_eval(
model.arch, nas_bench, extra_info, dataname
)
population.append(model)
history.append(model)
total_time_cost += time_cost
# Carry out evolution in cycles. Each cycle produces a model and removes
# another.
# while len(history) < cycles:
while total_time_cost < time_budget:
# Sample randomly chosen models from the current population.
start_time, sample = time.time(), []
while len(sample) < sample_size:
# Inefficient, but written this way for clarity. In the case of neural
# nets, the efficiency of this line is irrelevant because training neural
# nets is the rate-determining step.
candidate = random.choice(list(population))
sample.append(candidate)
# The parent is the best model in the sample.
parent = max(sample, key=lambda i: i.accuracy)
# Create the child model and store it.
child = Model()
child.arch = mutate_arch(parent.arch)
total_time_cost += time.time() - start_time
child.accuracy, time_cost = train_and_eval(
child.arch, nas_bench, extra_info, dataname
)
if total_time_cost + time_cost > time_budget: # return
return history, total_time_cost
else:
total_time_cost += time_cost
population.append(child)
history.append(child)
# Remove the oldest model.
population.popleft()
return history, total_time_cost
def main(xargs, nas_bench):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
if xargs.dataset == "cifar10":
dataname = "cifar10-valid"
else:
dataname = xargs.dataset
if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
split_Fpath = "configs/nas-benchmark/cifar-split.txt"
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log("Load split file from {:}".format(split_Fpath))
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(
config_path, {"class_num": class_num, "xshape": xshape}, logger
)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
num_workers=xargs.workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
num_workers=xargs.workers,
pin_memory=True,
)
logger.log(
"||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(train_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {
"config": config,
"train_loader": train_loader,
"valid_loader": valid_loader,
}
else:
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, None, logger)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {"config": config, "train_loader": None, "valid_loader": None}
search_space = get_search_spaces("cell", xargs.search_space_name)
random_arch = random_architecture_func(xargs.max_nodes, search_space)
mutate_arch = mutate_arch_func(search_space)
# x =random_arch() ; y = mutate_arch(x)
x_start_time = time.time()
logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench))
logger.log(
"-" * 30
+ " start searching with the time budget of {:} s".format(xargs.time_budget)
)
history, total_cost = regularized_evolution(
xargs.ea_cycles,
xargs.ea_population,
xargs.ea_sample_size,
xargs.time_budget,
random_arch,
mutate_arch,
nas_bench if args.ea_fast_by_api else None,
extra_info,
dataname,
)
logger.log(
"{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).".format(
time_string(), len(history), total_cost, time.time() - x_start_time
)
)
best_arch = max(history, key=lambda i: i.accuracy)
best_arch = best_arch.arch
logger.log("{:} best arch is {:}".format(time_string(), best_arch))
info = nas_bench.query_by_arch(best_arch, "200")
if info is None:
logger.log("Did not find this architecture : {:}.".format(best_arch))
else:
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch(best_arch)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument(
"--num_cells", type=int, help="The number of cells in one stage."
)
parser.add_argument("--ea_cycles", type=int, help="The number of cycles in EA.")
parser.add_argument("--ea_population", type=int, help="The population size in EA.")
parser.add_argument("--ea_sample_size", type=int, help="The sample size in EA.")
parser.add_argument(
"--ea_fast_by_api",
type=int,
help="Use our API to speed up the experiments or not.",
)
parser.add_argument(
"--time_budget",
type=int,
help="The total time cost budge for searching (in seconds).",
)
# log
parser.add_argument(
"--workers",
type=int,
default=2,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
parser.add_argument(
"--arch_nas_dataset",
type=str,
help="The path to load the architecture dataset (tiny-nas-benchmark).",
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
# if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
args.ea_fast_by_api = args.ea_fast_by_api > 0
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print(
"{:} build NAS-Benchmark-API from {:}".format(
time_string(), args.arch_nas_dataset
)
)
nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num = None, [], 500
for i in range(num):
print("{:} : {:03d}/{:03d}".format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000)
save_dir, index = main(args, nas_bench)
all_indexes.append(index)
torch.save(all_indexes, save_dir / "results.pth")
else:
main(args, nas_bench)

View File

@@ -0,0 +1,476 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
######################################################################################
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
######################################################################################
import sys, time, random, argparse
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, get_nas_search_loaders
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import get_cell_based_tiny_net, get_search_spaces
from nas_201_api import NASBench201API as API
def search_func(
xloader,
network,
criterion,
scheduler,
w_optimizer,
a_optimizer,
epoch_str,
print_freq,
logger,
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
network.train()
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
xloader
):
scheduler.update(None, 1.0 * step / len(xloader))
base_targets = base_targets.cuda(non_blocking=True)
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# update the weights
sampled_arch = network.module.dync_genotype(True)
network.module.set_cal_mode("dynamic", sampled_arch)
# network.module.set_cal_mode( 'urs' )
network.zero_grad()
_, logits = network(base_inputs)
base_loss = criterion(logits, base_targets)
base_loss.backward()
w_optimizer.step()
# record
base_prec1, base_prec5 = obtain_accuracy(
logits.data, base_targets.data, topk=(1, 5)
)
base_losses.update(base_loss.item(), base_inputs.size(0))
base_top1.update(base_prec1.item(), base_inputs.size(0))
base_top5.update(base_prec5.item(), base_inputs.size(0))
# update the architecture-weight
network.module.set_cal_mode("joint")
network.zero_grad()
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
arch_loss.backward()
a_optimizer.step()
# record
arch_prec1, arch_prec5 = obtain_accuracy(
logits.data, arch_targets.data, topk=(1, 5)
)
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or step + 1 == len(xloader):
Sstr = (
"*SEARCH* "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=base_losses, top1=base_top1, top5=base_top5
)
Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
loss=arch_losses, top1=arch_top1, top5=arch_top5
)
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr)
# print (nn.functional.softmax(network.module.arch_parameters, dim=-1))
# print (network.module.arch_parameters)
return (
base_losses.avg,
base_top1.avg,
base_top5.avg,
arch_losses.avg,
arch_top1.avg,
arch_top5.avg,
)
def get_best_arch(xloader, network, n_samples):
with torch.no_grad():
network.eval()
archs, valid_accs = network.module.return_topK(n_samples), []
# print ('obtain the top-{:} architectures'.format(n_samples))
loader_iter = iter(xloader)
for i, sampled_arch in enumerate(archs):
network.module.set_cal_mode("dynamic", sampled_arch)
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
_, logits = network(inputs)
val_top1, val_top5 = obtain_accuracy(
logits.cpu().data, targets.data, topk=(1, 5)
)
valid_accs.append(val_top1.item())
best_idx = np.argmax(valid_accs)
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
return best_arch, best_valid_acc
def valid_func(xloader, network, criterion):
data_time, batch_time = AverageMeter(), AverageMeter()
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
end = time.time()
with torch.no_grad():
network.eval()
for step, (arch_inputs, arch_targets) in enumerate(xloader):
arch_targets = arch_targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - end)
# prediction
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
# record
arch_prec1, arch_prec5 = obtain_accuracy(
logits.data, arch_targets.data, topk=(1, 5)
)
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
return arch_losses.avg, arch_top1.avg, arch_top5.avg
def main(xargs):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
train_data, valid_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
config = load_config(
xargs.config_path, {"class_num": class_num, "xshape": xshape}, logger
)
search_loader, _, valid_loader = get_nas_search_loaders(
train_data,
valid_data,
xargs.dataset,
"configs/nas-benchmark/",
(config.batch_size, config.test_batch_size),
xargs.workers,
)
logger.log(
"||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(search_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
search_space = get_search_spaces("cell", xargs.search_space_name)
if xargs.model_config is None:
model_config = dict2config(
dict(
name="SETN",
C=xargs.channel,
N=xargs.num_cells,
max_nodes=xargs.max_nodes,
num_classes=class_num,
space=search_space,
affine=False,
track_running_stats=bool(xargs.track_running_stats),
),
None,
)
else:
model_config = load_config(
xargs.model_config,
dict(
num_classes=class_num,
space=search_space,
affine=False,
track_running_stats=bool(xargs.track_running_stats),
),
None,
)
logger.log("search space : {:}".format(search_space))
search_model = get_cell_based_tiny_net(model_config)
w_optimizer, w_scheduler, criterion = get_optim_scheduler(
search_model.get_weights(), config
)
a_optimizer = torch.optim.Adam(
search_model.get_alphas(),
lr=xargs.arch_learning_rate,
betas=(0.5, 0.999),
weight_decay=xargs.arch_weight_decay,
)
logger.log("w-optimizer : {:}".format(w_optimizer))
logger.log("a-optimizer : {:}".format(a_optimizer))
logger.log("w-scheduler : {:}".format(w_scheduler))
logger.log("criterion : {:}".format(criterion))
flop, param = get_model_infos(search_model, xshape)
logger.log("FLOP = {:.2f} M, Params = {:.2f} MB".format(flop, param))
logger.log("search-space : {:}".format(search_space))
if xargs.arch_nas_dataset is None:
api = None
else:
api = API(xargs.arch_nas_dataset)
logger.log("{:} create API = {:} done".format(time_string(), api))
last_info, model_base_path, model_best_path = (
logger.path("info"),
logger.path("model"),
logger.path("best"),
)
network, criterion = torch.nn.DataParallel(search_model).cuda(), criterion.cuda()
if last_info.exists(): # automatically resume from previous checkpoint
logger.log(
"=> loading checkpoint of the last-info '{:}' start".format(last_info)
)
last_info = torch.load(last_info)
start_epoch = last_info["epoch"]
checkpoint = torch.load(last_info["last_checkpoint"])
genotypes = checkpoint["genotypes"]
valid_accuracies = checkpoint["valid_accuracies"]
search_model.load_state_dict(checkpoint["search_model"])
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
a_optimizer.load_state_dict(checkpoint["a_optimizer"])
logger.log(
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(
last_info, start_epoch
)
)
else:
logger.log("=> do not find the last-info file : {:}".format(last_info))
init_genotype, _ = get_best_arch(valid_loader, network, xargs.select_num)
start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: init_genotype}
# start training
start_time, search_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
AverageMeter(),
config.epochs + config.warmup,
)
for epoch in range(start_epoch, total_epoch):
w_scheduler.update(epoch, 0.0)
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.val * (total_epoch - epoch), True)
)
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
logger.log(
"\n[Search the {:}-th epoch] {:}, LR={:}".format(
epoch_str, need_time, min(w_scheduler.get_lr())
)
)
(
search_w_loss,
search_w_top1,
search_w_top5,
search_a_loss,
search_a_top1,
search_a_top5,
) = search_func(
search_loader,
network,
criterion,
w_scheduler,
w_optimizer,
a_optimizer,
epoch_str,
xargs.print_freq,
logger,
)
search_time.update(time.time() - start_time)
logger.log(
"[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
)
)
logger.log(
"[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
epoch_str, search_a_loss, search_a_top1, search_a_top5
)
)
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
network.module.set_cal_mode("dynamic", genotype)
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
valid_loader, network, criterion
)
logger.log(
"[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format(
epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype
)
)
# search_model.set_cal_mode('urs')
# valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
# logger.log('[{:}] URS---evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# search_model.set_cal_mode('joint')
# valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
# logger.log('[{:}] JOINT-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# search_model.set_cal_mode('select')
# valid_a_loss , valid_a_top1 , valid_a_top5 = valid_func(valid_loader, network, criterion)
# logger.log('[{:}] Selec-evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
# check the best accuracy
valid_accuracies[epoch] = valid_a_top1
genotypes[epoch] = genotype
logger.log(
"<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])
)
# save checkpoint
save_path = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(xargs),
"search_model": search_model.state_dict(),
"w_optimizer": w_optimizer.state_dict(),
"a_optimizer": a_optimizer.state_dict(),
"w_scheduler": w_scheduler.state_dict(),
"genotypes": genotypes,
"valid_accuracies": valid_accuracies,
},
model_base_path,
logger,
)
last_info = save_checkpoint(
{
"epoch": epoch + 1,
"args": deepcopy(args),
"last_checkpoint": save_path,
},
logger.path("info"),
logger,
)
with torch.no_grad():
logger.log("{:}".format(search_model.show_alphas()))
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
# the final post procedure : count the time
start_time = time.time()
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.select_num)
search_time.update(time.time() - start_time)
network.module.set_cal_mode("dynamic", genotype)
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
valid_loader, network, criterion
)
logger.log(
"Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format(
genotype, valid_a_top1
)
)
logger.log("\n" + "-" * 100)
# check the performance from the architecture dataset
logger.log(
"SETN : run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
total_epoch, search_time.sum, genotype
)
)
if api is not None:
logger.log("{:}".format(api.query_by_arch(genotype, "200")))
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("SETN")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument(
"--num_cells", type=int, help="The number of cells in one stage."
)
parser.add_argument(
"--select_num",
type=int,
help="The number of selected architectures to evaluate.",
)
parser.add_argument(
"--track_running_stats",
type=int,
choices=[0, 1],
help="Whether use track_running_stats or not in the BN layer.",
)
parser.add_argument(
"--config_path", type=str, help="The path of the configuration."
)
# architecture leraning rate
parser.add_argument(
"--arch_learning_rate",
type=float,
default=3e-4,
help="learning rate for arch encoding",
)
parser.add_argument(
"--arch_weight_decay",
type=float,
default=1e-3,
help="weight decay for arch encoding",
)
# log
parser.add_argument(
"--workers",
type=int,
default=2,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
parser.add_argument(
"--arch_nas_dataset",
type=str,
help="The path to load the architecture dataset (tiny-nas-benchmark).",
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
main(args)

View File

@@ -0,0 +1,294 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
#####################################################################################################
# modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py #
#####################################################################################################
import os, sys, time, glob, random, argparse
import numpy as np, collections
from copy import deepcopy
from pathlib import Path
import torch
import torch.nn as nn
from torch.distributions import Categorical
from xautodl.config_utils import load_config, dict2config, configure2str
from xautodl.datasets import get_datasets, SearchDataset
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
get_optim_scheduler,
)
from xautodl.utils import get_model_infos, obtain_accuracy
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.models import CellStructure, get_search_spaces
from nas_201_api import NASBench201API as API
from R_EA import train_and_eval
class Policy(nn.Module):
def __init__(self, max_nodes, search_space):
super(Policy, self).__init__()
self.max_nodes = max_nodes
self.search_space = deepcopy(search_space)
self.edge2index = {}
for i in range(1, max_nodes):
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
self.edge2index[node_str] = len(self.edge2index)
self.arch_parameters = nn.Parameter(
1e-3 * torch.randn(len(self.edge2index), len(search_space))
)
def generate_arch(self, actions):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
op_name = self.search_space[actions[self.edge2index[node_str]]]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return CellStructure(genotypes)
def genotype(self):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
with torch.no_grad():
weights = self.arch_parameters[self.edge2index[node_str]]
op_name = self.search_space[weights.argmax().item()]
xlist.append((op_name, j))
genotypes.append(tuple(xlist))
return CellStructure(genotypes)
def forward(self):
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
return alphas
class ExponentialMovingAverage(object):
"""Class that maintains an exponential moving average."""
def __init__(self, momentum):
self._numerator = 0
self._denominator = 0
self._momentum = momentum
def update(self, value):
self._numerator = (
self._momentum * self._numerator + (1 - self._momentum) * value
)
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
def value(self):
"""Return the current value of the moving average"""
return self._numerator / self._denominator
def select_action(policy):
probs = policy()
m = Categorical(probs)
action = m.sample()
# policy.saved_log_probs.append(m.log_prob(action))
return m.log_prob(action), action.cpu().tolist()
def main(xargs, nas_bench):
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)
if xargs.dataset == "cifar10":
dataname = "cifar10-valid"
else:
dataname = xargs.dataset
if xargs.data_path is not None:
train_data, valid_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
split_Fpath = "configs/nas-benchmark/cifar-split.txt"
cifar_split = load_config(split_Fpath, None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
logger.log("Load split file from {:}".format(split_Fpath))
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(
config_path, {"class_num": class_num, "xshape": xshape}, logger
)
# To split data
train_data_v2 = deepcopy(train_data)
train_data_v2.transform = valid_data.transform
valid_data = train_data_v2
search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split)
# data loader
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
num_workers=xargs.workers,
pin_memory=True,
)
valid_loader = torch.utils.data.DataLoader(
valid_data,
batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
num_workers=xargs.workers,
pin_memory=True,
)
logger.log(
"||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
xargs.dataset, len(train_loader), len(valid_loader), config.batch_size
)
)
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
extra_info = {
"config": config,
"train_loader": train_loader,
"valid_loader": valid_loader,
}
else:
config_path = "configs/nas-benchmark/algos/R-EA.config"
config = load_config(config_path, None, logger)
extra_info = {"config": config, "train_loader": None, "valid_loader": None}
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
search_space = get_search_spaces("cell", xargs.search_space_name)
policy = Policy(xargs.max_nodes, search_space)
optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
# optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate)
eps = np.finfo(np.float32).eps.item()
baseline = ExponentialMovingAverage(xargs.EMA_momentum)
logger.log("policy : {:}".format(policy))
logger.log("optimizer : {:}".format(optimizer))
logger.log("eps : {:}".format(eps))
# nas dataset load
logger.log("{:} use nas_bench : {:}".format(time_string(), nas_bench))
# REINFORCE
# attempts = 0
x_start_time = time.time()
logger.log(
"Will start searching with time budget of {:} s.".format(xargs.time_budget)
)
total_steps, total_costs, trace = 0, 0, []
# for istep in range(xargs.RL_steps):
while total_costs < xargs.time_budget:
start_time = time.time()
log_prob, action = select_action(policy)
arch = policy.generate_arch(action)
reward, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
trace.append((reward, arch))
# accumulate time
if total_costs + cost_time < xargs.time_budget:
total_costs += cost_time
else:
break
baseline.update(reward)
# calculate loss
policy_loss = (-log_prob * (reward - baseline.value())).sum()
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
# accumulate time
total_costs += time.time() - start_time
total_steps += 1
logger.log(
"step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}".format(
total_steps, baseline.value(), policy_loss.item(), policy.genotype()
)
)
# logger.log('----> {:}'.format(policy.arch_parameters))
# logger.log('')
# best_arch = policy.genotype() # first version
best_arch = max(trace, key=lambda x: x[0])[1]
logger.log(
"REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).".format(
total_steps, total_costs, time.time() - x_start_time
)
)
info = nas_bench.query_by_arch(best_arch, "200")
if info is None:
logger.log("Did not find this architecture : {:}.".format(best_arch))
else:
logger.log("{:}".format(info))
logger.log("-" * 100)
logger.close()
return logger.log_dir, nas_bench.query_index_by_arch(best_arch)
if __name__ == "__main__":
parser = argparse.ArgumentParser("The REINFORCE Algorithm")
parser.add_argument("--data_path", type=str, help="Path to dataset")
parser.add_argument(
"--dataset",
type=str,
choices=["cifar10", "cifar100", "ImageNet16-120"],
help="Choose between Cifar10/100 and ImageNet-16.",
)
# channels and number-of-cells
parser.add_argument("--search_space_name", type=str, help="The search space name.")
parser.add_argument("--max_nodes", type=int, help="The maximum number of nodes.")
parser.add_argument("--channel", type=int, help="The number of channels.")
parser.add_argument(
"--num_cells", type=int, help="The number of cells in one stage."
)
parser.add_argument(
"--learning_rate", type=float, help="The learning rate for REINFORCE."
)
# parser.add_argument('--RL_steps', type=int, help='The steps for REINFORCE.')
parser.add_argument(
"--EMA_momentum", type=float, help="The momentum value for EMA."
)
parser.add_argument(
"--time_budget",
type=int,
help="The total time cost budge for searching (in seconds).",
)
# log
parser.add_argument(
"--workers",
type=int,
default=2,
help="number of data loading workers (default: 2)",
)
parser.add_argument(
"--save_dir", type=str, help="Folder to save checkpoints and log."
)
parser.add_argument(
"--arch_nas_dataset",
type=str,
help="The path to load the architecture dataset (tiny-nas-benchmark).",
)
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
# if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset):
nas_bench = None
else:
print(
"{:} build NAS-Benchmark-API from {:}".format(
time_string(), args.arch_nas_dataset
)
)
nas_bench = API(args.arch_nas_dataset)
if args.rand_seed < 0:
save_dir, all_indexes, num = None, [], 500
for i in range(num):
print("{:} : {:03d}/{:03d}".format(time_string(), i, num))
args.rand_seed = random.randint(1, 100000)
save_dir, index = main(args, nas_bench)
all_indexes.append(index)
torch.save(all_indexes, save_dir / "results.pth")
else:
main(args, nas_bench)