update code styles
This commit is contained in:
@@ -8,7 +8,6 @@ from tqdm import tqdm
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import matplotlib
|
||||
@@ -498,6 +497,8 @@ def show_nas_sharing_w(api, dataset, subset, vis_save_dir, file_name, y_lims, x_
|
||||
|
||||
def get_accs(xdata):
|
||||
epochs, xresults = xdata['epoch'], []
|
||||
metrics = api.arch2infos_full[ api.random() ].get_metrics(dataset, subset, None, False)
|
||||
xresults.append( metrics['accuracy'] )
|
||||
for iepoch in range(epochs):
|
||||
genotype = xdata['genotypes'][iepoch]
|
||||
index = api.query_index_by_arch(genotype)
|
||||
@@ -547,7 +548,6 @@ if __name__ == '__main__':
|
||||
#visualize_relative_ranking(vis_save_dir)
|
||||
|
||||
api = API(args.api_path)
|
||||
"""
|
||||
for x_maxs in [50, 250]:
|
||||
show_nas_sharing_w(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
@@ -555,11 +555,12 @@ if __name__ == '__main__':
|
||||
show_nas_sharing_w(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
show_nas_sharing_w(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-plot.pdf', (0, 100,10), x_maxs)
|
||||
just_show(api)
|
||||
"""
|
||||
just_show(api)
|
||||
plot_results_nas(api, 'cifar10-valid' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||
plot_results_nas(api, 'cifar10' , 'ori-test', vis_save_dir, 'nas-com.pdf', (85,95, 1))
|
||||
plot_results_nas(api, 'cifar100' , 'x-valid' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||
plot_results_nas(api, 'cifar100' , 'x-test' , vis_save_dir, 'nas-com.pdf', (55,75, 3))
|
||||
plot_results_nas(api, 'ImageNet16-120', 'x-valid' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||
plot_results_nas(api, 'ImageNet16-120', 'x-test' , vis_save_dir, 'nas-com.pdf', (35,50, 3))
|
||||
"""
|
||||
|
@@ -10,7 +10,6 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Categorical
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
|
@@ -121,9 +121,19 @@ def main(xargs):
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True)
|
||||
elif xargs.dataset == 'cifar100':
|
||||
raise ValueError('not support yet : {:}'.format(xargs.dataset))
|
||||
elif xargs.dataset.startswith('ImageNet16'):
|
||||
raise ValueError('not support yet : {:}'.format(xargs.dataset))
|
||||
cifar100_test_split = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=xargs.workers, pin_memory=True)
|
||||
elif xargs.dataset == 'ImageNet16-120':
|
||||
imagenet_test_split = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(xargs.dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=config.batch_size, shuffle=True , num_workers=xargs.workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=xargs.workers, pin_memory=True)
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
|
||||
@@ -168,7 +178,7 @@ def main(xargs):
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}
|
||||
|
||||
# start training
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
@@ -230,7 +240,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
# channels and number-of-cells
|
||||
parser.add_argument('--config_path', type=str, help='The config paths.')
|
||||
parser.add_argument('--config_path', type=str, help='The config path.')
|
||||
parser.add_argument('--search_space_name', type=str, help='The search space name.')
|
||||
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
|
@@ -181,8 +181,8 @@ def main(xargs):
|
||||
logger.log('Load split file from {:}'.format(split_Fpath))
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||
config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||
config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
#config_path = 'configs/nas-benchmark/algos/DARTS.config'
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
# To split data
|
||||
train_data_v2 = deepcopy(train_data)
|
||||
train_data_v2.transform = valid_data.transform
|
||||
@@ -233,7 +233,7 @@ def main(xargs):
|
||||
logger.log("=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(last_info, start_epoch))
|
||||
else:
|
||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {}
|
||||
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: search_model.genotype()}
|
||||
|
||||
# start training
|
||||
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
|
||||
@@ -297,6 +297,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||
# channels and number-of-cells
|
||||
parser.add_argument('--config_path', type=str, help='The config path.')
|
||||
parser.add_argument('--search_space_name', type=str, help='The search space name.')
|
||||
parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.')
|
||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||
|
@@ -3,7 +3,7 @@
|
||||
###########################################################################
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||
###########################################################################
|
||||
import os, sys, time, glob, random, argparse
|
||||
import os, sys, time, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
@@ -11,7 +11,7 @@ import torch.nn as nn
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, dict2config, configure2str
|
||||
from config_utils import load_config, dict2config
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler
|
||||
from utils import get_model_infos, obtain_accuracy
|
||||
|
@@ -1,12 +1,14 @@
|
||||
# python ./exps/vis/test.py
|
||||
import os, sys, random
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
|
||||
from nas_102_api import NASBench102API as API
|
||||
|
||||
def test_nas_api():
|
||||
from nas_102_api import ArchResults
|
||||
@@ -72,7 +74,40 @@ def test_auto_grad():
|
||||
s_grads = torch.autograd.grad(grads, net.parameters())
|
||||
second_order_grads.append( s_grads )
|
||||
|
||||
|
||||
def test_one_shot_model(ckpath, use_train):
|
||||
from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from config_utils import load_config, dict2config
|
||||
from utils.nas_utils import evaluate_one_shot
|
||||
use_train = int(use_train) > 0
|
||||
#ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
|
||||
#ckpath = 'output/search-cell-nas-bench-102/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
|
||||
print ('ckpath : {:}'.format(ckpath))
|
||||
ckp = torch.load(ckpath)
|
||||
xargs = ckp['args']
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
|
||||
if xargs.dataset == 'cifar10':
|
||||
cifar_split = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
||||
xvalid_data = deepcopy(train_data)
|
||||
xvalid_data.transform = valid_data.transform
|
||||
valid_loader= torch.utils.data.DataLoader(xvalid_data, batch_size=2048, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid), num_workers=12, pin_memory=True)
|
||||
else: raise ValueError('invalid dataset : {:}'.format(xargs.dataseet))
|
||||
search_space = get_search_spaces('cell', xargs.search_space_name)
|
||||
model_config = dict2config({'name': 'SETN', 'C': xargs.channel, 'N': xargs.num_cells,
|
||||
'max_nodes': xargs.max_nodes, 'num_classes': class_num,
|
||||
'space' : search_space,
|
||||
'affine' : False, 'track_running_stats': True}, None)
|
||||
search_model = get_cell_based_tiny_net(model_config)
|
||||
search_model.load_state_dict( ckp['search_model'] )
|
||||
search_model = search_model.cuda()
|
||||
api = API('/home/dxy/.torch/NAS-Bench-102-v1_0-e61699.pth')
|
||||
archs, probs, accuracies = evaluate_one_shot(search_model, valid_loader, api, use_train)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#test_nas_api()
|
||||
#for i in range(200): plot('{:04d}'.format(i))
|
||||
test_auto_grad()
|
||||
#test_auto_grad()
|
||||
test_one_shot_model(sys.argv[1], sys.argv[2])
|
||||
|
Reference in New Issue
Block a user