Sync NATS-Bench's d11018d
This commit is contained in:
59
lib/nats_bench/api_test.py
Normal file
59
lib/nats_bench/api_test.py
Normal file
@@ -0,0 +1,59 @@
|
||||
##############################################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.08 ##########################
|
||||
##############################################################################
|
||||
# NATS-Bench: Benchmarking NAS algorithms for Architecture Topology and Size #
|
||||
##############################################################################
|
||||
"""This file is used to quickly test the API."""
|
||||
import random
|
||||
|
||||
from nats_bench.api_size import NATSsize
|
||||
from nats_bench.api_topology import NATStopology
|
||||
|
||||
|
||||
def test_nats_bench_tss(benchmark_dir):
|
||||
return test_nats_bench(benchmark_dir, True)
|
||||
|
||||
|
||||
def test_nats_bench_sss(benchmark_dir):
|
||||
return test_nats_bench(benchmark_dir, False)
|
||||
|
||||
|
||||
def test_nats_bench(benchmark_dir, is_tss, verbose=False):
|
||||
if is_tss:
|
||||
api = NATStopology(benchmark_dir, True, verbose)
|
||||
else:
|
||||
api = NATSsize(benchmark_dir, True, verbose)
|
||||
|
||||
test_indexes = [random.randint(0, len(api) - 1) for _ in range(10)]
|
||||
key2dataset = {'cifar10': 'CIFAR-10',
|
||||
'cifar100': 'CIFAR-100',
|
||||
'ImageNet16-120': 'ImageNet16-120'}
|
||||
|
||||
for index in test_indexes:
|
||||
print('\n\nEvaluate the {:5d}-th architecture.'.format(index))
|
||||
|
||||
for key, dataset in key2dataset.items():
|
||||
# Query the loss / accuracy / time for the `index`-th candidate
|
||||
# architecture on CIFAR-10
|
||||
# info is a dict, where you can easily figure out the meaning by key
|
||||
info = api.get_more_info(index, key)
|
||||
print(' -->> The performance on {:}: {:}'.format(dataset, info))
|
||||
|
||||
# Query the flops, params, latency. info is a dict.
|
||||
info = api.get_cost_info(index, key)
|
||||
print(' -->> The cost info on {:}: {:}'.format(dataset, info))
|
||||
|
||||
# Simulate the training of the `index`-th candidate:
|
||||
validation_accuracy, latency, time_cost, current_total_time_cost = api.simulate_train_eval(
|
||||
index, dataset=key, hp='12')
|
||||
print(' -->> The validation accuracy={:}, latency={:}, '
|
||||
'the current time cost={:} s, accumulated time cost={:} s'
|
||||
.format(validation_accuracy, latency, time_cost,
|
||||
current_total_time_cost))
|
||||
|
||||
# Print the configuration of the `index`-th architecture on CIFAR-10
|
||||
config = api.get_net_config(index, key)
|
||||
print(' -->> The configuration on {:} is {:}'.format(dataset, config))
|
||||
|
||||
# Show the information of the `index`-th architecture
|
||||
api.show(index)
|
Reference in New Issue
Block a user