Upgrade NAS-Bench-201 to APIv1.3/FILEv1.1
This commit is contained in:
@@ -5,4 +5,5 @@ from .api import NASBench201API
|
||||
from .api import ArchResults, ResultsCount
|
||||
|
||||
# NAS_BENCH_201_API_VERSION="v1.1" # [2020.02.25]
|
||||
NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09]
|
||||
# NAS_BENCH_201_API_VERSION="v1.2" # [2020.03.09]
|
||||
NAS_BENCH_201_API_VERSION="v1.3" # [2020.03.16]
|
||||
|
@@ -3,11 +3,14 @@
|
||||
############################################################################################
|
||||
# NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search, ICLR 2020 #
|
||||
############################################################################################
|
||||
# The history of benchmark files:
|
||||
# [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
|
||||
# [2020.03.08] Next version (coming soon)
|
||||
# [2020.03.16] NAS-Bench-201-v1_1-096897.pth : 2225 architectures are trained once, 5439 archiitectures are trained twice, 7961 architectures are trained three times on all training sets. For the hyper-parameters with the total epochs of 12, each model is trained on CIFAR-10, CIFAR-100, ImageNet16-120 once, and is trained on CIFAR-10-VALID twice.
|
||||
#
|
||||
# I'm still actively enhancing this benchmark. Please feel free to contact me if you have any question w.r.t. NAS-Bench-201.
|
||||
#
|
||||
import os, copy, random, torch, numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Text, Union, Dict
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
@@ -44,9 +47,12 @@ class NASBench201API(object):
|
||||
|
||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||
def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True):
|
||||
if isinstance(file_path_or_dict, str):
|
||||
self.filename = None
|
||||
if isinstance(file_path_or_dict, str) or isinstance(file_path_or_dict, Path):
|
||||
file_path_or_dict = str(file_path_or_dict)
|
||||
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
self.filename = Path(file_path_or_dict).name
|
||||
file_path_or_dict = torch.load(file_path_or_dict)
|
||||
elif isinstance(file_path_or_dict, dict):
|
||||
file_path_or_dict = copy.deepcopy( file_path_or_dict )
|
||||
@@ -76,7 +82,7 @@ class NASBench201API(object):
|
||||
return len(self.meta_archs)
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs)))
|
||||
return ('{name}({num}/{total} architectures, file={filename})'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs), filename=self.filename))
|
||||
|
||||
def random(self):
|
||||
"""Return a random index of all architectures."""
|
||||
@@ -98,9 +104,10 @@ class NASBench201API(object):
|
||||
else: arch_index = -1
|
||||
return arch_index
|
||||
|
||||
# Overwrite all information of the 'index'-th architecture in the search space.
|
||||
# It will load its data from 'archive_root'.
|
||||
def reload(self, archive_root: Text, index: int):
|
||||
"""Overwrite all information of the 'index'-th architecture in the search space.
|
||||
It will load its data from 'archive_root'.
|
||||
"""
|
||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
|
||||
@@ -109,6 +116,13 @@ class NASBench201API(object):
|
||||
assert isinstance(xdata, dict) and 'full' in xdata and 'less' in xdata, 'invalid format of data in {:}'.format(xfile_path)
|
||||
self.arch2infos_less[index] = ArchResults.create_from_state_dict( xdata['less'] )
|
||||
self.arch2infos_full[index] = ArchResults.create_from_state_dict( xdata['full'] )
|
||||
|
||||
def clear_params(self, index: int, use_12epochs_result: bool):
|
||||
"""Remove the architecture's weights to save memory."""
|
||||
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||
else : arch2infos = self.arch2infos_full
|
||||
archresult = arch2infos[index]
|
||||
archresult.clear_params()
|
||||
|
||||
# This function is used to query the information of a specific archiitecture
|
||||
# 'arch' can be an architecture index or an architecture string
|
||||
@@ -162,6 +176,7 @@ class NASBench201API(object):
|
||||
return archInfo
|
||||
|
||||
def find_best(self, dataset, metric_on_set, FLOP_max=None, Param_max=None, use_12epochs_result=False):
|
||||
"""Find the architecture with the highest accuracy based on some constraints."""
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
best_index, highest_accuracy = -1, None
|
||||
@@ -255,6 +270,65 @@ class NASBench201API(object):
|
||||
# `is_random`
|
||||
# When is_random=True, the performance of a random architecture will be returned
|
||||
# When is_random=False, the performanceo of all trials will be averaged.
|
||||
def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
archresult = arch2infos[index]
|
||||
# if randomly select one trial, select the seed at first
|
||||
if isinstance(is_random, bool) and is_random:
|
||||
seeds = archresult.get_dataset_seeds(dataset)
|
||||
is_random = random.choice(seeds)
|
||||
# collect the training information
|
||||
train_info = archresult.get_metrics(dataset, 'train', iepoch=iepoch, is_random=is_random)
|
||||
total = train_info['iepoch'] + 1
|
||||
xinfo = {'train-loss' : train_info['loss'],
|
||||
'train-accuracy': train_info['accuracy'],
|
||||
'train-per-time': train_info['all_time'] / total,
|
||||
'train-all-time': train_info['all_time']}
|
||||
# collect the evaluation information
|
||||
if dataset == 'cifar10-valid':
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
try:
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test_info = None
|
||||
valtest_info = None
|
||||
else:
|
||||
try: # collect results on the proposed test set
|
||||
if dataset == 'cifar10':
|
||||
test_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
test_info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
test_info = None
|
||||
try: # collect results on the proposed validation set
|
||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||
except:
|
||||
valid_info = None
|
||||
try:
|
||||
if dataset != 'cifar10':
|
||||
valtest_info = archresult.get_metrics(dataset, 'ori-test', iepoch=iepoch, is_random=is_random)
|
||||
else:
|
||||
valtest_info = None
|
||||
except:
|
||||
valtest_info = None
|
||||
if valid_info is not None:
|
||||
xinfo['valid-loss'] = valid_info['loss']
|
||||
xinfo['valid-accuracy'] = valid_info['accuracy']
|
||||
xinfo['valid-per-time'] = valid_info['all_time'] / total
|
||||
xinfo['valid-all-time'] = valid_info['all_time']
|
||||
if test_info is not None:
|
||||
xinfo['test-loss'] = test_info['loss']
|
||||
xinfo['test-accuracy'] = test_info['accuracy']
|
||||
xinfo['test-per-time'] = test_info['all_time'] / total
|
||||
xinfo['test-all-time'] = test_info['all_time']
|
||||
if valtest_info is not None:
|
||||
xinfo['valtest-loss'] = valtest_info['loss']
|
||||
xinfo['valtest-accuracy'] = valtest_info['accuracy']
|
||||
xinfo['valtest-per-time'] = valtest_info['all_time'] / total
|
||||
xinfo['valtest-all-time'] = valtest_info['all_time']
|
||||
return xinfo
|
||||
""" # The following logic is deprecated after March 15 2020, where the benchmark file upgrades from NAS-Bench-201-v1_0-e61699.pth to NAS-Bench-201-v1_1-096897.pth.
|
||||
def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
@@ -312,6 +386,7 @@ class NASBench201API(object):
|
||||
xifo['est-valid-loss'] = est_valid_info['loss']
|
||||
xifo['est-valid-accuracy'] = est_valid_info['accuracy']
|
||||
return xifo
|
||||
"""
|
||||
|
||||
|
||||
def show(self, index: int = -1) -> None:
|
||||
@@ -349,6 +424,26 @@ class NASBench201API(object):
|
||||
print('This index ({:}) is out of range (0~{:}).'.format(index, len(self.meta_archs)))
|
||||
|
||||
|
||||
def statistics(self, dataset: Text, use_12epochs_result: bool) -> Dict[int, int]:
|
||||
"""
|
||||
This function will count the number of total trials.
|
||||
"""
|
||||
valid_datasets = ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120']
|
||||
if dataset not in valid_datasets:
|
||||
raise ValueError('{:} not in {:}'.format(dataset, valid_datasets))
|
||||
if use_12epochs_result: arch2infos = self.arch2infos_less
|
||||
else : arch2infos = self.arch2infos_full
|
||||
nums = defaultdict(lambda: 0)
|
||||
for index in range(len(self)):
|
||||
archInfo = arch2infos[index]
|
||||
dataset_seed = archInfo.dataset_seed
|
||||
if dataset not in dataset_seed:
|
||||
nums[0] += 1
|
||||
else:
|
||||
nums[len(dataset_seed[dataset])] += 1
|
||||
return dict(nums)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def str2lists(arch_str: Text) -> List[tuple]:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user