add autodl
This commit is contained in:
29
AutoDL-Projects/exps/NATS-algos/README.md
Normal file
29
AutoDL-Projects/exps/NATS-algos/README.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# NAS Algorithms evaluated in [NATS-Bench](https://arxiv.org/abs/2009.00437)
|
||||
|
||||
The Python files in this folder are used to re-produce the results in ``NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size''.
|
||||
|
||||
- [`search-size.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/search-size.py) contains codes for weight-sharing-based search on the size search space.
|
||||
- [`search-cell.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/search-cell.py) contains codes for weight-sharing-based search on the topology search space.
|
||||
- [`bohb.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/bohb.py) contains the BOHB algorithm for both size and topology search spaces.
|
||||
- [`random_wo_share.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/random_wo_share.py) contains the random search algorithm for both search spaces.
|
||||
- [`regularized_ea.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/regularized_ea.py) contains the REA algorithm for both search spaces.
|
||||
- [`reinforce.py`](https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NATS-algos/reinforce.py) contains the REINFORCE algorithm for both search spaces.
|
||||
|
||||
## Requirements
|
||||
|
||||
- `nats_bench`>=v1.2 : you can use `pip install nats_bench` to install or from [sources](https://github.com/D-X-Y/NATS-Bench)
|
||||
- `hpbandster` : if you want to run BOHB
|
||||
|
||||
## Citation
|
||||
|
||||
If you find that this project helps your research, please consider citing the related paper:
|
||||
```
|
||||
@article{dong2021nats,
|
||||
title = {{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
|
||||
author = {Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
|
||||
doi = {10.1109/TPAMI.2021.3054824},
|
||||
journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
|
||||
year = {2021},
|
||||
note = {\mbox{doi}:\url{10.1109/TPAMI.2021.3054824}}
|
||||
}
|
||||
```
|
276
AutoDL-Projects/exps/NATS-algos/bohb.py
Normal file
276
AutoDL-Projects/exps/NATS-algos/bohb.py
Normal file
@@ -0,0 +1,276 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
###################################################################
|
||||
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale #
|
||||
# required to install hpbandster ##################################
|
||||
# pip install hpbandster ##################################
|
||||
###################################################################
|
||||
# OMP_NUM_THREADS=4 python exps/NATS-algos/bohb.py --search_space tss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
|
||||
# OMP_NUM_THREADS=4 python exps/NATS-algos/bohb.py --search_space sss --dataset cifar10 --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3 --rand_seed 1
|
||||
###################################################################
|
||||
import os, sys, time, random, argparse, collections
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
from xautodl.config_utils import load_config
|
||||
from xautodl.datasets import get_datasets, SearchDataset
|
||||
from xautodl.procedures import prepare_seed, prepare_logger
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from xautodl.models import CellStructure, get_search_spaces
|
||||
from nats_bench import create
|
||||
|
||||
# BOHB: Robust and Efficient Hyperparameter Optimization at Scale, ICML 2018
|
||||
import ConfigSpace
|
||||
from hpbandster.optimizers.bohb import BOHB
|
||||
import hpbandster.core.nameserver as hpns
|
||||
from hpbandster.core.worker import Worker
|
||||
|
||||
|
||||
def get_topology_config_space(search_space, max_nodes=4):
|
||||
cs = ConfigSpace.ConfigurationSpace()
|
||||
# edge2index = {}
|
||||
for i in range(1, max_nodes):
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
cs.add_hyperparameter(
|
||||
ConfigSpace.CategoricalHyperparameter(node_str, search_space)
|
||||
)
|
||||
return cs
|
||||
|
||||
|
||||
def get_size_config_space(search_space):
|
||||
cs = ConfigSpace.ConfigurationSpace()
|
||||
for ilayer in range(search_space["numbers"]):
|
||||
node_str = "layer-{:}".format(ilayer)
|
||||
cs.add_hyperparameter(
|
||||
ConfigSpace.CategoricalHyperparameter(node_str, search_space["candidates"])
|
||||
)
|
||||
return cs
|
||||
|
||||
|
||||
def config2topology_func(max_nodes=4):
|
||||
def config2structure(config):
|
||||
genotypes = []
|
||||
for i in range(1, max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_name = config[node_str]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return CellStructure(genotypes)
|
||||
|
||||
return config2structure
|
||||
|
||||
|
||||
def config2size_func(search_space):
|
||||
def config2structure(config):
|
||||
channels = []
|
||||
for ilayer in range(search_space["numbers"]):
|
||||
node_str = "layer-{:}".format(ilayer)
|
||||
channels.append(str(config[node_str]))
|
||||
return ":".join(channels)
|
||||
|
||||
return config2structure
|
||||
|
||||
|
||||
class MyWorker(Worker):
|
||||
def __init__(self, *args, convert_func=None, dataset=None, api=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.convert_func = convert_func
|
||||
self._dataset = dataset
|
||||
self._api = api
|
||||
self.total_times = []
|
||||
self.trajectory = []
|
||||
|
||||
def compute(self, config, budget, **kwargs):
|
||||
arch = self.convert_func(config)
|
||||
accuracy, latency, time_cost, total_time = self._api.simulate_train_eval(
|
||||
arch, self._dataset, iepoch=int(budget) - 1, hp="12"
|
||||
)
|
||||
self.trajectory.append((accuracy, arch))
|
||||
self.total_times.append(total_time)
|
||||
return {"loss": 100 - accuracy, "info": self._api.query_index_by_arch(arch)}
|
||||
|
||||
|
||||
def main(xargs, api):
|
||||
torch.set_num_threads(4)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
logger.log("{:} use api : {:}".format(time_string(), api))
|
||||
api.reset_time()
|
||||
search_space = get_search_spaces(xargs.search_space, "nats-bench")
|
||||
if xargs.search_space == "tss":
|
||||
cs = get_topology_config_space(search_space)
|
||||
config2structure = config2topology_func()
|
||||
else:
|
||||
cs = get_size_config_space(search_space)
|
||||
config2structure = config2size_func(search_space)
|
||||
|
||||
hb_run_id = "0"
|
||||
|
||||
NS = hpns.NameServer(run_id=hb_run_id, host="localhost", port=0)
|
||||
ns_host, ns_port = NS.start()
|
||||
num_workers = 1
|
||||
|
||||
workers = []
|
||||
for i in range(num_workers):
|
||||
w = MyWorker(
|
||||
nameserver=ns_host,
|
||||
nameserver_port=ns_port,
|
||||
convert_func=config2structure,
|
||||
dataset=xargs.dataset,
|
||||
api=api,
|
||||
run_id=hb_run_id,
|
||||
id=i,
|
||||
)
|
||||
w.run(background=True)
|
||||
workers.append(w)
|
||||
|
||||
start_time = time.time()
|
||||
bohb = BOHB(
|
||||
configspace=cs,
|
||||
run_id=hb_run_id,
|
||||
eta=3,
|
||||
min_budget=1,
|
||||
max_budget=12,
|
||||
nameserver=ns_host,
|
||||
nameserver_port=ns_port,
|
||||
num_samples=xargs.num_samples,
|
||||
random_fraction=xargs.random_fraction,
|
||||
bandwidth_factor=xargs.bandwidth_factor,
|
||||
ping_interval=10,
|
||||
min_bandwidth=xargs.min_bandwidth,
|
||||
)
|
||||
|
||||
results = bohb.run(xargs.n_iters, min_n_workers=num_workers)
|
||||
|
||||
bohb.shutdown(shutdown_workers=True)
|
||||
NS.shutdown()
|
||||
|
||||
# print('There are {:} runs.'.format(len(results.get_all_runs())))
|
||||
# workers[0].total_times
|
||||
# workers[0].trajectory
|
||||
current_best_index = []
|
||||
for idx in range(len(workers[0].trajectory)):
|
||||
trajectory = workers[0].trajectory[: idx + 1]
|
||||
arch = max(trajectory, key=lambda x: x[0])[1]
|
||||
current_best_index.append(api.query_index_by_arch(arch))
|
||||
|
||||
best_arch = max(workers[0].trajectory, key=lambda x: x[0])[1]
|
||||
logger.log(
|
||||
"Best found configuration: {:} within {:.3f} s".format(
|
||||
best_arch, workers[0].total_times[-1]
|
||||
)
|
||||
)
|
||||
info = api.query_info_str_by_arch(
|
||||
best_arch, "200" if xargs.search_space == "tss" else "90"
|
||||
)
|
||||
logger.log("{:}".format(info))
|
||||
logger.log("-" * 100)
|
||||
logger.close()
|
||||
|
||||
return logger.log_dir, current_best_index, workers[0].total_times
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
"BOHB: Robust and Efficient Hyperparameter Optimization at Scale"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
choices=["cifar10", "cifar100", "ImageNet16-120"],
|
||||
help="Choose between Cifar10/100 and ImageNet-16.",
|
||||
)
|
||||
# general arg
|
||||
parser.add_argument(
|
||||
"--search_space",
|
||||
type=str,
|
||||
choices=["tss", "sss"],
|
||||
help="Choose the search space.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time_budget",
|
||||
type=int,
|
||||
default=20000,
|
||||
help="The total time cost budge for searching (in seconds).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loops_if_rand", type=int, default=500, help="The total runs for evaluation."
|
||||
)
|
||||
# BOHB
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
default="sampling",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="optimization strategy for the acquisition function",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_bandwidth",
|
||||
default=0.3,
|
||||
type=float,
|
||||
nargs="?",
|
||||
help="minimum bandwidth for KDE",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_samples",
|
||||
default=64,
|
||||
type=int,
|
||||
nargs="?",
|
||||
help="number of samples for the acquisition function",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_fraction",
|
||||
default=0.33,
|
||||
type=float,
|
||||
nargs="?",
|
||||
help="fraction of random configurations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bandwidth_factor",
|
||||
default=3,
|
||||
type=int,
|
||||
nargs="?",
|
||||
help="factor multiplied to the bandwidth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_iters",
|
||||
default=300,
|
||||
type=int,
|
||||
nargs="?",
|
||||
help="number of iterations for optimization method",
|
||||
)
|
||||
# log
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./output/search",
|
||||
help="Folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
api = create(None, args.search_space, fast_mode=False, verbose=False)
|
||||
|
||||
args.save_dir = os.path.join(
|
||||
"{:}-{:}".format(args.save_dir, args.search_space),
|
||||
"{:}-T{:}".format(args.dataset, args.time_budget),
|
||||
"BOHB",
|
||||
)
|
||||
print("save-dir : {:}".format(args.save_dir))
|
||||
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_info = None, collections.OrderedDict()
|
||||
for i in range(args.loops_if_rand):
|
||||
print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
save_dir, all_archs, all_total_times = main(args, api)
|
||||
all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times}
|
||||
save_path = save_dir / "results.pth"
|
||||
print("save into {:}".format(save_path))
|
||||
torch.save(all_info, save_path)
|
||||
else:
|
||||
main(args, api)
|
156
AutoDL-Projects/exps/NATS-algos/random_wo_share.py
Normal file
156
AutoDL-Projects/exps/NATS-algos/random_wo_share.py
Normal file
@@ -0,0 +1,156 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
##############################################################################
|
||||
# Random Search for Hyper-Parameter Optimization, JMLR 2012 ##################
|
||||
##############################################################################
|
||||
# python ./exps/NATS-algos/random_wo_share.py --dataset cifar10 --search_space tss
|
||||
# python ./exps/NATS-algos/random_wo_share.py --dataset cifar100 --search_space tss
|
||||
# python ./exps/NATS-algos/random_wo_share.py --dataset ImageNet16-120 --search_space tss
|
||||
##############################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np, collections
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xautodl.config_utils import load_config, dict2config, configure2str
|
||||
from xautodl.datasets import get_datasets, SearchDataset
|
||||
from xautodl.procedures import (
|
||||
prepare_seed,
|
||||
prepare_logger,
|
||||
save_checkpoint,
|
||||
copy_checkpoint,
|
||||
get_optim_scheduler,
|
||||
)
|
||||
from xautodl.utils import get_model_infos, obtain_accuracy
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from xautodl.models import CellStructure, get_search_spaces
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
def random_topology_func(op_names, max_nodes=4):
|
||||
# Return a random architecture
|
||||
def random_architecture():
|
||||
genotypes = []
|
||||
for i in range(1, max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_name = random.choice(op_names)
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return CellStructure(genotypes)
|
||||
|
||||
return random_architecture
|
||||
|
||||
|
||||
def random_size_func(info):
|
||||
# Return a random architecture
|
||||
def random_architecture():
|
||||
channels = []
|
||||
for i in range(info["numbers"]):
|
||||
channels.append(str(random.choice(info["candidates"])))
|
||||
return ":".join(channels)
|
||||
|
||||
return random_architecture
|
||||
|
||||
|
||||
def main(xargs, api):
|
||||
torch.set_num_threads(4)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
logger.log("{:} use api : {:}".format(time_string(), api))
|
||||
api.reset_time()
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, "nats-bench")
|
||||
if xargs.search_space == "tss":
|
||||
random_arch = random_topology_func(search_space)
|
||||
else:
|
||||
random_arch = random_size_func(search_space)
|
||||
|
||||
best_arch, best_acc, total_time_cost, history = None, -1, [], []
|
||||
current_best_index = []
|
||||
while len(total_time_cost) == 0 or total_time_cost[-1] < xargs.time_budget:
|
||||
arch = random_arch()
|
||||
accuracy, _, _, total_cost = api.simulate_train_eval(
|
||||
arch, xargs.dataset, hp="12"
|
||||
)
|
||||
total_time_cost.append(total_cost)
|
||||
history.append(arch)
|
||||
if best_arch is None or best_acc < accuracy:
|
||||
best_acc, best_arch = accuracy, arch
|
||||
logger.log(
|
||||
"[{:03d}] : {:} : accuracy = {:.2f}%".format(len(history), arch, accuracy)
|
||||
)
|
||||
current_best_index.append(api.query_index_by_arch(best_arch))
|
||||
logger.log(
|
||||
"{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s.".format(
|
||||
time_string(), best_arch, best_acc, len(history), total_time_cost[-1]
|
||||
)
|
||||
)
|
||||
|
||||
info = api.query_info_str_by_arch(
|
||||
best_arch, "200" if xargs.search_space == "tss" else "90"
|
||||
)
|
||||
logger.log("{:}".format(info))
|
||||
logger.log("-" * 100)
|
||||
logger.close()
|
||||
return logger.log_dir, current_best_index, total_time_cost
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Random NAS")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
choices=["cifar10", "cifar100", "ImageNet16-120"],
|
||||
help="Choose between Cifar10/100 and ImageNet-16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_space",
|
||||
type=str,
|
||||
choices=["tss", "sss"],
|
||||
help="Choose the search space.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--time_budget",
|
||||
type=int,
|
||||
default=20000,
|
||||
help="The total time cost budge for searching (in seconds).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loops_if_rand", type=int, default=500, help="The total runs for evaluation."
|
||||
)
|
||||
# log
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./output/search",
|
||||
help="Folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
api = create(None, args.search_space, fast_mode=True, verbose=False)
|
||||
|
||||
args.save_dir = os.path.join(
|
||||
"{:}-{:}".format(args.save_dir, args.search_space),
|
||||
"{:}-T{:}".format(args.dataset, args.time_budget),
|
||||
"RANDOM",
|
||||
)
|
||||
print("save-dir : {:}".format(args.save_dir))
|
||||
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_info = None, collections.OrderedDict()
|
||||
for i in range(args.loops_if_rand):
|
||||
print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
save_dir, all_archs, all_total_times = main(args, api)
|
||||
all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times}
|
||||
save_path = save_dir / "results.pth"
|
||||
print("save into {:}".format(save_path))
|
||||
torch.save(all_info, save_path)
|
||||
else:
|
||||
main(args, api)
|
302
AutoDL-Projects/exps/NATS-algos/regularized_ea.py
Normal file
302
AutoDL-Projects/exps/NATS-algos/regularized_ea.py
Normal file
@@ -0,0 +1,302 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
##################################################################
|
||||
# Regularized Evolution for Image Classifier Architecture Search #
|
||||
##################################################################
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
|
||||
# python ./exps/NATS-algos/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --use_proxy 0
|
||||
##################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np, collections
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xautodl.config_utils import load_config, dict2config, configure2str
|
||||
from xautodl.datasets import get_datasets, SearchDataset
|
||||
from xautodl.procedures import (
|
||||
prepare_seed,
|
||||
prepare_logger,
|
||||
save_checkpoint,
|
||||
copy_checkpoint,
|
||||
get_optim_scheduler,
|
||||
)
|
||||
from xautodl.utils import get_model_infos, obtain_accuracy
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from xautodl.models import CellStructure, get_search_spaces
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
class Model(object):
|
||||
def __init__(self):
|
||||
self.arch = None
|
||||
self.accuracy = None
|
||||
|
||||
def __str__(self):
|
||||
"""Prints a readable version of this bitstring."""
|
||||
return "{:}".format(self.arch)
|
||||
|
||||
|
||||
def random_topology_func(op_names, max_nodes=4):
|
||||
# Return a random architecture
|
||||
def random_architecture():
|
||||
genotypes = []
|
||||
for i in range(1, max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_name = random.choice(op_names)
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return CellStructure(genotypes)
|
||||
|
||||
return random_architecture
|
||||
|
||||
|
||||
def random_size_func(info):
|
||||
# Return a random architecture
|
||||
def random_architecture():
|
||||
channels = []
|
||||
for i in range(info["numbers"]):
|
||||
channels.append(str(random.choice(info["candidates"])))
|
||||
return ":".join(channels)
|
||||
|
||||
return random_architecture
|
||||
|
||||
|
||||
def mutate_topology_func(op_names):
|
||||
"""Computes the architecture for a child of the given parent architecture.
|
||||
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
|
||||
"""
|
||||
|
||||
def mutate_topology_func(parent_arch):
|
||||
child_arch = deepcopy(parent_arch)
|
||||
node_id = random.randint(0, len(child_arch.nodes) - 1)
|
||||
node_info = list(child_arch.nodes[node_id])
|
||||
snode_id = random.randint(0, len(node_info) - 1)
|
||||
xop = random.choice(op_names)
|
||||
while xop == node_info[snode_id][0]:
|
||||
xop = random.choice(op_names)
|
||||
node_info[snode_id] = (xop, node_info[snode_id][1])
|
||||
child_arch.nodes[node_id] = tuple(node_info)
|
||||
return child_arch
|
||||
|
||||
return mutate_topology_func
|
||||
|
||||
|
||||
def mutate_size_func(info):
|
||||
"""Computes the architecture for a child of the given parent architecture.
|
||||
The parent architecture is cloned and mutated to produce the child architecture. The child architecture is mutated by randomly switch one operation to another.
|
||||
"""
|
||||
|
||||
def mutate_size_func(parent_arch):
|
||||
child_arch = deepcopy(parent_arch)
|
||||
child_arch = child_arch.split(":")
|
||||
index = random.randint(0, len(child_arch) - 1)
|
||||
child_arch[index] = str(random.choice(info["candidates"]))
|
||||
return ":".join(child_arch)
|
||||
|
||||
return mutate_size_func
|
||||
|
||||
|
||||
def regularized_evolution(
|
||||
cycles,
|
||||
population_size,
|
||||
sample_size,
|
||||
time_budget,
|
||||
random_arch,
|
||||
mutate_arch,
|
||||
api,
|
||||
use_proxy,
|
||||
dataset,
|
||||
):
|
||||
"""Algorithm for regularized evolution (i.e. aging evolution).
|
||||
|
||||
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image
|
||||
Classifier Architecture Search".
|
||||
|
||||
Args:
|
||||
cycles: the number of cycles the algorithm should run for.
|
||||
population_size: the number of individuals to keep in the population.
|
||||
sample_size: the number of individuals that should participate in each tournament.
|
||||
time_budget: the upper bound of searching cost
|
||||
|
||||
Returns:
|
||||
history: a list of `Model` instances, representing all the models computed
|
||||
during the evolution experiment.
|
||||
"""
|
||||
population = collections.deque()
|
||||
api.reset_time()
|
||||
history, total_time_cost = (
|
||||
[],
|
||||
[],
|
||||
) # Not used by the algorithm, only used to report results.
|
||||
current_best_index = []
|
||||
# Initialize the population with random models.
|
||||
while len(population) < population_size:
|
||||
model = Model()
|
||||
model.arch = random_arch()
|
||||
model.accuracy, _, _, total_cost = api.simulate_train_eval(
|
||||
model.arch, dataset, hp="12" if use_proxy else api.full_train_epochs
|
||||
)
|
||||
# Append the info
|
||||
population.append(model)
|
||||
history.append((model.accuracy, model.arch))
|
||||
total_time_cost.append(total_cost)
|
||||
current_best_index.append(
|
||||
api.query_index_by_arch(max(history, key=lambda x: x[0])[1])
|
||||
)
|
||||
|
||||
# Carry out evolution in cycles. Each cycle produces a model and removes another.
|
||||
while total_time_cost[-1] < time_budget:
|
||||
# Sample randomly chosen models from the current population.
|
||||
start_time, sample = time.time(), []
|
||||
while len(sample) < sample_size:
|
||||
# Inefficient, but written this way for clarity. In the case of neural
|
||||
# nets, the efficiency of this line is irrelevant because training neural
|
||||
# nets is the rate-determining step.
|
||||
candidate = random.choice(list(population))
|
||||
sample.append(candidate)
|
||||
|
||||
# The parent is the best model in the sample.
|
||||
parent = max(sample, key=lambda i: i.accuracy)
|
||||
|
||||
# Create the child model and store it.
|
||||
child = Model()
|
||||
child.arch = mutate_arch(parent.arch)
|
||||
child.accuracy, _, _, total_cost = api.simulate_train_eval(
|
||||
child.arch, dataset, hp="12" if use_proxy else api.full_train_epochs
|
||||
)
|
||||
# Append the info
|
||||
population.append(child)
|
||||
history.append((child.accuracy, child.arch))
|
||||
current_best_index.append(
|
||||
api.query_index_by_arch(max(history, key=lambda x: x[0])[1])
|
||||
)
|
||||
total_time_cost.append(total_cost)
|
||||
|
||||
# Remove the oldest model.
|
||||
population.popleft()
|
||||
return history, current_best_index, total_time_cost
|
||||
|
||||
|
||||
def main(xargs, api):
|
||||
torch.set_num_threads(4)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, "nats-bench")
|
||||
if xargs.search_space == "tss":
|
||||
random_arch = random_topology_func(search_space)
|
||||
mutate_arch = mutate_topology_func(search_space)
|
||||
else:
|
||||
random_arch = random_size_func(search_space)
|
||||
mutate_arch = mutate_size_func(search_space)
|
||||
|
||||
x_start_time = time.time()
|
||||
logger.log("{:} use api : {:}".format(time_string(), api))
|
||||
logger.log(
|
||||
"-" * 30
|
||||
+ " start searching with the time budget of {:} s".format(xargs.time_budget)
|
||||
)
|
||||
history, current_best_index, total_times = regularized_evolution(
|
||||
xargs.ea_cycles,
|
||||
xargs.ea_population,
|
||||
xargs.ea_sample_size,
|
||||
xargs.time_budget,
|
||||
random_arch,
|
||||
mutate_arch,
|
||||
api,
|
||||
xargs.use_proxy > 0,
|
||||
xargs.dataset,
|
||||
)
|
||||
logger.log(
|
||||
"{:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost={:.2f} s).".format(
|
||||
time_string(), len(history), total_times[-1], time.time() - x_start_time
|
||||
)
|
||||
)
|
||||
best_arch = max(history, key=lambda x: x[0])[1]
|
||||
logger.log("{:} best arch is {:}".format(time_string(), best_arch))
|
||||
|
||||
info = api.query_info_str_by_arch(
|
||||
best_arch, "200" if xargs.search_space == "tss" else "90"
|
||||
)
|
||||
logger.log("{:}".format(info))
|
||||
logger.log("-" * 100)
|
||||
logger.close()
|
||||
return logger.log_dir, current_best_index, total_times
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Regularized Evolution Algorithm")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
choices=["cifar10", "cifar100", "ImageNet16-120"],
|
||||
help="Choose between Cifar10/100 and ImageNet-16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_space",
|
||||
type=str,
|
||||
choices=["tss", "sss"],
|
||||
help="Choose the search space.",
|
||||
)
|
||||
# hyperparameters for REA
|
||||
parser.add_argument("--ea_cycles", type=int, help="The number of cycles in EA.")
|
||||
parser.add_argument("--ea_population", type=int, help="The population size in EA.")
|
||||
parser.add_argument("--ea_sample_size", type=int, help="The sample size in EA.")
|
||||
parser.add_argument(
|
||||
"--time_budget",
|
||||
type=int,
|
||||
default=20000,
|
||||
help="The total time cost budge for searching (in seconds).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_proxy",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Whether to use the proxy (H0) task or not.",
|
||||
)
|
||||
#
|
||||
parser.add_argument(
|
||||
"--loops_if_rand", type=int, default=500, help="The total runs for evaluation."
|
||||
)
|
||||
# log
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./output/search",
|
||||
help="Folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
api = create(None, args.search_space, fast_mode=True, verbose=False)
|
||||
|
||||
args.save_dir = os.path.join(
|
||||
"{:}-{:}".format(args.save_dir, args.search_space),
|
||||
"{:}-T{:}{:}".format(
|
||||
args.dataset, args.time_budget, "" if args.use_proxy > 0 else "-FULL"
|
||||
),
|
||||
"R-EA-SS{:}".format(args.ea_sample_size),
|
||||
)
|
||||
print("save-dir : {:}".format(args.save_dir))
|
||||
print("xargs : {:}".format(args))
|
||||
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_info = None, collections.OrderedDict()
|
||||
for i in range(args.loops_if_rand):
|
||||
print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
save_dir, all_archs, all_total_times = main(args, api)
|
||||
all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times}
|
||||
save_path = save_dir / "results.pth"
|
||||
print("save into {:}".format(save_path))
|
||||
torch.save(all_info, save_path)
|
||||
else:
|
||||
main(args, api)
|
268
AutoDL-Projects/exps/NATS-algos/reinforce.py
Normal file
268
AutoDL-Projects/exps/NATS-algos/reinforce.py
Normal file
@@ -0,0 +1,268 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
#####################################################################################################
|
||||
# modified from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py #
|
||||
#####################################################################################################
|
||||
# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space tss --learning_rate 0.01
|
||||
# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space tss --learning_rate 0.01
|
||||
# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space tss --learning_rate 0.01
|
||||
# python ./exps/NATS-algos/reinforce.py --dataset cifar10 --search_space sss --learning_rate 0.01
|
||||
# python ./exps/NATS-algos/reinforce.py --dataset cifar100 --search_space sss --learning_rate 0.01
|
||||
# python ./exps/NATS-algos/reinforce.py --dataset ImageNet16-120 --search_space sss --learning_rate 0.01
|
||||
#####################################################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np, collections
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from xautodl.config_utils import load_config, dict2config, configure2str
|
||||
from xautodl.datasets import get_datasets, SearchDataset
|
||||
from xautodl.procedures import (
|
||||
prepare_seed,
|
||||
prepare_logger,
|
||||
save_checkpoint,
|
||||
copy_checkpoint,
|
||||
get_optim_scheduler,
|
||||
)
|
||||
from xautodl.utils import get_model_infos, obtain_accuracy
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from xautodl.models import CellStructure, get_search_spaces
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
class PolicyTopology(nn.Module):
|
||||
def __init__(self, search_space, max_nodes=4):
|
||||
super(PolicyTopology, self).__init__()
|
||||
self.max_nodes = max_nodes
|
||||
self.search_space = deepcopy(search_space)
|
||||
self.edge2index = {}
|
||||
for i in range(1, max_nodes):
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
self.edge2index[node_str] = len(self.edge2index)
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(len(self.edge2index), len(search_space))
|
||||
)
|
||||
|
||||
def generate_arch(self, actions):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_name = self.search_space[actions[self.edge2index[node_str]]]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return CellStructure(genotypes)
|
||||
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[self.edge2index[node_str]]
|
||||
op_name = self.search_space[weights.argmax().item()]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return CellStructure(genotypes)
|
||||
|
||||
def forward(self):
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
return alphas
|
||||
|
||||
|
||||
class PolicySize(nn.Module):
|
||||
def __init__(self, search_space):
|
||||
super(PolicySize, self).__init__()
|
||||
self.candidates = search_space["candidates"]
|
||||
self.numbers = search_space["numbers"]
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(self.numbers, len(self.candidates))
|
||||
)
|
||||
|
||||
def generate_arch(self, actions):
|
||||
channels = [str(self.candidates[i]) for i in actions]
|
||||
return ":".join(channels)
|
||||
|
||||
def genotype(self):
|
||||
channels = []
|
||||
for i in range(self.numbers):
|
||||
index = self.arch_parameters[i].argmax().item()
|
||||
channels.append(str(self.candidates[index]))
|
||||
return ":".join(channels)
|
||||
|
||||
def forward(self):
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
return alphas
|
||||
|
||||
|
||||
class ExponentialMovingAverage(object):
|
||||
"""Class that maintains an exponential moving average."""
|
||||
|
||||
def __init__(self, momentum):
|
||||
self._numerator = 0
|
||||
self._denominator = 0
|
||||
self._momentum = momentum
|
||||
|
||||
def update(self, value):
|
||||
self._numerator = (
|
||||
self._momentum * self._numerator + (1 - self._momentum) * value
|
||||
)
|
||||
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
|
||||
|
||||
def value(self):
|
||||
"""Return the current value of the moving average"""
|
||||
return self._numerator / self._denominator
|
||||
|
||||
|
||||
def select_action(policy):
|
||||
probs = policy()
|
||||
m = Categorical(probs)
|
||||
action = m.sample()
|
||||
# policy.saved_log_probs.append(m.log_prob(action))
|
||||
return m.log_prob(action), action.cpu().tolist()
|
||||
|
||||
|
||||
def main(xargs, api):
|
||||
# torch.set_num_threads(4)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, "nats-bench")
|
||||
if xargs.search_space == "tss":
|
||||
policy = PolicyTopology(search_space)
|
||||
else:
|
||||
policy = PolicySize(search_space)
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=xargs.learning_rate)
|
||||
# optimizer = torch.optim.SGD(policy.parameters(), lr=xargs.learning_rate)
|
||||
eps = np.finfo(np.float32).eps.item()
|
||||
baseline = ExponentialMovingAverage(xargs.EMA_momentum)
|
||||
logger.log("policy : {:}".format(policy))
|
||||
logger.log("optimizer : {:}".format(optimizer))
|
||||
logger.log("eps : {:}".format(eps))
|
||||
|
||||
# nas dataset load
|
||||
logger.log("{:} use api : {:}".format(time_string(), api))
|
||||
api.reset_time()
|
||||
|
||||
# REINFORCE
|
||||
x_start_time = time.time()
|
||||
logger.log(
|
||||
"Will start searching with time budget of {:} s.".format(xargs.time_budget)
|
||||
)
|
||||
total_steps, total_costs, trace = 0, [], []
|
||||
current_best_index = []
|
||||
while len(total_costs) == 0 or total_costs[-1] < xargs.time_budget:
|
||||
start_time = time.time()
|
||||
log_prob, action = select_action(policy)
|
||||
arch = policy.generate_arch(action)
|
||||
reward, _, _, current_total_cost = api.simulate_train_eval(
|
||||
arch, xargs.dataset, hp="12"
|
||||
)
|
||||
trace.append((reward, arch))
|
||||
total_costs.append(current_total_cost)
|
||||
|
||||
baseline.update(reward)
|
||||
# calculate loss
|
||||
policy_loss = (-log_prob * (reward - baseline.value())).sum()
|
||||
optimizer.zero_grad()
|
||||
policy_loss.backward()
|
||||
optimizer.step()
|
||||
# accumulate time
|
||||
total_steps += 1
|
||||
logger.log(
|
||||
"step [{:3d}] : average-reward={:.3f} : policy_loss={:.4f} : {:}".format(
|
||||
total_steps, baseline.value(), policy_loss.item(), policy.genotype()
|
||||
)
|
||||
)
|
||||
# to analyze
|
||||
current_best_index.append(
|
||||
api.query_index_by_arch(max(trace, key=lambda x: x[0])[1])
|
||||
)
|
||||
# best_arch = policy.genotype() # first version
|
||||
best_arch = max(trace, key=lambda x: x[0])[1]
|
||||
logger.log(
|
||||
"REINFORCE finish with {:} steps and {:.1f} s (real cost={:.3f}).".format(
|
||||
total_steps, total_costs[-1], time.time() - x_start_time
|
||||
)
|
||||
)
|
||||
info = api.query_info_str_by_arch(
|
||||
best_arch, "200" if xargs.search_space == "tss" else "90"
|
||||
)
|
||||
logger.log("{:}".format(info))
|
||||
logger.log("-" * 100)
|
||||
logger.close()
|
||||
|
||||
return logger.log_dir, current_best_index, total_costs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("The REINFORCE Algorithm")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
choices=["cifar10", "cifar100", "ImageNet16-120"],
|
||||
help="Choose between Cifar10/100 and ImageNet-16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_space",
|
||||
type=str,
|
||||
choices=["tss", "sss"],
|
||||
help="Choose the search space.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate", type=float, help="The learning rate for REINFORCE."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--EMA_momentum", type=float, default=0.9, help="The momentum value for EMA."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time_budget",
|
||||
type=int,
|
||||
default=20000,
|
||||
help="The total time cost budge for searching (in seconds).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loops_if_rand", type=int, default=500, help="The total runs for evaluation."
|
||||
)
|
||||
# log
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./output/search",
|
||||
help="Folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch_nas_dataset",
|
||||
type=str,
|
||||
help="The path to load the architecture dataset (tiny-nas-benchmark).",
|
||||
)
|
||||
parser.add_argument("--print_freq", type=int, help="print frequency (default: 200)")
|
||||
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
||||
args = parser.parse_args()
|
||||
|
||||
api = create(None, args.search_space, fast_mode=True, verbose=False)
|
||||
|
||||
args.save_dir = os.path.join(
|
||||
"{:}-{:}".format(args.save_dir, args.search_space),
|
||||
"{:}-T{:}".format(args.dataset, args.time_budget),
|
||||
"REINFORCE-{:}".format(args.learning_rate),
|
||||
)
|
||||
print("save-dir : {:}".format(args.save_dir))
|
||||
|
||||
if args.rand_seed < 0:
|
||||
save_dir, all_info = None, collections.OrderedDict()
|
||||
for i in range(args.loops_if_rand):
|
||||
print("{:} : {:03d}/{:03d}".format(time_string(), i, args.loops_if_rand))
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
save_dir, all_archs, all_total_times = main(args, api)
|
||||
all_info[i] = {"all_archs": all_archs, "all_total_times": all_total_times}
|
||||
save_path = save_dir / "results.pth"
|
||||
print("save into {:}".format(save_path))
|
||||
torch.save(all_info, save_path)
|
||||
else:
|
||||
main(args, api)
|
51
AutoDL-Projects/exps/NATS-algos/run-all.sh
Normal file
51
AutoDL-Projects/exps/NATS-algos/run-all.sh
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
# bash ./exps/NATS-algos/run-all.sh mul
|
||||
# bash ./exps/NATS-algos/run-all.sh ws
|
||||
set -e
|
||||
echo script name: $0
|
||||
echo $# arguments
|
||||
if [ "$#" -ne 1 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
echo "Need 1 parameters for type of algorithms."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
alg_type=$1
|
||||
|
||||
if [ "$alg_type" == "mul" ]; then
|
||||
# datasets="cifar10 cifar100 ImageNet16-120"
|
||||
run_four_algorithms(){
|
||||
dataset=$1
|
||||
search_space=$2
|
||||
time_budget=$3
|
||||
python ./exps/NATS-algos/reinforce.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --learning_rate 0.01
|
||||
python ./exps/NATS-algos/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3
|
||||
python ./exps/NATS-algos/random_wo_share.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget}
|
||||
python ./exps/NATS-algos/bohb.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --num_samples 4 --random_fraction 0.0 --bandwidth_factor 3
|
||||
}
|
||||
# The topology search space
|
||||
run_four_algorithms "cifar10" "tss" "20000"
|
||||
run_four_algorithms "cifar100" "tss" "40000"
|
||||
run_four_algorithms "ImageNet16-120" "tss" "120000"
|
||||
|
||||
# The size search space
|
||||
run_four_algorithms "cifar10" "sss" "20000"
|
||||
run_four_algorithms "cifar100" "sss" "40000"
|
||||
run_four_algorithms "ImageNet16-120" "sss" "60000"
|
||||
# python exps/experimental/vis-bench-algos.py --search_space tss
|
||||
# python exps/experimental/vis-bench-algos.py --search_space sss
|
||||
else
|
||||
seeds="777 888 999"
|
||||
algos="darts-v1 darts-v2 gdas setn random enas"
|
||||
epoch=200
|
||||
for seed in ${seeds}
|
||||
do
|
||||
for alg in ${algos}
|
||||
do
|
||||
python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||
python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||
python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo ${alg} --rand_seed ${seed} --overwite_epochs ${epoch}
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
879
AutoDL-Projects/exps/NATS-algos/search-cell.py
Normal file
879
AutoDL-Projects/exps/NATS-algos/search-cell.py
Normal file
@@ -0,0 +1,879 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
######################################################################################
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --drop_path_rate 0.3
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1
|
||||
####
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v2 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v2
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v2
|
||||
####
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo gdas
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo gdas
|
||||
####
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo setn --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo setn
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo setn
|
||||
####
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo random --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random
|
||||
####
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
|
||||
####
|
||||
# The following scripts are added in 20 Mar 2022
|
||||
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas_v1 --rand_seed 777
|
||||
######################################################################################
|
||||
import os, sys, time, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xautodl.config_utils import load_config, dict2config, configure2str
|
||||
from xautodl.datasets import get_datasets, get_nas_search_loaders
|
||||
from xautodl.procedures import (
|
||||
prepare_seed,
|
||||
prepare_logger,
|
||||
save_checkpoint,
|
||||
copy_checkpoint,
|
||||
get_optim_scheduler,
|
||||
)
|
||||
from xautodl.utils import count_parameters_in_MB, obtain_accuracy
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from xautodl.models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
# The following three functions are used for DARTS-V2
|
||||
def _concat(xs):
|
||||
return torch.cat([x.view(-1) for x in xs])
|
||||
|
||||
|
||||
def _hessian_vector_product(
|
||||
vector, network, criterion, base_inputs, base_targets, r=1e-2
|
||||
):
|
||||
R = r / _concat(vector).norm()
|
||||
for p, v in zip(network.weights, vector):
|
||||
p.data.add_(R, v)
|
||||
_, logits = network(base_inputs)
|
||||
loss = criterion(logits, base_targets)
|
||||
grads_p = torch.autograd.grad(loss, network.alphas)
|
||||
|
||||
for p, v in zip(network.weights, vector):
|
||||
p.data.sub_(2 * R, v)
|
||||
_, logits = network(base_inputs)
|
||||
loss = criterion(logits, base_targets)
|
||||
grads_n = torch.autograd.grad(loss, network.alphas)
|
||||
|
||||
for p, v in zip(network.weights, vector):
|
||||
p.data.add_(R, v)
|
||||
return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)]
|
||||
|
||||
|
||||
def backward_step_unrolled(
|
||||
network,
|
||||
criterion,
|
||||
base_inputs,
|
||||
base_targets,
|
||||
w_optimizer,
|
||||
arch_inputs,
|
||||
arch_targets,
|
||||
):
|
||||
# _compute_unrolled_model
|
||||
_, logits = network(base_inputs)
|
||||
loss = criterion(logits, base_targets)
|
||||
LR, WD, momentum = (
|
||||
w_optimizer.param_groups[0]["lr"],
|
||||
w_optimizer.param_groups[0]["weight_decay"],
|
||||
w_optimizer.param_groups[0]["momentum"],
|
||||
)
|
||||
with torch.no_grad():
|
||||
theta = _concat(network.weights)
|
||||
try:
|
||||
moment = _concat(
|
||||
w_optimizer.state[v]["momentum_buffer"] for v in network.weights
|
||||
)
|
||||
moment = moment.mul_(momentum)
|
||||
except:
|
||||
moment = torch.zeros_like(theta)
|
||||
dtheta = _concat(torch.autograd.grad(loss, network.weights)) + WD * theta
|
||||
params = theta.sub(LR, moment + dtheta)
|
||||
unrolled_model = deepcopy(network)
|
||||
model_dict = unrolled_model.state_dict()
|
||||
new_params, offset = {}, 0
|
||||
for k, v in network.named_parameters():
|
||||
if "arch_parameters" in k:
|
||||
continue
|
||||
v_length = np.prod(v.size())
|
||||
new_params[k] = params[offset : offset + v_length].view(v.size())
|
||||
offset += v_length
|
||||
model_dict.update(new_params)
|
||||
unrolled_model.load_state_dict(model_dict)
|
||||
|
||||
unrolled_model.zero_grad()
|
||||
_, unrolled_logits = unrolled_model(arch_inputs)
|
||||
unrolled_loss = criterion(unrolled_logits, arch_targets)
|
||||
unrolled_loss.backward()
|
||||
|
||||
dalpha = unrolled_model.arch_parameters.grad
|
||||
vector = [v.grad.data for v in unrolled_model.weights]
|
||||
[implicit_grads] = _hessian_vector_product(
|
||||
vector, network, criterion, base_inputs, base_targets
|
||||
)
|
||||
|
||||
dalpha.data.sub_(LR, implicit_grads.data)
|
||||
|
||||
if network.arch_parameters.grad is None:
|
||||
network.arch_parameters.grad = deepcopy(dalpha)
|
||||
else:
|
||||
network.arch_parameters.grad.data.copy_(dalpha.data)
|
||||
return unrolled_loss.detach(), unrolled_logits.detach()
|
||||
|
||||
|
||||
def search_func(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
w_optimizer,
|
||||
a_optimizer,
|
||||
epoch_str,
|
||||
print_freq,
|
||||
algo,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
end = time.time()
|
||||
network.train()
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
|
||||
xloader
|
||||
):
|
||||
scheduler.update(None, 1.0 * step / len(xloader))
|
||||
base_inputs = base_inputs.cuda(non_blocking=True)
|
||||
arch_inputs = arch_inputs.cuda(non_blocking=True)
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# Update the weights
|
||||
if algo == "setn":
|
||||
sampled_arch = network.dync_genotype(True)
|
||||
network.set_cal_mode("dynamic", sampled_arch)
|
||||
elif algo == "gdas":
|
||||
network.set_cal_mode("gdas", None)
|
||||
elif algo == "gdas_v1":
|
||||
network.set_cal_mode("gdas_v1", None)
|
||||
elif algo.startswith("darts"):
|
||||
network.set_cal_mode("joint", None)
|
||||
elif algo == "random":
|
||||
network.set_cal_mode("urs", None)
|
||||
elif algo == "enas":
|
||||
with torch.no_grad():
|
||||
network.controller.eval()
|
||||
_, _, sampled_arch = network.controller()
|
||||
network.set_cal_mode("dynamic", sampled_arch)
|
||||
else:
|
||||
raise ValueError("Invalid algo name : {:}".format(algo))
|
||||
|
||||
network.zero_grad()
|
||||
_, logits = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
w_optimizer.step()
|
||||
# record
|
||||
base_prec1, base_prec5 = obtain_accuracy(
|
||||
logits.data, base_targets.data, topk=(1, 5)
|
||||
)
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
base_top1.update(base_prec1.item(), base_inputs.size(0))
|
||||
base_top5.update(base_prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture-weight
|
||||
if algo == "setn":
|
||||
network.set_cal_mode("joint")
|
||||
elif algo == "gdas":
|
||||
network.set_cal_mode("gdas", None)
|
||||
elif algo == "gdas_v1":
|
||||
network.set_cal_mode("gdas_v1", None)
|
||||
elif algo.startswith("darts"):
|
||||
network.set_cal_mode("joint", None)
|
||||
elif algo == "random":
|
||||
network.set_cal_mode("urs", None)
|
||||
elif algo != "enas":
|
||||
raise ValueError("Invalid algo name : {:}".format(algo))
|
||||
network.zero_grad()
|
||||
if algo == "darts-v2":
|
||||
arch_loss, logits = backward_step_unrolled(
|
||||
network,
|
||||
criterion,
|
||||
base_inputs,
|
||||
base_targets,
|
||||
w_optimizer,
|
||||
arch_inputs,
|
||||
arch_targets,
|
||||
)
|
||||
a_optimizer.step()
|
||||
elif algo == "random" or algo == "enas":
|
||||
with torch.no_grad():
|
||||
_, logits = network(arch_inputs)
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
else:
|
||||
_, logits = network(arch_inputs)
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
arch_loss.backward()
|
||||
a_optimizer.step()
|
||||
# record
|
||||
arch_prec1, arch_prec5 = obtain_accuracy(
|
||||
logits.data, arch_targets.data, topk=(1, 5)
|
||||
)
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
|
||||
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % print_freq == 0 or step + 1 == len(xloader):
|
||||
Sstr = (
|
||||
"*SEARCH* "
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
|
||||
)
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
|
||||
loss=base_losses, top1=base_top1, top5=base_top5
|
||||
)
|
||||
Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
|
||||
loss=arch_losses, top1=arch_top1, top5=arch_top5
|
||||
)
|
||||
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr)
|
||||
return (
|
||||
base_losses.avg,
|
||||
base_top1.avg,
|
||||
base_top5.avg,
|
||||
arch_losses.avg,
|
||||
arch_top1.avg,
|
||||
arch_top5.avg,
|
||||
)
|
||||
|
||||
|
||||
def train_controller(
|
||||
xloader, network, criterion, optimizer, prev_baseline, epoch_str, print_freq, logger
|
||||
):
|
||||
# config. (containing some necessary arg)
|
||||
# baseline: The baseline score (i.e. average val_acc) from the previous epoch
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
(
|
||||
GradnormMeter,
|
||||
LossMeter,
|
||||
ValAccMeter,
|
||||
EntropyMeter,
|
||||
BaselineMeter,
|
||||
RewardMeter,
|
||||
xend,
|
||||
) = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
time.time(),
|
||||
)
|
||||
|
||||
controller_num_aggregate = 20
|
||||
controller_train_steps = 50
|
||||
controller_bl_dec = 0.99
|
||||
controller_entropy_weight = 0.0001
|
||||
|
||||
network.eval()
|
||||
network.controller.train()
|
||||
network.controller.zero_grad()
|
||||
loader_iter = iter(xloader)
|
||||
for step in range(controller_train_steps * controller_num_aggregate):
|
||||
try:
|
||||
inputs, targets = next(loader_iter)
|
||||
except:
|
||||
loader_iter = iter(xloader)
|
||||
inputs, targets = next(loader_iter)
|
||||
inputs = inputs.cuda(non_blocking=True)
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - xend)
|
||||
|
||||
log_prob, entropy, sampled_arch = network.controller()
|
||||
with torch.no_grad():
|
||||
network.set_cal_mode("dynamic", sampled_arch)
|
||||
_, logits = network(inputs)
|
||||
val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
val_top1 = val_top1.view(-1) / 100
|
||||
reward = val_top1 + controller_entropy_weight * entropy
|
||||
if prev_baseline is None:
|
||||
baseline = val_top1
|
||||
else:
|
||||
baseline = prev_baseline - (1 - controller_bl_dec) * (
|
||||
prev_baseline - reward
|
||||
)
|
||||
|
||||
loss = -1 * log_prob * (reward - baseline)
|
||||
|
||||
# account
|
||||
RewardMeter.update(reward.item())
|
||||
BaselineMeter.update(baseline.item())
|
||||
ValAccMeter.update(val_top1.item() * 100)
|
||||
LossMeter.update(loss.item())
|
||||
EntropyMeter.update(entropy.item())
|
||||
|
||||
# Average gradient over controller_num_aggregate samples
|
||||
loss = loss / controller_num_aggregate
|
||||
loss.backward(retain_graph=True)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - xend)
|
||||
xend = time.time()
|
||||
if (step + 1) % controller_num_aggregate == 0:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
network.controller.parameters(), 5.0
|
||||
)
|
||||
GradnormMeter.update(grad_norm)
|
||||
optimizer.step()
|
||||
network.controller.zero_grad()
|
||||
|
||||
if step % print_freq == 0:
|
||||
Sstr = (
|
||||
"*Train-Controller* "
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(
|
||||
epoch_str, step, controller_train_steps * controller_num_aggregate
|
||||
)
|
||||
)
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})".format(
|
||||
loss=LossMeter,
|
||||
top1=ValAccMeter,
|
||||
reward=RewardMeter,
|
||||
basel=BaselineMeter,
|
||||
)
|
||||
Estr = "Entropy={:.4f} ({:.4f})".format(EntropyMeter.val, EntropyMeter.avg)
|
||||
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Estr)
|
||||
|
||||
return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg
|
||||
|
||||
|
||||
def get_best_arch(xloader, network, n_samples, algo):
|
||||
with torch.no_grad():
|
||||
network.eval()
|
||||
if algo == "random":
|
||||
archs, valid_accs = network.return_topK(n_samples, True), []
|
||||
elif algo == "setn":
|
||||
archs, valid_accs = network.return_topK(n_samples, False), []
|
||||
elif algo.startswith("darts") or algo == "gdas" or algo == "gdas_v1":
|
||||
arch = network.genotype
|
||||
archs, valid_accs = [arch], []
|
||||
elif algo == "enas":
|
||||
archs, valid_accs = [], []
|
||||
for _ in range(n_samples):
|
||||
_, _, sampled_arch = network.controller()
|
||||
archs.append(sampled_arch)
|
||||
else:
|
||||
raise ValueError("Invalid algorithm name : {:}".format(algo))
|
||||
loader_iter = iter(xloader)
|
||||
for i, sampled_arch in enumerate(archs):
|
||||
network.set_cal_mode("dynamic", sampled_arch)
|
||||
try:
|
||||
inputs, targets = next(loader_iter)
|
||||
except:
|
||||
loader_iter = iter(xloader)
|
||||
inputs, targets = next(loader_iter)
|
||||
_, logits = network(inputs.cuda(non_blocking=True))
|
||||
val_top1, val_top5 = obtain_accuracy(
|
||||
logits.cpu().data, targets.data, topk=(1, 5)
|
||||
)
|
||||
valid_accs.append(val_top1.item())
|
||||
best_idx = np.argmax(valid_accs)
|
||||
best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
|
||||
return best_arch, best_valid_acc
|
||||
|
||||
|
||||
def valid_func(xloader, network, criterion, algo, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
network.eval()
|
||||
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# prediction
|
||||
_, logits = network(arch_inputs.cuda(non_blocking=True))
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
# record
|
||||
arch_prec1, arch_prec5 = obtain_accuracy(
|
||||
logits.data, arch_targets.data, topk=(1, 5)
|
||||
)
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
|
||||
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||
|
||||
|
||||
def main(xargs):
|
||||
assert torch.cuda.is_available(), "CUDA is not available."
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(xargs.workers)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(
|
||||
xargs.dataset, xargs.data_path, -1
|
||||
)
|
||||
if xargs.overwite_epochs is None:
|
||||
extra_info = {"class_num": class_num, "xshape": xshape}
|
||||
else:
|
||||
extra_info = {
|
||||
"class_num": class_num,
|
||||
"xshape": xshape,
|
||||
"epochs": xargs.overwite_epochs,
|
||||
}
|
||||
config = load_config(xargs.config_path, extra_info, logger)
|
||||
search_loader, train_loader, valid_loader = get_nas_search_loaders(
|
||||
train_data,
|
||||
valid_data,
|
||||
xargs.dataset,
|
||||
"configs/nas-benchmark/",
|
||||
(config.batch_size, config.test_batch_size),
|
||||
xargs.workers,
|
||||
)
|
||||
logger.log(
|
||||
"||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
|
||||
xargs.dataset, len(search_loader), len(valid_loader), config.batch_size
|
||||
)
|
||||
)
|
||||
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, "nats-bench")
|
||||
|
||||
model_config = dict2config(
|
||||
dict(
|
||||
name="generic",
|
||||
C=xargs.channel,
|
||||
N=xargs.num_cells,
|
||||
max_nodes=xargs.max_nodes,
|
||||
num_classes=class_num,
|
||||
space=search_space,
|
||||
affine=bool(xargs.affine),
|
||||
track_running_stats=bool(xargs.track_running_stats),
|
||||
),
|
||||
None,
|
||||
)
|
||||
logger.log("search space : {:}".format(search_space))
|
||||
logger.log("model config : {:}".format(model_config))
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
search_model.set_algo(xargs.algo)
|
||||
logger.log("{:}".format(search_model))
|
||||
|
||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(
|
||||
search_model.weights, config
|
||||
)
|
||||
a_optimizer = torch.optim.Adam(
|
||||
search_model.alphas,
|
||||
lr=xargs.arch_learning_rate,
|
||||
betas=(0.5, 0.999),
|
||||
weight_decay=xargs.arch_weight_decay,
|
||||
eps=xargs.arch_eps,
|
||||
)
|
||||
logger.log("w-optimizer : {:}".format(w_optimizer))
|
||||
logger.log("a-optimizer : {:}".format(a_optimizer))
|
||||
logger.log("w-scheduler : {:}".format(w_scheduler))
|
||||
logger.log("criterion : {:}".format(criterion))
|
||||
params = count_parameters_in_MB(search_model)
|
||||
logger.log("The parameters of the search model = {:.2f} MB".format(params))
|
||||
logger.log("search-space : {:}".format(search_space))
|
||||
if bool(xargs.use_api):
|
||||
api = create(None, "topology", fast_mode=True, verbose=False)
|
||||
else:
|
||||
api = None
|
||||
logger.log("{:} create API = {:} done".format(time_string(), api))
|
||||
|
||||
last_info, model_base_path, model_best_path = (
|
||||
logger.path("info"),
|
||||
logger.path("model"),
|
||||
logger.path("best"),
|
||||
)
|
||||
network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU
|
||||
|
||||
last_info, model_base_path, model_best_path = (
|
||||
logger.path("info"),
|
||||
logger.path("model"),
|
||||
logger.path("best"),
|
||||
)
|
||||
|
||||
if last_info.exists(): # automatically resume from previous checkpoint
|
||||
logger.log(
|
||||
"=> loading checkpoint of the last-info '{:}' start".format(last_info)
|
||||
)
|
||||
last_info = torch.load(last_info)
|
||||
start_epoch = last_info["epoch"]
|
||||
checkpoint = torch.load(last_info["last_checkpoint"])
|
||||
genotypes = checkpoint["genotypes"]
|
||||
baseline = checkpoint["baseline"]
|
||||
valid_accuracies = checkpoint["valid_accuracies"]
|
||||
search_model.load_state_dict(checkpoint["search_model"])
|
||||
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
|
||||
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
|
||||
a_optimizer.load_state_dict(checkpoint["a_optimizer"])
|
||||
logger.log(
|
||||
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(
|
||||
last_info, start_epoch
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies, genotypes = (
|
||||
0,
|
||||
{"best": -1},
|
||||
{-1: network.return_topK(1, True)[0]},
|
||||
)
|
||||
baseline = None
|
||||
|
||||
# start training
|
||||
start_time, search_time, epoch_time, total_epoch = (
|
||||
time.time(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
config.epochs + config.warmup,
|
||||
)
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
w_scheduler.update(epoch, 0.0)
|
||||
need_time = "Time Left: {:}".format(
|
||||
convert_secs2time(epoch_time.val * (total_epoch - epoch), True)
|
||||
)
|
||||
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
|
||||
logger.log(
|
||||
"\n[Search the {:}-th epoch] {:}, LR={:}".format(
|
||||
epoch_str, need_time, min(w_scheduler.get_lr())
|
||||
)
|
||||
)
|
||||
|
||||
network.set_drop_path(float(epoch + 1) / total_epoch, xargs.drop_path_rate)
|
||||
if xargs.algo == "gdas" or xargs.algo == "gdas_v1":
|
||||
network.set_tau(
|
||||
xargs.tau_max
|
||||
- (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)
|
||||
)
|
||||
logger.log(
|
||||
"[RESET tau as : {:} and drop_path as {:}]".format(
|
||||
network.tau, network.drop_path
|
||||
)
|
||||
)
|
||||
(
|
||||
search_w_loss,
|
||||
search_w_top1,
|
||||
search_w_top5,
|
||||
search_a_loss,
|
||||
search_a_top1,
|
||||
search_a_top5,
|
||||
) = search_func(
|
||||
search_loader,
|
||||
network,
|
||||
criterion,
|
||||
w_scheduler,
|
||||
w_optimizer,
|
||||
a_optimizer,
|
||||
epoch_str,
|
||||
xargs.print_freq,
|
||||
xargs.algo,
|
||||
logger,
|
||||
)
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log(
|
||||
"[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
|
||||
epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
|
||||
)
|
||||
)
|
||||
logger.log(
|
||||
"[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
|
||||
epoch_str, search_a_loss, search_a_top1, search_a_top5
|
||||
)
|
||||
)
|
||||
if xargs.algo == "enas":
|
||||
ctl_loss, ctl_acc, baseline, ctl_reward = train_controller(
|
||||
valid_loader,
|
||||
network,
|
||||
criterion,
|
||||
a_optimizer,
|
||||
baseline,
|
||||
epoch_str,
|
||||
xargs.print_freq,
|
||||
logger,
|
||||
)
|
||||
logger.log(
|
||||
"[{:}] controller : loss={:}, acc={:}, baseline={:}, reward={:}".format(
|
||||
epoch_str, ctl_loss, ctl_acc, baseline, ctl_reward
|
||||
)
|
||||
)
|
||||
|
||||
genotype, temp_accuracy = get_best_arch(
|
||||
valid_loader, network, xargs.eval_candidate_num, xargs.algo
|
||||
)
|
||||
if xargs.algo == "setn" or xargs.algo == "enas":
|
||||
network.set_cal_mode("dynamic", genotype)
|
||||
elif xargs.algo == "gdas":
|
||||
network.set_cal_mode("gdas", None)
|
||||
elif xargs.algo == "gdas_v1":
|
||||
network.set_cal_mode("gdas_v1", None)
|
||||
elif xargs.algo.startswith("darts"):
|
||||
network.set_cal_mode("joint", None)
|
||||
elif xargs.algo == "random":
|
||||
network.set_cal_mode("urs", None)
|
||||
else:
|
||||
raise ValueError("Invalid algorithm name : {:}".format(xargs.algo))
|
||||
logger.log(
|
||||
"[{:}] - [get_best_arch] : {:} -> {:}".format(
|
||||
epoch_str, genotype, temp_accuracy
|
||||
)
|
||||
)
|
||||
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
|
||||
valid_loader, network, criterion, xargs.algo, logger
|
||||
)
|
||||
logger.log(
|
||||
"[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format(
|
||||
epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype
|
||||
)
|
||||
)
|
||||
valid_accuracies[epoch] = valid_a_top1
|
||||
|
||||
genotypes[epoch] = genotype
|
||||
logger.log(
|
||||
"<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])
|
||||
)
|
||||
# save checkpoint
|
||||
save_path = save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"args": deepcopy(xargs),
|
||||
"baseline": baseline,
|
||||
"search_model": search_model.state_dict(),
|
||||
"w_optimizer": w_optimizer.state_dict(),
|
||||
"a_optimizer": a_optimizer.state_dict(),
|
||||
"w_scheduler": w_scheduler.state_dict(),
|
||||
"genotypes": genotypes,
|
||||
"valid_accuracies": valid_accuracies,
|
||||
},
|
||||
model_base_path,
|
||||
logger,
|
||||
)
|
||||
last_info = save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"args": deepcopy(args),
|
||||
"last_checkpoint": save_path,
|
||||
},
|
||||
logger.path("info"),
|
||||
logger,
|
||||
)
|
||||
with torch.no_grad():
|
||||
logger.log("{:}".format(search_model.show_alphas()))
|
||||
if api is not None:
|
||||
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "200")))
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
# the final post procedure : count the time
|
||||
start_time = time.time()
|
||||
genotype, temp_accuracy = get_best_arch(
|
||||
valid_loader, network, xargs.eval_candidate_num, xargs.algo
|
||||
)
|
||||
if xargs.algo == "setn" or xargs.algo == "enas":
|
||||
network.set_cal_mode("dynamic", genotype)
|
||||
elif xargs.algo == "gdas":
|
||||
network.set_cal_mode("gdas", None)
|
||||
elif xargs.algo == "gdas_v1":
|
||||
network.set_cal_mode("gdas_v1", None)
|
||||
elif xargs.algo.startswith("darts"):
|
||||
network.set_cal_mode("joint", None)
|
||||
elif xargs.algo == "random":
|
||||
network.set_cal_mode("urs", None)
|
||||
else:
|
||||
raise ValueError("Invalid algorithm name : {:}".format(xargs.algo))
|
||||
search_time.update(time.time() - start_time)
|
||||
|
||||
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
|
||||
valid_loader, network, criterion, xargs.algo, logger
|
||||
)
|
||||
logger.log(
|
||||
"Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format(
|
||||
genotype, valid_a_top1
|
||||
)
|
||||
)
|
||||
|
||||
logger.log("\n" + "-" * 100)
|
||||
# check the performance from the architecture dataset
|
||||
logger.log(
|
||||
"[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
|
||||
xargs.algo, total_epoch, search_time.sum, genotype
|
||||
)
|
||||
)
|
||||
if api is not None:
|
||||
logger.log("{:}".format(api.query_by_arch(genotype, "200")))
|
||||
logger.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.")
|
||||
parser.add_argument("--data_path", type=str, help="Path to dataset")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
choices=["cifar10", "cifar100", "ImageNet16-120"],
|
||||
help="Choose between Cifar10/100 and ImageNet-16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_space",
|
||||
type=str,
|
||||
default="tss",
|
||||
choices=["tss"],
|
||||
help="The search space name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--algo",
|
||||
type=str,
|
||||
choices=["darts-v1", "darts-v2", "gdas", "gdas_v1", "setn", "random", "enas"],
|
||||
help="The search space name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_api",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[0, 1],
|
||||
help="Whether use API or not (which will cost much memory).",
|
||||
)
|
||||
# FOR GDAS
|
||||
parser.add_argument(
|
||||
"--tau_min", type=float, default=0.1, help="The minimum tau for Gumbel Softmax."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tau_max", type=float, default=10, help="The maximum tau for Gumbel Softmax."
|
||||
)
|
||||
# channels and number-of-cells
|
||||
parser.add_argument(
|
||||
"--max_nodes", type=int, default=4, help="The maximum number of nodes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel", type=int, default=16, help="The number of channels."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_cells", type=int, default=5, help="The number of cells in one stage."
|
||||
)
|
||||
#
|
||||
parser.add_argument(
|
||||
"--eval_candidate_num",
|
||||
type=int,
|
||||
default=100,
|
||||
help="The number of selected architectures to evaluate.",
|
||||
)
|
||||
#
|
||||
parser.add_argument(
|
||||
"--track_running_stats",
|
||||
type=int,
|
||||
default=0,
|
||||
choices=[0, 1],
|
||||
help="Whether use track_running_stats or not in the BN layer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--affine",
|
||||
type=int,
|
||||
default=0,
|
||||
choices=[0, 1],
|
||||
help="Whether use affine=True or False in the BN layer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
default="./configs/nas-benchmark/algos/weight-sharing.config",
|
||||
help="The path of configuration.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwite_epochs",
|
||||
type=int,
|
||||
help="The number of epochs to overwrite that value in config files.",
|
||||
)
|
||||
# architecture leraning rate
|
||||
parser.add_argument(
|
||||
"--arch_learning_rate",
|
||||
type=float,
|
||||
default=3e-4,
|
||||
help="learning rate for arch encoding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch_weight_decay",
|
||||
type=float,
|
||||
default=1e-3,
|
||||
help="weight decay for arch encoding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch_eps", type=float, default=1e-8, help="weight decay for arch encoding"
|
||||
)
|
||||
parser.add_argument("--drop_path_rate", type=float, help="The drop path rate.")
|
||||
# log
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./output/search",
|
||||
help="Folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_freq", type=int, default=200, help="print frequency (default: 200)"
|
||||
)
|
||||
parser.add_argument("--rand_seed", type=int, help="manual seed")
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
if args.overwite_epochs is None:
|
||||
args.save_dir = os.path.join(
|
||||
"{:}-{:}".format(args.save_dir, args.search_space),
|
||||
args.dataset,
|
||||
"{:}-affine{:}_BN{:}-{:}".format(
|
||||
args.algo, args.affine, args.track_running_stats, args.drop_path_rate
|
||||
),
|
||||
)
|
||||
else:
|
||||
args.save_dir = os.path.join(
|
||||
"{:}-{:}".format(args.save_dir, args.search_space),
|
||||
args.dataset,
|
||||
"{:}-affine{:}_BN{:}-E{:}-{:}".format(
|
||||
args.algo,
|
||||
args.affine,
|
||||
args.track_running_stats,
|
||||
args.overwite_epochs,
|
||||
args.drop_path_rate,
|
||||
),
|
||||
)
|
||||
|
||||
main(args)
|
582
AutoDL-Projects/exps/NATS-algos/search-size.py
Normal file
582
AutoDL-Projects/exps/NATS-algos/search-size.py
Normal file
@@ -0,0 +1,582 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
###########################################################################################################################################
|
||||
#
|
||||
# In this file, we aims to evaluate three kinds of channel searching strategies:
|
||||
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
|
||||
# - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
|
||||
# - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
|
||||
#
|
||||
# For simplicity, we use tas, mask_gumbel, and mask_rl to refer these three strategies. Their official implementations are at the following links:
|
||||
# - TAS: https://github.com/D-X-Y/AutoDL-Projects/blob/main/docs/NeurIPS-2019-TAS.md
|
||||
# - FBNetV2: https://github.com/facebookresearch/mobile-vision
|
||||
# - TuNAS: https://github.com/google-research/google-research/tree/master/tunas
|
||||
####
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --warmup_ratio 0.25
|
||||
####
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo tas --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo tas --rand_seed 777
|
||||
####
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_gumbel --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_gumbel --rand_seed 777
|
||||
####
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777 --use_api 0
|
||||
# python ./exps/NATS-algos/search-size.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo mask_rl --arch_weight_decay 0 --rand_seed 777
|
||||
# python ./exps/NATS-algos/search-size.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo mask_rl --arch_weight_decay 0 --rand_seed 777
|
||||
###########################################################################################################################################
|
||||
import os, sys, time, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xautodl.config_utils import load_config, dict2config, configure2str
|
||||
from xautodl.datasets import get_datasets, get_nas_search_loaders
|
||||
from xautodl.procedures import (
|
||||
prepare_seed,
|
||||
prepare_logger,
|
||||
save_checkpoint,
|
||||
copy_checkpoint,
|
||||
get_optim_scheduler,
|
||||
)
|
||||
from xautodl.utils import count_parameters_in_MB, obtain_accuracy
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from xautodl.models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
# Ad-hoc for RL algorithms.
|
||||
class ExponentialMovingAverage(object):
|
||||
"""Class that maintains an exponential moving average."""
|
||||
|
||||
def __init__(self, momentum):
|
||||
self._numerator = 0
|
||||
self._denominator = 0
|
||||
self._momentum = momentum
|
||||
|
||||
def update(self, value):
|
||||
self._numerator = (
|
||||
self._momentum * self._numerator + (1 - self._momentum) * value
|
||||
)
|
||||
self._denominator = self._momentum * self._denominator + (1 - self._momentum)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
"""Return the current value of the moving average"""
|
||||
return self._numerator / self._denominator
|
||||
|
||||
|
||||
RL_BASELINE_EMA = ExponentialMovingAverage(0.95)
|
||||
|
||||
|
||||
def search_func(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
w_optimizer,
|
||||
a_optimizer,
|
||||
enable_controller,
|
||||
algo,
|
||||
epoch_str,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
end = time.time()
|
||||
network.train()
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
|
||||
xloader
|
||||
):
|
||||
scheduler.update(None, 1.0 * step / len(xloader))
|
||||
base_inputs = base_inputs.cuda(non_blocking=True)
|
||||
arch_inputs = arch_inputs.cuda(non_blocking=True)
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# Update the weights
|
||||
network.zero_grad()
|
||||
_, logits, _ = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
w_optimizer.step()
|
||||
# record
|
||||
base_prec1, base_prec5 = obtain_accuracy(
|
||||
logits.data, base_targets.data, topk=(1, 5)
|
||||
)
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
base_top1.update(base_prec1.item(), base_inputs.size(0))
|
||||
base_top5.update(base_prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture-weight
|
||||
network.zero_grad()
|
||||
a_optimizer.zero_grad()
|
||||
_, logits, log_probs = network(arch_inputs)
|
||||
arch_prec1, arch_prec5 = obtain_accuracy(
|
||||
logits.data, arch_targets.data, topk=(1, 5)
|
||||
)
|
||||
if algo == "mask_rl":
|
||||
with torch.no_grad():
|
||||
RL_BASELINE_EMA.update(arch_prec1.item())
|
||||
rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
|
||||
rl_log_prob = sum(log_probs)
|
||||
arch_loss = -rl_advantage * rl_log_prob
|
||||
elif algo == "tas" or algo == "mask_gumbel":
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
else:
|
||||
raise ValueError("invalid algorightm name: {:}".format(algo))
|
||||
if enable_controller:
|
||||
arch_loss.backward()
|
||||
a_optimizer.step()
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
|
||||
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % print_freq == 0 or step + 1 == len(xloader):
|
||||
Sstr = (
|
||||
"*SEARCH* "
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader))
|
||||
)
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
|
||||
loss=base_losses, top1=base_top1, top5=base_top5
|
||||
)
|
||||
Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
|
||||
loss=arch_losses, top1=arch_top1, top5=arch_top5
|
||||
)
|
||||
logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr)
|
||||
return (
|
||||
base_losses.avg,
|
||||
base_top1.avg,
|
||||
base_top5.avg,
|
||||
arch_losses.avg,
|
||||
arch_top1.avg,
|
||||
arch_top5.avg,
|
||||
)
|
||||
|
||||
|
||||
def valid_func(xloader, network, criterion, logger):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
end = time.time()
|
||||
with torch.no_grad():
|
||||
network.eval()
|
||||
for step, (arch_inputs, arch_targets) in enumerate(xloader):
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# prediction
|
||||
_, logits, _ = network(arch_inputs.cuda(non_blocking=True))
|
||||
arch_loss = criterion(logits, arch_targets)
|
||||
# record
|
||||
arch_prec1, arch_prec5 = obtain_accuracy(
|
||||
logits.data, arch_targets.data, topk=(1, 5)
|
||||
)
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
|
||||
arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return arch_losses.avg, arch_top1.avg, arch_top5.avg
|
||||
|
||||
|
||||
def main(xargs):
|
||||
assert torch.cuda.is_available(), "CUDA is not available."
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
# torch.set_num_threads(xargs.workers)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(
|
||||
xargs.dataset, xargs.data_path, -1
|
||||
)
|
||||
if xargs.overwite_epochs is None:
|
||||
extra_info = {"class_num": class_num, "xshape": xshape}
|
||||
else:
|
||||
extra_info = {
|
||||
"class_num": class_num,
|
||||
"xshape": xshape,
|
||||
"epochs": xargs.overwite_epochs,
|
||||
}
|
||||
config = load_config(xargs.config_path, extra_info, logger)
|
||||
search_loader, train_loader, valid_loader = get_nas_search_loaders(
|
||||
train_data,
|
||||
valid_data,
|
||||
xargs.dataset,
|
||||
"configs/nas-benchmark/",
|
||||
(config.batch_size, config.test_batch_size),
|
||||
xargs.workers,
|
||||
)
|
||||
logger.log(
|
||||
"||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
|
||||
xargs.dataset, len(search_loader), len(valid_loader), config.batch_size
|
||||
)
|
||||
)
|
||||
logger.log("||||||| {:10s} ||||||| Config={:}".format(xargs.dataset, config))
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, "nats-bench")
|
||||
|
||||
model_config = dict2config(
|
||||
dict(
|
||||
name="generic",
|
||||
super_type="search-shape",
|
||||
candidate_Cs=search_space["candidates"],
|
||||
max_num_Cs=search_space["numbers"],
|
||||
num_classes=class_num,
|
||||
genotype=args.genotype,
|
||||
affine=bool(xargs.affine),
|
||||
track_running_stats=bool(xargs.track_running_stats),
|
||||
),
|
||||
None,
|
||||
)
|
||||
logger.log("search space : {:}".format(search_space))
|
||||
logger.log("model config : {:}".format(model_config))
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
search_model.set_algo(xargs.algo)
|
||||
logger.log("{:}".format(search_model))
|
||||
|
||||
w_optimizer, w_scheduler, criterion = get_optim_scheduler(
|
||||
search_model.weights, config
|
||||
)
|
||||
a_optimizer = torch.optim.Adam(
|
||||
search_model.alphas,
|
||||
lr=xargs.arch_learning_rate,
|
||||
betas=(0.5, 0.999),
|
||||
weight_decay=xargs.arch_weight_decay,
|
||||
eps=xargs.arch_eps,
|
||||
)
|
||||
logger.log("w-optimizer : {:}".format(w_optimizer))
|
||||
logger.log("a-optimizer : {:}".format(a_optimizer))
|
||||
logger.log("w-scheduler : {:}".format(w_scheduler))
|
||||
logger.log("criterion : {:}".format(criterion))
|
||||
params = count_parameters_in_MB(search_model)
|
||||
logger.log("The parameters of the search model = {:.2f} MB".format(params))
|
||||
logger.log("search-space : {:}".format(search_space))
|
||||
if bool(xargs.use_api):
|
||||
api = create(None, "size", fast_mode=True, verbose=False)
|
||||
else:
|
||||
api = None
|
||||
logger.log("{:} create API = {:} done".format(time_string(), api))
|
||||
|
||||
last_info, model_base_path, model_best_path = (
|
||||
logger.path("info"),
|
||||
logger.path("model"),
|
||||
logger.path("best"),
|
||||
)
|
||||
network, criterion = search_model.cuda(), criterion.cuda() # use a single GPU
|
||||
|
||||
last_info, model_base_path, model_best_path = (
|
||||
logger.path("info"),
|
||||
logger.path("model"),
|
||||
logger.path("best"),
|
||||
)
|
||||
|
||||
if last_info.exists(): # automatically resume from previous checkpoint
|
||||
logger.log(
|
||||
"=> loading checkpoint of the last-info '{:}' start".format(last_info)
|
||||
)
|
||||
last_info = torch.load(last_info)
|
||||
start_epoch = last_info["epoch"]
|
||||
checkpoint = torch.load(last_info["last_checkpoint"])
|
||||
genotypes = checkpoint["genotypes"]
|
||||
valid_accuracies = checkpoint["valid_accuracies"]
|
||||
search_model.load_state_dict(checkpoint["search_model"])
|
||||
w_scheduler.load_state_dict(checkpoint["w_scheduler"])
|
||||
w_optimizer.load_state_dict(checkpoint["w_optimizer"])
|
||||
a_optimizer.load_state_dict(checkpoint["a_optimizer"])
|
||||
logger.log(
|
||||
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(
|
||||
last_info, start_epoch
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies, genotypes = 0, {"best": -1}, {-1: network.random}
|
||||
|
||||
# start training
|
||||
start_time, search_time, epoch_time, total_epoch = (
|
||||
time.time(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
config.epochs + config.warmup,
|
||||
)
|
||||
for epoch in range(start_epoch, total_epoch):
|
||||
w_scheduler.update(epoch, 0.0)
|
||||
need_time = "Time Left: {:}".format(
|
||||
convert_secs2time(epoch_time.val * (total_epoch - epoch), True)
|
||||
)
|
||||
epoch_str = "{:03d}-{:03d}".format(epoch, total_epoch)
|
||||
|
||||
if (
|
||||
xargs.warmup_ratio is None
|
||||
or xargs.warmup_ratio <= float(epoch) / total_epoch
|
||||
):
|
||||
enable_controller = True
|
||||
network.set_warmup_ratio(None)
|
||||
else:
|
||||
enable_controller = False
|
||||
network.set_warmup_ratio(
|
||||
1.0 - float(epoch) / total_epoch / xargs.warmup_ratio
|
||||
)
|
||||
|
||||
logger.log(
|
||||
"\n[Search the {:}-th epoch] {:}, LR={:}, controller-warmup={:}, enable_controller={:}".format(
|
||||
epoch_str,
|
||||
need_time,
|
||||
min(w_scheduler.get_lr()),
|
||||
network.warmup_ratio,
|
||||
enable_controller,
|
||||
)
|
||||
)
|
||||
|
||||
if xargs.algo == "mask_gumbel" or xargs.algo == "tas":
|
||||
network.set_tau(
|
||||
xargs.tau_max
|
||||
- (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)
|
||||
)
|
||||
logger.log("[RESET tau as : {:}]".format(network.tau))
|
||||
(
|
||||
search_w_loss,
|
||||
search_w_top1,
|
||||
search_w_top5,
|
||||
search_a_loss,
|
||||
search_a_top1,
|
||||
search_a_top5,
|
||||
) = search_func(
|
||||
search_loader,
|
||||
network,
|
||||
criterion,
|
||||
w_scheduler,
|
||||
w_optimizer,
|
||||
a_optimizer,
|
||||
enable_controller,
|
||||
xargs.algo,
|
||||
epoch_str,
|
||||
xargs.print_freq,
|
||||
logger,
|
||||
)
|
||||
search_time.update(time.time() - start_time)
|
||||
logger.log(
|
||||
"[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s".format(
|
||||
epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum
|
||||
)
|
||||
)
|
||||
logger.log(
|
||||
"[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%".format(
|
||||
epoch_str, search_a_loss, search_a_top1, search_a_top5
|
||||
)
|
||||
)
|
||||
|
||||
genotype = network.genotype
|
||||
logger.log("[{:}] - [get_best_arch] : {:}".format(epoch_str, genotype))
|
||||
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
|
||||
valid_loader, network, criterion, logger
|
||||
)
|
||||
logger.log(
|
||||
"[{:}] evaluate : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}% | {:}".format(
|
||||
epoch_str, valid_a_loss, valid_a_top1, valid_a_top5, genotype
|
||||
)
|
||||
)
|
||||
valid_accuracies[epoch] = valid_a_top1
|
||||
|
||||
genotypes[epoch] = genotype
|
||||
logger.log(
|
||||
"<<<--->>> The {:}-th epoch : {:}".format(epoch_str, genotypes[epoch])
|
||||
)
|
||||
# save checkpoint
|
||||
save_path = save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"args": deepcopy(xargs),
|
||||
"search_model": search_model.state_dict(),
|
||||
"w_optimizer": w_optimizer.state_dict(),
|
||||
"a_optimizer": a_optimizer.state_dict(),
|
||||
"w_scheduler": w_scheduler.state_dict(),
|
||||
"genotypes": genotypes,
|
||||
"valid_accuracies": valid_accuracies,
|
||||
},
|
||||
model_base_path,
|
||||
logger,
|
||||
)
|
||||
last_info = save_checkpoint(
|
||||
{
|
||||
"epoch": epoch + 1,
|
||||
"args": deepcopy(args),
|
||||
"last_checkpoint": save_path,
|
||||
},
|
||||
logger.path("info"),
|
||||
logger,
|
||||
)
|
||||
with torch.no_grad():
|
||||
logger.log("{:}".format(search_model.show_alphas()))
|
||||
if api is not None:
|
||||
logger.log("{:}".format(api.query_by_arch(genotypes[epoch], "90")))
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
# the final post procedure : count the time
|
||||
start_time = time.time()
|
||||
genotype = network.genotype
|
||||
search_time.update(time.time() - start_time)
|
||||
|
||||
valid_a_loss, valid_a_top1, valid_a_top5 = valid_func(
|
||||
valid_loader, network, criterion, logger
|
||||
)
|
||||
logger.log(
|
||||
"Last : the gentotype is : {:}, with the validation accuracy of {:.3f}%.".format(
|
||||
genotype, valid_a_top1
|
||||
)
|
||||
)
|
||||
|
||||
logger.log("\n" + "-" * 100)
|
||||
# check the performance from the architecture dataset
|
||||
logger.log(
|
||||
"[{:}] run {:} epochs, cost {:.1f} s, last-geno is {:}.".format(
|
||||
xargs.algo, total_epoch, search_time.sum, genotype
|
||||
)
|
||||
)
|
||||
if api is not None:
|
||||
logger.log("{:}".format(api.query_by_arch(genotype, "90")))
|
||||
logger.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.")
|
||||
parser.add_argument("--data_path", type=str, help="Path to dataset")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
choices=["cifar10", "cifar100", "ImageNet16-120"],
|
||||
help="Choose between Cifar10/100 and ImageNet-16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_space",
|
||||
type=str,
|
||||
default="sss",
|
||||
choices=["sss"],
|
||||
help="The search space name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--algo",
|
||||
type=str,
|
||||
choices=["tas", "mask_gumbel", "mask_rl"],
|
||||
help="The search space name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--genotype",
|
||||
type=str,
|
||||
default="|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|",
|
||||
help="The genotype.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_api",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[0, 1],
|
||||
help="Whether use API or not (which will cost much memory).",
|
||||
)
|
||||
# FOR GDAS
|
||||
parser.add_argument(
|
||||
"--tau_min", type=float, default=0.1, help="The minimum tau for Gumbel Softmax."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tau_max", type=float, default=10, help="The maximum tau for Gumbel Softmax."
|
||||
)
|
||||
# FOR ALL
|
||||
parser.add_argument(
|
||||
"--warmup_ratio", type=float, help="The warmup ratio, if None, not use warmup."
|
||||
)
|
||||
#
|
||||
parser.add_argument(
|
||||
"--track_running_stats",
|
||||
type=int,
|
||||
default=0,
|
||||
choices=[0, 1],
|
||||
help="Whether use track_running_stats or not in the BN layer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--affine",
|
||||
type=int,
|
||||
default=0,
|
||||
choices=[0, 1],
|
||||
help="Whether use affine=True or False in the BN layer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
default="./configs/nas-benchmark/algos/weight-sharing.config",
|
||||
help="The path of configuration.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwite_epochs",
|
||||
type=int,
|
||||
help="The number of epochs to overwrite that value in config files.",
|
||||
)
|
||||
# architecture leraning rate
|
||||
parser.add_argument(
|
||||
"--arch_learning_rate",
|
||||
type=float,
|
||||
default=3e-4,
|
||||
help="learning rate for arch encoding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch_weight_decay",
|
||||
type=float,
|
||||
default=1e-3,
|
||||
help="weight decay for arch encoding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch_eps", type=float, default=1e-8, help="weight decay for arch encoding"
|
||||
)
|
||||
# log
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="number of data loading workers (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dir",
|
||||
type=str,
|
||||
default="./output/search",
|
||||
help="Folder to save checkpoints and log.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_freq", type=int, default=200, help="print frequency (default: 200)"
|
||||
)
|
||||
parser.add_argument("--rand_seed", type=int, help="manual seed")
|
||||
args = parser.parse_args()
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
dirname = "{:}-affine{:}_BN{:}-AWD{:}-WARM{:}".format(
|
||||
args.algo,
|
||||
args.affine,
|
||||
args.track_running_stats,
|
||||
args.arch_weight_decay,
|
||||
args.warmup_ratio,
|
||||
)
|
||||
if args.overwite_epochs is not None:
|
||||
dirname = dirname + "-E{:}".format(args.overwite_epochs)
|
||||
args.save_dir = os.path.join(
|
||||
"{:}-{:}".format(args.save_dir, args.search_space), args.dataset, dirname
|
||||
)
|
||||
|
||||
main(args)
|
Reference in New Issue
Block a user