Update NATS-Bench (tss version 1.0) and remove the trace of 301

This commit is contained in:
D-X-Y
2020-09-16 08:28:27 +00:00
parent bd9288f45d
commit 9db28392c2
10 changed files with 169 additions and 249 deletions

View File

@@ -26,7 +26,7 @@ from log_utils import time_string
from models import get_cell_based_tiny_net, CellStructure
def test_api(api, is_301=True):
def test_api(api, sss_or_tss=True):
print('{:} start testing the api : {:}'.format(time_string(), api))
api.clear_params(12)
api.reload(index=12)
@@ -39,7 +39,7 @@ def test_api(api, is_301=True):
info = api.query_by_index(113, 'cifar100')
print('{:}\n'.format(info))
info = api.query_meta_info_by_index(115, '90' if is_301 else '200')
info = api.query_meta_info_by_index(115, '90' if sss_or_tss else '200')
print('{:}\n'.format(info))
for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
@@ -48,6 +48,7 @@ def test_api(api, is_301=True):
print('')
params = api.get_net_param(12, 'cifar10', None)
import pdb; pdb.set_trace()
# Obtain the config and create the network
config = api.get_net_config(12, 'cifar10')
print('{:}\n'.format(config))
@@ -74,7 +75,7 @@ def test_api(api, is_301=True):
print('{:}\n'.format(info))
print('{:} finish testing the api : {:}'.format(time_string(), api))
if not is_301:
if not sss_or_tss:
arch_str = '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'
matrix = api.str2matrix(arch_str)
print('Compute the adjacency matrix of {:}'.format(arch_str))
@@ -88,13 +89,13 @@ if __name__ == '__main__':
# api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True)
for fast_mode in [True, False]:
for verbose in [True, False]:
api201 = create(None, 'tss', fast_mode=fast_mode, verbose=True)
api_nats_tss = create(None, 'tss', fast_mode=fast_mode, verbose=True)
print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
test_api(api201, False)
test_api(api_nats_tss, False)
for fast_mode in [True, False]:
for verbose in [True, False]:
print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
api301 = create(None, 'size', fast_mode=fast_mode, verbose=True)
print('{:} --->>> {:}'.format(time_string(), api301))
test_api(api301, True)
api_nats_sss = create(None, 'size', fast_mode=fast_mode, verbose=True)
print('{:} --->>> {:}'.format(time_string(), api_nats_sss))
test_api(api_nats_sss, True)