first commit
This commit is contained in:
286
NAS-Bench-201/main_exp/diffusion/run_lib.py
Normal file
286
NAS-Bench-201/main_exp/diffusion/run_lib.py
Normal file
@@ -0,0 +1,286 @@
|
||||
import torch
|
||||
import sys
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
sys.path.append('.')
|
||||
import sampling
|
||||
import datasets_nas
|
||||
from models import cate
|
||||
from models import digcn
|
||||
from models import digcn_meta
|
||||
from models import utils as mutils
|
||||
from models.ema import ExponentialMovingAverage
|
||||
import sde_lib
|
||||
from utils import *
|
||||
from analysis.arch_functions import BasicArchMetricsMeta
|
||||
from all_path import *
|
||||
|
||||
|
||||
def get_sampling_fn_meta(config):
|
||||
## Set SDE
|
||||
if config.training.sde.lower() == 'vpsde':
|
||||
sde = sde_lib.VPSDE(
|
||||
beta_min=config.model.beta_min,
|
||||
beta_max=config.model.beta_max,
|
||||
N=config.model.num_scales)
|
||||
sampling_eps = 1e-3
|
||||
elif config.training.sde.lower() == 'subvpsde':
|
||||
sde = sde_lib.subVPSDE(
|
||||
beta_min=config.model.beta_min,
|
||||
beta_max=config.model.beta_max,
|
||||
N=config.model.num_scales)
|
||||
sampling_eps = 1e-3
|
||||
elif config.training.sde.lower() == 'vesde':
|
||||
sde = sde_lib.VESDE(
|
||||
sigma_min=config.model.sigma_min,
|
||||
sigma_max=config.model.sigma_max,
|
||||
N=config.model.num_scales)
|
||||
sampling_eps = 1e-5
|
||||
else:
|
||||
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
|
||||
|
||||
## Get data normalizer inverse
|
||||
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
|
||||
|
||||
## Get sampling function
|
||||
sampling_shape = (config.eval.batch_size, config.data.max_node, config.data.n_vocab)
|
||||
sampling_fn = sampling.get_sampling_fn(
|
||||
config=config,
|
||||
sde=sde,
|
||||
shape=sampling_shape,
|
||||
inverse_scaler=inverse_scaler,
|
||||
eps=sampling_eps,
|
||||
conditional=True,
|
||||
data_name=config.sampling.check_dataname,
|
||||
num_sample=config.model.num_sample)
|
||||
|
||||
return sampling_fn, sde
|
||||
|
||||
|
||||
def get_score_model(config):
|
||||
try:
|
||||
score_config = torch.load(config.scorenet_ckpt_path)['config']
|
||||
ckpt_path = config.scorenet_ckpt_path
|
||||
except:
|
||||
config.scorenet_ckpt_path = SCORENET_CKPT_PATH
|
||||
score_config = torch.load(config.scorenet_ckpt_path)['config']
|
||||
ckpt_path = config.scorenet_ckpt_path
|
||||
|
||||
score_model = mutils.create_model(score_config)
|
||||
score_ema = ExponentialMovingAverage(
|
||||
score_model.parameters(), decay=score_config.model.ema_rate)
|
||||
score_state = dict(
|
||||
model=score_model, ema=score_ema, step=0, config=score_config)
|
||||
score_state = restore_checkpoint(
|
||||
ckpt_path, score_state,
|
||||
device=config.device, resume=True)
|
||||
score_ema.copy_to(score_model.parameters())
|
||||
return score_model, score_ema, score_config
|
||||
|
||||
|
||||
def get_surrogate(config):
|
||||
surrogate_model = mutils.create_model(config)
|
||||
return surrogate_model
|
||||
|
||||
|
||||
def get_adj(except_inout=False):
|
||||
_adj = np.asarray(
|
||||
[[0, 1, 1, 1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 1, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]]
|
||||
)
|
||||
_adj = torch.tensor(_adj, dtype=torch.float32, device=torch.device('cpu'))
|
||||
if except_inout: _adj = _adj[1:-1, 1:-1]
|
||||
return _adj
|
||||
|
||||
|
||||
def generate_archs_meta(
|
||||
config,
|
||||
sampling_fn,
|
||||
score_model,
|
||||
score_ema,
|
||||
meta_surrogate_model,
|
||||
num_samples,
|
||||
args=None,
|
||||
task=None,
|
||||
patient_factor=20,
|
||||
batch_size=256,):
|
||||
|
||||
metrics = BasicArchMetricsMeta()
|
||||
|
||||
## Get the adj and mask
|
||||
adj_s = get_adj()
|
||||
mask_s = aug_mask(adj_s)[0]
|
||||
adj_c = get_adj()
|
||||
mask_c = aug_mask(adj_c)[0]
|
||||
assert (adj_s == adj_c).all() and (mask_s == mask_c).all()
|
||||
adj_s, mask_s, adj_c, mask_c = \
|
||||
adj_s.to(config.device), mask_s.to(config.device), adj_c.to(config.device), mask_c.to(config.device)
|
||||
|
||||
score_ema.copy_to(score_model.parameters())
|
||||
score_model.eval()
|
||||
meta_surrogate_model.eval()
|
||||
c_scale = args.classifier_scale
|
||||
|
||||
num_sampling_rounds = int(np.ceil(num_samples / batch_size) * patient_factor) if num_samples > batch_size else int(patient_factor)
|
||||
round = 0
|
||||
all_samples = []
|
||||
while True and round < num_sampling_rounds:
|
||||
round += 1
|
||||
sample = sampling_fn(score_model,
|
||||
mask_s,
|
||||
meta_surrogate_model,
|
||||
classifier_scale=c_scale,
|
||||
task=task)
|
||||
quantized_sample = quantize(sample)
|
||||
_, _, valid_arch_str, _ = metrics.compute_validity(quantized_sample)
|
||||
if len(valid_arch_str) > 0: all_samples += valid_arch_str
|
||||
# to sample various architectures
|
||||
c_scale -= args.scale_step
|
||||
seed = int(random.random() * 10000)
|
||||
reset_seed(seed)
|
||||
# stop sampling if we have enough samples
|
||||
if (len(set(all_samples)) >= num_samples):
|
||||
break
|
||||
|
||||
return list(set(all_samples))
|
||||
|
||||
|
||||
def save_checkpoint(ckpt_dir, state, epoch, is_best):
|
||||
saved_state = {}
|
||||
for k in state:
|
||||
if k in ['optimizer', 'model', 'ema']:
|
||||
saved_state.update({k: state[k].state_dict()})
|
||||
else:
|
||||
saved_state.update({k: state[k]})
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
torch.save(saved_state, os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'))
|
||||
if is_best:
|
||||
shutil.copy(os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'), os.path.join(ckpt_dir, 'model_best.pth.tar'))
|
||||
|
||||
# remove the ckpt except is_best state
|
||||
for ckpt_file in sorted(os.listdir(ckpt_dir)):
|
||||
if not ckpt_file.startswith('checkpoint'):
|
||||
continue
|
||||
if os.path.join(ckpt_dir, ckpt_file) != os.path.join(ckpt_dir, 'model_best.pth.tar'):
|
||||
os.remove(os.path.join(ckpt_dir, ckpt_file))
|
||||
|
||||
|
||||
def restore_checkpoint(ckpt_dir, state, device, resume=False):
|
||||
if not resume:
|
||||
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
|
||||
return state
|
||||
elif not os.path.exists(ckpt_dir):
|
||||
if not os.path.exists(os.path.dirname(ckpt_dir)):
|
||||
os.makedirs(os.path.dirname(ckpt_dir))
|
||||
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
||||
f"Returned the same state as input")
|
||||
return state
|
||||
else:
|
||||
loaded_state = torch.load(ckpt_dir, map_location=device)
|
||||
for k in state:
|
||||
if k in ['optimizer', 'model', 'ema']:
|
||||
state[k].load_state_dict(loaded_state[k])
|
||||
else:
|
||||
state[k] = loaded_state[k]
|
||||
return state
|
||||
|
||||
|
||||
def floyed(r):
|
||||
"""
|
||||
:param r: a numpy NxN matrix with float 0,1
|
||||
:return: a numpy NxN matrix with float 0,1
|
||||
"""
|
||||
if type(r) == torch.Tensor:
|
||||
r = r.cpu().numpy()
|
||||
N = r.shape[0]
|
||||
for k in range(N):
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if r[i, k] > 0 and r[k, j] > 0:
|
||||
r[i, j] = 1
|
||||
return r
|
||||
|
||||
|
||||
def aug_mask(adj, algo='floyed', data='NASBench201'):
|
||||
if len(adj.shape) == 2:
|
||||
adj = adj.unsqueeze(0)
|
||||
|
||||
if data.lower() in ['nasbench201', 'ofa']:
|
||||
assert len(adj.shape) == 3
|
||||
r = adj[0].clone().detach()
|
||||
if algo == 'long_range':
|
||||
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
|
||||
elif algo == 'floyed':
|
||||
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
|
||||
else:
|
||||
mask_i = r
|
||||
masks = [mask_i] * adj.size(0)
|
||||
return torch.stack(masks)
|
||||
else:
|
||||
masks = []
|
||||
for r in adj:
|
||||
if algo == 'long_range':
|
||||
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
|
||||
elif algo == 'floyed':
|
||||
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
|
||||
else:
|
||||
mask_i = r
|
||||
masks.append(mask_i)
|
||||
return torch.stack(masks)
|
||||
|
||||
|
||||
def long_range(r):
|
||||
"""
|
||||
:param r: a numpy NxN matrix with float 0,1
|
||||
:return: a numpy NxN matrix with float 0,1
|
||||
"""
|
||||
# r = np.array(r)
|
||||
if type(r) == torch.Tensor:
|
||||
r = r.cpu().numpy()
|
||||
N = r.shape[0]
|
||||
for j in range(1, N):
|
||||
col_j = r[:, j][:j]
|
||||
in_to_j = [i for i, val in enumerate(col_j) if val > 0]
|
||||
if len(in_to_j) > 0:
|
||||
for i in in_to_j:
|
||||
col_i = r[:, i][:i]
|
||||
in_to_i = [i for i, val in enumerate(col_i) if val > 0]
|
||||
if len(in_to_i) > 0:
|
||||
for k in in_to_i:
|
||||
r[k, j] = 1
|
||||
return r
|
||||
|
||||
|
||||
def quantize(x):
|
||||
"""Covert the PyTorch tensor x, adj matrices to numpy array.
|
||||
|
||||
Args:
|
||||
x: [Batch_size, Max_node, N_vocab]
|
||||
"""
|
||||
x_list = []
|
||||
|
||||
# discretization
|
||||
x[x >= 0.5] = 1.
|
||||
x[x < 0.5] = 0.
|
||||
|
||||
for i in range(x.shape[0]):
|
||||
x_tmp = x[i]
|
||||
x_tmp = x_tmp.cpu().numpy()
|
||||
x_list.append(x_tmp)
|
||||
|
||||
return x_list
|
||||
|
||||
|
||||
def reset_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
137
NAS-Bench-201/main_exp/logger.py
Normal file
137
NAS-Bench-201/main_exp/logger.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
import wandb
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(
|
||||
self,
|
||||
log_dir=None,
|
||||
write_textfile=True
|
||||
):
|
||||
|
||||
self.log_dir = log_dir
|
||||
self.write_textfile = write_textfile
|
||||
|
||||
self.logs_for_save = {}
|
||||
self.logs = {}
|
||||
|
||||
if self.write_textfile:
|
||||
self.f = open(os.path.join(log_dir, 'logs.txt'), 'w')
|
||||
|
||||
|
||||
def write_str(self, log_str):
|
||||
self.f.write(log_str+'\n')
|
||||
self.f.flush()
|
||||
|
||||
|
||||
def update_config(self, v, is_args=False):
|
||||
if is_args:
|
||||
self.logs_for_save.update({'args': v})
|
||||
else:
|
||||
self.logs_for_save.update(v)
|
||||
|
||||
|
||||
def write_log(self, element, step, return_log_dict=False):
|
||||
log_str = f"{step} | "
|
||||
log_dict = {}
|
||||
for head, keys in element.items():
|
||||
for k in keys:
|
||||
if k in self.logs:
|
||||
v = self.logs[k].avg
|
||||
if not k in self.logs_for_save:
|
||||
self.logs_for_save[k] = []
|
||||
self.logs_for_save[k].append(v)
|
||||
log_str += f'{k} {v}| '
|
||||
log_dict[f'{head}/{k}'] = v
|
||||
|
||||
if self.write_textfile:
|
||||
self.f.write(log_str+'\n')
|
||||
self.f.flush()
|
||||
|
||||
if return_log_dict:
|
||||
return log_dict
|
||||
|
||||
|
||||
def save_log(self, name=None):
|
||||
name = 'logs.pt' if name is None else name
|
||||
torch.save(self.logs_for_save, os.path.join(self.log_dir, name))
|
||||
|
||||
|
||||
def update(self, key, v, n=1):
|
||||
if not key in self.logs:
|
||||
self.logs[key] = AverageMeter()
|
||||
self.logs[key].update(v, n)
|
||||
|
||||
|
||||
def reset(self, keys=None, except_keys=[]):
|
||||
if keys is not None:
|
||||
if isinstance(keys, list):
|
||||
for key in keys:
|
||||
self.logs[key] = AverageMeter()
|
||||
else:
|
||||
self.logs[keys] = AverageMeter()
|
||||
else:
|
||||
for key in self.logs.keys():
|
||||
if not key in except_keys:
|
||||
self.logs[key] = AverageMeter()
|
||||
|
||||
|
||||
def avg(self, keys=None, except_keys=[]):
|
||||
if keys is not None:
|
||||
if isinstance(keys, list):
|
||||
return {key: self.logs[key].avg for key in keys if key in self.logs.keys()}
|
||||
else:
|
||||
return self.logs[keys].avg
|
||||
else:
|
||||
avg_dict = {}
|
||||
for key in self.logs.keys():
|
||||
if not key in except_keys:
|
||||
avg_dict[key] = self.logs[key].avg
|
||||
return avg_dict
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""
|
||||
Computes and stores the average and current value
|
||||
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def get_metrics(g_embeds, x_embeds, logit_scale, prefix='train'):
|
||||
metrics = {}
|
||||
logits_per_g = (logit_scale * g_embeds @ x_embeds.t()).detach().cpu()
|
||||
logits_per_x = logits_per_g.t().detach().cpu()
|
||||
|
||||
logits = {"g_to_x": logits_per_g, "x_to_g": logits_per_x}
|
||||
ground_truth = torch.arange(len(x_embeds)).view(-1, 1)
|
||||
|
||||
for name, logit in logits.items():
|
||||
ranking = torch.argsort(logit, descending=True)
|
||||
preds = torch.where(ranking == ground_truth)[1]
|
||||
preds = preds.detach().cpu().numpy()
|
||||
metrics[f"{prefix}_{name}_mean_rank"] = preds.mean() + 1
|
||||
metrics[f"{prefix}_{name}_median_rank"] = np.floor(np.median(preds)) + 1
|
||||
for k in [1, 5, 10]:
|
||||
metrics[f"{prefix}_{name}_R@{k}"] = np.mean(preds < k)
|
||||
|
||||
return metrics
|
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
@author: Hayeon Lee
|
||||
2020/02/19
|
||||
Script for downloading, and reorganizing aircraft
|
||||
for few shot classification
|
||||
Run this file as follows:
|
||||
python get_data.py
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import os
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import tarfile
|
||||
from PIL import Image
|
||||
import glob
|
||||
import shutil
|
||||
import pickle
|
||||
import collections
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
|
||||
from all_path import RAW_DATA_PATH
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
dir_path = RAW_DATA_PATH
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
file_name = os.path.join(dir_path, 'fgvc-aircraft-2013b.tar.gz')
|
||||
|
||||
if not os.path.exists(file_name):
|
||||
print(f"Downloading {file_name}\n")
|
||||
download_file(
|
||||
'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz',
|
||||
file_name)
|
||||
print("\nDownloading done.\n")
|
||||
else:
|
||||
print("fgvc-aircraft-2013b.tar.gz has already been downloaded. Did not download twice.\n")
|
||||
|
||||
untar_file_name = os.path.join(dir_path, 'aircraft')
|
||||
if not os.path.exists(untar_file_name):
|
||||
tarname = file_name
|
||||
print("Untarring: {}".format(tarname))
|
||||
tar = tarfile.open(tarname)
|
||||
tar.extractall(untar_file_name)
|
||||
tar.close()
|
||||
else:
|
||||
print(f"{untar_file_name} folder already exists. Did not untarring twice\n")
|
||||
os.remove(file_name)
|
50
NAS-Bench-201/main_exp/transfer_nag/get_files/get_pets.py
Normal file
50
NAS-Bench-201/main_exp/transfer_nag/get_files/get_pets.py
Normal file
@@ -0,0 +1,50 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
|
||||
from all_path import RAW_DATA_PATH
|
||||
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm(unit="B", total=int(r.headers['Content-Length']))
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update(len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
|
||||
dir_path = os.path.join(RAW_DATA_PATH, 'pets')
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
full_name = os.path.join(dir_path, 'test15.pth')
|
||||
if not os.path.exists(full_name):
|
||||
print(f"Downloading {full_name}\n")
|
||||
download_file(
|
||||
'https://www.dropbox.com/s/kzmrwyyk5iaugv0/test15.pth?dl=1', full_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{full_name} has already been downloaded. Did not download twice.\n")
|
||||
|
||||
full_name = os.path.join(dir_path, 'train85.pth')
|
||||
if not os.path.exists(full_name):
|
||||
print(f"Downloading {full_name}\n")
|
||||
download_file(
|
||||
'https://www.dropbox.com/s/w7mikpztkamnw9s/train85.pth?dl=1', full_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{full_name} has already been downloaded. Did not download twice.\n")
|
@@ -0,0 +1,47 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
|
||||
|
||||
DATA_PATH = "./data/transfer_nag"
|
||||
dir_path = DATA_PATH
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
|
||||
def get_preprocessed_data(file_name, url):
|
||||
print(f"Downloading {file_name} datasets\n")
|
||||
full_name = os.path.join(dir_path, file_name)
|
||||
download_file(url, full_name)
|
||||
print("Downloading done.\n")
|
||||
|
||||
|
||||
for file_name, url in [
|
||||
('aircraftbylabel.pt', 'https://www.dropbox.com/s/mb66kitv20ykctp/aircraftbylabel.pt?dl=1'),
|
||||
('cifar100bylabel.pt', 'https://www.dropbox.com/s/y0xahxgzj29kffk/cifar100bylabel.pt?dl=1'),
|
||||
('cifar10bylabel.pt', 'https://www.dropbox.com/s/wt1pcwi991xyhwr/cifar10bylabel.pt?dl=1'),
|
||||
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
|
||||
('petsbylabel.pt', 'https://www.dropbox.com/s/mxh6qz3grhy7wcn/petsbylabel.pt?dl=1'),
|
||||
]:
|
||||
|
||||
get_preprocessed_data(file_name, url)
|
130
NAS-Bench-201/main_exp/transfer_nag/loader.py
Normal file
130
NAS-Bench-201/main_exp/transfer_nag/loader.py
Normal file
@@ -0,0 +1,130 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=True):
|
||||
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
|
||||
print(f'==> The number of tasks for meta-training: {len(dataset)}')
|
||||
|
||||
loader = DataLoader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
collate_fn=collate_fn)
|
||||
return loader
|
||||
|
||||
|
||||
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
|
||||
dataset = MetaTestDataset(data_path, data_name, num_class)
|
||||
print(f'==> Meta-Test dataset {data_name}')
|
||||
|
||||
loader = DataLoader(dataset=dataset,
|
||||
batch_size=100,
|
||||
shuffle=False,
|
||||
num_workers=0)
|
||||
return loader
|
||||
|
||||
|
||||
class MetaTrainDatabase(Dataset):
|
||||
def __init__(self, data_path, num_sample, is_pred=True):
|
||||
self.mode = 'train'
|
||||
self.acc_norm = True
|
||||
self.num_sample = num_sample
|
||||
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
|
||||
|
||||
mtr_data_path = os.path.join(
|
||||
data_path, 'meta_train_tasks_predictor.pt')
|
||||
idx_path = os.path.join(
|
||||
data_path, 'meta_train_tasks_predictor_idx.pt')
|
||||
data = torch.load(mtr_data_path)
|
||||
self.acc = data['acc']
|
||||
self.task = data['task']
|
||||
self.graph = data['g']
|
||||
|
||||
random_idx_lst = torch.load(idx_path)
|
||||
self.idx_lst = {}
|
||||
self.idx_lst['valid'] = random_idx_lst[:400]
|
||||
self.idx_lst['train'] = random_idx_lst[400:]
|
||||
self.acc = torch.tensor(self.acc)
|
||||
self.mean = torch.mean(self.acc[self.idx_lst['train']]).item()
|
||||
self.std = torch.std(self.acc[self.idx_lst['train']]).item()
|
||||
self.task_lst = torch.load(os.path.join(
|
||||
data_path, 'meta_train_task_lst.pt'))
|
||||
|
||||
|
||||
def set_mode(self, mode):
|
||||
self.mode = mode
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.idx_lst[self.mode])
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
ridx = self.idx_lst[self.mode]
|
||||
tidx = self.task[ridx[index]]
|
||||
classes = self.task_lst[tidx]
|
||||
graph = self.graph[ridx[index]]
|
||||
acc = self.acc[ridx[index]]
|
||||
for cls in classes:
|
||||
cx = self.x[cls-1][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
if self.acc_norm:
|
||||
acc = ((acc - self.mean) / self.std) / 100.0
|
||||
else:
|
||||
acc = acc / 100.0
|
||||
return x, graph, acc
|
||||
|
||||
|
||||
class MetaTestDataset(Dataset):
|
||||
def __init__(self, data_path, data_name, num_sample, num_class=None):
|
||||
self.num_sample = num_sample
|
||||
self.data_name = data_name
|
||||
|
||||
num_class_dict = {
|
||||
'cifar100': 100,
|
||||
'cifar10': 10,
|
||||
'mnist': 10,
|
||||
'svhn': 10,
|
||||
'aircraft': 30,
|
||||
'pets': 37
|
||||
}
|
||||
|
||||
if num_class is not None:
|
||||
self.num_class = num_class
|
||||
else:
|
||||
self.num_class = num_class_dict[data_name]
|
||||
|
||||
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return 1000000
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
classes = list(range(self.num_class))
|
||||
for cls in classes:
|
||||
cx = self.x[cls][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
return x
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
x = torch.stack([item[0] for item in batch])
|
||||
graph = [item[1] for item in batch]
|
||||
acc = torch.stack([item[2] for item in batch])
|
||||
return [x, graph, acc]
|
91
NAS-Bench-201/main_exp/transfer_nag/main.py
Normal file
91
NAS-Bench-201/main_exp/transfer_nag/main.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import numpy as np
|
||||
import argparse
|
||||
import torch
|
||||
import os
|
||||
from nag import NAG
|
||||
sys.path.append(os.getcwd())
|
||||
save_path = "results"
|
||||
data_path = os.path.join('MetaD2A_nas_bench_201', 'data')
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ['t', 'true', True]
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('--seed', type=int, default=444)
|
||||
parser.add_argument('--gpu', type=str, default='0', help='set visible gpus')
|
||||
parser.add_argument('--save-path', type=str, default=save_path, help='the path of save directory')
|
||||
parser.add_argument('--data-path', type=str, default=data_path, help='the path of save directory')
|
||||
parser.add_argument('--model-load-path', type=str, default='', help='')
|
||||
parser.add_argument('--save-epoch', type=int, default=20, help='how many epochs to wait each time to save model states')
|
||||
parser.add_argument('--max-epoch', type=int, default=1000, help='number of epochs to train')
|
||||
parser.add_argument('--batch_size', type=int, default=1024, help='batch size for generator')
|
||||
parser.add_argument('--graph-data-name', default='nasbench201', help='graph dataset name')
|
||||
parser.add_argument('--nvt', type=int, default=7, help='number of different node types, 7: NAS-Bench-201 including in/out node')
|
||||
# set encoder
|
||||
parser.add_argument('--num-sample', type=int, default=20, help='the number of images as input for set encoder')
|
||||
# graph encoder
|
||||
parser.add_argument('--hs', type=int, default=512, help='hidden size of GRUs')
|
||||
parser.add_argument('--nz', type=int, default=56, help='the number of dimensions of latent vectors z')
|
||||
# test
|
||||
parser.add_argument('--test', action='store_true', default=True, help='turn on test mode')
|
||||
parser.add_argument('--load-epoch', type=int, default=100, help='checkpoint epoch loaded for meta-test')
|
||||
parser.add_argument('--data-name', type=str, default='pets', help='meta-test dataset name')
|
||||
parser.add_argument('--trials', type=int, default=20)
|
||||
parser.add_argument('--num-class', type=int, default=None, help='the number of class of dataset')
|
||||
parser.add_argument('--num-gen-arch', type=int, default=500, help='the number of candidate architectures generated by the generator')
|
||||
parser.add_argument('--train-arch', type=str2bool, default=True, help='whether to train the searched architecture')
|
||||
parser.add_argument('--n_init', type=int, default=10)
|
||||
parser.add_argument('--N', type=int, default=1)
|
||||
# DiffusionNAG
|
||||
parser.add_argument('--folder_name', type=str, default='debug')
|
||||
parser.add_argument('--exp_name', type=str, default='')
|
||||
parser.add_argument('--classifier_scale', type=float, default=10000., help='classifier scale')
|
||||
parser.add_argument('--scale_step', type=float, default=300.)
|
||||
parser.add_argument('--eval_batch_size', type=int, default=256)
|
||||
parser.add_argument('--predictor', type=str, default='euler_maruyama', choices=['euler_maruyama', 'reverse_diffusion', 'none'])
|
||||
parser.add_argument('--corrector', type=str, default='langevin', choices=['none', 'langevin'])
|
||||
parser.add_argument('--patient_factor', type=int, default=20)
|
||||
parser.add_argument('--n_gen_samples', type=int, default=10)
|
||||
parser.add_argument('--multi_proc', type=str2bool, default=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def set_exp_name(args):
|
||||
exp_name = f'./results/transfer_nag/{args.folder_name}/{args.data_name}'
|
||||
os.makedirs(exp_name, exist_ok=True)
|
||||
args.exp_name = exp_name
|
||||
|
||||
|
||||
def main():
|
||||
## Get arguments
|
||||
args = get_parser()
|
||||
|
||||
## Set gpus and seeds
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
## Set experiment name
|
||||
set_exp_name(args)
|
||||
|
||||
## Run
|
||||
nag = NAG(args)
|
||||
nag.meta_test()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
305
NAS-Bench-201/main_exp/transfer_nag/nag.py
Normal file
305
NAS-Bench-201/main_exp/transfer_nag/nag.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import os
|
||||
import gc
|
||||
import sys
|
||||
import numpy as np
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from nag_utils import mean_confidence_interval
|
||||
from nag_utils import restore_checkpoint
|
||||
from nag_utils import load_graph_config
|
||||
from nag_utils import load_model
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
|
||||
from nas_bench_201 import train_single_model
|
||||
from unnoised_model import MetaSurrogateUnnoisedModel
|
||||
from diffusion.run_lib import generate_archs_meta
|
||||
from diffusion.run_lib import get_sampling_fn_meta
|
||||
from diffusion.run_lib import get_score_model
|
||||
from diffusion.run_lib import get_surrogate
|
||||
from loader import MetaTestDataset
|
||||
from logger import Logger
|
||||
from all_path import *
|
||||
|
||||
|
||||
class NAG:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
## Target dataset information
|
||||
self.raw_data_path = RAW_DATA_PATH
|
||||
self.data_path = DATA_PATH
|
||||
self.data_name = args.data_name
|
||||
self.num_class = args.num_class
|
||||
self.num_sample = args.num_sample
|
||||
|
||||
graph_config = load_graph_config(args.graph_data_name, args.nvt, NASBENCH201)
|
||||
self.meta_surrogate_unnoised_model = MetaSurrogateUnnoisedModel(args, graph_config)
|
||||
load_model(model=self.meta_surrogate_unnoised_model,
|
||||
ckpt_path=META_SURROGATE_UNNOISED_CKPT_PATH)
|
||||
self.meta_surrogate_unnoised_model.to(self.device)
|
||||
|
||||
## Load pre-trained meta-surrogate model
|
||||
self.meta_surrogate_ckpt_path = META_SURROGATE_CKPT_PATH
|
||||
|
||||
## Load score network model (base diffusion model)
|
||||
self.load_diffusion_model(args=args)
|
||||
|
||||
## Check config
|
||||
self.check_config()
|
||||
|
||||
## Set logger
|
||||
self.logger = Logger(
|
||||
log_dir=args.exp_name,
|
||||
write_textfile=True
|
||||
)
|
||||
self.logger.update_config(args, is_args=True)
|
||||
self.logger.write_str(str(vars(args)))
|
||||
self.logger.write_str('-' * 100)
|
||||
|
||||
|
||||
def check_config(self):
|
||||
"""
|
||||
Check if the configuration of the pre-trained score network model matches that of the meta surrogate model.
|
||||
"""
|
||||
scorenet_config = torch.load(self.config.scorenet_ckpt_path)['config']
|
||||
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
|
||||
assert scorenet_config.model.sigma_min == meta_surrogate_config.model.sigma_min
|
||||
assert scorenet_config.model.sigma_max == meta_surrogate_config.model.sigma_max
|
||||
assert scorenet_config.training.sde == meta_surrogate_config.training.sde
|
||||
assert scorenet_config.training.continuous == meta_surrogate_config.training.continuous
|
||||
assert scorenet_config.data.centered == meta_surrogate_config.data.centered
|
||||
assert scorenet_config.data.max_node == meta_surrogate_config.data.max_node
|
||||
assert scorenet_config.data.n_vocab == meta_surrogate_config.data.n_vocab
|
||||
|
||||
|
||||
def forward(self, x, arch):
|
||||
D_mu = self.meta_surrogate_unnoised_model.set_encode(x.to(self.device))
|
||||
G_mu = self.meta_surrogate_unnoised_model.graph_encode(arch)
|
||||
y_pred = self.meta_surrogate_unnoised_model.predict(D_mu, G_mu)
|
||||
return y_pred
|
||||
|
||||
|
||||
def meta_test(self):
|
||||
if self.data_name == 'all':
|
||||
for data_name in ['cifar10', 'cifar100', 'aircraft', 'pets']:
|
||||
self.meta_test_per_dataset(data_name)
|
||||
else:
|
||||
self.meta_test_per_dataset(self.data_name)
|
||||
|
||||
|
||||
def meta_test_per_dataset(self, data_name):
|
||||
## Load NASBench201
|
||||
self.nasbench201 = torch.load(NASBENCH201)
|
||||
all_arch_str = np.array(self.nasbench201['arch']['str'])
|
||||
|
||||
## Load meta-test dataset
|
||||
self.test_dataset = MetaTestDataset(self.data_path, data_name, self.num_sample, self.num_class)
|
||||
|
||||
## Set save path
|
||||
meta_test_path = os.path.join(META_TEST_PATH, data_name)
|
||||
os.makedirs(meta_test_path, exist_ok=True)
|
||||
f_arch_str = open(os.path.join(self.args.exp_name, 'architecture.txt'), 'w')
|
||||
f_arch_acc = open(os.path.join(self.args.exp_name, 'accuracy.txt'), 'w')
|
||||
|
||||
## Generate architectures
|
||||
gen_arch_str = self.get_gen_arch_str()
|
||||
gen_arch_igraph = self.get_items(
|
||||
full_target=self.nasbench201['arch']['igraph'],
|
||||
full_source=self.nasbench201['arch']['str'],
|
||||
source=gen_arch_str)
|
||||
|
||||
## Sort with unnoised meta-surrogate model
|
||||
y_pred_all = []
|
||||
self.meta_surrogate_unnoised_model.eval()
|
||||
self.meta_surrogate_unnoised_model.to(self.device)
|
||||
with torch.no_grad():
|
||||
for arch_igraph in gen_arch_igraph:
|
||||
x, g = self.collect_data(arch_igraph)
|
||||
y_pred = self.forward(x, g)
|
||||
y_pred = torch.mean(y_pred)
|
||||
y_pred_all.append(y_pred.cpu().detach().item())
|
||||
sorted_arch_lst = self.sort_arch(data_name, torch.tensor(y_pred_all), gen_arch_str)
|
||||
|
||||
## Record the information of the architecture generated in sorted order
|
||||
for _, arch_str in enumerate(sorted_arch_lst):
|
||||
f_arch_str.write(f'{arch_str}\n')
|
||||
arch_idx_lst = [self.nasbench201['arch']['str'].index(i) for i in sorted_arch_lst]
|
||||
arch_str_lst = []
|
||||
arch_acc_lst = []
|
||||
|
||||
## Get the accuracy of the architecture
|
||||
if 'cifar' in data_name:
|
||||
sorted_acc_lst = self.get_items(
|
||||
full_target=self.nasbench201['test-acc'][data_name],
|
||||
full_source=self.nasbench201['arch']['str'],
|
||||
source=sorted_arch_lst)
|
||||
arch_str_lst += sorted_arch_lst
|
||||
arch_acc_lst += sorted_acc_lst
|
||||
for arch_idx, acc in zip(arch_idx_lst, sorted_acc_lst):
|
||||
msg = f'Avg {acc:4f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
else:
|
||||
if self.args.multi_proc:
|
||||
## Run multiple processes in parallel
|
||||
run_file = os.path.join(os.getcwd(), 'main_exp', 'transfer_nag', 'run_multi_proc.py')
|
||||
MAX_CAP = 5 # hard-coded for available GPUs
|
||||
if not len(arch_idx_lst) > MAX_CAP:
|
||||
arch_idx_lst_ = [arch_idx for arch_idx in arch_idx_lst if not os.path.exists(os.path.join(meta_test_path, str(arch_idx)))]
|
||||
support_ = ','.join([str(i) for i in arch_idx_lst_])
|
||||
num_split = int(3 * len(arch_idx_lst_)) # why 3? => running for 3 seeds
|
||||
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
|
||||
subprocess.run([cmd], shell=True)
|
||||
else:
|
||||
arch_idx_lst_ = []
|
||||
for j, arch_idx in enumerate(arch_idx_lst):
|
||||
if not os.path.exists(os.path.join(meta_test_path, str(arch_idx))):
|
||||
arch_idx_lst_.append(arch_idx)
|
||||
if (len(arch_idx_lst_) == MAX_CAP) or (j == len(arch_idx_lst) - 1):
|
||||
support_ = ','.join([str(i) for i in arch_idx_lst_])
|
||||
num_split = int(3 * len(arch_idx_lst_))
|
||||
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
|
||||
subprocess.run([cmd], shell=True)
|
||||
arch_idx_lst_ = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
acc_runs_lst = []
|
||||
epoch = 199
|
||||
seeds = (777, 888, 999)
|
||||
for arch_idx in arch_idx_lst:
|
||||
acc_runs = []
|
||||
save_path_ = os.path.join(meta_test_path, str(arch_idx))
|
||||
for seed in seeds:
|
||||
result = torch.load(os.path.join(save_path_, f'seed-0{seed}.pth'))
|
||||
acc_runs.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
|
||||
acc_runs_lst.append(acc_runs)
|
||||
break
|
||||
except:
|
||||
pass
|
||||
for i in acc_runs_lst:print(np.mean(i))
|
||||
for arch_idx, acc_runs in zip(arch_idx_lst, acc_runs_lst):
|
||||
for r, acc in enumerate(acc_runs):
|
||||
msg = f'run {r+1} {acc:.2f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
m, h = mean_confidence_interval(acc_runs)
|
||||
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
arch_acc_lst.append(np.mean(acc_runs))
|
||||
arch_str_lst.append(all_arch_str[arch_idx])
|
||||
|
||||
else:
|
||||
for arch_idx in arch_idx_lst:
|
||||
acc_runs = self.train_single_arch(
|
||||
data_name, self.nasbench201['str'][arch_idx], meta_test_path)
|
||||
for r, acc in enumerate(acc_runs):
|
||||
msg = f'run {r+1} {acc:.2f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
m, h = mean_confidence_interval(acc_runs)
|
||||
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
|
||||
f_arch_acc.write(msg + '\n')
|
||||
arch_acc_lst.append(np.mean(acc_runs))
|
||||
arch_str_lst.append(all_arch_str[arch_idx])
|
||||
|
||||
# Save results
|
||||
results_path = os.path.join(self.args.exp_name, 'results.pt')
|
||||
torch.save({
|
||||
'arch_idx_lst': arch_idx_lst,
|
||||
'arch_str_lst': arch_str_lst,
|
||||
'arch_acc_lst': arch_acc_lst
|
||||
}, results_path)
|
||||
print(f">>> Save the results at {results_path}...")
|
||||
|
||||
|
||||
def train_single_arch(self, data_name, arch_str, meta_test_path):
|
||||
save_path = os.path.join(meta_test_path, arch_str)
|
||||
seeds = (777, 888, 999)
|
||||
train_single_model(save_dir=save_path,
|
||||
workers=24,
|
||||
datasets=[data_name],
|
||||
xpaths=[f'{self.raw_data_path}/{data_name}'],
|
||||
splits=[0],
|
||||
use_less=False,
|
||||
seeds=seeds,
|
||||
model_str=arch_str,
|
||||
arch_config={'channel': 16, 'num_cells': 5})
|
||||
epoch = 199
|
||||
test_acc_lst = []
|
||||
for seed in seeds:
|
||||
result = torch.load(os.path.join(save_path, f'seed-0{seed}.pth'))
|
||||
test_acc_lst.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
|
||||
return test_acc_lst
|
||||
|
||||
|
||||
def sort_arch(self, data_name, y_pred_all, gen_arch_str):
|
||||
_, sorted_idx = torch.sort(y_pred_all, descending=True)
|
||||
sotred_gen_arch_str = [gen_arch_str[_] for _ in sorted_idx]
|
||||
return sotred_gen_arch_str
|
||||
|
||||
|
||||
def collect_data_only(self):
|
||||
x_batch = []
|
||||
x_batch.append(self.test_dataset[0])
|
||||
return torch.stack(x_batch).to(self.device)
|
||||
|
||||
|
||||
def collect_data(self, arch_igraph):
|
||||
x_batch, g_batch = [], []
|
||||
for _ in range(10):
|
||||
x_batch.append(self.test_dataset[0])
|
||||
g_batch.append(arch_igraph)
|
||||
return torch.stack(x_batch).to(self.device), g_batch
|
||||
|
||||
|
||||
def get_items(self, full_target, full_source, source):
|
||||
return [full_target[full_source.index(_)] for _ in source]
|
||||
|
||||
|
||||
def load_diffusion_model(self, args):
|
||||
self.config = torch.load('./configs/transfer_nag_config.pt')
|
||||
self.config.device = torch.device('cuda')
|
||||
self.config.data.label_list = ['meta-acc']
|
||||
self.config.scorenet_ckpt_path = SCORENET_CKPT_PATH
|
||||
self.config.sampling.classifier_scale = args.classifier_scale
|
||||
self.config.eval.batch_size = args.eval_batch_size
|
||||
self.config.sampling.predictor = args.predictor
|
||||
self.config.sampling.corrector = args.corrector
|
||||
self.config.sampling.check_dataname = self.data_name
|
||||
self.sampling_fn, self.sde = get_sampling_fn_meta(self.config)
|
||||
self.score_model, self.score_ema, self.score_config = get_score_model(self.config)
|
||||
|
||||
|
||||
def get_gen_arch_str(self):
|
||||
## Load meta-surrogate model
|
||||
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
|
||||
meta_surrogate_model = get_surrogate(meta_surrogate_config)
|
||||
meta_surrogate_state = dict(model=meta_surrogate_model, step=0, config=meta_surrogate_config)
|
||||
meta_surrogate_state = restore_checkpoint(
|
||||
self.meta_surrogate_ckpt_path,
|
||||
meta_surrogate_state,
|
||||
device=self.config.device,
|
||||
resume=True)
|
||||
|
||||
## Get dataset embedding, x
|
||||
with torch.no_grad():
|
||||
x = self.collect_data_only()
|
||||
|
||||
## Generate architectures
|
||||
generated_arch_str = generate_archs_meta(
|
||||
config=self.config,
|
||||
sampling_fn=self.sampling_fn,
|
||||
score_model=self.score_model,
|
||||
score_ema=self.score_ema,
|
||||
meta_surrogate_model=meta_surrogate_model,
|
||||
num_samples=self.args.n_gen_samples,
|
||||
args=self.args,
|
||||
task=x)
|
||||
|
||||
## Clean up
|
||||
meta_surrogate_model = None
|
||||
gc.collect()
|
||||
|
||||
return generated_arch_str
|
301
NAS-Bench-201/main_exp/transfer_nag/nag_utils.py
Normal file
301
NAS-Bench-201/main_exp/transfer_nag/nag_utils.py
Normal file
@@ -0,0 +1,301 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
import igraph
|
||||
import random
|
||||
import numpy as np
|
||||
import scipy.stats
|
||||
import torch
|
||||
import logging
|
||||
|
||||
|
||||
def reset_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def restore_checkpoint(ckpt_dir, state, device, resume=False):
|
||||
if not resume:
|
||||
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
|
||||
return state
|
||||
elif not os.path.exists(ckpt_dir):
|
||||
if not os.path.exists(os.path.dirname(ckpt_dir)):
|
||||
os.makedirs(os.path.dirname(ckpt_dir))
|
||||
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
||||
f"Returned the same state as input")
|
||||
return state
|
||||
else:
|
||||
loaded_state = torch.load(ckpt_dir, map_location=device)
|
||||
for k in state:
|
||||
if k in ['optimizer', 'model', 'ema']:
|
||||
state[k].load_state_dict(loaded_state[k])
|
||||
else:
|
||||
state[k] = loaded_state[k]
|
||||
return state
|
||||
|
||||
|
||||
def load_graph_config(graph_data_name, nvt, data_path):
|
||||
if graph_data_name is not 'nasbench201':
|
||||
raise NotImplementedError(graph_data_name)
|
||||
g_list = []
|
||||
max_n = 0 # maximum number of nodes
|
||||
ms = torch.load(data_path)['arch']['matrix']
|
||||
for i in range(len(ms)):
|
||||
g, n = decode_NAS_BENCH_201_8_to_igraph(ms[i])
|
||||
max_n = max(max_n, n)
|
||||
g_list.append((g, 0))
|
||||
# number of different node types including in/out node
|
||||
graph_config = {}
|
||||
graph_config['num_vertex_type'] = nvt # original types + start/end types
|
||||
graph_config['max_n'] = max_n # maximum number of nodes
|
||||
graph_config['START_TYPE'] = 0 # predefined start vertex type
|
||||
graph_config['END_TYPE'] = 1 # predefined end vertex type
|
||||
|
||||
return graph_config
|
||||
|
||||
|
||||
def decode_NAS_BENCH_201_8_to_igraph(row):
|
||||
if type(row) == str:
|
||||
row = eval(row) # convert string to list of lists
|
||||
n = len(row)
|
||||
g = igraph.Graph(directed=True)
|
||||
g.add_vertices(n)
|
||||
for i, node in enumerate(row):
|
||||
g.vs[i]['type'] = node[0]
|
||||
if i < (n - 2) and i > 0:
|
||||
g.add_edge(i, i + 1) # always connect from last node
|
||||
for j, edge in enumerate(node[1:]):
|
||||
if edge == 1:
|
||||
g.add_edge(j, i)
|
||||
return g, n
|
||||
|
||||
|
||||
def is_valid_NAS201(g, START_TYPE=0, END_TYPE=1):
|
||||
# first need to be a valid DAG computation graph
|
||||
res = is_valid_DAG(g, START_TYPE, END_TYPE)
|
||||
# in addition, node i must connect to node i+1
|
||||
res = res and len(g.vs['type']) == 8
|
||||
res = res and not (0 in g.vs['type'][1:-1])
|
||||
res = res and not (1 in g.vs['type'][1:-1])
|
||||
return res
|
||||
|
||||
|
||||
def decode_igraph_to_NAS201_matrix(g):
|
||||
m = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
|
||||
xys = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]
|
||||
for i, xy in enumerate(xys):
|
||||
m[xy[0]][xy[1]] = float(g.vs[i + 1]['type']) - 2
|
||||
import numpy
|
||||
return numpy.array(m)
|
||||
|
||||
|
||||
def decode_igraph_to_NAS_BENCH_201_string(g):
|
||||
if not is_valid_NAS201(g):
|
||||
return None
|
||||
m = decode_igraph_to_NAS201_matrix(g)
|
||||
types = ['none', 'skip_connect', 'nor_conv_1x1',
|
||||
'nor_conv_3x3', 'avg_pool_3x3']
|
||||
return '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
|
||||
format(types[int(m[1][0])],
|
||||
types[int(m[2][0])], types[int(m[2][1])],
|
||||
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
|
||||
|
||||
|
||||
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
|
||||
res = g.is_dag()
|
||||
n_start, n_end = 0, 0
|
||||
for v in g.vs:
|
||||
if v['type'] == START_TYPE:
|
||||
n_start += 1
|
||||
elif v['type'] == END_TYPE:
|
||||
n_end += 1
|
||||
if v.indegree() == 0 and v['type'] != START_TYPE:
|
||||
return False
|
||||
if v.outdegree() == 0 and v['type'] != END_TYPE:
|
||||
return False
|
||||
return res and n_start == 1 and n_end == 1
|
||||
|
||||
|
||||
class Accumulator():
|
||||
def __init__(self, *args):
|
||||
self.args = args
|
||||
self.argdict = {}
|
||||
for i, arg in enumerate(args):
|
||||
self.argdict[arg] = i
|
||||
self.sums = [0] * len(args)
|
||||
self.cnt = 0
|
||||
|
||||
def accum(self, val):
|
||||
val = [val] if type(val) is not list else val
|
||||
val = [v for v in val if v is not None]
|
||||
assert (len(val) == len(self.args))
|
||||
for i in range(len(val)):
|
||||
if torch.is_tensor(val[i]):
|
||||
val[i] = val[i].item()
|
||||
self.sums[i] += val[i]
|
||||
self.cnt += 1
|
||||
|
||||
def clear(self):
|
||||
self.sums = [0] * len(self.args)
|
||||
self.cnt = 0
|
||||
|
||||
def get(self, arg, avg=True):
|
||||
i = self.argdict.get(arg, -1)
|
||||
assert (i is not -1)
|
||||
if avg:
|
||||
return self.sums[i] / (self.cnt + 1e-8)
|
||||
else:
|
||||
return self.sums[i]
|
||||
|
||||
def print_(self, header=None, time=None,
|
||||
logfile=None, do_not_print=[], as_int=[],
|
||||
avg=True):
|
||||
msg = '' if header is None else header + ': '
|
||||
if time is not None:
|
||||
msg += ('(%.3f secs), ' % time)
|
||||
|
||||
args = [arg for arg in self.args if arg not in do_not_print]
|
||||
arg = []
|
||||
for arg in args:
|
||||
val = self.sums[self.argdict[arg]]
|
||||
if avg:
|
||||
val /= (self.cnt + 1e-8)
|
||||
if arg in as_int:
|
||||
msg += ('%s %d, ' % (arg, int(val)))
|
||||
else:
|
||||
msg += ('%s %.4f, ' % (arg, val))
|
||||
print(msg)
|
||||
|
||||
if logfile is not None:
|
||||
logfile.write(msg + '\n')
|
||||
logfile.flush()
|
||||
|
||||
def add_scalars(self, summary, header=None, tag_scalar=None,
|
||||
step=None, avg=True, args=None):
|
||||
for arg in self.args:
|
||||
val = self.sums[self.argdict[arg]]
|
||||
if avg:
|
||||
val /= (self.cnt + 1e-8)
|
||||
else:
|
||||
val = val
|
||||
tag = f'{header}/{arg}' if header is not None else arg
|
||||
if tag_scalar is not None:
|
||||
summary.add_scalars(main_tag=tag,
|
||||
tag_scalar_dict={tag_scalar: val},
|
||||
global_step=step)
|
||||
else:
|
||||
summary.add_scalar(tag=tag,
|
||||
scalar_value=val,
|
||||
global_step=step)
|
||||
|
||||
|
||||
class Log:
|
||||
def __init__(self, args, logf, summary=None):
|
||||
self.args = args
|
||||
self.logf = logf
|
||||
self.summary = summary
|
||||
self.stime = time.time()
|
||||
self.ep_sttime = None
|
||||
|
||||
def print(self, logger, epoch, tag=None, avg=True):
|
||||
if tag == 'train':
|
||||
ct = time.time() - self.ep_sttime
|
||||
tt = time.time() - self.stime
|
||||
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
|
||||
print(msg)
|
||||
self.logf.write(msg+'\n')
|
||||
logger.print_(header=tag, logfile=self.logf, avg=avg)
|
||||
|
||||
if self.summary is not None:
|
||||
logger.add_scalars(
|
||||
self.summary, header=tag, step=epoch, avg=avg)
|
||||
logger.clear()
|
||||
|
||||
def print_args(self):
|
||||
argdict = vars(self.args)
|
||||
print(argdict)
|
||||
for k, v in argdict.items():
|
||||
self.logf.write(k + ': ' + str(v) + '\n')
|
||||
self.logf.write('\n')
|
||||
|
||||
def set_time(self):
|
||||
self.stime = time.time()
|
||||
|
||||
def save_time_log(self):
|
||||
ct = time.time() - self.stime
|
||||
msg = f'({ct:6.2f}s) meta-training phase done'
|
||||
print(msg)
|
||||
self.logf.write(msg+'\n')
|
||||
|
||||
def print_pred_log(self, loss, corr, tag, epoch=None, max_corr_dict=None):
|
||||
if tag == 'train':
|
||||
ct = time.time() - self.ep_sttime
|
||||
tt = time.time() - self.stime
|
||||
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
|
||||
self.logf.write(msg+'\n')
|
||||
print(msg)
|
||||
self.logf.flush()
|
||||
# msg = f'ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} '
|
||||
# msg += f'time {time.time() - sttime:6.2f} '
|
||||
if max_corr_dict is not None:
|
||||
max_corr = max_corr_dict['corr']
|
||||
max_loss = max_corr_dict['loss']
|
||||
msg = f'{tag}: loss {loss:.6f} ({max_loss:.6f}) '
|
||||
msg += f'corr {corr:.4f} ({max_corr:.4f})'
|
||||
else:
|
||||
msg = f'{tag}: loss {loss:.6f} corr {corr:.4f}'
|
||||
self.logf.write(msg+'\n')
|
||||
print(msg)
|
||||
self.logf.flush()
|
||||
|
||||
def max_corr_log(self, max_corr_dict):
|
||||
corr = max_corr_dict['corr']
|
||||
loss = max_corr_dict['loss']
|
||||
epoch = max_corr_dict['epoch']
|
||||
msg = f'[epoch {epoch}] max correlation: {corr:.4f}, loss: {loss:.6f}'
|
||||
self.logf.write(msg+'\n')
|
||||
print(msg)
|
||||
self.logf.flush()
|
||||
|
||||
|
||||
def get_log(epoch, loss, y_pred, y, acc_std, acc_mean, tag='train'):
|
||||
msg = f'[{tag}] Ep {epoch} loss {loss.item()/len(y):0.4f} '
|
||||
if type(y_pred) == list:
|
||||
msg += f'pacc {y_pred[0]:0.4f}'
|
||||
msg += f'({y_pred[0]*100.0*acc_std+acc_mean:0.4f}) '
|
||||
else:
|
||||
msg += f'pacc {y_pred:0.4f}'
|
||||
msg += f'({y_pred*100.0*acc_std+acc_mean:0.4f}) '
|
||||
msg += f'acc {y[0]:0.4f}({y[0]*100*acc_std+acc_mean:0.4f})'
|
||||
return msg
|
||||
|
||||
|
||||
def load_model(model, ckpt_path):
|
||||
model.cpu()
|
||||
model.load_state_dict(torch.load(ckpt_path))
|
||||
|
||||
|
||||
def save_model(epoch, model, model_path, max_corr=None):
|
||||
print("==> save current model...")
|
||||
if max_corr is not None:
|
||||
torch.save(model.cpu().state_dict(),
|
||||
os.path.join(model_path, 'ckpt_max_corr.pt'))
|
||||
else:
|
||||
torch.save(model.cpu().state_dict(),
|
||||
os.path.join(model_path, f'ckpt_{epoch}.pt'))
|
||||
|
||||
|
||||
def mean_confidence_interval(data, confidence=0.95):
|
||||
a = 1.0 * np.array(data)
|
||||
n = len(a)
|
||||
m, se = np.mean(a), scipy.stats.sem(a)
|
||||
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
|
||||
return m, h
|
@@ -0,0 +1,6 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
dir_path = (Path(__file__).parent).resolve()
|
||||
if str(dir_path) not in sys.path: sys.path.insert(0, str(dir_path))
|
||||
|
||||
from .architecture import train_single_model
|
@@ -0,0 +1,173 @@
|
||||
###############################################################
|
||||
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
||||
###############################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
###############################################################
|
||||
from functions import evaluate_for_seed
|
||||
from nas_bench_201_models import CellStructure, CellArchitectures, get_search_spaces
|
||||
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||
from nas_bench_201_datasets import get_datasets
|
||||
from procedures import get_machine_info
|
||||
from procedures import save_checkpoint, copy_checkpoint
|
||||
from config_utils import load_config
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
import random
|
||||
import argparse
|
||||
from PIL import ImageFile
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
NASBENCH201_CONFIG_PATH = os.path.join(
|
||||
os.getcwd(), 'main_exp', 'transfer_nag')
|
||||
|
||||
|
||||
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed,
|
||||
arch_config, workers, logger):
|
||||
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
||||
all_infos = {'info': machine_info}
|
||||
all_dataset_keys = []
|
||||
# look all the datasets
|
||||
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||||
# train valid data
|
||||
task = None
|
||||
train_data, valid_data, xshape, class_num = get_datasets(
|
||||
dataset, xpath, -1, task)
|
||||
|
||||
# load the configuration
|
||||
if dataset in ['mnist', 'svhn', 'aircraft', 'pets']:
|
||||
if use_less:
|
||||
config_path = os.path.join(
|
||||
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/LESS.config')
|
||||
else:
|
||||
config_path = os.path.join(
|
||||
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}.config'.format(dataset))
|
||||
|
||||
p = os.path.join(
|
||||
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(dataset))
|
||||
if not os.path.exists(p):
|
||||
import json
|
||||
label_list = list(range(len(train_data)))
|
||||
random.shuffle(label_list)
|
||||
strlist = [str(label_list[i]) for i in range(len(label_list))]
|
||||
splited = {'train': ["int", strlist[:len(train_data) // 2]],
|
||||
'valid': ["int", strlist[len(train_data) // 2:]]}
|
||||
with open(p, 'w') as f:
|
||||
f.write(json.dumps(splited))
|
||||
split_info = load_config(os.path.join(
|
||||
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(dataset)), None, None)
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||
|
||||
config = load_config(
|
||||
config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
||||
# data loader
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size,
|
||||
shuffle=True, num_workers=workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
||||
shuffle=False, num_workers=workers, pin_memory=True)
|
||||
splits = load_config(os.path.join(
|
||||
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}-test-split.txt'.format(dataset)), None, None)
|
||||
ValLoaders = {'ori-test': valid_loader,
|
||||
'x-valid': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
splits.xvalid),
|
||||
num_workers=workers, pin_memory=True),
|
||||
'x-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
splits.xtest),
|
||||
num_workers=workers, pin_memory=True)
|
||||
}
|
||||
dataset_key = '{:}'.format(dataset)
|
||||
if bool(split):
|
||||
dataset_key = dataset_key + '-valid'
|
||||
logger.log(
|
||||
'Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.
|
||||
format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
|
||||
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(
|
||||
dataset_key, config))
|
||||
for key, value in ValLoaders.items():
|
||||
logger.log(
|
||||
'Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
|
||||
|
||||
results = evaluate_for_seed(
|
||||
arch_config, config, arch, train_loader, ValLoaders, seed, logger)
|
||||
all_infos[dataset_key] = results
|
||||
all_dataset_keys.append(dataset_key)
|
||||
all_infos['all_dataset_keys'] = all_dataset_keys
|
||||
return all_infos
|
||||
|
||||
|
||||
def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less,
|
||||
seeds, model_str, arch_config):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(workers)
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
logger = Logger(str(save_dir), 0, False)
|
||||
|
||||
if model_str in CellArchitectures:
|
||||
arch = CellArchitectures[model_str]
|
||||
logger.log(
|
||||
'The model string is found in pre-defined architecture dict : {:}'.format(model_str))
|
||||
else:
|
||||
try:
|
||||
arch = CellStructure.str2structure(model_str)
|
||||
except:
|
||||
raise ValueError(
|
||||
'Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
|
||||
|
||||
assert arch.check_valid_op(get_search_spaces(
|
||||
'cell', 'nas-bench-201')), '{:} has the invalid op.'.format(arch)
|
||||
# assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
|
||||
logger.log('Start train-evaluate {:}'.format(arch.tostr()))
|
||||
logger.log('arch_config : {:}'.format(arch_config))
|
||||
|
||||
start_time, seed_time = time.time(), AverageMeter()
|
||||
for _is, seed in enumerate(seeds):
|
||||
logger.log(
|
||||
'\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds),
|
||||
seed))
|
||||
to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
|
||||
if to_save_name.exists():
|
||||
logger.log(
|
||||
'Find the existing file {:}, directly load!'.format(to_save_name))
|
||||
checkpoint = torch.load(to_save_name)
|
||||
else:
|
||||
logger.log(
|
||||
'Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
|
||||
checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less,
|
||||
seed, arch_config, workers, logger)
|
||||
torch.save(checkpoint, to_save_name)
|
||||
# log information
|
||||
logger.log('{:}'.format(checkpoint['info']))
|
||||
all_dataset_keys = checkpoint['all_dataset_keys']
|
||||
for dataset_key in all_dataset_keys:
|
||||
logger.log('\n{:} dataset : {:} {:}'.format(
|
||||
'-' * 15, dataset_key, '-' * 15))
|
||||
dataset_info = checkpoint[dataset_key]
|
||||
# logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
|
||||
logger.log('Flops = {:} MB, Params = {:} MB'.format(
|
||||
dataset_info['flop'], dataset_info['param']))
|
||||
logger.log('config : {:}'.format(dataset_info['config']))
|
||||
logger.log('Training State (finish) = {:}'.format(
|
||||
dataset_info['finish-train']))
|
||||
last_epoch = dataset_info['total_epoch'] - 1
|
||||
train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
|
||||
valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
|
||||
# measure elapsed time
|
||||
seed_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = 'Time Left: {:}'.format(convert_secs2time(
|
||||
seed_time.avg * (len(seeds) - _is - 1), True))
|
||||
logger.log(
|
||||
'\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed,
|
||||
need_time))
|
||||
logger.close()
|
@@ -0,0 +1,13 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .configure_utils import load_config, dict2config#, configure2str
|
||||
#from .basic_args import obtain_basic_args
|
||||
#from .attention_args import obtain_attention_args
|
||||
#from .random_baseline import obtain_RandomSearch_args
|
||||
#from .cls_kd_args import obtain_cls_kd_args
|
||||
#from .cls_init_args import obtain_cls_init_args
|
||||
#from .search_single_args import obtain_search_single_args
|
||||
#from .search_args import obtain_search_args
|
||||
# for network pruning
|
||||
#from .pruning_args import obtain_pruning_args
|
@@ -0,0 +1,106 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
import os, json
|
||||
from os import path as osp
|
||||
from pathlib import Path
|
||||
from collections import namedtuple
|
||||
|
||||
support_types = ('str', 'int', 'bool', 'float', 'none')
|
||||
|
||||
|
||||
def convert_param(original_lists):
|
||||
assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists)
|
||||
ctype, value = original_lists[0], original_lists[1]
|
||||
assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types)
|
||||
is_list = isinstance(value, list)
|
||||
if not is_list: value = [value]
|
||||
outs = []
|
||||
for x in value:
|
||||
if ctype == 'int':
|
||||
x = int(x)
|
||||
elif ctype == 'str':
|
||||
x = str(x)
|
||||
elif ctype == 'bool':
|
||||
x = bool(int(x))
|
||||
elif ctype == 'float':
|
||||
x = float(x)
|
||||
elif ctype == 'none':
|
||||
if x.lower() != 'none':
|
||||
raise ValueError('For the none type, the value must be none instead of {:}'.format(x))
|
||||
x = None
|
||||
else:
|
||||
raise TypeError('Does not know this type : {:}'.format(ctype))
|
||||
outs.append(x)
|
||||
if not is_list: outs = outs[0]
|
||||
return outs
|
||||
|
||||
|
||||
def load_config(path, extra, logger):
|
||||
path = str(path)
|
||||
if hasattr(logger, 'log'): logger.log(path)
|
||||
assert os.path.exists(path), 'Can not find {:}'.format(path)
|
||||
# Reading data back
|
||||
with open(path, 'r') as f:
|
||||
data = json.load(f)
|
||||
content = { k: convert_param(v) for k,v in data.items()}
|
||||
assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra)
|
||||
if isinstance(extra, dict): content = {**content, **extra}
|
||||
Arguments = namedtuple('Configure', ' '.join(content.keys()))
|
||||
content = Arguments(**content)
|
||||
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
|
||||
return content
|
||||
|
||||
|
||||
def configure2str(config, xpath=None):
|
||||
if not isinstance(config, dict):
|
||||
config = config._asdict()
|
||||
def cstring(x):
|
||||
return "\"{:}\"".format(x)
|
||||
def gtype(x):
|
||||
if isinstance(x, list): x = x[0]
|
||||
if isinstance(x, str) : return 'str'
|
||||
elif isinstance(x, bool) : return 'bool'
|
||||
elif isinstance(x, int): return 'int'
|
||||
elif isinstance(x, float): return 'float'
|
||||
elif x is None : return 'none'
|
||||
else: raise ValueError('invalid : {:}'.format(x))
|
||||
def cvalue(x, xtype):
|
||||
if isinstance(x, list): is_list = True
|
||||
else:
|
||||
is_list, x = False, [x]
|
||||
temps = []
|
||||
for temp in x:
|
||||
if xtype == 'bool' : temp = cstring(int(temp))
|
||||
elif xtype == 'none': temp = cstring('None')
|
||||
else : temp = cstring(temp)
|
||||
temps.append( temp )
|
||||
if is_list:
|
||||
return "[{:}]".format( ', '.join( temps ) )
|
||||
else:
|
||||
return temps[0]
|
||||
|
||||
xstrings = []
|
||||
for key, value in config.items():
|
||||
xtype = gtype(value)
|
||||
string = ' {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype))
|
||||
xstrings.append(string)
|
||||
Fstring = '{\n' + ',\n'.join(xstrings) + '\n}'
|
||||
if xpath is not None:
|
||||
parent = Path(xpath).resolve().parent
|
||||
parent.mkdir(parents=True, exist_ok=True)
|
||||
if osp.isfile(xpath): os.remove(xpath)
|
||||
with open(xpath, "w") as text_file:
|
||||
text_file.write('{:}'.format(Fstring))
|
||||
return Fstring
|
||||
|
||||
|
||||
def dict2config(xdict, logger):
|
||||
assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) )
|
||||
Arguments = namedtuple('Configure', ' '.join(xdict.keys()))
|
||||
content = Arguments(**xdict)
|
||||
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
|
||||
return content
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"scheduler": ["str", "cos"],
|
||||
"eta_min" : ["float", "0.0"],
|
||||
"epochs" : ["int", "200"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
"LR" : ["float", "0.1"],
|
||||
"decay" : ["float", "0.0005"],
|
||||
"momentum" : ["float", "0.9"],
|
||||
"nesterov" : ["bool", "1"],
|
||||
"criterion": ["str", "Softmax"],
|
||||
"batch_size": ["int", "256"]
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"scheduler": ["str", "cos"],
|
||||
"eta_min" : ["float", "0.0"],
|
||||
"epochs" : ["int", "50"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
"LR" : ["float", "0.1"],
|
||||
"decay" : ["float", "0.0005"],
|
||||
"momentum" : ["float", "0.9"],
|
||||
"nesterov" : ["bool", "1"],
|
||||
"criterion": ["str", "Softmax"],
|
||||
"batch_size": ["int", "256"]
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"scheduler": ["str", "cos"],
|
||||
"eta_min" : ["float", "0.0"],
|
||||
"epochs" : ["int", "200"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
"LR" : ["float", "0.1"],
|
||||
"decay" : ["float", "0.0005"],
|
||||
"momentum" : ["float", "0.9"],
|
||||
"nesterov" : ["bool", "1"],
|
||||
"criterion": ["str", "Softmax"],
|
||||
"batch_size": ["int", "256"]
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"scheduler": ["str", "cos"],
|
||||
"eta_min" : ["float", "0.0"],
|
||||
"epochs" : ["int", "200"],
|
||||
"warmup" : ["int", "0"],
|
||||
"optim" : ["str", "SGD"],
|
||||
"LR" : ["float", "0.1"],
|
||||
"decay" : ["float", "0.0005"],
|
||||
"momentum" : ["float", "0.9"],
|
||||
"nesterov" : ["bool", "1"],
|
||||
"criterion": ["str", "Softmax"],
|
||||
"batch_size": ["int", "256"]
|
||||
}
|
153
NAS-Bench-201/main_exp/transfer_nag/nas_bench_201/functions.py
Normal file
153
NAS-Bench-201/main_exp/transfer_nag/nas_bench_201/functions.py
Normal file
@@ -0,0 +1,153 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
import time
|
||||
import torch
|
||||
from procedures import prepare_seed, get_optim_scheduler
|
||||
from nasbench_utils import get_model_infos, obtain_accuracy
|
||||
from config_utils import dict2config
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from nas_bench_201_models import get_cell_based_tiny_net
|
||||
|
||||
|
||||
__all__ = ['evaluate_for_seed', 'pure_evaluate']
|
||||
|
||||
|
||||
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
latencies = []
|
||||
network.eval()
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
inputs = inputs.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
batch_time.update(time.time() - end)
|
||||
if batch is None or batch == inputs.size(0):
|
||||
batch = inputs.size(0)
|
||||
latencies.append(batch_time.val - data_time.val)
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(
|
||||
logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
end = time.time()
|
||||
if len(latencies) > 2:
|
||||
latencies = latencies[1:]
|
||||
return losses.avg, top1.avg, top5.avg, latencies
|
||||
|
||||
|
||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode):
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == 'train':
|
||||
network.train()
|
||||
elif mode == 'valid':
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == 'train':
|
||||
scheduler.update(None, 1.0 * i / len(xloader))
|
||||
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
if mode == 'train':
|
||||
optimizer.zero_grad()
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# backward
|
||||
if mode == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
# count time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||
|
||||
|
||||
def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger):
|
||||
prepare_seed(seed) # random seed
|
||||
net = get_cell_based_tiny_net(dict2config({'name': 'infer.tiny',
|
||||
'C': arch_config['channel'], 'N': arch_config['num_cells'],
|
||||
'genotype': arch, 'num_classes': config.class_num}, None)
|
||||
)
|
||||
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||
if 'ckpt_path' in arch_config.keys():
|
||||
ckpt = torch.load(arch_config['ckpt_path'])
|
||||
ckpt['classifier.weight'] = net.state_dict()['classifier.weight']
|
||||
ckpt['classifier.bias'] = net.state_dict()['classifier.bias']
|
||||
net.load_state_dict(ckpt)
|
||||
|
||||
flop, param = get_model_infos(net, config.xshape)
|
||||
logger.log('Network : {:}'.format(net.get_message()), False)
|
||||
logger.log(
|
||||
'{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed))
|
||||
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
|
||||
# train and valid
|
||||
optimizer, scheduler, criterion = get_optim_scheduler(
|
||||
net.parameters(), config)
|
||||
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
|
||||
# network, criterion = torch.nn.DataParallel(net).to(torch.device(f"cuda:{device}")), criterion.to(torch.device(f"cuda:{device}"))
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(
|
||||
), AverageMeter(), config.epochs + config.warmup
|
||||
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {
|
||||
}, {}, {}, {}, {}, {}
|
||||
train_times, valid_times = {}, {}
|
||||
for epoch in range(total_epoch):
|
||||
scheduler.update(epoch, 0.0)
|
||||
|
||||
train_loss, train_acc1, train_acc5, train_tm = procedure(
|
||||
train_loader, network, criterion, scheduler, optimizer, 'train')
|
||||
train_losses[epoch] = train_loss
|
||||
train_acc1es[epoch] = train_acc1
|
||||
train_acc5es[epoch] = train_acc5
|
||||
train_times[epoch] = train_tm
|
||||
with torch.no_grad():
|
||||
for key, xloder in valid_loaders.items():
|
||||
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(
|
||||
xloder, network, criterion, None, None, 'valid')
|
||||
valid_losses['{:}@{:}'.format(key, epoch)] = valid_loss
|
||||
valid_acc1es['{:}@{:}'.format(key, epoch)] = valid_acc1
|
||||
valid_acc5es['{:}@{:}'.format(key, epoch)] = valid_acc5
|
||||
valid_times['{:}@{:}'.format(key, epoch)] = valid_tm
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = 'Time Left: {:}'.format(convert_secs2time(
|
||||
epoch_time.avg * (total_epoch-epoch-1), True))
|
||||
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]'.format(
|
||||
time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5))
|
||||
info_seed = {'flop': flop,
|
||||
'param': param,
|
||||
'channel': arch_config['channel'],
|
||||
'num_cells': arch_config['num_cells'],
|
||||
'config': config._asdict(),
|
||||
'total_epoch': total_epoch,
|
||||
'train_losses': train_losses,
|
||||
'train_acc1es': train_acc1es,
|
||||
'train_acc5es': train_acc5es,
|
||||
'train_times': train_times,
|
||||
'valid_losses': valid_losses,
|
||||
'valid_acc1es': valid_acc1es,
|
||||
'valid_acc5es': valid_acc5es,
|
||||
'valid_times': valid_times,
|
||||
'net_state_dict': net.state_dict(),
|
||||
'net_string': '{:}'.format(net),
|
||||
'finish-train': True
|
||||
}
|
||||
return info_seed
|
@@ -0,0 +1,9 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# every package does not rely on pytorch or tensorflow
|
||||
# I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib
|
||||
from .logger import Logger#, PrintLogger
|
||||
from .meter import AverageMeter
|
||||
from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_secs2time
|
||||
from .time_utils import time_string, convert_secs2time
|
@@ -0,0 +1,150 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from pathlib import Path
|
||||
import importlib, warnings
|
||||
import os, sys, time, numpy as np
|
||||
if sys.version_info.major == 2: # Python 2.x
|
||||
from StringIO import StringIO as BIO
|
||||
else: # Python 3.x
|
||||
from io import BytesIO as BIO
|
||||
|
||||
if importlib.util.find_spec('tensorflow'):
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class PrintLogger(object):
|
||||
|
||||
def __init__(self):
|
||||
"""Create a summary writer logging to log_dir."""
|
||||
self.name = 'PrintLogger'
|
||||
|
||||
def log(self, string):
|
||||
print (string)
|
||||
|
||||
def close(self):
|
||||
print ('-'*30 + ' close printer ' + '-'*30)
|
||||
|
||||
|
||||
class Logger(object):
|
||||
|
||||
def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False):
|
||||
"""Create a summary writer logging to log_dir."""
|
||||
self.seed = int(seed)
|
||||
self.log_dir = Path(log_dir)
|
||||
self.model_dir = Path(log_dir) / 'checkpoint'
|
||||
self.log_dir.mkdir (parents=True, exist_ok=True)
|
||||
if create_model_dir:
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
#self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
|
||||
|
||||
self.use_tf = bool(use_tf)
|
||||
self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h', time.gmtime(time.time()) )))
|
||||
#self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) )))
|
||||
self.logger_path = self.log_dir / 'seed-{:}-T-{:}.log'.format(self.seed, time.strftime('%d-%h-at-%H-%M-%S', time.gmtime(time.time())))
|
||||
self.logger_file = open(self.logger_path, 'w')
|
||||
|
||||
if self.use_tf:
|
||||
self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
|
||||
self.writer = tf.summary.FileWriter(str(self.tensorboard_dir))
|
||||
else:
|
||||
self.writer = None
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def path(self, mode):
|
||||
valids = ('model', 'best', 'info', 'log')
|
||||
if mode == 'model': return self.model_dir / 'seed-{:}-basic.pth'.format(self.seed)
|
||||
elif mode == 'best' : return self.model_dir / 'seed-{:}-best.pth'.format(self.seed)
|
||||
elif mode == 'info' : return self.log_dir / 'seed-{:}-last-info.pth'.format(self.seed)
|
||||
elif mode == 'log' : return self.log_dir
|
||||
else: raise TypeError('Unknow mode = {:}, valid modes = {:}'.format(mode, valids))
|
||||
|
||||
def extract_log(self):
|
||||
return self.logger_file
|
||||
|
||||
def close(self):
|
||||
self.logger_file.close()
|
||||
if self.writer is not None:
|
||||
self.writer.close()
|
||||
|
||||
def log(self, string, save=True, stdout=False):
|
||||
if stdout:
|
||||
sys.stdout.write(string); sys.stdout.flush()
|
||||
else:
|
||||
print (string)
|
||||
if save:
|
||||
self.logger_file.write('{:}\n'.format(string))
|
||||
self.logger_file.flush()
|
||||
|
||||
def scalar_summary(self, tags, values, step):
|
||||
"""Log a scalar variable."""
|
||||
if not self.use_tf:
|
||||
warnings.warn('Do set use-tensorflow installed but call scalar_summary')
|
||||
else:
|
||||
assert isinstance(tags, list) == isinstance(values, list), 'Type : {:} vs {:}'.format(type(tags), type(values))
|
||||
if not isinstance(tags, list):
|
||||
tags, values = [tags], [values]
|
||||
for tag, value in zip(tags, values):
|
||||
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
|
||||
def image_summary(self, tag, images, step):
|
||||
"""Log a list of images."""
|
||||
import scipy
|
||||
if not self.use_tf:
|
||||
warnings.warn('Do set use-tensorflow installed but call scalar_summary')
|
||||
return
|
||||
|
||||
img_summaries = []
|
||||
for i, img in enumerate(images):
|
||||
# Write the image to a string
|
||||
try:
|
||||
s = StringIO()
|
||||
except:
|
||||
s = BytesIO()
|
||||
scipy.misc.toimage(img).save(s, format="png")
|
||||
|
||||
# Create an Image object
|
||||
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
|
||||
height=img.shape[0],
|
||||
width=img.shape[1])
|
||||
# Create a Summary value
|
||||
img_summaries.append(tf.Summary.Value(tag='{}/{}'.format(tag, i), image=img_sum))
|
||||
|
||||
# Create and write Summary
|
||||
summary = tf.Summary(value=img_summaries)
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
|
||||
def histo_summary(self, tag, values, step, bins=1000):
|
||||
"""Log a histogram of the tensor of values."""
|
||||
if not self.use_tf: raise ValueError('Do not have tensorflow')
|
||||
import tensorflow as tf
|
||||
|
||||
# Create a histogram using numpy
|
||||
counts, bin_edges = np.histogram(values, bins=bins)
|
||||
|
||||
# Fill the fields of the histogram proto
|
||||
hist = tf.HistogramProto()
|
||||
hist.min = float(np.min(values))
|
||||
hist.max = float(np.max(values))
|
||||
hist.num = int(np.prod(values.shape))
|
||||
hist.sum = float(np.sum(values))
|
||||
hist.sum_squares = float(np.sum(values**2))
|
||||
|
||||
# Drop the start of the first bin
|
||||
bin_edges = bin_edges[1:]
|
||||
|
||||
# Add bin edges and counts
|
||||
for edge in bin_edges:
|
||||
hist.bucket_limit.append(edge)
|
||||
for c in counts:
|
||||
hist.bucket.append(c)
|
||||
|
||||
# Create and write Summary
|
||||
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0.0
|
||||
self.avg = 0.0
|
||||
self.sum = 0.0
|
||||
self.count = 0.0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
|
||||
class RecorderMeter(object):
|
||||
"""Computes and stores the minimum loss value and its epoch index"""
|
||||
def __init__(self, total_epoch):
|
||||
self.reset(total_epoch)
|
||||
|
||||
def reset(self, total_epoch):
|
||||
assert total_epoch > 0, 'total_epoch should be greater than 0 vs {:}'.format(total_epoch)
|
||||
self.total_epoch = total_epoch
|
||||
self.current_epoch = 0
|
||||
self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
|
||||
self.epoch_losses = self.epoch_losses - 1
|
||||
self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
|
||||
self.epoch_accuracy= self.epoch_accuracy
|
||||
|
||||
def update(self, idx, train_loss, train_acc, val_loss, val_acc):
|
||||
assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx)
|
||||
self.epoch_losses [idx, 0] = train_loss
|
||||
self.epoch_losses [idx, 1] = val_loss
|
||||
self.epoch_accuracy[idx, 0] = train_acc
|
||||
self.epoch_accuracy[idx, 1] = val_acc
|
||||
self.current_epoch = idx + 1
|
||||
return self.max_accuracy(False) == self.epoch_accuracy[idx, 1]
|
||||
|
||||
def max_accuracy(self, istrain):
|
||||
if self.current_epoch <= 0: return 0
|
||||
if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max()
|
||||
else: return self.epoch_accuracy[:self.current_epoch, 1].max()
|
||||
|
||||
def plot_curve(self, save_path):
|
||||
import matplotlib
|
||||
matplotlib.use('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
title = 'the accuracy/loss curve of train/val'
|
||||
dpi = 100
|
||||
width, height = 1600, 1000
|
||||
legend_fontsize = 10
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
|
||||
y_axis = np.zeros(self.total_epoch)
|
||||
|
||||
plt.xlim(0, self.total_epoch)
|
||||
plt.ylim(0, 100)
|
||||
interval_y = 5
|
||||
interval_x = 5
|
||||
plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
|
||||
plt.yticks(np.arange(0, 100 + interval_y, interval_y))
|
||||
plt.grid()
|
||||
plt.title(title, fontsize=20)
|
||||
plt.xlabel('the training epoch', fontsize=16)
|
||||
plt.ylabel('accuracy', fontsize=16)
|
||||
|
||||
y_axis[:] = self.epoch_accuracy[:, 0]
|
||||
plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
y_axis[:] = self.epoch_accuracy[:, 1]
|
||||
plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
|
||||
y_axis[:] = self.epoch_losses[:, 0]
|
||||
plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
y_axis[:] = self.epoch_losses[:, 1]
|
||||
plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
if save_path is not None:
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
||||
print ('---- save figure {} into {}'.format(title, save_path))
|
||||
plt.close(fig)
|
@@ -0,0 +1,42 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import time, sys
|
||||
import numpy as np
|
||||
|
||||
def time_for_file():
|
||||
ISOTIMEFORMAT='%d-%h-at-%H-%M-%S'
|
||||
return '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||
|
||||
def time_string():
|
||||
ISOTIMEFORMAT='%Y-%m-%d %X'
|
||||
string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||
return string
|
||||
|
||||
def time_string_short():
|
||||
ISOTIMEFORMAT='%Y%m%d'
|
||||
string = '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||
return string
|
||||
|
||||
def time_print(string, is_print=True):
|
||||
if (is_print):
|
||||
print('{} : {}'.format(time_string(), string))
|
||||
|
||||
def convert_secs2time(epoch_time, return_str=False):
|
||||
need_hour = int(epoch_time / 3600)
|
||||
need_mins = int((epoch_time - 3600*need_hour) / 60)
|
||||
need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
|
||||
if return_str:
|
||||
str = '[{:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
|
||||
return str
|
||||
else:
|
||||
return need_hour, need_mins, need_secs
|
||||
|
||||
def print_log(print_string, log):
|
||||
#if isinstance(log, Logger): log.log('{:}'.format(print_string))
|
||||
if hasattr(log, 'log'): log.log('{:}'.format(print_string))
|
||||
else:
|
||||
print("{:}".format(print_string))
|
||||
if log is not None:
|
||||
log.write('{:}\n'.format(print_string))
|
||||
log.flush()
|
@@ -0,0 +1,4 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .get_dataset_with_transform import get_datasets
|
@@ -0,0 +1,179 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import torch.utils.data as data
|
||||
from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader
|
||||
from PIL import Image
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_dataset(dir, image_ids, targets):
|
||||
assert (len(image_ids) == len(targets))
|
||||
images = []
|
||||
dir = os.path.expanduser(dir)
|
||||
for i in range(len(image_ids)):
|
||||
item = (os.path.join(dir, 'data', 'images',
|
||||
'%s.jpg' % image_ids[i]), targets[i])
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
|
||||
def find_classes(classes_file):
|
||||
# read classes file, separating out image IDs and class names
|
||||
image_ids = []
|
||||
targets = []
|
||||
f = open(classes_file, 'r')
|
||||
for line in f:
|
||||
split_line = line.split(' ')
|
||||
image_ids.append(split_line[0])
|
||||
targets.append(' '.join(split_line[1:]))
|
||||
f.close()
|
||||
|
||||
# index class names
|
||||
classes = np.unique(targets)
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
targets = [class_to_idx[c] for c in targets]
|
||||
|
||||
return (image_ids, targets, classes, class_to_idx)
|
||||
|
||||
|
||||
class FGVCAircraft(data.Dataset):
|
||||
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset.
|
||||
Args:
|
||||
root (string): Root directory path to dataset.
|
||||
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
|
||||
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
|
||||
transform (callable, optional): A function/transform that takes in a PIL image
|
||||
and returns a transformed version. E.g. ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
loader (callable, optional): A function to load an image given its path.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in the root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
"""
|
||||
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
|
||||
class_types = ('variant', 'family', 'manufacturer')
|
||||
splits = ('train', 'val', 'trainval', 'test')
|
||||
|
||||
def __init__(self, root, class_type='variant', split='train', transform=None,
|
||||
target_transform=None, loader=default_loader, download=False):
|
||||
if split not in self.splits:
|
||||
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
|
||||
split, ', '.join(self.splits),
|
||||
))
|
||||
if class_type not in self.class_types:
|
||||
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
|
||||
class_type, ', '.join(self.class_types),
|
||||
))
|
||||
self.root = os.path.expanduser(root)
|
||||
self.root = os.path.join(self.root, 'fgvc-aircraft-2013b')
|
||||
self.class_type = class_type
|
||||
self.split = split
|
||||
self.classes_file = os.path.join(self.root, 'data',
|
||||
'images_%s_%s.txt' % (self.class_type, self.split))
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
|
||||
samples = make_dataset(self.root, image_ids, targets)
|
||||
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self.loader = loader
|
||||
|
||||
self.samples = samples
|
||||
self.classes = classes
|
||||
self.class_to_idx = class_to_idx
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
tuple: (sample, target) where target is class_index of the target class.
|
||||
"""
|
||||
|
||||
path, target = self.samples[index]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return sample, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
||||
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
|
||||
fmt_str += ' Root Location: {}\n'.format(self.root)
|
||||
tmp = ' Transforms (if any): '
|
||||
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
tmp = ' Target Transforms (if any): '
|
||||
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
return fmt_str
|
||||
|
||||
def _check_exists(self):
|
||||
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
|
||||
os.path.exists(self.classes_file)
|
||||
|
||||
def download(self):
|
||||
"""Download the FGVC-Aircraft data if it doesn't exist already."""
|
||||
from six.moves import urllib
|
||||
import tarfile
|
||||
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
|
||||
print('Downloading %s ... (may take a few minutes)' % self.url)
|
||||
parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
|
||||
tar_name = self.url.rpartition('/')[-1]
|
||||
tar_path = os.path.join(parent_dir, tar_name)
|
||||
data = urllib.request.urlopen(self.url)
|
||||
|
||||
# download .tar.gz file
|
||||
with open(tar_path, 'wb') as f:
|
||||
f.write(data.read())
|
||||
|
||||
# extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
|
||||
data_folder = tar_path.strip('.tar.gz')
|
||||
print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
|
||||
tar = tarfile.open(tar_path)
|
||||
tar.extractall(parent_dir)
|
||||
|
||||
# if necessary, rename data folder to self.root
|
||||
if not os.path.samefile(data_folder, self.root):
|
||||
print('Renaming %s to %s ...' % (data_folder, self.root))
|
||||
os.rename(data_folder, self.root)
|
||||
|
||||
# delete .tar.gz file
|
||||
print('Deleting %s ...' % tar_path)
|
||||
os.remove(tar_path)
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='train', transform=None,
|
||||
target_transform=None, loader=default_loader, download=False)
|
||||
print(len(air))
|
||||
print(len(air))
|
||||
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='val', transform=None,
|
||||
target_transform=None, loader=default_loader, download=False)
|
||||
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='trainval', transform=None,
|
||||
target_transform=None, loader=default_loader, download=False)
|
||||
print(len(air))
|
||||
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b/', class_type='manufacturer', split='test', transform=None,
|
||||
target_transform=None, loader=default_loader, download=False)
|
||||
print(len(air))
|
||||
import pdb;
|
||||
pdb.set_trace()
|
||||
print(len(air))
|
@@ -0,0 +1,304 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung 2021. 03.
|
||||
##################################################
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
from copy import deepcopy
|
||||
# from PIL import Image
|
||||
import random
|
||||
import pdb
|
||||
from .aircraft import FGVCAircraft
|
||||
from .pets import PetDataset
|
||||
from config_utils import load_config
|
||||
|
||||
Dataset2Class = {'cifar10': 10,
|
||||
'cifar100': 100,
|
||||
'mnist': 10,
|
||||
'svhn': 10,
|
||||
'aircraft': 30,
|
||||
'pets': 37}
|
||||
|
||||
|
||||
class CUTOUT(object):
|
||||
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def __call__(self, img):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1: y2, x1: x2] = 0.
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
|
||||
imagenet_pca = {
|
||||
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
|
||||
'eigvec': np.asarray([
|
||||
[-0.5675, 0.7192, 0.4009],
|
||||
[-0.5808, -0.0045, -0.8140],
|
||||
[-0.5836, -0.6948, 0.4203],
|
||||
])
|
||||
}
|
||||
|
||||
|
||||
class Lighting(object):
|
||||
def __init__(self, alphastd,
|
||||
eigval=imagenet_pca['eigval'],
|
||||
eigvec=imagenet_pca['eigvec']):
|
||||
self.alphastd = alphastd
|
||||
assert eigval.shape == (3,)
|
||||
assert eigvec.shape == (3, 3)
|
||||
self.eigval = eigval
|
||||
self.eigvec = eigvec
|
||||
|
||||
def __call__(self, img):
|
||||
if self.alphastd == 0.:
|
||||
return img
|
||||
rnd = np.random.randn(3) * self.alphastd
|
||||
rnd = rnd.astype('float32')
|
||||
v = rnd
|
||||
old_dtype = np.asarray(img).dtype
|
||||
v = v * self.eigval
|
||||
v = v.reshape((3, 1))
|
||||
inc = np.dot(self.eigvec, v).reshape((3,))
|
||||
img = np.add(img, inc)
|
||||
if old_dtype == np.uint8:
|
||||
img = np.clip(img, 0, 255)
|
||||
img = Image.fromarray(img.astype(old_dtype), 'RGB')
|
||||
return img
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '()'
|
||||
|
||||
|
||||
def get_datasets(name, root, cutout, use_num_cls=None):
|
||||
if name == 'cifar10':
|
||||
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||
elif name == 'cifar100':
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
elif name.startswith('mnist'):
|
||||
mean, std = [0.1307, 0.1307, 0.1307], [0.3081, 0.3081, 0.3081]
|
||||
elif name.startswith('svhn'):
|
||||
mean, std = [0.4376821, 0.4437697, 0.47280442], [
|
||||
0.19803012, 0.20101562, 0.19703614]
|
||||
elif name.startswith('aircraft'):
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
elif name.startswith('pets'):
|
||||
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
# Data Argumentation
|
||||
if name == 'cifar10' or name == 'cifar100':
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std)]
|
||||
if cutout > 0:
|
||||
lists += [CUTOUT(cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('cub200'):
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('mnist'):
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
|
||||
transforms.Normalize(mean, std),
|
||||
])
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
|
||||
transforms.Normalize(mean, std)
|
||||
])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('svhn'):
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('aircraft'):
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std),
|
||||
])
|
||||
xshape = (1, 3, 32, 32)
|
||||
elif name.startswith('pets'):
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std),
|
||||
])
|
||||
xshape = (1, 3, 32, 32)
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
if name == 'cifar10':
|
||||
train_data = dset.CIFAR10(
|
||||
root, train=True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR10(
|
||||
root, train=False, transform=test_transform, download=True)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name == 'cifar100':
|
||||
train_data = dset.CIFAR100(
|
||||
root, train=True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(
|
||||
root, train=False, transform=test_transform, download=True)
|
||||
assert len(train_data) == 50000 and len(test_data) == 10000
|
||||
elif name == 'mnist':
|
||||
train_data = dset.MNIST(
|
||||
root, train=True, transform=train_transform, download=True)
|
||||
test_data = dset.MNIST(
|
||||
root, train=False, transform=test_transform, download=True)
|
||||
assert len(train_data) == 60000 and len(test_data) == 10000
|
||||
elif name == 'svhn':
|
||||
train_data = dset.SVHN(root, split='train',
|
||||
transform=train_transform, download=True)
|
||||
test_data = dset.SVHN(root, split='test',
|
||||
transform=test_transform, download=True)
|
||||
assert len(train_data) == 73257 and len(test_data) == 26032
|
||||
elif name == 'aircraft':
|
||||
train_data = FGVCAircraft(root, class_type='manufacturer', split='trainval',
|
||||
transform=train_transform, download=False)
|
||||
test_data = FGVCAircraft(root, class_type='manufacturer', split='test',
|
||||
transform=test_transform, download=False)
|
||||
assert len(train_data) == 6667 and len(test_data) == 3333
|
||||
elif name == 'pets':
|
||||
train_data = PetDataset(root, train=True, num_cl=37,
|
||||
val_split=0.15, transforms=train_transform)
|
||||
test_data = PetDataset(root, train=False, num_cl=37,
|
||||
val_split=0.15, transforms=test_transform)
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(name))
|
||||
|
||||
class_num = Dataset2Class[name] if use_num_cls is None else len(
|
||||
use_num_cls)
|
||||
return train_data, test_data, xshape, class_num
|
||||
|
||||
|
||||
def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, num_cls=None):
|
||||
if isinstance(batch_size, (list, tuple)):
|
||||
batch, test_batch = batch_size
|
||||
else:
|
||||
batch, test_batch = batch_size, batch_size
|
||||
if dataset == 'cifar10':
|
||||
# split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
cifar_split = load_config(
|
||||
'{:}/cifar-split.txt'.format(config_root), None, None)
|
||||
# search over the proposed training and validation set
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
# logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
|
||||
# To split data
|
||||
xvalid_data = deepcopy(train_data)
|
||||
if hasattr(xvalid_data, 'transforms'): # to avoid a print issue
|
||||
xvalid_data.transforms = valid_data.transform
|
||||
xvalid_data.transform = deepcopy(valid_data.transform)
|
||||
search_data = SearchDataset(
|
||||
dataset, train_data, train_split, valid_split)
|
||||
# data loader
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
|
||||
pin_memory=True)
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
train_split),
|
||||
num_workers=workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
valid_split),
|
||||
num_workers=workers, pin_memory=True)
|
||||
elif dataset == 'cifar100':
|
||||
cifar100_test_split = load_config(
|
||||
'{:}/cifar100-test-split.txt'.format(config_root), None, None)
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data)
|
||||
search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
|
||||
list(range(len(search_train_data))),
|
||||
cifar100_test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
|
||||
pin_memory=True)
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True, num_workers=workers,
|
||||
pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
cifar100_test_split.xvalid), num_workers=workers, pin_memory=True)
|
||||
elif dataset in ['mnist', 'svhn', 'aircraft', 'pets']:
|
||||
if not os.path.exists('{:}/{}-test-split.txt'.format(config_root, dataset)):
|
||||
import json
|
||||
label_list = list(range(len(valid_data)))
|
||||
random.shuffle(label_list)
|
||||
strlist = [str(label_list[i]) for i in range(len(label_list))]
|
||||
split = {'xvalid': ["int", strlist[:len(valid_data) // 2]],
|
||||
'xtest': ["int", strlist[len(valid_data) // 2:]]}
|
||||
with open('{:}/{}-test-split.txt'.format(config_root, dataset), 'w') as f:
|
||||
f.write(json.dumps(split))
|
||||
test_split = load_config(
|
||||
'{:}/{}-test-split.txt'.format(config_root, dataset), None, None)
|
||||
|
||||
search_train_data = train_data
|
||||
search_valid_data = deepcopy(valid_data)
|
||||
search_valid_data.transform = train_data.transform
|
||||
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
|
||||
list(range(len(search_train_data))), test_split.xvalid)
|
||||
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True,
|
||||
num_workers=workers, pin_memory=True)
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True,
|
||||
num_workers=workers, pin_memory=True)
|
||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||||
test_split.xvalid), num_workers=workers, pin_memory=True)
|
||||
else:
|
||||
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||
return search_loader, train_loader, valid_loader
|
@@ -0,0 +1,45 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import torch
|
||||
from glob import glob
|
||||
from torch.utils.data.dataset import Dataset
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def load_image(filename):
|
||||
img = Image.open(filename)
|
||||
img = img.convert('RGB')
|
||||
return img
|
||||
|
||||
class PetDataset(Dataset):
|
||||
def __init__(self, root, train=True, num_cl=37, val_split=0.2, transforms=None):
|
||||
self.data = torch.load(os.path.join(root,'{}{}.pth'.format('train' if train else 'test',
|
||||
int(100*(1-val_split)) if train else int(100*val_split))))
|
||||
self.len = len(self.data)
|
||||
self.transform = transforms
|
||||
def __getitem__(self, index):
|
||||
img, label = self.data[index]
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, label
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Added
|
||||
import torchvision.transforms as transforms
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_transform = transforms.Compose(
|
||||
[transforms.Resize(256), transforms.RandomRotation(45), transforms.CenterCrop(224),
|
||||
transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
|
||||
test_transform = transforms.Compose(
|
||||
[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
|
||||
root = '/w14/dataset/MetaGen/pets'
|
||||
train_data, test_data = get_pets(root, num_cl=37, val_split=0.2,
|
||||
tr_transform=train_transform,
|
||||
te_transform=test_transform)
|
||||
import pdb;
|
||||
pdb.set_trace()
|
@@ -0,0 +1,34 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def additive_func(A, B):
|
||||
assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size())
|
||||
C = min(A.size(1), B.size(1))
|
||||
if A.size(1) == B.size(1):
|
||||
return A + B
|
||||
elif A.size(1) < B.size(1):
|
||||
out = B.clone()
|
||||
out[:,:C] += A
|
||||
return out
|
||||
else:
|
||||
out = A.clone()
|
||||
out[:,:C] += B
|
||||
return out
|
||||
|
||||
|
||||
def change_key(key, value):
|
||||
def func(m):
|
||||
if hasattr(m, key):
|
||||
setattr(m, key, value)
|
||||
return func
|
||||
|
||||
|
||||
def parse_channel_info(xstring):
|
||||
blocks = xstring.split(' ')
|
||||
blocks = [x.split('-') for x in blocks]
|
||||
blocks = [[int(_) for _ in x] for x in blocks]
|
||||
return blocks
|
@@ -0,0 +1,45 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from os import path as osp
|
||||
from typing import List, Text
|
||||
import torch
|
||||
|
||||
__all__ = ['get_cell_based_tiny_net', 'get_search_spaces', \
|
||||
'CellStructure', 'CellArchitectures'
|
||||
]
|
||||
|
||||
# useful modules
|
||||
from config_utils import dict2config
|
||||
from .SharedUtils import change_key
|
||||
from .cell_searchs import CellStructure, CellArchitectures
|
||||
|
||||
|
||||
# Cell-based NAS Models
|
||||
def get_cell_based_tiny_net(config):
|
||||
if config.name == 'infer.tiny':
|
||||
from .cell_infers import TinyNetwork
|
||||
if hasattr(config, 'genotype'):
|
||||
genotype = config.genotype
|
||||
elif hasattr(config, 'arch_str'):
|
||||
genotype = CellStructure.str2structure(config.arch_str)
|
||||
else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
|
||||
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
|
||||
else:
|
||||
raise ValueError('invalid network name : {:}'.format(config.name))
|
||||
|
||||
|
||||
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
|
||||
def get_search_spaces(xtype, name) -> List[Text]:
|
||||
if xtype == 'cell' or xtype == 'tss': # The topology search space.
|
||||
from .cell_operations import SearchSpaceNames
|
||||
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
|
||||
return SearchSpaceNames[name]
|
||||
elif xtype == 'sss': # The size search space.
|
||||
if name == 'nas-bench-301':
|
||||
return {'candidates': [8, 16, 24, 32, 40, 48, 56, 64],
|
||||
'numbers': 5}
|
||||
else:
|
||||
raise ValueError('Invalid name : {:}'.format(name))
|
||||
else:
|
||||
raise ValueError('invalid search-space type is {:}'.format(xtype))
|
@@ -0,0 +1,4 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from .tiny_network import TinyNetwork
|
@@ -0,0 +1,122 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import OPS
|
||||
|
||||
|
||||
# Cell for NAS-Bench-201
|
||||
class InferCell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_in, C_out, stride):
|
||||
super(InferCell, self).__init__()
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.node_IN = []
|
||||
self.node_IX = []
|
||||
self.genotype = deepcopy(genotype)
|
||||
for i in range(1, len(genotype)):
|
||||
node_info = genotype[i-1]
|
||||
cur_index = []
|
||||
cur_innod = []
|
||||
for (op_name, op_in) in node_info:
|
||||
if op_in == 0:
|
||||
layer = OPS[op_name](C_in , C_out, stride, True, True)
|
||||
else:
|
||||
layer = OPS[op_name](C_out, C_out, 1, True, True)
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
cur_index.append( len(self.layers) )
|
||||
cur_innod.append( op_in )
|
||||
self.layers.append( layer )
|
||||
self.node_IX.append( cur_index )
|
||||
self.node_IN.append( cur_innod )
|
||||
self.nodes = len(genotype)
|
||||
self.in_dim = C_in
|
||||
self.out_dim = C_out
|
||||
|
||||
def extra_repr(self):
|
||||
string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
|
||||
laystr = []
|
||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
|
||||
y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)]
|
||||
x = '{:}<-({:})'.format(i+1, ','.join(y))
|
||||
laystr.append( x )
|
||||
return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr())
|
||||
|
||||
def forward(self, inputs):
|
||||
nodes = [inputs]
|
||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
|
||||
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
|
||||
nodes.append( node_feature )
|
||||
return nodes[-1]
|
||||
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
class NASNetInferCell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
|
||||
super(NASNetInferCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
|
||||
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
|
||||
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
|
||||
|
||||
if not reduction:
|
||||
nodes, concats = genotype['normal'], genotype['normal_concat']
|
||||
else:
|
||||
nodes, concats = genotype['reduce'], genotype['reduce_concat']
|
||||
self._multiplier = len(concats)
|
||||
self._concats = concats
|
||||
self._steps = len(nodes)
|
||||
self._nodes = nodes
|
||||
self.edges = nn.ModuleDict()
|
||||
for i, node in enumerate(nodes):
|
||||
for in_node in node:
|
||||
name, j = in_node[0], in_node[1]
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
node_str = '{:}<-{:}'.format(i+2, j)
|
||||
self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats)
|
||||
|
||||
# [TODO] to support drop_prob in this function..
|
||||
def forward(self, s0, s1, unused_drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i, node in enumerate(self._nodes):
|
||||
clist = []
|
||||
for in_node in node:
|
||||
name, j = in_node[0], in_node[1]
|
||||
node_str = '{:}<-{:}'.format(i+2, j)
|
||||
op = self.edges[ node_str ]
|
||||
clist.append( op(states[j]) )
|
||||
states.append( sum(clist) )
|
||||
return torch.cat([states[x] for x in self._concats], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0),-1))
|
||||
return x
|
@@ -0,0 +1,66 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .cells import InferCell
|
||||
|
||||
|
||||
# The macro structure for architectures in NAS-Bench-201
|
||||
class TinyNetwork(nn.Module):
|
||||
|
||||
def __init__(self, C, N, genotype, num_classes):
|
||||
super(TinyNetwork, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C))
|
||||
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev = C
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
|
||||
else:
|
||||
cell = InferCell(genotype, C_prev, C_curr, 1)
|
||||
self.cells.append( cell )
|
||||
C_prev = cell.out_dim
|
||||
self._Layer= len(self.cells)
|
||||
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
feature = cell(feature)
|
||||
'''
|
||||
out2 = self.lastact(feature)
|
||||
out = self.global_pooling( out2 )
|
||||
out = out.view(out.size(0), -1)
|
||||
out2 = out2.view(out2.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
return out2, logits
|
||||
|
||||
'''
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling( out )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
@@ -0,0 +1,308 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
|
||||
|
||||
OPS = {
|
||||
'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride),
|
||||
'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats),
|
||||
'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats),
|
||||
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
|
||||
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
|
||||
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
|
||||
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats),
|
||||
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
|
||||
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
|
||||
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
|
||||
}
|
||||
|
||||
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
||||
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
|
||||
|
||||
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
||||
'nas-bench-201': NAS_BENCH_201,
|
||||
'nas-bench-301': NAS_BENCH_201,
|
||||
'darts' : DARTS_SPACE}
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine),
|
||||
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine),
|
||||
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DualSepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
|
||||
super(DualSepConv, self).__init__()
|
||||
self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats)
|
||||
self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.op_a(x)
|
||||
x = self.op_b(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes, stride, affine=True):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine)
|
||||
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine)
|
||||
if stride == 2:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
|
||||
elif inplanes != planes:
|
||||
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.in_dim = inplanes
|
||||
self.out_dim = planes
|
||||
self.stride = stride
|
||||
self.num_conv = 2
|
||||
|
||||
def extra_repr(self):
|
||||
string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__)
|
||||
return string
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
return residual + basicblock
|
||||
|
||||
|
||||
class POOLING(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True):
|
||||
super(POOLING, self).__init__()
|
||||
if C_in == C_out:
|
||||
self.preprocess = None
|
||||
else:
|
||||
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1, affine, track_running_stats)
|
||||
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
||||
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
||||
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.preprocess: x = self.preprocess(inputs)
|
||||
else : x = inputs
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.stride = stride
|
||||
self.is_zero = True
|
||||
|
||||
def forward(self, x):
|
||||
if self.C_in == self.C_out:
|
||||
if self.stride == 1: return x.mul(0.)
|
||||
else : return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||
else:
|
||||
shape = list(x.shape)
|
||||
shape[1] = self.C_out
|
||||
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
|
||||
return zeros
|
||||
|
||||
def extra_repr(self):
|
||||
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
self.stride = stride
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
if stride == 2:
|
||||
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
|
||||
C_outs = [C_out // 2, C_out - C_out // 2]
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(2):
|
||||
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine))
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
elif stride == 1:
|
||||
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
|
||||
else:
|
||||
raise ValueError('Invalid stride : {:}'.format(stride))
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 2:
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
|
||||
else:
|
||||
out = self.conv(x)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
def extra_repr(self):
|
||||
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
|
||||
|
||||
|
||||
# Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019
|
||||
class PartAwareOp(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, part=4):
|
||||
super().__init__()
|
||||
self.part = 4
|
||||
self.hidden = C_in // 3
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.local_conv_list = nn.ModuleList()
|
||||
for i in range(self.part):
|
||||
self.local_conv_list.append(
|
||||
nn.Sequential(nn.ReLU(), nn.Conv2d(C_in, self.hidden, 1), nn.BatchNorm2d(self.hidden, affine=True))
|
||||
)
|
||||
self.W_K = nn.Linear(self.hidden, self.hidden)
|
||||
self.W_Q = nn.Linear(self.hidden, self.hidden)
|
||||
|
||||
if stride == 2 : self.last = FactorizedReduce(C_in + self.hidden, C_out, 2)
|
||||
elif stride == 1: self.last = FactorizedReduce(C_in + self.hidden, C_out, 1)
|
||||
else: raise ValueError('Invalid Stride : {:}'.format(stride))
|
||||
|
||||
def forward(self, x):
|
||||
batch, C, H, W = x.size()
|
||||
assert H >= self.part, 'input size too small : {:} vs {:}'.format(x.shape, self.part)
|
||||
IHs = [0]
|
||||
for i in range(self.part): IHs.append( min(H, int((i+1)*(float(H)/self.part))) )
|
||||
local_feat_list = []
|
||||
for i in range(self.part):
|
||||
feature = x[:, :, IHs[i]:IHs[i+1], :]
|
||||
xfeax = self.avg_pool(feature)
|
||||
xfea = self.local_conv_list[i]( xfeax )
|
||||
local_feat_list.append( xfea )
|
||||
part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part)
|
||||
part_feature = part_feature.transpose(1,2).contiguous()
|
||||
part_K = self.W_K(part_feature)
|
||||
part_Q = self.W_Q(part_feature).transpose(1,2).contiguous()
|
||||
weight_att = torch.bmm(part_K, part_Q)
|
||||
attention = torch.softmax(weight_att, dim=2)
|
||||
aggreateF = torch.bmm(attention, part_feature).transpose(1,2).contiguous()
|
||||
features = []
|
||||
for i in range(self.part):
|
||||
feature = aggreateF[:, :, i:i+1].expand(batch, self.hidden, IHs[i+1]-IHs[i])
|
||||
feature = feature.view(batch, self.hidden, IHs[i+1]-IHs[i], 1)
|
||||
features.append( feature )
|
||||
features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W)
|
||||
final_fea = torch.cat((x,features), dim=1)
|
||||
outputs = self.last( final_fea )
|
||||
return outputs
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1. - drop_prob
|
||||
mask = x.new_zeros(x.size(0), 1, 1, 1)
|
||||
mask = mask.bernoulli_(keep_prob)
|
||||
x = torch.div(x, keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours
|
||||
class GDAS_Reduction_Cell(nn.Module):
|
||||
|
||||
def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier, affine, track_running_stats):
|
||||
super(GDAS_Reduction_Cell, self).__init__()
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine, track_running_stats)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, 1, affine, track_running_stats)
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.reduction = True
|
||||
self.ops1 = nn.ModuleList(
|
||||
[nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
|
||||
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True)),
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
|
||||
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C, affine=True))])
|
||||
|
||||
self.ops2 = nn.ModuleList(
|
||||
[nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(C, affine=True)),
|
||||
nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(C, affine=True))])
|
||||
|
||||
def forward(self, s0, s1, drop_prob = -1):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
X0 = self.ops1[0] (s0)
|
||||
X1 = self.ops1[1] (s1)
|
||||
if self.training and drop_prob > 0.:
|
||||
X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)
|
||||
|
||||
#X2 = self.ops2[0] (X0+X1)
|
||||
X2 = self.ops2[0] (s0)
|
||||
X3 = self.ops2[1] (s1)
|
||||
if self.training and drop_prob > 0.:
|
||||
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
|
||||
return torch.cat([X0, X1, X2, X3], dim=1)
|
@@ -0,0 +1,26 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# The macro structure is defined in NAS-Bench-201
|
||||
# from .search_model_darts import TinyNetworkDarts
|
||||
# from .search_model_gdas import TinyNetworkGDAS
|
||||
# from .search_model_setn import TinyNetworkSETN
|
||||
# from .search_model_enas import TinyNetworkENAS
|
||||
# from .search_model_random import TinyNetworkRANDOM
|
||||
# from .generic_model import GenericNAS201Model
|
||||
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
|
||||
# NASNet-based macro structure
|
||||
# from .search_model_gdas_nasnet import NASNetworkGDAS
|
||||
# from .search_model_darts_nasnet import NASNetworkDARTS
|
||||
|
||||
|
||||
# nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
|
||||
# "DARTS-V2": TinyNetworkDarts,
|
||||
# "GDAS": TinyNetworkGDAS,
|
||||
# "SETN": TinyNetworkSETN,
|
||||
# "ENAS": TinyNetworkENAS,
|
||||
# "RANDOM": TinyNetworkRANDOM,
|
||||
# "generic": GenericNAS201Model}
|
||||
|
||||
# nasnet_super_nets = {"GDAS": NASNetworkGDAS,
|
||||
# "DARTS": NASNetworkDARTS}
|
@@ -0,0 +1,198 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def get_combination(space, num):
|
||||
combs = []
|
||||
for i in range(num):
|
||||
if i == 0:
|
||||
for func in space:
|
||||
combs.append( [(func, i)] )
|
||||
else:
|
||||
new_combs = []
|
||||
for string in combs:
|
||||
for func in space:
|
||||
xstring = string + [(func, i)]
|
||||
new_combs.append( xstring )
|
||||
combs = new_combs
|
||||
return combs
|
||||
|
||||
|
||||
class Structure:
|
||||
|
||||
def __init__(self, genotype):
|
||||
assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype))
|
||||
self.node_num = len(genotype) + 1
|
||||
self.nodes = []
|
||||
self.node_N = []
|
||||
for idx, node_info in enumerate(genotype):
|
||||
assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info))
|
||||
assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info))
|
||||
for node_in in node_info:
|
||||
assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in))
|
||||
assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in)
|
||||
self.node_N.append( len(node_info) )
|
||||
self.nodes.append( tuple(deepcopy(node_info)) )
|
||||
|
||||
def tolist(self, remove_str):
|
||||
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
|
||||
# note that we re-order the input node in this function
|
||||
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
|
||||
genotypes = []
|
||||
for node_info in self.nodes:
|
||||
node_info = list( node_info )
|
||||
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
|
||||
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
|
||||
if len(node_info) == 0: return None, False
|
||||
genotypes.append( node_info )
|
||||
return genotypes, True
|
||||
|
||||
def node(self, index):
|
||||
assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self))
|
||||
return self.nodes[index]
|
||||
|
||||
def tostr(self):
|
||||
strings = []
|
||||
for node_info in self.nodes:
|
||||
string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info])
|
||||
string = '|{:}|'.format(string)
|
||||
strings.append( string )
|
||||
return '+'.join(strings)
|
||||
|
||||
def check_valid(self):
|
||||
nodes = {0: True}
|
||||
for i, node_info in enumerate(self.nodes):
|
||||
sums = []
|
||||
for op, xin in node_info:
|
||||
if op == 'none' or nodes[xin] is False: x = False
|
||||
else: x = True
|
||||
sums.append( x )
|
||||
nodes[i+1] = sum(sums) > 0
|
||||
return nodes[len(self.nodes)]
|
||||
|
||||
def to_unique_str(self, consider_zero=False):
|
||||
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
|
||||
# two operations are special, i.e., none and skip_connect
|
||||
nodes = {0: '0'}
|
||||
for i_node, node_info in enumerate(self.nodes):
|
||||
cur_node = []
|
||||
for op, xin in node_info:
|
||||
if consider_zero is None:
|
||||
x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
elif consider_zero:
|
||||
if op == 'none' or nodes[xin] == '#': x = '#' # zero
|
||||
elif op == 'skip_connect': x = nodes[xin]
|
||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
else:
|
||||
if op == 'skip_connect': x = nodes[xin]
|
||||
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
|
||||
cur_node.append(x)
|
||||
nodes[i_node+1] = '+'.join( sorted(cur_node) )
|
||||
return nodes[ len(self.nodes) ]
|
||||
|
||||
def check_valid_op(self, op_names):
|
||||
for node_info in self.nodes:
|
||||
for inode_edge in node_info:
|
||||
#assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
|
||||
if inode_edge[0] not in op_names: return False
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.nodes) + 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.nodes[index]
|
||||
|
||||
@staticmethod
|
||||
def str2structure(xstr):
|
||||
if isinstance(xstr, Structure): return xstr
|
||||
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||
nodestrs = xstr.split('+')
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
inputs = ( xi.split('~') for xi in inputs )
|
||||
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
|
||||
genotypes.append( input_infos )
|
||||
return Structure( genotypes )
|
||||
|
||||
@staticmethod
|
||||
def str2fullstructure(xstr, default_name='none'):
|
||||
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
|
||||
nodestrs = xstr.split('+')
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
inputs = list(filter(lambda x: x != '', node_str.split('|')))
|
||||
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
|
||||
inputs = ( xi.split('~') for xi in inputs )
|
||||
input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
|
||||
all_in_nodes= list(x[1] for x in input_infos)
|
||||
for j in range(i):
|
||||
if j not in all_in_nodes: input_infos.append((default_name, j))
|
||||
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
|
||||
genotypes.append( tuple(node_info) )
|
||||
return Structure( genotypes )
|
||||
|
||||
@staticmethod
|
||||
def gen_all(search_space, num, return_ori):
|
||||
assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
|
||||
assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
|
||||
all_archs = get_combination(search_space, 1)
|
||||
for i, arch in enumerate(all_archs):
|
||||
all_archs[i] = [ tuple(arch) ]
|
||||
|
||||
for inode in range(2, num):
|
||||
cur_nodes = get_combination(search_space, inode)
|
||||
new_all_archs = []
|
||||
for previous_arch in all_archs:
|
||||
for cur_node in cur_nodes:
|
||||
new_all_archs.append( previous_arch + [tuple(cur_node)] )
|
||||
all_archs = new_all_archs
|
||||
if return_ori:
|
||||
return all_archs
|
||||
else:
|
||||
return [Structure(x) for x in all_archs]
|
||||
|
||||
|
||||
|
||||
ResNet_CODE = Structure(
|
||||
[(('nor_conv_3x3', 0), ), # node-1
|
||||
(('nor_conv_3x3', 1), ), # node-2
|
||||
(('skip_connect', 0), ('skip_connect', 2))] # node-3
|
||||
)
|
||||
|
||||
AllConv3x3_CODE = Structure(
|
||||
[(('nor_conv_3x3', 0), ), # node-1
|
||||
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2
|
||||
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3
|
||||
)
|
||||
|
||||
AllFull_CODE = Structure(
|
||||
[(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1
|
||||
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2
|
||||
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3
|
||||
)
|
||||
|
||||
AllConv1x1_CODE = Structure(
|
||||
[(('nor_conv_1x1', 0), ), # node-1
|
||||
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2
|
||||
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3
|
||||
)
|
||||
|
||||
AllIdentity_CODE = Structure(
|
||||
[(('skip_connect', 0), ), # node-1
|
||||
(('skip_connect', 0), ('skip_connect', 1)), # node-2
|
||||
(('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3
|
||||
)
|
||||
|
||||
architectures = {'resnet' : ResNet_CODE,
|
||||
'all_c3x3': AllConv3x3_CODE,
|
||||
'all_c1x1': AllConv1x1_CODE,
|
||||
'all_idnt': AllIdentity_CODE,
|
||||
'all_full': AllFull_CODE}
|
@@ -0,0 +1,167 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
||||
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else : self.avg = None
|
||||
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
|
||||
if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
else : self.bn = None
|
||||
if has_relu: self.relu = nn.ReLU(inplace=True)
|
||||
else : self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg : out = self.avg( inputs )
|
||||
else : out = inputs
|
||||
conv = self.conv( out )
|
||||
if self.bn : out = self.bn( conv )
|
||||
else : out = conv
|
||||
if self.relu: out = self.relu( out )
|
||||
else : out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
|
||||
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
|
||||
else:
|
||||
self.downsample = None
|
||||
#self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
|
||||
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
#self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class InferCifarResNet(nn.Module):
|
||||
|
||||
def __init__(self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual):
|
||||
super(InferCifarResNet, self).__init__()
|
||||
|
||||
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == 'ResNetBasicblock':
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == 'ResNetBottleneck':
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError('invalid block : {:}'.format(block_name))
|
||||
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
|
||||
|
||||
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
|
||||
last_channel_idx = 1
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append ( module )
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(iL+1, layer_blocks):
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
break
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer( x )
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
@@ -0,0 +1,150 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
||||
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else : self.avg = None
|
||||
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
|
||||
if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
else : self.bn = None
|
||||
if has_relu: self.relu = nn.ReLU(inplace=True)
|
||||
else : self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg : out = self.avg( inputs )
|
||||
else : out = inputs
|
||||
conv = self.conv( out )
|
||||
if self.bn : out = self.bn( conv )
|
||||
else : out = conv
|
||||
if self.relu: out = self.relu( out )
|
||||
else : out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
|
||||
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
|
||||
elif inplanes != planes*self.expansion:
|
||||
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes*self.expansion
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class InferDepthCifarResNet(nn.Module):
|
||||
|
||||
def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual):
|
||||
super(InferDepthCifarResNet, self).__init__()
|
||||
|
||||
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == 'ResNetBasicblock':
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == 'ResNetBottleneck':
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError('invalid block : {:}'.format(block_name))
|
||||
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
|
||||
|
||||
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
|
||||
self.num_classes = num_classes
|
||||
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
|
||||
self.channels = [16]
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2**stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append( module.out_dim )
|
||||
self.layers.append ( module )
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, planes, module.out_dim, stride)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
break
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.channels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer( x )
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
@@ -0,0 +1,160 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
||||
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else : self.avg = None
|
||||
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
|
||||
if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
else : self.bn = None
|
||||
if has_relu: self.relu = nn.ReLU(inplace=True)
|
||||
else : self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg : out = self.avg( inputs )
|
||||
else : out = inputs
|
||||
conv = self.conv( out )
|
||||
if self.bn : out = self.bn( conv )
|
||||
else : out = conv
|
||||
if self.relu: out = self.relu( out )
|
||||
else : out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
|
||||
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
|
||||
else:
|
||||
self.downsample = None
|
||||
#self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
|
||||
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
#self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class InferWidthCifarResNet(nn.Module):
|
||||
|
||||
def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual):
|
||||
super(InferWidthCifarResNet, self).__init__()
|
||||
|
||||
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == 'ResNetBasicblock':
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == 'ResNetBottleneck':
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError('invalid block : {:}'.format(block_name))
|
||||
|
||||
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
|
||||
last_channel_idx = 1
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append ( module )
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer( x )
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
@@ -0,0 +1,170 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
||||
num_conv = 1
|
||||
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else : self.avg = None
|
||||
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
|
||||
if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
else : self.bn = None
|
||||
if has_relu: self.relu = nn.ReLU(inplace=True)
|
||||
else : self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg : out = self.avg( inputs )
|
||||
else : out = inputs
|
||||
conv = self.conv( out )
|
||||
if self.bn : out = self.bn( conv )
|
||||
else : out = conv
|
||||
if self.relu: out = self.relu( out )
|
||||
else : out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
|
||||
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
|
||||
else:
|
||||
self.downsample = None
|
||||
#self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
|
||||
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=True, has_relu=False)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
#self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
|
||||
class InferImagenetResNet(nn.Module):
|
||||
|
||||
def __init__(self, block_name, layers, xblocks, xchannels, deep_stem, num_classes, zero_init_residual):
|
||||
super(InferImagenetResNet, self).__init__()
|
||||
|
||||
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == 'BasicBlock':
|
||||
block = ResNetBasicblock
|
||||
elif block_name == 'Bottleneck':
|
||||
block = ResNetBottleneck
|
||||
else:
|
||||
raise ValueError('invalid block : {:}'.format(block_name))
|
||||
assert len(xblocks) == len(layers), 'invalid layers : {:} vs xblocks : {:}'.format(layers, xblocks)
|
||||
|
||||
self.message = 'InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}'.format(sum(layers)*block.num_conv, sum(xblocks)*block.num_conv, xblocks)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
if not deep_stem:
|
||||
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 7, 2, 3, False, has_avg=False, has_bn=True, has_relu=True) ] )
|
||||
last_channel_idx = 1
|
||||
else:
|
||||
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||
,ConvBNReLU(xchannels[1], xchannels[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
|
||||
last_channel_idx = 2
|
||||
self.layers.append( nn.MaxPool2d(kernel_size=3, stride=2, padding=1) )
|
||||
for stage, layer_blocks in enumerate(layers):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append ( module )
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(iL+1, layer_blocks):
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
break
|
||||
assert last_channel_idx + 1 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer( x )
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
@@ -0,0 +1,122 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
|
||||
from torch import nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import parse_channel_info
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride, groups, has_bn=True, has_relu=True):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False)
|
||||
if has_bn: self.bn = nn.BatchNorm2d(out_planes)
|
||||
else : self.bn = None
|
||||
if has_relu: self.relu = nn.ReLU6(inplace=True)
|
||||
else : self.relu = None
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv( x )
|
||||
if self.bn: out = self.bn ( out )
|
||||
if self.relu: out = self.relu( out )
|
||||
return out
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, channels, stride, expand_ratio, additive):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2], 'invalid stride : {:}'.format(stride)
|
||||
assert len(channels) in [2, 3], 'invalid channels : {:}'.format(channels)
|
||||
|
||||
if len(channels) == 2:
|
||||
layers = []
|
||||
else:
|
||||
layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)]
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]),
|
||||
# pw-linear
|
||||
ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
self.additive = additive
|
||||
if self.additive and channels[0] != channels[-1]:
|
||||
self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False)
|
||||
else:
|
||||
self.shortcut = None
|
||||
self.out_dim = channels[-1]
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
# if self.additive: return additive_func(out, x)
|
||||
if self.shortcut: return out + self.shortcut(x)
|
||||
else : return out
|
||||
|
||||
|
||||
class InferMobileNetV2(nn.Module):
|
||||
def __init__(self, num_classes, xchannels, xblocks, dropout):
|
||||
super(InferMobileNetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16 , 1, 1],
|
||||
[6, 24 , 2, 2],
|
||||
[6, 32 , 3, 2],
|
||||
[6, 64 , 4, 2],
|
||||
[6, 96 , 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
assert len(inverted_residual_setting) == len(xblocks), 'invalid number of layers : {:} vs {:}'.format(len(inverted_residual_setting), len(xblocks))
|
||||
for block_num, ir_setting in zip(xblocks, inverted_residual_setting):
|
||||
assert block_num <= ir_setting[2], '{:} vs {:}'.format(block_num, ir_setting)
|
||||
xchannels = parse_channel_info(xchannels)
|
||||
#for i, chs in enumerate(xchannels):
|
||||
# if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs)
|
||||
self.xchannels = xchannels
|
||||
self.message = 'InferMobileNetV2 : xblocks={:}'.format(xblocks)
|
||||
# building first layer
|
||||
features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)]
|
||||
last_channel_idx = 1
|
||||
|
||||
# building inverted residual blocks
|
||||
for stage, (t, c, n, s) in enumerate(inverted_residual_setting):
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
additv = True if i > 0 else False
|
||||
module = block(self.xchannels[last_channel_idx], stride, t, additv)
|
||||
features.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(stage, i, n, len(features), self.xchannels[last_channel_idx], stride, t, c)
|
||||
last_channel_idx += 1
|
||||
if i + 1 == xblocks[stage]:
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(i+1, n):
|
||||
last_channel_idx += 1
|
||||
self.xchannels[last_channel_idx][0] = module.out_dim
|
||||
break
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(self.xchannels[last_channel_idx][0], self.xchannels[last_channel_idx][1], 1, 1, 1))
|
||||
assert last_channel_idx + 2 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(self.xchannels[last_channel_idx][1], num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
self.apply( initialize_resnet )
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
features = self.features(inputs)
|
||||
vectors = features.mean([2, 3])
|
||||
predicts = self.classifier(vectors)
|
||||
return features, predicts
|
@@ -0,0 +1,58 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from typing import List, Text, Any
|
||||
import torch.nn as nn
|
||||
from models.cell_operations import ResNetBasicblock
|
||||
from models.cell_infers.cells import InferCell
|
||||
|
||||
|
||||
class DynamicShapeTinyNet(nn.Module):
|
||||
|
||||
def __init__(self, channels: List[int], genotype: Any, num_classes: int):
|
||||
super(DynamicShapeTinyNet, self).__init__()
|
||||
self._channels = channels
|
||||
if len(channels) % 3 != 2:
|
||||
raise ValueError('invalid number of layers : {:}'.format(len(channels)))
|
||||
self._num_stage = N = len(channels) // 3
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(channels[0]))
|
||||
|
||||
# layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
c_prev = channels[0]
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)):
|
||||
if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True)
|
||||
else : cell = InferCell(genotype, c_prev, c_curr, 1)
|
||||
self.cells.append( cell )
|
||||
c_prev = cell.out_dim
|
||||
self._num_layer = len(self.cells)
|
||||
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(c_prev, num_classes)
|
||||
|
||||
def get_message(self) -> Text:
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling( out )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
@@ -0,0 +1,9 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from .InferCifarResNet_width import InferWidthCifarResNet
|
||||
from .InferImagenetResNet import InferImagenetResNet
|
||||
from .InferCifarResNet_depth import InferDepthCifarResNet
|
||||
from .InferCifarResNet import InferCifarResNet
|
||||
from .InferMobileNetV2 import InferMobileNetV2
|
||||
from .InferTinyCellNet import DynamicShapeTinyNet
|
@@ -0,0 +1,5 @@
|
||||
def parse_channel_info(xstring):
|
||||
blocks = xstring.split(' ')
|
||||
blocks = [x.split('-') for x in blocks]
|
||||
blocks = [[int(_) for _ in x] for x in blocks]
|
||||
return blocks
|
@@ -0,0 +1,2 @@
|
||||
from .evaluation_utils import obtain_accuracy
|
||||
from .flop_benchmark import get_model_infos
|
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
def obtain_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
# correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
@@ -0,0 +1,181 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
def count_parameters_in_MB(model):
|
||||
if isinstance(model, nn.Module):
|
||||
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
|
||||
else:
|
||||
return np.sum(np.prod(v.size()) for v in model)/1e6
|
||||
|
||||
|
||||
def get_model_infos(model, shape):
|
||||
#model = copy.deepcopy( model )
|
||||
|
||||
model = add_flops_counting_methods(model)
|
||||
#model = model.cuda()
|
||||
model.eval()
|
||||
|
||||
#cache_inputs = torch.zeros(*shape).cuda()
|
||||
#cache_inputs = torch.zeros(*shape)
|
||||
cache_inputs = torch.rand(*shape)
|
||||
if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda()
|
||||
#print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
|
||||
with torch.no_grad():
|
||||
_____ = model(cache_inputs)
|
||||
FLOPs = compute_average_flops_cost( model ) / 1e6
|
||||
Param = count_parameters_in_MB(model)
|
||||
|
||||
if hasattr(model, 'auxiliary_param'):
|
||||
aux_params = count_parameters_in_MB(model.auxiliary_param())
|
||||
print ('The auxiliary params of this model is : {:}'.format(aux_params))
|
||||
print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param))
|
||||
Param = Param - aux_params
|
||||
|
||||
#print_log('FLOPs : {:} MB'.format(FLOPs), log)
|
||||
torch.cuda.empty_cache()
|
||||
model.apply( remove_hook_function )
|
||||
return FLOPs, Param
|
||||
|
||||
|
||||
# ---- Public functions
|
||||
def add_flops_counting_methods( model ):
|
||||
model.__batch_counter__ = 0
|
||||
add_batch_counter_hook_function( model )
|
||||
model.apply( add_flops_counter_variable_or_reset )
|
||||
model.apply( add_flops_counter_hook_function )
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def compute_average_flops_cost(model):
|
||||
"""
|
||||
A method that will be available after add_flops_counting_methods() is called on a desired net object.
|
||||
Returns current mean flops consumption per image.
|
||||
"""
|
||||
batches_count = model.__batch_counter__
|
||||
flops_sum = 0
|
||||
#or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
|
||||
or isinstance(module, torch.nn.Conv1d) \
|
||||
or hasattr(module, 'calculate_flop_self'):
|
||||
flops_sum += module.__flops__
|
||||
return flops_sum / batches_count
|
||||
|
||||
|
||||
# ---- Internal functions
|
||||
def pool_flops_counter_hook(pool_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
kernel_size = pool_module.kernel_size
|
||||
out_C, output_height, output_width = output.shape[1:]
|
||||
assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size())
|
||||
|
||||
overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
|
||||
pool_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def self_calculate_flops_counter_hook(self_module, inputs, output):
|
||||
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
|
||||
self_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def fc_flops_counter_hook(fc_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
xin, xout = fc_module.in_features, fc_module.out_features
|
||||
assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout)
|
||||
overall_flops = batch_size * xin * xout
|
||||
if fc_module.bias is not None:
|
||||
overall_flops += batch_size * xout
|
||||
fc_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def conv1d_flops_counter_hook(conv_module, inputs, outputs):
|
||||
batch_size = inputs[0].size(0)
|
||||
outL = outputs.shape[-1]
|
||||
[kernel] = conv_module.kernel_size
|
||||
in_channels = conv_module.in_channels
|
||||
out_channels = conv_module.out_channels
|
||||
groups = conv_module.groups
|
||||
conv_per_position_flops = kernel * in_channels * out_channels / groups
|
||||
|
||||
active_elements_count = batch_size * outL
|
||||
overall_flops = conv_per_position_flops * active_elements_count
|
||||
|
||||
if conv_module.bias is not None:
|
||||
overall_flops += out_channels * active_elements_count
|
||||
conv_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def conv2d_flops_counter_hook(conv_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
output_height, output_width = output.shape[2:]
|
||||
|
||||
kernel_height, kernel_width = conv_module.kernel_size
|
||||
in_channels = conv_module.in_channels
|
||||
out_channels = conv_module.out_channels
|
||||
groups = conv_module.groups
|
||||
conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups
|
||||
|
||||
active_elements_count = batch_size * output_height * output_width
|
||||
overall_flops = conv_per_position_flops * active_elements_count
|
||||
|
||||
if conv_module.bias is not None:
|
||||
overall_flops += out_channels * active_elements_count
|
||||
conv_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def batch_counter_hook(module, inputs, output):
|
||||
# Can have multiple inputs, getting the first one
|
||||
inputs = inputs[0]
|
||||
batch_size = inputs.shape[0]
|
||||
module.__batch_counter__ += batch_size
|
||||
|
||||
|
||||
def add_batch_counter_hook_function(module):
|
||||
if not hasattr(module, '__batch_counter_handle__'):
|
||||
handle = module.register_forward_hook(batch_counter_hook)
|
||||
module.__batch_counter_handle__ = handle
|
||||
|
||||
|
||||
def add_flops_counter_variable_or_reset(module):
|
||||
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
|
||||
or isinstance(module, torch.nn.Conv1d) \
|
||||
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
|
||||
or hasattr(module, 'calculate_flop_self'):
|
||||
module.__flops__ = 0
|
||||
|
||||
|
||||
def add_flops_counter_hook_function(module):
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(conv2d_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.Conv1d):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(conv1d_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.Linear):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(fc_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(pool_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif hasattr(module, 'calculate_flop_self'): # self-defined module
|
||||
if not hasattr(module, '__flops_handle__'):
|
||||
handle = module.register_forward_hook(self_calculate_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
|
||||
|
||||
def remove_hook_function(module):
|
||||
hookers = ['__batch_counter_handle__', '__flops_handle__']
|
||||
for hooker in hookers:
|
||||
if hasattr(module, hooker):
|
||||
handle = getattr(module, hooker)
|
||||
handle.remove()
|
||||
keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers
|
||||
for ckey in keys:
|
||||
if hasattr(module, ckey): delattr(module, ckey)
|
@@ -0,0 +1,28 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .starts import get_machine_info, save_checkpoint, copy_checkpoint
|
||||
from .optimizers import get_optim_scheduler
|
||||
from .starts import prepare_seed #, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint
|
||||
'''
|
||||
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
|
||||
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
|
||||
from .funcs_nasbench import get_nas_bench_loaders
|
||||
|
||||
def get_procedures(procedure):
|
||||
from .basic_main import basic_train, basic_valid
|
||||
from .search_main import search_train, search_valid
|
||||
from .search_main_v2 import search_train_v2
|
||||
from .simple_KD_main import simple_KD_train, simple_KD_valid
|
||||
|
||||
train_funcs = {'basic' : basic_train, \
|
||||
'search': search_train,'Simple-KD': simple_KD_train, \
|
||||
'search-v2': search_train_v2}
|
||||
valid_funcs = {'basic' : basic_valid, \
|
||||
'search': search_valid,'Simple-KD': simple_KD_valid, \
|
||||
'search-v2': search_valid}
|
||||
|
||||
train_func = train_funcs[procedure]
|
||||
valid_func = valid_funcs[procedure]
|
||||
return train_func, valid_func
|
||||
'''
|
@@ -0,0 +1,204 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
from bisect import bisect_right
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class _LRScheduler(object):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs):
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError('{:} is not an Optimizer'.format(type(optimizer).__name__))
|
||||
self.optimizer = optimizer
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault('initial_lr', group['lr'])
|
||||
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
|
||||
self.max_epochs = epochs
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.current_epoch = 0
|
||||
self.current_iter = 0
|
||||
|
||||
def extra_repr(self):
|
||||
return ''
|
||||
|
||||
def __repr__(self):
|
||||
return ('{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}'.format(name=self.__class__.__name__, **self.__dict__)
|
||||
+ ', {:})'.format(self.extra_repr()))
|
||||
|
||||
def state_dict(self):
|
||||
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_lr(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_min_info(self):
|
||||
lrs = self.get_lr()
|
||||
return '#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#'.format(min(lrs), max(lrs), self.current_epoch, self.current_iter)
|
||||
|
||||
def get_min_lr(self):
|
||||
return min( self.get_lr() )
|
||||
|
||||
def update(self, cur_epoch, cur_iter):
|
||||
if cur_epoch is not None:
|
||||
assert isinstance(cur_epoch, int) and cur_epoch>=0, 'invalid cur-epoch : {:}'.format(cur_epoch)
|
||||
self.current_epoch = cur_epoch
|
||||
if cur_iter is not None:
|
||||
assert isinstance(cur_iter, float) and cur_iter>=0, 'invalid cur-iter : {:}'.format(cur_iter)
|
||||
self.current_iter = cur_iter
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group['lr'] = lr
|
||||
|
||||
|
||||
|
||||
class CosineAnnealingLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, T-max={:}, eta-min={:}'.format('cosine', self.T_max, self.eta_min)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
#if last_epoch < self.T_max:
|
||||
#if last_epoch < self.max_epochs:
|
||||
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
|
||||
#else:
|
||||
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
|
||||
elif self.current_epoch >= self.max_epochs:
|
||||
lr = self.eta_min
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
|
||||
class MultiStepLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
|
||||
assert len(milestones) == len(gammas), 'invalid {:} vs {:}'.format(len(milestones), len(gammas))
|
||||
self.milestones = milestones
|
||||
self.gammas = gammas
|
||||
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, milestones={:}, gammas={:}, base-lrs={:}'.format('multistep', self.milestones, self.gammas, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
idx = bisect_right(self.milestones, last_epoch)
|
||||
lr = base_lr
|
||||
for x in self.gammas[:idx]: lr *= x
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
class ExponentialLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
|
||||
self.gamma = gamma
|
||||
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, gamma={:}, base-lrs={:}'.format('exponential', self.gamma, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
|
||||
lr = base_lr * (self.gamma ** last_epoch)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
class LinearLR(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
|
||||
self.max_LR = max_LR
|
||||
self.min_LR = min_LR
|
||||
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'type={:}, max_LR={:}, min_LR={:}, base-lrs={:}'.format('LinearLR', self.max_LR, self.min_LR, self.base_lrs)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
|
||||
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
|
||||
lr = base_lr * (1-ratio)
|
||||
else:
|
||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
||||
lrs.append( lr )
|
||||
return lrs
|
||||
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, epsilon):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
def get_optim_scheduler(parameters, config):
|
||||
assert hasattr(config, 'optim') and hasattr(config, 'scheduler') and hasattr(config, 'criterion'), 'config must have optim / scheduler / criterion keys instead of {:}'.format(config)
|
||||
if config.optim == 'SGD':
|
||||
optim = torch.optim.SGD(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov)
|
||||
elif config.optim == 'RMSprop':
|
||||
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
|
||||
else:
|
||||
raise ValueError('invalid optim : {:}'.format(config.optim))
|
||||
|
||||
if config.scheduler == 'cos':
|
||||
T_max = getattr(config, 'T_max', config.epochs)
|
||||
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
|
||||
elif config.scheduler == 'multistep':
|
||||
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
|
||||
elif config.scheduler == 'exponential':
|
||||
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
|
||||
elif config.scheduler == 'linear':
|
||||
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
|
||||
else:
|
||||
raise ValueError('invalid scheduler : {:}'.format(config.scheduler))
|
||||
|
||||
if config.criterion == 'Softmax':
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
elif config.criterion == 'SmoothSoftmax':
|
||||
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
|
||||
else:
|
||||
raise ValueError('invalid criterion : {:}'.format(config.criterion))
|
||||
return optim, scheduler, criterion
|
@@ -0,0 +1,64 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, torch, random, PIL, copy, numpy as np
|
||||
from os import path as osp
|
||||
from shutil import copyfile
|
||||
|
||||
|
||||
def prepare_seed(rand_seed):
|
||||
random.seed(rand_seed)
|
||||
np.random.seed(rand_seed)
|
||||
torch.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed_all(rand_seed)
|
||||
|
||||
|
||||
def prepare_logger(xargs):
|
||||
args = copy.deepcopy( xargs )
|
||||
from log_utils import Logger
|
||||
logger = Logger(args.save_dir, args.rand_seed)
|
||||
logger.log('Main Function with logger : {:}'.format(logger))
|
||||
logger.log('Arguments : -------------------------------')
|
||||
for name, value in args._get_kwargs():
|
||||
logger.log('{:16} : {:}'.format(name, value))
|
||||
logger.log("Python Version : {:}".format(sys.version.replace('\n', ' ')))
|
||||
logger.log("Pillow Version : {:}".format(PIL.__version__))
|
||||
logger.log("PyTorch Version : {:}".format(torch.__version__))
|
||||
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
|
||||
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
|
||||
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
|
||||
logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None'))
|
||||
return logger
|
||||
|
||||
|
||||
def get_machine_info():
|
||||
info = "Python Version : {:}".format(sys.version.replace('\n', ' '))
|
||||
info+= "\nPillow Version : {:}".format(PIL.__version__)
|
||||
info+= "\nPyTorch Version : {:}".format(torch.__version__)
|
||||
info+= "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
|
||||
info+= "\nCUDA available : {:}".format(torch.cuda.is_available())
|
||||
info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||
info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES'])
|
||||
else:
|
||||
info+= "\nDoes not set CUDA_VISIBLE_DEVICES"
|
||||
return info
|
||||
|
||||
|
||||
def save_checkpoint(state, filename, logger):
|
||||
if osp.isfile(filename):
|
||||
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename))
|
||||
os.remove(filename)
|
||||
torch.save(state, filename)
|
||||
assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename)
|
||||
if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename))
|
||||
return filename
|
||||
|
||||
|
||||
def copy_checkpoint(src, dst, logger):
|
||||
if osp.isfile(dst):
|
||||
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst))
|
||||
os.remove(dst)
|
||||
copyfile(src, dst)
|
||||
if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst))
|
83
NAS-Bench-201/main_exp/transfer_nag/run_multi_proc.py
Normal file
83
NAS-Bench-201/main_exp/transfer_nag/run_multi_proc.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from torch.multiprocessing import Process
|
||||
import os
|
||||
from absl import app, flags
|
||||
import sys
|
||||
import torch
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
|
||||
from nas_bench_201 import train_single_model
|
||||
from all_path import NASBENCH201
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_split", 15, "The number of splits")
|
||||
flags.DEFINE_list("arch_idx_lst", None, "arch index list")
|
||||
flags.DEFINE_list("arch_str_lst", None, "arch str list")
|
||||
flags.DEFINE_string("meta_test_path", None, "meta test path")
|
||||
flags.DEFINE_string("data_name", None, "data_name")
|
||||
flags.DEFINE_string("raw_data_path", None, "raw_data_path")
|
||||
|
||||
|
||||
def run_single_process(rank, seed, arch_idx, meta_test_path, data_name,
|
||||
raw_data_path, num_split=15, backend="nccl"):
|
||||
# 8 GPUs
|
||||
device = ['0', '1', '2', '3', '4', '5', '6', '7', '0', '1', '2', '3', '4', '5', '6', '7',
|
||||
'0', '1', '2', '3', '4', '5', '6', '7', '0', '1', '2', '3', '4', '5', '6', '7'][rank]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = device
|
||||
|
||||
save_path = os.path.join(meta_test_path, str(arch_idx))
|
||||
if type(seed) == int:
|
||||
seeds = [seed]
|
||||
elif type(seed) in [list, tuple]:
|
||||
seeds = seed
|
||||
|
||||
nasbench201 = torch.load(NASBENCH201)
|
||||
arch_str = nasbench201['arch']['str'][arch_idx]
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
train_single_model(save_dir=save_path,
|
||||
workers=24,
|
||||
datasets=[data_name],
|
||||
xpaths=[f'{raw_data_path}/{data_name}'],
|
||||
splits=[0],
|
||||
use_less=False,
|
||||
seeds=seeds,
|
||||
model_str=arch_str,
|
||||
arch_config={'channel': 16, 'num_cells': 5})
|
||||
|
||||
|
||||
def run_multi_process(argv):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "1234"
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
processes = []
|
||||
|
||||
arch_idx_lst = [int(i) for i in FLAGS.arch_idx_lst]
|
||||
seeds = [777, 888, 999] * len(arch_idx_lst)
|
||||
arch_idx_lst_ = []
|
||||
for i in arch_idx_lst:
|
||||
arch_idx_lst_ += [i] * 3
|
||||
|
||||
for arch_idx in arch_idx_lst:
|
||||
os.makedirs(os.path.join(FLAGS.meta_test_path, str(arch_idx)), exist_ok=True)
|
||||
|
||||
for rank in range(FLAGS.num_split):
|
||||
arch_idx = arch_idx_lst_[rank]
|
||||
seed = seeds[rank]
|
||||
p = Process(target=run_single_process, args=(rank,
|
||||
seed,
|
||||
arch_idx,
|
||||
FLAGS.meta_test_path,
|
||||
FLAGS.data_name,
|
||||
FLAGS.raw_data_path))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
while any(p.is_alive() for p in processes):
|
||||
continue
|
||||
print("All processes have completed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(run_multi_process)
|
@@ -0,0 +1,38 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from set_encoder.setenc_modules import *
|
||||
|
||||
|
||||
class SetPool(nn.Module):
|
||||
def __init__(self, dim_input, num_outputs, dim_output,
|
||||
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
|
||||
super(SetPool, self).__init__()
|
||||
if 'sab' in mode: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
|
||||
else: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
|
||||
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
|
||||
if 'PF' in mode: # [32, 1, 501]
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
elif 'P' in mode:
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
|
||||
else: # torch.Size([32, 1, 501])
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
# "", sm, sab, sabsm
|
||||
|
||||
def forward(self, X):
|
||||
x1 = self.enc(X)
|
||||
x2 = self.dec(x1)
|
||||
return x2
|
@@ -0,0 +1,67 @@
|
||||
#####################################################################################
|
||||
# Copyright (c) Juho Lee SetTransformer, ICML 2019 [GitHub set_transformer]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class MAB(nn.Module):
|
||||
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
|
||||
super(MAB, self).__init__()
|
||||
self.dim_V = dim_V
|
||||
self.num_heads = num_heads
|
||||
self.fc_q = nn.Linear(dim_Q, dim_V)
|
||||
self.fc_k = nn.Linear(dim_K, dim_V)
|
||||
self.fc_v = nn.Linear(dim_K, dim_V)
|
||||
if ln:
|
||||
self.ln0 = nn.LayerNorm(dim_V)
|
||||
self.ln1 = nn.LayerNorm(dim_V)
|
||||
self.fc_o = nn.Linear(dim_V, dim_V)
|
||||
|
||||
def forward(self, Q, K):
|
||||
Q = self.fc_q(Q)
|
||||
K, V = self.fc_k(K), self.fc_v(K)
|
||||
|
||||
dim_split = self.dim_V // self.num_heads
|
||||
Q_ = torch.cat(Q.split(dim_split, 2), 0)
|
||||
K_ = torch.cat(K.split(dim_split, 2), 0)
|
||||
V_ = torch.cat(V.split(dim_split, 2), 0)
|
||||
|
||||
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
|
||||
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
|
||||
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
|
||||
O = O + F.relu(self.fc_o(O))
|
||||
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
|
||||
return O
|
||||
|
||||
class SAB(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, num_heads, ln=False):
|
||||
super(SAB, self).__init__()
|
||||
self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
return self.mab(X, X)
|
||||
|
||||
class ISAB(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
|
||||
super(ISAB, self).__init__()
|
||||
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
|
||||
nn.init.xavier_uniform_(self.I)
|
||||
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
|
||||
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
|
||||
return self.mab1(X, H)
|
||||
|
||||
class PMA(nn.Module):
|
||||
def __init__(self, dim, num_heads, num_seeds, ln=False):
|
||||
super(PMA, self).__init__()
|
||||
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
|
||||
nn.init.xavier_uniform_(self.S)
|
||||
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
return self.mab(self.S.repeat(X.size(0), 1, 1), X)
|
243
NAS-Bench-201/main_exp/transfer_nag/unnoised_model.py
Normal file
243
NAS-Bench-201/main_exp/transfer_nag/unnoised_model.py
Normal file
@@ -0,0 +1,243 @@
|
||||
######################################################################################
|
||||
# Copyright (c) muhanzhang, D-VAE, NeurIPS 2019 [GitHub D-VAE]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import torch
|
||||
from torch import nn
|
||||
from set_encoder.setenc_models import SetPool
|
||||
|
||||
|
||||
class MetaSurrogateUnnoisedModel(nn.Module):
|
||||
def __init__(self, args, graph_config):
|
||||
super(MetaSurrogateUnnoisedModel, self).__init__()
|
||||
self.max_n = graph_config['max_n'] # maximum number of vertices
|
||||
self.nvt = args.nvt # number of vertex types
|
||||
self.START_TYPE = graph_config['START_TYPE']
|
||||
self.END_TYPE = graph_config['END_TYPE']
|
||||
self.hs = args.hs # hidden state size of each vertex
|
||||
self.nz = args.nz # size of latent representation z
|
||||
self.gs = args.hs # size of graph state
|
||||
self.bidir = True # whether to use bidirectional encoding
|
||||
self.vid = True
|
||||
self.device = None
|
||||
self.input_type = 'DG'
|
||||
self.num_sample = args.num_sample
|
||||
|
||||
if self.vid:
|
||||
self.vs = self.hs + self.max_n # vertex state size = hidden state + vid
|
||||
else:
|
||||
self.vs = self.hs
|
||||
|
||||
# 0. encoding-related
|
||||
self.grue_forward = nn.GRUCell(self.nvt, self.hs) # encoder GRU
|
||||
self.grue_backward = nn.GRUCell(
|
||||
self.nvt, self.hs) # backward encoder GRU
|
||||
self.fc1 = nn.Linear(self.gs, self.nz) # latent mean
|
||||
self.fc2 = nn.Linear(self.gs, self.nz) # latent logvar
|
||||
|
||||
# 2. gate-related
|
||||
self.gate_forward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.gate_backward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.mapper_forward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs, bias=False),
|
||||
) # disable bias to ensure padded zeros also mapped to zeros
|
||||
self.mapper_backward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs, bias=False),
|
||||
)
|
||||
|
||||
# 3. bidir-related, to unify sizes
|
||||
if self.bidir:
|
||||
self.hv_unify = nn.Sequential(
|
||||
nn.Linear(self.hs * 2, self.hs),
|
||||
)
|
||||
self.hg_unify = nn.Sequential(
|
||||
nn.Linear(self.gs * 2, self.gs),
|
||||
)
|
||||
|
||||
# 4. other
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.tanh = nn.Tanh()
|
||||
self.logsoftmax1 = nn.LogSoftmax(1)
|
||||
|
||||
# 6. predictor
|
||||
np = self.gs
|
||||
self.intra_setpool = SetPool(dim_input=512,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.inter_setpool = SetPool(dim_input=self.nz,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.set_fc = nn.Sequential(
|
||||
nn.Linear(512, self.nz),
|
||||
nn.ReLU())
|
||||
|
||||
input_dim = 0
|
||||
if 'D' in self.input_type:
|
||||
input_dim += self.nz
|
||||
if 'G' in self.input_type:
|
||||
input_dim += self.nz
|
||||
|
||||
self.pred_fc = nn.Sequential(
|
||||
nn.Linear(input_dim, self.hs),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.hs, 1)
|
||||
)
|
||||
self.mseloss = nn.MSELoss(reduction='sum')
|
||||
|
||||
def predict(self, D_mu, G_mu):
|
||||
input_vec = []
|
||||
if 'D' in self.input_type:
|
||||
input_vec.append(D_mu)
|
||||
if 'G' in self.input_type:
|
||||
input_vec.append(G_mu)
|
||||
input_vec = torch.cat(input_vec, dim=1)
|
||||
return self.pred_fc(input_vec)
|
||||
|
||||
def get_device(self):
|
||||
if self.device is None:
|
||||
self.device = next(self.parameters()).device
|
||||
return self.device
|
||||
|
||||
def _get_zeros(self, n, length):
|
||||
# get a zero hidden state
|
||||
return torch.zeros(n, length).to(self.get_device())
|
||||
|
||||
def _get_zero_hidden(self, n=1):
|
||||
return self._get_zeros(n, self.hs) # get a zero hidden state
|
||||
|
||||
def _one_hot(self, idx, length):
|
||||
if type(idx) in [list, range]:
|
||||
if idx == []:
|
||||
return None
|
||||
idx = torch.LongTensor(idx).unsqueeze(0).t()
|
||||
x = torch.zeros((len(idx), length)).scatter_(
|
||||
1, idx, 1).to(self.get_device())
|
||||
else:
|
||||
idx = torch.LongTensor([idx]).unsqueeze(0)
|
||||
x = torch.zeros((1, length)).scatter_(
|
||||
1, idx, 1).to(self.get_device())
|
||||
return x
|
||||
|
||||
def _gated(self, h, gate, mapper):
|
||||
return gate(h) * mapper(h)
|
||||
|
||||
def _collate_fn(self, G):
|
||||
return [g.copy() for g in G]
|
||||
|
||||
def _propagate_to(self, G, v, propagator, H=None, reverse=False, gate=None, mapper=None):
|
||||
# propagate messages to vertex index v for all graphs in G
|
||||
# return the new messages (states) at v
|
||||
G = [g for g in G if g.vcount() > v]
|
||||
if len(G) == 0:
|
||||
return
|
||||
if H is not None:
|
||||
idx = [i for i, g in enumerate(G) if g.vcount() > v]
|
||||
H = H[idx]
|
||||
v_types = [g.vs[v]['type'] for g in G]
|
||||
X = self._one_hot(v_types, self.nvt)
|
||||
if reverse:
|
||||
H_name = 'H_backward' # name of the hidden states attribute
|
||||
H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G]
|
||||
if self.vid:
|
||||
vids = [self._one_hot(g.successors(v), self.max_n) for g in G]
|
||||
gate, mapper = self.gate_backward, self.mapper_backward
|
||||
else:
|
||||
H_name = 'H_forward' # name of the hidden states attribute
|
||||
H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
|
||||
if self.vid:
|
||||
vids = [self._one_hot(g.predecessors(v), self.max_n)
|
||||
for g in G]
|
||||
if gate is None:
|
||||
gate, mapper = self.gate_forward, self.mapper_forward
|
||||
if self.vid:
|
||||
H_pred = [[torch.cat([x[i], y[i:i + 1]], 1)
|
||||
for i in range(len(x))] for x, y in zip(H_pred, vids)]
|
||||
# if h is not provided, use gated sum of v's predecessors' states as the input hidden state
|
||||
if H is None:
|
||||
# maximum number of predecessors
|
||||
max_n_pred = max([len(x) for x in H_pred])
|
||||
if max_n_pred == 0:
|
||||
H = self._get_zero_hidden(len(G))
|
||||
else:
|
||||
H_pred = [torch.cat(h_pred +
|
||||
[self._get_zeros(max_n_pred - len(h_pred), self.vs)], 0).unsqueeze(0)
|
||||
for h_pred in H_pred] # pad all to same length
|
||||
H_pred = torch.cat(H_pred, 0) # batch * max_n_pred * vs
|
||||
H = self._gated(H_pred, gate, mapper).sum(1) # batch * hs
|
||||
Hv = propagator(X, H)
|
||||
for i, g in enumerate(G):
|
||||
g.vs[v][H_name] = Hv[i:i + 1]
|
||||
return Hv
|
||||
|
||||
def _propagate_from(self, G, v, propagator, H0=None, reverse=False):
|
||||
# perform a series of propagation_to steps starting from v following a topo order
|
||||
# assume the original vertex indices are in a topological order
|
||||
if reverse:
|
||||
prop_order = range(v, -1, -1)
|
||||
else:
|
||||
prop_order = range(v, self.max_n)
|
||||
Hv = self._propagate_to(G, v, propagator, H0,
|
||||
reverse=reverse) # the initial vertex
|
||||
for v_ in prop_order[1:]:
|
||||
self._propagate_to(G, v_, propagator, reverse=reverse)
|
||||
return Hv
|
||||
|
||||
def _get_graph_state(self, G, decode=False):
|
||||
# get the graph states
|
||||
# when decoding, use the last generated vertex's state as the graph state
|
||||
# when encoding, use the ending vertex state or unify the starting and ending vertex states
|
||||
Hg = []
|
||||
for g in G:
|
||||
hg = g.vs[g.vcount() - 1]['H_forward']
|
||||
if self.bidir and not decode: # decoding never uses backward propagation
|
||||
hg_b = g.vs[0]['H_backward']
|
||||
hg = torch.cat([hg, hg_b], 1)
|
||||
Hg.append(hg)
|
||||
Hg = torch.cat(Hg, 0)
|
||||
if self.bidir and not decode:
|
||||
Hg = self.hg_unify(Hg)
|
||||
return Hg
|
||||
|
||||
def set_encode(self, X):
|
||||
proto_batch = []
|
||||
for x in X:
|
||||
cls_protos = self.intra_setpool(
|
||||
x.view(-1, self.num_sample, 512)).squeeze(1)
|
||||
proto_batch.append(
|
||||
self.inter_setpool(cls_protos.unsqueeze(0)))
|
||||
v = torch.stack(proto_batch).squeeze()
|
||||
return v
|
||||
|
||||
def graph_encode(self, G):
|
||||
# encode graphs G into latent vectors
|
||||
if type(G) != list:
|
||||
G = [G]
|
||||
self._propagate_from(G, 0, self.grue_forward, H0=self._get_zero_hidden(len(G)),
|
||||
reverse=False)
|
||||
if self.bidir:
|
||||
self._propagate_from(G, self.max_n - 1, self.grue_backward,
|
||||
H0=self._get_zero_hidden(len(G)), reverse=True)
|
||||
Hg = self._get_graph_state(G)
|
||||
mu = self.fc1(Hg)
|
||||
# logvar = self.fc2(Hg)
|
||||
return mu # , logvar
|
||||
|
||||
def reparameterize(self, mu, logvar, eps_scale=0.01):
|
||||
# return z ~ N(mu, std)
|
||||
if self.training:
|
||||
std = logvar.mul(0.5).exp_()
|
||||
eps = torch.randn_like(std) * eps_scale
|
||||
return eps.mul(std).add_(mu)
|
||||
else:
|
||||
return mu
|
33
NAS-Bench-201/main_exp/utils.py
Normal file
33
NAS-Bench-201/main_exp/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
def restore_checkpoint(ckpt_dir, state, device, resume=False):
|
||||
if not resume:
|
||||
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
|
||||
return state
|
||||
elif not os.path.exists(ckpt_dir):
|
||||
if not os.path.exists(os.path.dirname(ckpt_dir)):
|
||||
os.makedirs(os.path.dirname(ckpt_dir))
|
||||
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
||||
f"Returned the same state as input")
|
||||
return state
|
||||
else:
|
||||
loaded_state = torch.load(ckpt_dir, map_location=device)
|
||||
for k in state:
|
||||
if k in ['optimizer', 'model', 'ema']:
|
||||
state[k].load_state_dict(loaded_state[k])
|
||||
else:
|
||||
state[k] = loaded_state[k]
|
||||
return state
|
||||
|
||||
|
||||
def reset_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
Reference in New Issue
Block a user