Update find_best API
This commit is contained in:
@@ -92,6 +92,10 @@ class ImageNet16(data.Dataset):
|
||||
#std_data = np.mean(np.mean(std_data, axis=0), axis=0)
|
||||
#print ('Std : {:}'.format(std_data))
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({num} images, {classes} classes)'.format(name=self.__class__.__name__, num=len(self.data), classes=len(set(self.targets))))
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.data[index], self.targets[index] - 1
|
||||
|
||||
@@ -114,16 +118,16 @@ class ImageNet16(data.Dataset):
|
||||
return False
|
||||
return True
|
||||
|
||||
#
|
||||
"""
|
||||
if __name__ == '__main__':
|
||||
train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)
|
||||
valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)
|
||||
train = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None)
|
||||
valid = ImageNet16('~/.torch/cifar.python/ImageNet16', False, None)
|
||||
|
||||
print ( len(train) )
|
||||
print ( len(valid) )
|
||||
image, label = train[111]
|
||||
trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200)
|
||||
validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200)
|
||||
trainX = ImageNet16('~/.torch/cifar.python/ImageNet16', True , None, 200)
|
||||
validX = ImageNet16('~/.torch/cifar.python/ImageNet16', False , None, 200)
|
||||
print ( len(trainX) )
|
||||
print ( len(validX) )
|
||||
#import pdb; pdb.set_trace()
|
||||
"""
|
||||
|
@@ -482,6 +482,7 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
best_index, highest_accuracy = -1, None
|
||||
evaluated_indexes = sorted(list(self.evaluated_indexes))
|
||||
for arch_index in evaluated_indexes:
|
||||
self._prepare_info(arch_index)
|
||||
arch_info = self.arch2infos_dict[arch_index][hp]
|
||||
info = arch_info.get_compute_costs(dataset) # the information of costs
|
||||
flop, param, latency = info['flops'], info['params'], info['latency']
|
||||
@@ -622,6 +623,8 @@ class NASBenchMetaAPI(metaclass=abc.ABCMeta):
|
||||
print('<' * 40 + '------------' + '<' * 40)
|
||||
else:
|
||||
if 0 <= index < len(self.meta_archs):
|
||||
if index not in self.evaluated_indexes:
|
||||
self._prepare_info(index)
|
||||
if index not in self.evaluated_indexes:
|
||||
print('The {:}-th architecture has not been evaluated '
|
||||
'or not saved.'.format(index))
|
||||
|
Reference in New Issue
Block a user