Create NATS

This commit is contained in:
D-X-Y
2020-07-30 13:07:11 +00:00
parent df45e68366
commit 6061d74631
21 changed files with 1336 additions and 126 deletions

View File

@@ -1,9 +1,11 @@
###############################################################
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
###############################################################
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NAS-Bench-201/test-nas-api.py
# Usage: python exps/NAS-Bench-201/test-nas-api.py #
###############################################################
import os, sys, time, torch, argparse
import numpy as np
@@ -21,7 +23,7 @@ import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nas_201_api import NASBench201API, NASBench301API
from nats_bench import create
from log_utils import time_string
from models import get_cell_based_tiny_net, CellStructure
@@ -97,15 +99,14 @@ def test_issue_81_82(api):
if __name__ == '__main__':
api201 = NASBench201API(os.path.join(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_0-e61699.pth'), verbose=True)
api201 = create(os.path.join(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_0-e61699.pth'), 'topology', True)
test_issue_81_82(api201)
# test_api(api201, False)
print ('Test {:} done'.format(api201))
api201 = NASBench201API(None, verbose=True)
api201 = create(None, 'topology', True) # use the default file path
test_issue_81_82(api201)
test_api(api201, False)
print ('Test {:} done'.format(api201))
# api301 = NASBench301API(None, verbose=True)
# test_api(api301, True)
api301 = create(None, 'size', True)
test_api(api301, True)

View File

@@ -16,7 +16,7 @@ from log_utils import AverageMeter, time_string, convert_secs2time
from config_utils import dict2config
# NAS-Bench-201 related module or function
from models import CellStructure, get_cell_based_tiny_net
from nas_201_api import NASBench301API, ArchResults, ResultsCount
from nas_201_api import ArchResults, ResultsCount
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders