v2
This commit is contained in:
206
search.py
206
search.py
@@ -1,35 +1,49 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import nasspace
|
||||
import datasets
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
from scores import get_score_func
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from tqdm import trange
|
||||
from statistics import mean
|
||||
import time
|
||||
from utils import add_dropout
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS Without Training')
|
||||
parser.add_argument('--data_loc', default='../datasets/cifar', type=str, help='dataset folder')
|
||||
parser.add_argument('--api_loc', default='../datasets/NAS-Bench-201-v1_1-096897.pth',
|
||||
parser.add_argument('--data_loc', default='../cifardata/', type=str, help='dataset folder')
|
||||
parser.add_argument('--api_loc', default='../NAS-Bench-201-v1_0-e61699.pth',
|
||||
type=str, help='path to API')
|
||||
parser.add_argument('--save_loc', default='results', type=str, help='folder to save results')
|
||||
parser.add_argument('--batch_size', default=256, type=int)
|
||||
parser.add_argument('--save_loc', default='results/ICML', type=str, help='folder to save results')
|
||||
parser.add_argument('--save_string', default='naswot', type=str, help='prefix of results file')
|
||||
parser.add_argument('--score', default='hook_logdet', type=str, help='the score to evaluate')
|
||||
parser.add_argument('--nasspace', default='nasbench201', type=str, help='the nas search space to use')
|
||||
parser.add_argument('--batch_size', default=128, type=int)
|
||||
parser.add_argument('--kernel', action='store_true')
|
||||
parser.add_argument('--dropout', action='store_true')
|
||||
parser.add_argument('--repeat', default=1, type=int, help='how often to repeat a single image with a batch')
|
||||
parser.add_argument('--augtype', default='none', type=str, help='which perturbations to use')
|
||||
parser.add_argument('--sigma', default=0.05, type=float, help='noise level if augtype is "gaussnoise"')
|
||||
parser.add_argument('--GPU', default='0', type=str)
|
||||
parser.add_argument('--seed', default=1, type=int)
|
||||
parser.add_argument('--init', default='', type=str)
|
||||
parser.add_argument('--trainval', action='store_true')
|
||||
parser.add_argument('--activations', action='store_true')
|
||||
parser.add_argument('--cosine', action='store_true')
|
||||
parser.add_argument('--dataset', default='cifar10', type=str)
|
||||
parser.add_argument('--n_samples', default=100, type=int)
|
||||
parser.add_argument('--n_runs', default=500, type=int)
|
||||
parser.add_argument('--stem_out_channels', default=16, type=int, help='output channels of stem convolution (nasbench101)')
|
||||
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules (nasbench101)')
|
||||
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack (nasbench101)')
|
||||
parser.add_argument('--num_labels', default=1, type=int, help='#classes (nasbench101)')
|
||||
|
||||
args = parser.parse_args()
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision.datasets as datasets
|
||||
import torch.optim as optim
|
||||
|
||||
from models import get_cell_based_tiny_net
|
||||
|
||||
# Reproducibility
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
@@ -37,120 +51,140 @@ random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
from datasets import get_datasets
|
||||
from config_utils import load_config
|
||||
from nas_201_api import NASBench201API as API
|
||||
|
||||
def get_batch_jacobian(net, x, target, to, device, args=None):
|
||||
def get_batch_jacobian(net, x, target, device, args=None):
|
||||
net.zero_grad()
|
||||
|
||||
x.requires_grad_(True)
|
||||
|
||||
_, y = net(x)
|
||||
|
||||
y, ints = net(x)
|
||||
y.backward(torch.ones_like(y))
|
||||
jacob = x.grad.detach()
|
||||
|
||||
return jacob, target.detach()
|
||||
|
||||
|
||||
def eval_score(jacob, labels=None):
|
||||
corrs = np.corrcoef(jacob)
|
||||
v, _ = np.linalg.eig(corrs)
|
||||
k = 1e-5
|
||||
return -np.sum(np.log(v + k) + 1./(v + k))
|
||||
|
||||
return jacob, target.detach(), y.detach(), ints.detach()
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(device)
|
||||
THE_START = time.time()
|
||||
api = API(args.api_loc)
|
||||
|
||||
searchspace = nasspace.get_search_space(args)
|
||||
train_loader = datasets.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
|
||||
os.makedirs(args.save_loc, exist_ok=True)
|
||||
|
||||
train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_loc, cutout=0)
|
||||
|
||||
if args.dataset == 'cifar10':
|
||||
acc_type = 'ori-test'
|
||||
val_acc_type = 'x-valid'
|
||||
|
||||
else:
|
||||
acc_type = 'x-test'
|
||||
val_acc_type = 'x-valid'
|
||||
|
||||
if args.trainval:
|
||||
cifar_split = load_config('config_utils/cifar-split.txt', None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
|
||||
num_workers=0, pin_memory=True, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split))
|
||||
|
||||
else:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
|
||||
num_workers=0, pin_memory=True)
|
||||
|
||||
times = []
|
||||
chosen = []
|
||||
acc = []
|
||||
val_acc = []
|
||||
topscores = []
|
||||
|
||||
dset = args.dataset if not args.trainval else 'cifar10-valid'
|
||||
|
||||
order_fn = np.nanargmax
|
||||
|
||||
|
||||
if args.dataset == 'cifar10':
|
||||
acc_type = 'ori-test'
|
||||
val_acc_type = 'x-valid'
|
||||
else:
|
||||
acc_type = 'x-test'
|
||||
val_acc_type = 'x-valid'
|
||||
|
||||
|
||||
|
||||
runs = trange(args.n_runs, desc='acc: ')
|
||||
for N in runs:
|
||||
start = time.time()
|
||||
indices = np.random.randint(0,15625,args.n_samples)
|
||||
indices = np.random.randint(0,len(searchspace),args.n_samples)
|
||||
scores = []
|
||||
|
||||
npstate = np.random.get_state()
|
||||
ranstate = random.getstate()
|
||||
torchstate = torch.random.get_rng_state()
|
||||
for arch in indices:
|
||||
|
||||
data_iterator = iter(train_loader)
|
||||
x, target = next(data_iterator)
|
||||
x, target = x.to(device), target.to(device)
|
||||
|
||||
config = api.get_net_config(arch, args.dataset)
|
||||
config['num_classes'] = 1
|
||||
|
||||
network = get_cell_based_tiny_net(config) # create the network from configuration
|
||||
network = network.to(device)
|
||||
|
||||
jacobs, labels= get_batch_jacobian(network, x, target, 1, device, args)
|
||||
jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy()
|
||||
|
||||
try:
|
||||
s = eval_score(jacobs, labels)
|
||||
uid = searchspace[arch]
|
||||
network = searchspace.get_network(uid)
|
||||
network.to(device)
|
||||
if args.dropout:
|
||||
add_dropout(network, args.sigma)
|
||||
if args.init != '':
|
||||
init_network(network, args.init)
|
||||
if 'hook_' in args.score:
|
||||
network.K = np.zeros((args.batch_size, args.batch_size))
|
||||
def counting_forward_hook(module, inp, out):
|
||||
try:
|
||||
if not module.visited_backwards:
|
||||
return
|
||||
if isinstance(inp, tuple):
|
||||
inp = inp[0]
|
||||
inp = inp.view(inp.size(0), -1)
|
||||
x = (inp > 0).float()
|
||||
K = x @ x.t()
|
||||
K2 = (1.-x) @ (1.-x.t())
|
||||
network.K = network.K + K.cpu().numpy() + K2.cpu().numpy()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def counting_backward_hook(module, inp, out):
|
||||
module.visited_backwards = True
|
||||
|
||||
|
||||
for name, module in network.named_modules():
|
||||
if 'ReLU' in str(type(module)):
|
||||
#hooks[name] = module.register_forward_hook(counting_hook)
|
||||
module.register_forward_hook(counting_forward_hook)
|
||||
module.register_backward_hook(counting_backward_hook)
|
||||
|
||||
random.setstate(ranstate)
|
||||
np.random.set_state(npstate)
|
||||
torch.set_rng_state(torchstate)
|
||||
|
||||
data_iterator = iter(train_loader)
|
||||
x, target = next(data_iterator)
|
||||
x2 = torch.clone(x)
|
||||
x2 = x2.to(device)
|
||||
x, target = x.to(device), target.to(device)
|
||||
jacobs, labels, y, out = get_batch_jacobian(network, x, target, device, args)
|
||||
|
||||
if args.kernel:
|
||||
s = get_score_func(args.score)(out, labels)
|
||||
elif 'hook_' in args.score:
|
||||
network(x2.to(device))
|
||||
s = get_score_func(args.score)(network.K, target)
|
||||
elif args.repeat < args.batch_size:
|
||||
s = get_score_func(args.score)(jacobs, labels, args.repeat)
|
||||
else:
|
||||
s = get_score_func(args.score)(jacobs, labels)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
s = np.nan
|
||||
|
||||
s = 0.
|
||||
|
||||
scores.append(s)
|
||||
|
||||
#print(len(scores))
|
||||
#print(scores)
|
||||
#print(order_fn(scores))
|
||||
|
||||
|
||||
|
||||
best_arch = indices[order_fn(scores)]
|
||||
info = api.query_by_index(best_arch)
|
||||
uid = searchspace[best_arch]
|
||||
topscores.append(scores[order_fn(scores)])
|
||||
chosen.append(best_arch)
|
||||
acc.append(info.get_metrics(dset, acc_type)['accuracy'])
|
||||
#acc.append(searchspace.get_accuracy(uid, acc_type, args.trainval))
|
||||
acc.append(searchspace.get_final_accuracy(uid, acc_type, False))
|
||||
|
||||
if not args.dataset == 'cifar10' or args.trainval:
|
||||
val_acc.append(info.get_metrics(dset, val_acc_type)['accuracy'])
|
||||
val_acc.append(searchspace.get_final_accuracy(uid, val_acc_type, args.trainval))
|
||||
# val_acc.append(info.get_metrics(dset, val_acc_type)['accuracy'])
|
||||
|
||||
times.append(time.time()-start)
|
||||
runs.set_description(f"acc: {mean(acc if not args.trainval else val_acc):.2f}%")
|
||||
runs.set_description(f"acc: {mean(acc):.2f}% time:{mean(times):.2f}")
|
||||
|
||||
print(f"Final mean test accuracy: {np.mean(acc)}")
|
||||
if len(val_acc) > 1:
|
||||
print(f"Final mean validation accuracy: {np.mean(val_acc)}")
|
||||
#if len(val_acc) > 1:
|
||||
# print(f"Final mean validation accuracy: {np.mean(val_acc)}")
|
||||
|
||||
state = {'accs': acc,
|
||||
'val_accs': val_acc,
|
||||
'chosen': chosen,
|
||||
'times': times,
|
||||
'topscores': topscores,
|
||||
}
|
||||
|
||||
dset = args.dataset if not args.trainval else 'cifar10-valid'
|
||||
fname = f"{args.save_loc}/{dset}_{args.n_runs}_{args.n_samples}_{args.seed}.t7"
|
||||
dset = args.dataset if not (args.trainval and args.dataset == 'cifar10') else 'cifar10-valid'
|
||||
fname = f"{args.save_loc}/{args.save_string}_{args.score}_{args.nasspace}_{dset}_{args.kernel}_{args.dropout}_{args.augtype}_{args.sigma}_{args.repeat}_{args.batch_size}_{args.n_runs}_{args.n_samples}_{args.seed}.t7"
|
||||
torch.save(state, fname)
|
||||
|
Reference in New Issue
Block a user