Reformulate via black

This commit is contained in:
D-X-Y
2021-03-17 09:25:58 +00:00
parent a9093e41e1
commit f98edea22a
59 changed files with 12289 additions and 8918 deletions

View File

@@ -8,37 +8,45 @@
import os, sys, time, tqdm, argparse
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from datasets import get_datasets
from nats_bench import create
def show_time(api, epoch=12):
print('Show the time for {:} with {:}-epoch-training'.format(api, epoch))
all_cifar10_time, all_cifar100_time, all_imagenet_time = 0, 0, 0
for index in tqdm.tqdm(range(len(api))):
info = api.get_more_info(index, 'ImageNet16-120', hp=epoch)
imagenet_time = info['train-all-time']
info = api.get_more_info(index, 'cifar10-valid', hp=epoch)
cifar10_time = info['train-all-time']
info = api.get_more_info(index, 'cifar100', hp=epoch)
cifar100_time = info['train-all-time']
# accumulate the time
all_cifar10_time += cifar10_time
all_cifar100_time += cifar100_time
all_imagenet_time += imagenet_time
print('The total training time for CIFAR-10 (held-out train set) is {:} seconds'.format(all_cifar10_time))
print('The total training time for CIFAR-100 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10'.format(all_cifar100_time, all_cifar100_time / all_cifar10_time))
print('The total training time for ImageNet-16-120 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10'.format(all_imagenet_time, all_imagenet_time / all_cifar10_time))
print("Show the time for {:} with {:}-epoch-training".format(api, epoch))
all_cifar10_time, all_cifar100_time, all_imagenet_time = 0, 0, 0
for index in tqdm.tqdm(range(len(api))):
info = api.get_more_info(index, "ImageNet16-120", hp=epoch)
imagenet_time = info["train-all-time"]
info = api.get_more_info(index, "cifar10-valid", hp=epoch)
cifar10_time = info["train-all-time"]
info = api.get_more_info(index, "cifar100", hp=epoch)
cifar100_time = info["train-all-time"]
# accumulate the time
all_cifar10_time += cifar10_time
all_cifar100_time += cifar100_time
all_imagenet_time += imagenet_time
print("The total training time for CIFAR-10 (held-out train set) is {:} seconds".format(all_cifar10_time))
print(
"The total training time for CIFAR-100 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format(
all_cifar100_time, all_cifar100_time / all_cifar10_time
)
)
print(
"The total training time for ImageNet-16-120 (held-out train set) is {:} seconds, {:.2f} times longer than that on CIFAR-10".format(
all_imagenet_time, all_imagenet_time / all_cifar10_time
)
)
if __name__ == '__main__':
if __name__ == "__main__":
api_nats_tss = create(None, 'tss', fast_mode=True, verbose=False)
show_time(api_nats_tss, 12)
api_nats_sss = create(None, 'sss', fast_mode=True, verbose=False)
show_time(api_nats_sss, 12)
api_nats_tss = create(None, "tss", fast_mode=True, verbose=False)
show_time(api_nats_tss, 12)
api_nats_sss = create(None, "sss", fast_mode=True, verbose=False)
show_time(api_nats_sss, 12)