update NAS-Bench

This commit is contained in:
D-X-Y
2020-03-09 19:38:00 +11:00
parent 9a83814a46
commit e59eb804cb
35 changed files with 693 additions and 64 deletions

View File

@@ -7,7 +7,8 @@
# [2020.03.08] Next version (coming soon)
#
#
import os, sys, copy, random, torch, numpy as np
import os, copy, random, torch, numpy as np
from typing import List, Text, Union, Dict, Any
from collections import OrderedDict, defaultdict
@@ -43,7 +44,7 @@ This is the class for API of NAS-Bench-201.
class NASBench201API(object):
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
def __init__(self, file_path_or_dict, verbose=True):
def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True):
if isinstance(file_path_or_dict, str):
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
@@ -69,7 +70,7 @@ class NASBench201API(object):
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
self.archstr2index[ arch ] = idx
def __getitem__(self, index):
def __getitem__(self, index: int):
return copy.deepcopy( self.meta_archs[index] )
def __len__(self):
@@ -99,7 +100,7 @@ class NASBench201API(object):
# Overwrite all information of the 'index'-th architecture in the search space.
# It will load its data from 'archive_root'.
def reload(self, archive_root, index):
def reload(self, archive_root: Text, index: int):
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
@@ -141,7 +142,8 @@ class NASBench201API(object):
# -- cifar10 : training the model on the CIFAR-10 training + validation set.
# -- cifar100 : training the model on the CIFAR-100 training set.
# -- ImageNet16-120 : training the model on the ImageNet16-120 training set.
def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False):
def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None,
use_12epochs_result: bool = False):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
@@ -177,7 +179,7 @@ class NASBench201API(object):
return best_index, highest_accuracy
# return the topology structure of the `index`-th architecture
def arch(self, index):
def arch(self, index: int):
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
return copy.deepcopy(self.meta_archs[index])
@@ -238,7 +240,7 @@ class NASBench201API(object):
# `is_random`
# When is_random=True, the performance of a random architecture will be returned
# When is_random=False, the performanceo of all trials will be averaged.
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
else : basestr, arch2infos = '200epochs', self.arch2infos_full
archresult = arch2infos[index]
@@ -301,7 +303,7 @@ class NASBench201API(object):
If the index < 0: it will loop for all architectures and print their information one by one.
else: it will print the information of the 'index'-th archiitecture.
"""
def show(self, index=-1):
def show(self, index: int = -1) -> None:
if index < 0: # show all architectures
print(self)
for i, idx in enumerate(self.evaluated_indexes):
@@ -336,8 +338,8 @@ class NASBench201API(object):
# for i, node in enumerate(arch):
# print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
@staticmethod
def str2lists(xstr):
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
def str2lists(xstr: Text) -> List[Any]:
# assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):