Create NATS
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
###############################################################
|
||||
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
||||
###############################################################
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size
|
||||
###############################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
|
||||
###############################################################
|
||||
# Usage: python exps/NAS-Bench-201/test-nas-api.py
|
||||
# Usage: python exps/NAS-Bench-201/test-nas-api.py #
|
||||
###############################################################
|
||||
import os, sys, time, torch, argparse
|
||||
import numpy as np
|
||||
@@ -21,7 +23,7 @@ import matplotlib.ticker as ticker
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import dict2config, load_config
|
||||
from nas_201_api import NASBench201API, NASBench301API
|
||||
from nats_bench import create
|
||||
from log_utils import time_string
|
||||
from models import get_cell_based_tiny_net, CellStructure
|
||||
|
||||
@@ -97,15 +99,14 @@ def test_issue_81_82(api):
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
api201 = NASBench201API(os.path.join(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_0-e61699.pth'), verbose=True)
|
||||
api201 = create(os.path.join(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_0-e61699.pth'), 'topology', True)
|
||||
test_issue_81_82(api201)
|
||||
# test_api(api201, False)
|
||||
print ('Test {:} done'.format(api201))
|
||||
|
||||
api201 = NASBench201API(None, verbose=True)
|
||||
api201 = create(None, 'topology', True) # use the default file path
|
||||
test_issue_81_82(api201)
|
||||
test_api(api201, False)
|
||||
print ('Test {:} done'.format(api201))
|
||||
|
||||
# api301 = NASBench301API(None, verbose=True)
|
||||
# test_api(api301, True)
|
||||
api301 = create(None, 'size', True)
|
||||
test_api(api301, True)
|
||||
|
@@ -16,7 +16,7 @@ from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from config_utils import dict2config
|
||||
# NAS-Bench-201 related module or function
|
||||
from models import CellStructure, get_cell_based_tiny_net
|
||||
from nas_201_api import NASBench301API, ArchResults, ResultsCount
|
||||
from nas_201_api import ArchResults, ResultsCount
|
||||
from procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user