update code styles
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
# python ./exps/vis/test.py
|
||||
import os, sys, random
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from nas_102_api import NASBench102API as API
|
||||
|
||||
def test_nas_api():
|
||||
from nas_102_api import ArchResults
|
||||
@@ -72,7 +74,40 @@ def test_auto_grad():
|
||||
s_grads = torch.autograd.grad(grads, net.parameters())
|
||||
second_order_grads.append( s_grads )
|
||||
|
||||
|
||||
def test_one_shot_model(ckpath, use_train):
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from config_utils import load_config, dict2config
|
||||
from utils.nas_utils import evaluate_one_shot
|
||||
use_train = int(use_train) > 0
|
||||
#ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
|
||||
#ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
|
||||
print ('ckpath : {:}'.format(ckpath))
|
||||
ckp = torch.load(ckpath)
|
||||
xargs = ckp['args']
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
|
||||
if xargs.dataset == 'cifar10':
|
||||
cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
||||
xvalid_data = deepcopy(train_data)
|
||||
xvalid_data.transform = valid_data.transform
|
||||
valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True)
|
||||
else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet))
|
||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space,
|
||||
'affine' : False, 'track_running_stats': True}, None)
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
search_model.load_state_dict( ckp['search_model'] )
|
||||
search_model = search_model.cuda()
|
||||
api = API('/home/dxy/.torch/NAS-Bench-102-v1_0-e61699.pth')
|
||||
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#test_nas_api()
|
||||
#for i in range(200): plot('{:04d}'.format(i))
|
||||
test_auto_grad()
|
||||
#test_auto_grad()
|
||||
test_one_shot_model(sys.argv[1], sys.argv[2])
|
||||
|
Reference in New Issue
Block a user