Reformulate via black

This commit is contained in:
D-X-Y
2021-03-17 09:25:58 +00:00
parent a9093e41e1
commit f98edea22a
59 changed files with 12289 additions and 8918 deletions

View File

@@ -10,16 +10,18 @@ import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use('agg')
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create
from log_utils import time_string
@@ -27,78 +29,78 @@ from models import get_cell_based_tiny_net, CellStructure
def test_api(api, sss_or_tss=True):
print('{:} start testing the api : {:}'.format(time_string(), api))
api.clear_params(12)
api.reload(index=12)
# Query the informations of 1113-th architecture
info_strs = api.query_info_str_by_arch(1113)
print(info_strs)
info = api.query_by_index(113)
print('{:}\n'.format(info))
info = api.query_by_index(113, 'cifar100')
print('{:}\n'.format(info))
print("{:} start testing the api : {:}".format(time_string(), api))
api.clear_params(12)
api.reload(index=12)
info = api.query_meta_info_by_index(115, '90' if sss_or_tss else '200')
print('{:}\n'.format(info))
# Query the informations of 1113-th architecture
info_strs = api.query_info_str_by_arch(1113)
print(info_strs)
info = api.query_by_index(113)
print("{:}\n".format(info))
info = api.query_by_index(113, "cifar100")
print("{:}\n".format(info))
for dataset in ['cifar10', 'cifar100', 'ImageNet16-120']:
for xset in ['train', 'test', 'valid']:
best_index, highest_accuracy = api.find_best(dataset, xset)
print('')
params = api.get_net_param(12, 'cifar10', None)
info = api.query_meta_info_by_index(115, "90" if sss_or_tss else "200")
print("{:}\n".format(info))
# Obtain the config and create the network
config = api.get_net_config(12, 'cifar10')
print('{:}\n'.format(config))
network = get_cell_based_tiny_net(config)
network.load_state_dict(next(iter(params.values())))
for dataset in ["cifar10", "cifar100", "ImageNet16-120"]:
for xset in ["train", "test", "valid"]:
best_index, highest_accuracy = api.find_best(dataset, xset)
print("")
params = api.get_net_param(12, "cifar10", None)
# Obtain the cost information
info = api.get_cost_info(12, 'cifar10')
print('{:}\n'.format(info))
info = api.get_latency(12, 'cifar10')
print('{:}\n'.format(info))
for index in [13, 15, 19, 200]:
info = api.get_latency(index, 'cifar10')
# Obtain the config and create the network
config = api.get_net_config(12, "cifar10")
print("{:}\n".format(config))
network = get_cell_based_tiny_net(config)
network.load_state_dict(next(iter(params.values())))
# Count the number of architectures
info = api.statistics('cifar100', '12')
print('{:} statistics results : {:}\n'.format(time_string(), info))
# Obtain the cost information
info = api.get_cost_info(12, "cifar10")
print("{:}\n".format(info))
info = api.get_latency(12, "cifar10")
print("{:}\n".format(info))
for index in [13, 15, 19, 200]:
info = api.get_latency(index, "cifar10")
# Show the information of the 123-th architecture
api.show(123)
# Count the number of architectures
info = api.statistics("cifar100", "12")
print("{:} statistics results : {:}\n".format(time_string(), info))
# Obtain both cost and performance information
info = api.get_more_info(1234, 'cifar10')
print('{:}\n'.format(info))
print('{:} finish testing the api : {:}'.format(time_string(), api))
# Show the information of the 123-th architecture
api.show(123)
if not sss_or_tss:
arch_str = '|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|'
matrix = api.str2matrix(arch_str)
print('Compute the adjacency matrix of {:}'.format(arch_str))
print(matrix)
info = api.simulate_train_eval(123, 'cifar10')
print('simulate_train_eval : {:}\n\n'.format(info))
# Obtain both cost and performance information
info = api.get_more_info(1234, "cifar10")
print("{:}\n".format(info))
print("{:} finish testing the api : {:}".format(time_string(), api))
if not sss_or_tss:
arch_str = "|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|"
matrix = api.str2matrix(arch_str)
print("Compute the adjacency matrix of {:}".format(arch_str))
print(matrix)
info = api.simulate_train_eval(123, "cifar10")
print("simulate_train_eval : {:}\n\n".format(info))
if __name__ == '__main__':
if __name__ == "__main__":
# api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True)
for fast_mode in [True, False]:
for verbose in [True, False]:
api_nats_tss = create(None, 'tss', fast_mode=fast_mode, verbose=True)
print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
test_api(api_nats_tss, False)
del api_nats_tss
gc.collect()
# api201 = create('./output/NATS-Bench-topology/process-FULL', 'topology', fast_mode=True, verbose=True)
for fast_mode in [True, False]:
for verbose in [True, False]:
api_nats_tss = create(None, "tss", fast_mode=fast_mode, verbose=True)
print("{:} create with fast_mode={:} and verbose={:}".format(time_string(), fast_mode, verbose))
test_api(api_nats_tss, False)
del api_nats_tss
gc.collect()
for fast_mode in [True, False]:
for verbose in [True, False]:
print('{:} create with fast_mode={:} and verbose={:}'.format(time_string(), fast_mode, verbose))
api_nats_sss = create(None, 'size', fast_mode=fast_mode, verbose=True)
print('{:} --->>> {:}'.format(time_string(), api_nats_sss))
test_api(api_nats_sss, True)
del api_nats_sss
gc.collect()
for fast_mode in [True, False]:
for verbose in [True, False]:
print("{:} create with fast_mode={:} and verbose={:}".format(time_string(), fast_mode, verbose))
api_nats_sss = create(None, "size", fast_mode=fast_mode, verbose=True)
print("{:} --->>> {:}".format(time_string(), api_nats_sss))
test_api(api_nats_sss, True)
del api_nats_sss
gc.collect()