Updates
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
##############################################################################
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||
# NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size #
|
||||
##############################################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
|
||||
##############################################################################
|
||||
@@ -19,15 +19,15 @@ from datasets import get_datasets
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
def show_time(api):
|
||||
print('Show the time for {:} with 12-epoch-training'.format(api))
|
||||
def show_time(api, epoch=12):
|
||||
print('Show the time for {:} with {:}-epoch-training'.format(api, epoch))
|
||||
all_cifar10_time, all_cifar100_time, all_imagenet_time = 0, 0, 0
|
||||
for index in tqdm.tqdm(range(len(api))):
|
||||
info = api.get_more_info(index, 'ImageNet16-120', hp='12')
|
||||
info = api.get_more_info(index, 'ImageNet16-120', hp=epoch)
|
||||
imagenet_time = info['train-all-time']
|
||||
info = api.get_more_info(index, 'cifar10-valid', hp='12')
|
||||
info = api.get_more_info(index, 'cifar10-valid', hp=epoch)
|
||||
cifar10_time = info['train-all-time']
|
||||
info = api.get_more_info(index, 'cifar100', hp='12')
|
||||
info = api.get_more_info(index, 'cifar100', hp=epoch)
|
||||
cifar100_time = info['train-all-time']
|
||||
# accumulate the time
|
||||
all_cifar10_time += cifar10_time
|
||||
@@ -41,8 +41,8 @@ def show_time(api):
|
||||
if __name__ == '__main__':
|
||||
|
||||
api_nats_tss = create(None, 'tss', fast_mode=True, verbose=False)
|
||||
show_time(api_nats_tss)
|
||||
show_time(api_nats_tss, 12)
|
||||
|
||||
api_nats_sss = create(None, 'sss', fast_mode=True, verbose=False)
|
||||
show_time(api_nats_sss)
|
||||
show_time(api_nats_sss, 12)
|
||||
|
||||
|
Reference in New Issue
Block a user