Upgrade API of NAS-Bench-201

This commit is contained in:
D-X-Y
2020-03-10 19:08:56 +11:00
parent c8f2a93ecf
commit d783193392
10 changed files with 623 additions and 178 deletions

View File

@@ -4,7 +4,6 @@
import os, sys, time, argparse, collections
from copy import deepcopy
import torch
import torch.nn as nn
from pathlib import Path
from collections import defaultdict
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
@@ -15,8 +14,7 @@ from datasets import get_datasets
# NAS-Bench-201 related module or function
from models import CellStructure, get_cell_based_tiny_net
from nas_201_api import ArchResults, ResultsCount
from functions import pure_evaluate
from procedures import bench_pure_evaluate as pure_evaluate
def create_result_count(used_seed, dataset, arch_config, results, dataloader_dict):
@@ -69,7 +67,6 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic
return information
def GET_DataLoaders(workers):
torch.set_num_threads(workers)
@@ -137,7 +134,6 @@ def GET_DataLoaders(workers):
return loaders
def simplify(save_dir, meta_file, basestr, target_dir):
meta_infos = torch.load(meta_file, map_location='cpu')
meta_archs = meta_infos['archs'] # a list of architecture strings
@@ -221,7 +217,6 @@ def simplify(save_dir, meta_file, basestr, target_dir):
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
def merge_all(save_dir, meta_file, basestr):
meta_infos = torch.load(meta_file, map_location='cpu')
meta_archs = meta_infos['archs']
@@ -268,7 +263,6 @@ def merge_all(save_dir, meta_file, basestr):
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-BENCH-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@@ -280,7 +274,7 @@ if __name__ == '__main__':
parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.')
args = parser.parse_args()
save_dir = Path( args.base_save_dir )
save_dir = Path(args.base_save_dir)
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node)
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
@@ -292,4 +286,4 @@ if __name__ == '__main__':
elif args.mode == 'merge':
merge_all(save_dir, meta_path, basestr)
else:
raise ValueError('invalid mode : {:}'.format(args.mode))
raise ValueError('invalid mode : {:}'.format(args.mode))