Upgrade API of NAS-Bench-201
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
import os, sys, time, argparse, collections
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
@@ -15,8 +14,7 @@ from datasets import get_datasets
|
||||
# NAS-Bench-201 related module or function
|
||||
from models import CellStructure, get_cell_based_tiny_net
|
||||
from nas_201_api import ArchResults, ResultsCount
|
||||
from functions import pure_evaluate
|
||||
|
||||
from procedures import bench_pure_evaluate as pure_evaluate
|
||||
|
||||
|
||||
def create_result_count(used_seed, dataset, arch_config, results, dataloader_dict):
|
||||
@@ -69,7 +67,6 @@ def account_one_arch(arch_index, arch_str, checkpoints, datasets, dataloader_dic
|
||||
return information
|
||||
|
||||
|
||||
|
||||
def GET_DataLoaders(workers):
|
||||
|
||||
torch.set_num_threads(workers)
|
||||
@@ -137,7 +134,6 @@ def GET_DataLoaders(workers):
|
||||
return loaders
|
||||
|
||||
|
||||
|
||||
def simplify(save_dir, meta_file, basestr, target_dir):
|
||||
meta_infos = torch.load(meta_file, map_location='cpu')
|
||||
meta_archs = meta_infos['archs'] # a list of architecture strings
|
||||
@@ -221,7 +217,6 @@ def simplify(save_dir, meta_file, basestr, target_dir):
|
||||
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
|
||||
|
||||
|
||||
|
||||
def merge_all(save_dir, meta_file, basestr):
|
||||
meta_infos = torch.load(meta_file, map_location='cpu')
|
||||
meta_archs = meta_infos['archs']
|
||||
@@ -268,7 +263,6 @@ def merge_all(save_dir, meta_file, basestr):
|
||||
print ('Save {:} / {:} architecture results into {:}.'.format(len(evaluated_indexes), meta_num_archs, save_file_name))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS-BENCH-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
@@ -280,7 +274,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--num_cells' , type=int, default=5, help='The number of cells in one stage.')
|
||||
args = parser.parse_args()
|
||||
|
||||
save_dir = Path( args.base_save_dir )
|
||||
save_dir = Path(args.base_save_dir)
|
||||
meta_path = save_dir / 'meta-node-{:}.pth'.format(args.max_node)
|
||||
assert save_dir.exists(), 'invalid save dir path : {:}'.format(save_dir)
|
||||
assert meta_path.exists(), 'invalid saved meta path : {:}'.format(meta_path)
|
||||
@@ -292,4 +286,4 @@ if __name__ == '__main__':
|
||||
elif args.mode == 'merge':
|
||||
merge_all(save_dir, meta_path, basestr)
|
||||
else:
|
||||
raise ValueError('invalid mode : {:}'.format(args.mode))
|
||||
raise ValueError('invalid mode : {:}'.format(args.mode))
|
Reference in New Issue
Block a user