Update docs of NATS-Bench
This commit is contained in:
@@ -801,7 +801,6 @@ if __name__ == '__main__':
|
||||
|
||||
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-VALID.pdf', (0, 100,10), 250)
|
||||
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'BN0', 'BN0-XX-CIFAR010-TEST.pdf' , (0, 100,10), 250)
|
||||
import pdb; pdb.set_trace()
|
||||
"""
|
||||
for x_maxs in [50, 250]:
|
||||
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
|
@@ -48,7 +48,6 @@ def test_api(api, sss_or_tss=True):
|
||||
print('')
|
||||
params = api.get_net_param(12, 'cifar10', None)
|
||||
|
||||
import pdb; pdb.set_trace()
|
||||
# Obtain the config and create the network
|
||||
config = api.get_net_config(12, 'cifar10')
|
||||
print('{:}\n'.format(config))
|
||||
|
@@ -95,7 +95,7 @@ def main(xargs, api):
|
||||
|
||||
logger.log('{:} use api : {:}'.format(time_string(), api))
|
||||
api.reset_time()
|
||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
|
||||
if xargs.search_space == 'tss':
|
||||
cs = get_topology_config_space(search_space)
|
||||
config2structure = config2topology_func()
|
||||
|
@@ -33,7 +33,7 @@ def main(xargs, api):
|
||||
logger.log('{:} use api : {:}'.format(time_string(), api))
|
||||
api.reset_time()
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
|
||||
if xargs.search_space == 'tss':
|
||||
random_arch = random_topology_func(search_space)
|
||||
else:
|
||||
|
@@ -160,7 +160,7 @@ def main(xargs, api):
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
|
||||
if xargs.search_space == 'tss':
|
||||
random_arch = random_topology_func(search_space)
|
||||
mutate_arch = mutate_topology_func(search_space)
|
||||
|
@@ -124,7 +124,7 @@ def main(xargs, api):
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
|
||||
if xargs.search_space == 'tss':
|
||||
policy = PolicyTopology(search_space)
|
||||
else:
|
||||
|
@@ -342,9 +342,8 @@ def main(xargs):
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
|
||||
|
||||
|
||||
model_config = dict2config(
|
||||
dict(name='generic', C=xargs.channel, N=xargs.num_cells, max_nodes=xargs.max_nodes, num_classes=class_num,
|
||||
space=search_space, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None)
|
||||
|
@@ -155,8 +155,8 @@ def main(xargs):
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, 'nas-bench-301')
|
||||
|
||||
search_space = get_search_spaces(xargs.search_space, 'nats-bench')
|
||||
|
||||
model_config = dict2config(
|
||||
dict(name='generic', super_type='search-shape', candidate_Cs=search_space['candidates'], max_num_Cs=search_space['numbers'], num_classes=class_num,
|
||||
genotype=args.genotype, affine=bool(xargs.affine), track_running_stats=bool(xargs.track_running_stats)), None)
|
||||
|
@@ -3,10 +3,10 @@
|
||||
###########################################################################################################################################################
|
||||
# Before run these commands, the files must be properly put.
|
||||
#
|
||||
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar10
|
||||
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset cifar100
|
||||
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NAS-Bench-301-v1_0 --dataset ImageNet16-120
|
||||
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NAS-Bench-201-v1_1 --dataset cifar10
|
||||
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10
|
||||
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset cifar100
|
||||
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset ImageNet16-120
|
||||
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10
|
||||
###########################################################################################################################################################
|
||||
import os, gc, sys, math, argparse, psutil
|
||||
import numpy as np
|
||||
@@ -140,7 +140,7 @@ if __name__ == '__main__':
|
||||
save_dir = Path(args.save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
meta_file = Path(args.base_path + '.pth')
|
||||
weight_dir = Path(args.base_path + '-archive')
|
||||
weight_dir = Path(args.base_path + '-full')
|
||||
assert meta_file.exists(), 'invalid path for api : {:}'.format(meta_file)
|
||||
assert weight_dir.exists() and weight_dir.is_dir(), 'invalid path for weight dir : {:}'.format(weight_dir)
|
||||
|
||||
|
@@ -395,9 +395,9 @@ if __name__ == '__main__':
|
||||
for xdata in datasets:
|
||||
visualize_tss_info(api201, xdata, to_save_dir)
|
||||
|
||||
api301 = create(None, 'size', verbose=True)
|
||||
api_sss = create(None, 'size', verbose=True)
|
||||
for xdata in datasets:
|
||||
visualize_sss_info(api301, xdata, to_save_dir)
|
||||
visualize_sss_info(api_sss, xdata, to_save_dir)
|
||||
|
||||
visualize_info(None, to_save_dir, 'tss')
|
||||
visualize_info(None, to_save_dir, 'sss')
|
||||
|
Reference in New Issue
Block a user