Update NATS-Bench (tss version 1.0) and remove the trace of 301
This commit is contained in:
@@ -26,7 +26,7 @@ from log_utils import time_string
|
||||
from models import get_cell_based_tiny_net, CellStructure
|
||||
|
||||
|
||||
def test_api(api, is_301=True):
|
||||
def test_api(api, sss_or_tss=True):
|
||||
print('{:} start testing the api : {:}'.format(time_string(), api))
|
||||
api.clear_params(12)
|
||||
api.reload(index=12)
|
||||
@@ -39,7 +39,7 @@ def test_api(api, is_301=True):
|
||||
info = api.query_by_index(113, 'cifar100')
|
||||
print('{:}\n'.format(info))
|
||||
|
||||
info = api.query_meta_info_by_index(115, '90' if is_301 else '200')
|
||||
info = api.query_meta_info_by_index(115, '90' if sss_or_tss else '200')
|
||||
print('{:}\n'.format(info))
|
||||
|
||||
for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
|
||||
@@ -48,6 +48,7 @@ def test_api(api, is_301=True):
|
||||
print('')
|
||||
params = api.get_net_param(12, 'cifar10', None)
|
||||
|
||||
import pdb; pdb.set_trace()
|
||||
# Obtain the config and create the network
|
||||
config = api.get_net_config(12, 'cifar10')
|
||||
print('{:}\n'.format(config))
|
||||
@@ -74,7 +75,7 @@ def test_api(api, is_301=True):
|
||||
print('{:}\n'.format(info))
|
||||
print('{:} finish testing the api : {:}'.format(time_string(), api))
|
||||
|
||||
if not is_301:
|
||||
if not sss_or_tss:
|
||||
arch_str = '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'
|
||||
matrix = api.str2matrix(arch_str)
|
||||
print('Compute the adjacency matrix of {:}'.format(arch_str))
|
||||
@@ -88,13 +89,13 @@ if __name__ == '__main__':
|
||||
# api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True)
|
||||
for fast_mode in [True, False]:
|
||||
for verbose in [True, False]:
|
||||
api201 = create(None, 'tss', fast_mode=fast_mode, verbose=True)
|
||||
api_nats_tss = create(None, 'tss', fast_mode=fast_mode, verbose=True)
|
||||
print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
|
||||
test_api(api201, False)
|
||||
test_api(api_nats_tss, False)
|
||||
|
||||
for fast_mode in [True, False]:
|
||||
for verbose in [True, False]:
|
||||
print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
|
||||
api301 = create(None, 'size', fast_mode=fast_mode, verbose=True)
|
||||
print('{:} --->>> {:}'.format(time_string(), api301))
|
||||
test_api(api301, True)
|
||||
api_nats_sss = create(None, 'size', fast_mode=fast_mode, verbose=True)
|
||||
print('{:} --->>> {:}'.format(time_string(), api_nats_sss))
|
||||
test_api(api_nats_sss, True)
|
||||
|
Reference in New Issue
Block a user