update NAS-Bench
This commit is contained in:
@@ -7,7 +7,8 @@
|
||||
# [2020.03.08] Next version (coming soon)
|
||||
#
|
||||
#
|
||||
import os, sys, copy, random, torch, numpy as np
|
||||
import os, copy, random, torch, numpy as np
|
||||
from typing import List, Text, Union, Dict, Any
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
|
||||
@@ -43,7 +44,7 @@ This is the class for API of NAS-Bench-201.
|
||||
class NASBench201API(object):
|
||||
|
||||
""" The initialization function that takes the dataset file path (or a dict loaded from that path) as input. """
|
||||
def __init__(self, file_path_or_dict, verbose=True):
|
||||
def __init__(self, file_path_or_dict: Union[Text, Dict], verbose: bool=True):
|
||||
if isinstance(file_path_or_dict, str):
|
||||
if verbose: print('try to create the NAS-Bench-201 api from {:}'.format(file_path_or_dict))
|
||||
assert os.path.isfile(file_path_or_dict), 'invalid path : {:}'.format(file_path_or_dict)
|
||||
@@ -69,7 +70,7 @@ class NASBench201API(object):
|
||||
assert arch not in self.archstr2index, 'This [{:}]-th arch {:} already in the dict ({:}).'.format(idx, arch, self.archstr2index[arch])
|
||||
self.archstr2index[ arch ] = idx
|
||||
|
||||
def __getitem__(self, index):
|
||||
def __getitem__(self, index: int):
|
||||
return copy.deepcopy( self.meta_archs[index] )
|
||||
|
||||
def __len__(self):
|
||||
@@ -99,7 +100,7 @@ class NASBench201API(object):
|
||||
|
||||
# Overwrite all information of the 'index'-th architecture in the search space.
|
||||
# It will load its data from 'archive_root'.
|
||||
def reload(self, archive_root, index):
|
||||
def reload(self, archive_root: Text, index: int):
|
||||
assert os.path.isdir(archive_root), 'invalid directory : {:}'.format(archive_root)
|
||||
xfile_path = os.path.join(archive_root, '{:06d}-FULL.pth'.format(index))
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index of {:}'.format(index)
|
||||
@@ -141,7 +142,8 @@ class NASBench201API(object):
|
||||
# -- cifar10 : training the model on the CIFAR-10 training + validation set.
|
||||
# -- cifar100 : training the model on the CIFAR-100 training set.
|
||||
# -- ImageNet16-120 : training the model on the ImageNet16-120 training set.
|
||||
def query_by_index(self, arch_index, dataname=None, use_12epochs_result=False):
|
||||
def query_by_index(self, arch_index: int, dataname: Union[None, Text] = None,
|
||||
use_12epochs_result: bool = False):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
assert arch_index in arch2infos, 'arch_index [{:}] does not in arch2info with {:}'.format(arch_index, basestr)
|
||||
@@ -177,7 +179,7 @@ class NASBench201API(object):
|
||||
return best_index, highest_accuracy
|
||||
|
||||
# return the topology structure of the `index`-th architecture
|
||||
def arch(self, index):
|
||||
def arch(self, index: int):
|
||||
assert 0 <= index < len(self.meta_archs), 'invalid index : {:} vs. {:}.'.format(index, len(self.meta_archs))
|
||||
return copy.deepcopy(self.meta_archs[index])
|
||||
|
||||
@@ -238,7 +240,7 @@ class NASBench201API(object):
|
||||
# `is_random`
|
||||
# When is_random=True, the performance of a random architecture will be returned
|
||||
# When is_random=False, the performanceo of all trials will be averaged.
|
||||
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
||||
def get_more_info(self, index: int, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||
archresult = arch2infos[index]
|
||||
@@ -301,7 +303,7 @@ class NASBench201API(object):
|
||||
If the index < 0: it will loop for all architectures and print their information one by one.
|
||||
else: it will print the information of the 'index'-th archiitecture.
|
||||
"""
|
||||
def show(self, index=-1):
|
||||
def show(self, index: int = -1) -> None:
|
||||
if index < 0: # show all architectures
|
||||
print(self)
|
||||
for i, idx in enumerate(self.evaluated_indexes):
|
||||
@@ -336,8 +338,8 @@ class NASBench201API(object):
|
||||
# for i, node in enumerate(arch):
|
||||
# print('the {:}-th node is the sum of these {:} nodes with op: {:}'.format(i+1, len(node), node))
|
||||
@staticmethod
|
||||
def str2lists(xstr):
|
||||
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||
def str2lists(xstr: Text) -> List[Any]:
|
||||
# assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||
nodestrs = xstr.split('+')
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
|
Reference in New Issue
Block a user