first commit

This commit is contained in:
CownowAn
2024-03-15 14:38:51 +00:00
commit bc2ed1304f
321 changed files with 44802 additions and 0 deletions

View File

@@ -0,0 +1,286 @@
import torch
import sys
import numpy as np
import random
sys.path.append('.')
import sampling
import datasets_nas
from models import cate
from models import digcn
from models import digcn_meta
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import sde_lib
from utils import *
from analysis.arch_functions import BasicArchMetricsMeta
from all_path import *
def get_sampling_fn_meta(config):
## Set SDE
if config.training.sde.lower() == 'vpsde':
sde = sde_lib.VPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales)
sampling_eps = 1e-3
elif config.training.sde.lower() == 'subvpsde':
sde = sde_lib.subVPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales)
sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
sde = sde_lib.VESDE(
sigma_min=config.model.sigma_min,
sigma_max=config.model.sigma_max,
N=config.model.num_scales)
sampling_eps = 1e-5
else:
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
## Get data normalizer inverse
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
## Get sampling function
sampling_shape = (config.eval.batch_size, config.data.max_node, config.data.n_vocab)
sampling_fn = sampling.get_sampling_fn(
config=config,
sde=sde,
shape=sampling_shape,
inverse_scaler=inverse_scaler,
eps=sampling_eps,
conditional=True,
data_name=config.sampling.check_dataname,
num_sample=config.model.num_sample)
return sampling_fn, sde
def get_score_model(config):
try:
score_config = torch.load(config.scorenet_ckpt_path)['config']
ckpt_path = config.scorenet_ckpt_path
except:
config.scorenet_ckpt_path = SCORENET_CKPT_PATH
score_config = torch.load(config.scorenet_ckpt_path)['config']
ckpt_path = config.scorenet_ckpt_path
score_model = mutils.create_model(score_config)
score_ema = ExponentialMovingAverage(
score_model.parameters(), decay=score_config.model.ema_rate)
score_state = dict(
model=score_model, ema=score_ema, step=0, config=score_config)
score_state = restore_checkpoint(
ckpt_path, score_state,
device=config.device, resume=True)
score_ema.copy_to(score_model.parameters())
return score_model, score_ema, score_config
def get_surrogate(config):
surrogate_model = mutils.create_model(config)
return surrogate_model
def get_adj(except_inout=False):
_adj = np.asarray(
[[0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0]]
)
_adj = torch.tensor(_adj, dtype=torch.float32, device=torch.device('cpu'))
if except_inout: _adj = _adj[1:-1, 1:-1]
return _adj
def generate_archs_meta(
config,
sampling_fn,
score_model,
score_ema,
meta_surrogate_model,
num_samples,
args=None,
task=None,
patient_factor=20,
batch_size=256,):
metrics = BasicArchMetricsMeta()
## Get the adj and mask
adj_s = get_adj()
mask_s = aug_mask(adj_s)[0]
adj_c = get_adj()
mask_c = aug_mask(adj_c)[0]
assert (adj_s == adj_c).all() and (mask_s == mask_c).all()
adj_s, mask_s, adj_c, mask_c = \
adj_s.to(config.device), mask_s.to(config.device), adj_c.to(config.device), mask_c.to(config.device)
score_ema.copy_to(score_model.parameters())
score_model.eval()
meta_surrogate_model.eval()
c_scale = args.classifier_scale
num_sampling_rounds = int(np.ceil(num_samples / batch_size) * patient_factor) if num_samples > batch_size else int(patient_factor)
round = 0
all_samples = []
while True and round < num_sampling_rounds:
round += 1
sample = sampling_fn(score_model,
mask_s,
meta_surrogate_model,
classifier_scale=c_scale,
task=task)
quantized_sample = quantize(sample)
_, _, valid_arch_str, _ = metrics.compute_validity(quantized_sample)
if len(valid_arch_str) > 0: all_samples += valid_arch_str
# to sample various architectures
c_scale -= args.scale_step
seed = int(random.random() * 10000)
reset_seed(seed)
# stop sampling if we have enough samples
if (len(set(all_samples)) >= num_samples):
break
return list(set(all_samples))
def save_checkpoint(ckpt_dir, state, epoch, is_best):
saved_state = {}
for k in state:
if k in ['optimizer', 'model', 'ema']:
saved_state.update({k: state[k].state_dict()})
else:
saved_state.update({k: state[k]})
os.makedirs(ckpt_dir, exist_ok=True)
torch.save(saved_state, os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'))
if is_best:
shutil.copy(os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'), os.path.join(ckpt_dir, 'model_best.pth.tar'))
# remove the ckpt except is_best state
for ckpt_file in sorted(os.listdir(ckpt_dir)):
if not ckpt_file.startswith('checkpoint'):
continue
if os.path.join(ckpt_dir, ckpt_file) != os.path.join(ckpt_dir, 'model_best.pth.tar'):
os.remove(os.path.join(ckpt_dir, ckpt_file))
def restore_checkpoint(ckpt_dir, state, device, resume=False):
if not resume:
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
return state
elif not os.path.exists(ckpt_dir):
if not os.path.exists(os.path.dirname(ckpt_dir)):
os.makedirs(os.path.dirname(ckpt_dir))
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
return state
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
for k in state:
if k in ['optimizer', 'model', 'ema']:
state[k].load_state_dict(loaded_state[k])
else:
state[k] = loaded_state[k]
return state
def floyed(r):
"""
:param r: a numpy NxN matrix with float 0,1
:return: a numpy NxN matrix with float 0,1
"""
if type(r) == torch.Tensor:
r = r.cpu().numpy()
N = r.shape[0]
for k in range(N):
for i in range(N):
for j in range(N):
if r[i, k] > 0 and r[k, j] > 0:
r[i, j] = 1
return r
def aug_mask(adj, algo='floyed', data='NASBench201'):
if len(adj.shape) == 2:
adj = adj.unsqueeze(0)
if data.lower() in ['nasbench201', 'ofa']:
assert len(adj.shape) == 3
r = adj[0].clone().detach()
if algo == 'long_range':
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
elif algo == 'floyed':
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
else:
mask_i = r
masks = [mask_i] * adj.size(0)
return torch.stack(masks)
else:
masks = []
for r in adj:
if algo == 'long_range':
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
elif algo == 'floyed':
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
else:
mask_i = r
masks.append(mask_i)
return torch.stack(masks)
def long_range(r):
"""
:param r: a numpy NxN matrix with float 0,1
:return: a numpy NxN matrix with float 0,1
"""
# r = np.array(r)
if type(r) == torch.Tensor:
r = r.cpu().numpy()
N = r.shape[0]
for j in range(1, N):
col_j = r[:, j][:j]
in_to_j = [i for i, val in enumerate(col_j) if val > 0]
if len(in_to_j) > 0:
for i in in_to_j:
col_i = r[:, i][:i]
in_to_i = [i for i, val in enumerate(col_i) if val > 0]
if len(in_to_i) > 0:
for k in in_to_i:
r[k, j] = 1
return r
def quantize(x):
"""Covert the PyTorch tensor x, adj matrices to numpy array.
Args:
x: [Batch_size, Max_node, N_vocab]
"""
x_list = []
# discretization
x[x >= 0.5] = 1.
x[x < 0.5] = 0.
for i in range(x.shape[0]):
x_tmp = x[i]
x_tmp = x_tmp.cpu().numpy()
x_list.append(x_tmp)
return x_list
def reset_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True

View File

@@ -0,0 +1,137 @@
import os
import wandb
import torch
import numpy as np
class Logger:
def __init__(
self,
log_dir=None,
write_textfile=True
):
self.log_dir = log_dir
self.write_textfile = write_textfile
self.logs_for_save = {}
self.logs = {}
if self.write_textfile:
self.f = open(os.path.join(log_dir, 'logs.txt'), 'w')
def write_str(self, log_str):
self.f.write(log_str+'\n')
self.f.flush()
def update_config(self, v, is_args=False):
if is_args:
self.logs_for_save.update({'args': v})
else:
self.logs_for_save.update(v)
def write_log(self, element, step, return_log_dict=False):
log_str = f"{step} | "
log_dict = {}
for head, keys in element.items():
for k in keys:
if k in self.logs:
v = self.logs[k].avg
if not k in self.logs_for_save:
self.logs_for_save[k] = []
self.logs_for_save[k].append(v)
log_str += f'{k} {v}| '
log_dict[f'{head}/{k}'] = v
if self.write_textfile:
self.f.write(log_str+'\n')
self.f.flush()
if return_log_dict:
return log_dict
def save_log(self, name=None):
name = 'logs.pt' if name is None else name
torch.save(self.logs_for_save, os.path.join(self.log_dir, name))
def update(self, key, v, n=1):
if not key in self.logs:
self.logs[key] = AverageMeter()
self.logs[key].update(v, n)
def reset(self, keys=None, except_keys=[]):
if keys is not None:
if isinstance(keys, list):
for key in keys:
self.logs[key] = AverageMeter()
else:
self.logs[keys] = AverageMeter()
else:
for key in self.logs.keys():
if not key in except_keys:
self.logs[key] = AverageMeter()
def avg(self, keys=None, except_keys=[]):
if keys is not None:
if isinstance(keys, list):
return {key: self.logs[key].avg for key in keys if key in self.logs.keys()}
else:
return self.logs[keys].avg
else:
avg_dict = {}
for key in self.logs.keys():
if not key in except_keys:
avg_dict[key] = self.logs[key].avg
return avg_dict
class AverageMeter(object):
"""
Computes and stores the average and current value
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_metrics(g_embeds, x_embeds, logit_scale, prefix='train'):
metrics = {}
logits_per_g = (logit_scale * g_embeds @ x_embeds.t()).detach().cpu()
logits_per_x = logits_per_g.t().detach().cpu()
logits = {"g_to_x": logits_per_g, "x_to_g": logits_per_x}
ground_truth = torch.arange(len(x_embeds)).view(-1, 1)
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1]
preds = preds.detach().cpu().numpy()
metrics[f"{prefix}_{name}_mean_rank"] = preds.mean() + 1
metrics[f"{prefix}_{name}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{prefix}_{name}_R@{k}"] = np.mean(preds < k)
return metrics

View File

@@ -0,0 +1,63 @@
"""
@author: Hayeon Lee
2020/02/19
Script for downloading, and reorganizing aircraft
for few shot classification
Run this file as follows:
python get_data.py
"""
import pickle
import os
import numpy as np
from tqdm import tqdm
import requests
import tarfile
from PIL import Image
import glob
import shutil
import pickle
import collections
import sys
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
from all_path import RAW_DATA_PATH
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
dir_path = RAW_DATA_PATH
if not os.path.exists(dir_path):
os.makedirs(dir_path)
file_name = os.path.join(dir_path, 'fgvc-aircraft-2013b.tar.gz')
if not os.path.exists(file_name):
print(f"Downloading {file_name}\n")
download_file(
'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz',
file_name)
print("\nDownloading done.\n")
else:
print("fgvc-aircraft-2013b.tar.gz has already been downloaded. Did not download twice.\n")
untar_file_name = os.path.join(dir_path, 'aircraft')
if not os.path.exists(untar_file_name):
tarname = file_name
print("Untarring: {}".format(tarname))
tar = tarfile.open(tarname)
tar.extractall(untar_file_name)
tar.close()
else:
print(f"{untar_file_name} folder already exists. Did not untarring twice\n")
os.remove(file_name)

View File

@@ -0,0 +1,50 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
import sys
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
from all_path import RAW_DATA_PATH
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm(unit="B", total=int(r.headers['Content-Length']))
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update(len(chunk))
f.write(chunk)
return filename
dir_path = os.path.join(RAW_DATA_PATH, 'pets')
if not os.path.exists(dir_path):
os.makedirs(dir_path)
full_name = os.path.join(dir_path, 'test15.pth')
if not os.path.exists(full_name):
print(f"Downloading {full_name}\n")
download_file(
'https://www.dropbox.com/s/kzmrwyyk5iaugv0/test15.pth?dl=1', full_name)
print("Downloading done.\n")
else:
print(f"{full_name} has already been downloaded. Did not download twice.\n")
full_name = os.path.join(dir_path, 'train85.pth')
if not os.path.exists(full_name):
print(f"Downloading {full_name}\n")
download_file(
'https://www.dropbox.com/s/w7mikpztkamnw9s/train85.pth?dl=1', full_name)
print("Downloading done.\n")
else:
print(f"{full_name} has already been downloaded. Did not download twice.\n")

View File

@@ -0,0 +1,47 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
DATA_PATH = "./data/transfer_nag"
dir_path = DATA_PATH
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
def get_preprocessed_data(file_name, url):
print(f"Downloading {file_name} datasets\n")
full_name = os.path.join(dir_path, file_name)
download_file(url, full_name)
print("Downloading done.\n")
for file_name, url in [
('aircraftbylabel.pt', 'https://www.dropbox.com/s/mb66kitv20ykctp/aircraftbylabel.pt?dl=1'),
('cifar100bylabel.pt', 'https://www.dropbox.com/s/y0xahxgzj29kffk/cifar100bylabel.pt?dl=1'),
('cifar10bylabel.pt', 'https://www.dropbox.com/s/wt1pcwi991xyhwr/cifar10bylabel.pt?dl=1'),
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
('petsbylabel.pt', 'https://www.dropbox.com/s/mxh6qz3grhy7wcn/petsbylabel.pt?dl=1'),
]:
get_preprocessed_data(file_name, url)

View File

@@ -0,0 +1,130 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import os
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=True):
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
print(f'==> The number of tasks for meta-training: {len(dataset)}')
loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=collate_fn)
return loader
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
dataset = MetaTestDataset(data_path, data_name, num_class)
print(f'==> Meta-Test dataset {data_name}')
loader = DataLoader(dataset=dataset,
batch_size=100,
shuffle=False,
num_workers=0)
return loader
class MetaTrainDatabase(Dataset):
def __init__(self, data_path, num_sample, is_pred=True):
self.mode = 'train'
self.acc_norm = True
self.num_sample = num_sample
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
mtr_data_path = os.path.join(
data_path, 'meta_train_tasks_predictor.pt')
idx_path = os.path.join(
data_path, 'meta_train_tasks_predictor_idx.pt')
data = torch.load(mtr_data_path)
self.acc = data['acc']
self.task = data['task']
self.graph = data['g']
random_idx_lst = torch.load(idx_path)
self.idx_lst = {}
self.idx_lst['valid'] = random_idx_lst[:400]
self.idx_lst['train'] = random_idx_lst[400:]
self.acc = torch.tensor(self.acc)
self.mean = torch.mean(self.acc[self.idx_lst['train']]).item()
self.std = torch.std(self.acc[self.idx_lst['train']]).item()
self.task_lst = torch.load(os.path.join(
data_path, 'meta_train_task_lst.pt'))
def set_mode(self, mode):
self.mode = mode
def __len__(self):
return len(self.idx_lst[self.mode])
def __getitem__(self, index):
data = []
ridx = self.idx_lst[self.mode]
tidx = self.task[ridx[index]]
classes = self.task_lst[tidx]
graph = self.graph[ridx[index]]
acc = self.acc[ridx[index]]
for cls in classes:
cx = self.x[cls-1][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
if self.acc_norm:
acc = ((acc - self.mean) / self.std) / 100.0
else:
acc = acc / 100.0
return x, graph, acc
class MetaTestDataset(Dataset):
def __init__(self, data_path, data_name, num_sample, num_class=None):
self.num_sample = num_sample
self.data_name = data_name
num_class_dict = {
'cifar100': 100,
'cifar10': 10,
'mnist': 10,
'svhn': 10,
'aircraft': 30,
'pets': 37
}
if num_class is not None:
self.num_class = num_class
else:
self.num_class = num_class_dict[data_name]
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
def __len__(self):
return 1000000
def __getitem__(self, index):
data = []
classes = list(range(self.num_class))
for cls in classes:
cx = self.x[cls][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
return x
def collate_fn(batch):
x = torch.stack([item[0] for item in batch])
graph = [item[1] for item in batch]
acc = torch.stack([item[2] for item in batch])
return [x, graph, acc]

View File

@@ -0,0 +1,91 @@
import os
import sys
import random
import numpy as np
import argparse
import torch
import os
from nag import NAG
sys.path.append(os.getcwd())
save_path = "results"
data_path = os.path.join('MetaD2A_nas_bench_201', 'data')
def str2bool(v):
return v.lower() in ['t', 'true', True]
def get_parser():
parser = argparse.ArgumentParser()
# general settings
parser.add_argument('--seed', type=int, default=444)
parser.add_argument('--gpu', type=str, default='0', help='set visible gpus')
parser.add_argument('--save-path', type=str, default=save_path, help='the path of save directory')
parser.add_argument('--data-path', type=str, default=data_path, help='the path of save directory')
parser.add_argument('--model-load-path', type=str, default='', help='')
parser.add_argument('--save-epoch', type=int, default=20, help='how many epochs to wait each time to save model states')
parser.add_argument('--max-epoch', type=int, default=1000, help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=1024, help='batch size for generator')
parser.add_argument('--graph-data-name', default='nasbench201', help='graph dataset name')
parser.add_argument('--nvt', type=int, default=7, help='number of different node types, 7: NAS-Bench-201 including in/out node')
# set encoder
parser.add_argument('--num-sample', type=int, default=20, help='the number of images as input for set encoder')
# graph encoder
parser.add_argument('--hs', type=int, default=512, help='hidden size of GRUs')
parser.add_argument('--nz', type=int, default=56, help='the number of dimensions of latent vectors z')
# test
parser.add_argument('--test', action='store_true', default=True, help='turn on test mode')
parser.add_argument('--load-epoch', type=int, default=100, help='checkpoint epoch loaded for meta-test')
parser.add_argument('--data-name', type=str, default='pets', help='meta-test dataset name')
parser.add_argument('--trials', type=int, default=20)
parser.add_argument('--num-class', type=int, default=None, help='the number of class of dataset')
parser.add_argument('--num-gen-arch', type=int, default=500, help='the number of candidate architectures generated by the generator')
parser.add_argument('--train-arch', type=str2bool, default=True, help='whether to train the searched architecture')
parser.add_argument('--n_init', type=int, default=10)
parser.add_argument('--N', type=int, default=1)
# DiffusionNAG
parser.add_argument('--folder_name', type=str, default='debug')
parser.add_argument('--exp_name', type=str, default='')
parser.add_argument('--classifier_scale', type=float, default=10000., help='classifier scale')
parser.add_argument('--scale_step', type=float, default=300.)
parser.add_argument('--eval_batch_size', type=int, default=256)
parser.add_argument('--predictor', type=str, default='euler_maruyama', choices=['euler_maruyama', 'reverse_diffusion', 'none'])
parser.add_argument('--corrector', type=str, default='langevin', choices=['none', 'langevin'])
parser.add_argument('--patient_factor', type=int, default=20)
parser.add_argument('--n_gen_samples', type=int, default=10)
parser.add_argument('--multi_proc', type=str2bool, default=True)
args = parser.parse_args()
return args
def set_exp_name(args):
exp_name = f'./results/transfer_nag/{args.folder_name}/{args.data_name}'
os.makedirs(exp_name, exist_ok=True)
args.exp_name = exp_name
def main():
## Get arguments
args = get_parser()
## Set gpus and seeds
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)
random.seed(args.seed)
## Set experiment name
set_exp_name(args)
## Run
nag = NAG(args)
nag.meta_test()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,305 @@
from __future__ import print_function
import torch
import os
import gc
import sys
import numpy as np
import os
import subprocess
from nag_utils import mean_confidence_interval
from nag_utils import restore_checkpoint
from nag_utils import load_graph_config
from nag_utils import load_model
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
from nas_bench_201 import train_single_model
from unnoised_model import MetaSurrogateUnnoisedModel
from diffusion.run_lib import generate_archs_meta
from diffusion.run_lib import get_sampling_fn_meta
from diffusion.run_lib import get_score_model
from diffusion.run_lib import get_surrogate
from loader import MetaTestDataset
from logger import Logger
from all_path import *
class NAG:
def __init__(self, args):
self.args = args
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
## Target dataset information
self.raw_data_path = RAW_DATA_PATH
self.data_path = DATA_PATH
self.data_name = args.data_name
self.num_class = args.num_class
self.num_sample = args.num_sample
graph_config = load_graph_config(args.graph_data_name, args.nvt, NASBENCH201)
self.meta_surrogate_unnoised_model = MetaSurrogateUnnoisedModel(args, graph_config)
load_model(model=self.meta_surrogate_unnoised_model,
ckpt_path=META_SURROGATE_UNNOISED_CKPT_PATH)
self.meta_surrogate_unnoised_model.to(self.device)
## Load pre-trained meta-surrogate model
self.meta_surrogate_ckpt_path = META_SURROGATE_CKPT_PATH
## Load score network model (base diffusion model)
self.load_diffusion_model(args=args)
## Check config
self.check_config()
## Set logger
self.logger = Logger(
log_dir=args.exp_name,
write_textfile=True
)
self.logger.update_config(args, is_args=True)
self.logger.write_str(str(vars(args)))
self.logger.write_str('-' * 100)
def check_config(self):
"""
Check if the configuration of the pre-trained score network model matches that of the meta surrogate model.
"""
scorenet_config = torch.load(self.config.scorenet_ckpt_path)['config']
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
assert scorenet_config.model.sigma_min == meta_surrogate_config.model.sigma_min
assert scorenet_config.model.sigma_max == meta_surrogate_config.model.sigma_max
assert scorenet_config.training.sde == meta_surrogate_config.training.sde
assert scorenet_config.training.continuous == meta_surrogate_config.training.continuous
assert scorenet_config.data.centered == meta_surrogate_config.data.centered
assert scorenet_config.data.max_node == meta_surrogate_config.data.max_node
assert scorenet_config.data.n_vocab == meta_surrogate_config.data.n_vocab
def forward(self, x, arch):
D_mu = self.meta_surrogate_unnoised_model.set_encode(x.to(self.device))
G_mu = self.meta_surrogate_unnoised_model.graph_encode(arch)
y_pred = self.meta_surrogate_unnoised_model.predict(D_mu, G_mu)
return y_pred
def meta_test(self):
if self.data_name == 'all':
for data_name in ['cifar10', 'cifar100', 'aircraft', 'pets']:
self.meta_test_per_dataset(data_name)
else:
self.meta_test_per_dataset(self.data_name)
def meta_test_per_dataset(self, data_name):
## Load NASBench201
self.nasbench201 = torch.load(NASBENCH201)
all_arch_str = np.array(self.nasbench201['arch']['str'])
## Load meta-test dataset
self.test_dataset = MetaTestDataset(self.data_path, data_name, self.num_sample, self.num_class)
## Set save path
meta_test_path = os.path.join(META_TEST_PATH, data_name)
os.makedirs(meta_test_path, exist_ok=True)
f_arch_str = open(os.path.join(self.args.exp_name, 'architecture.txt'), 'w')
f_arch_acc = open(os.path.join(self.args.exp_name, 'accuracy.txt'), 'w')
## Generate architectures
gen_arch_str = self.get_gen_arch_str()
gen_arch_igraph = self.get_items(
full_target=self.nasbench201['arch']['igraph'],
full_source=self.nasbench201['arch']['str'],
source=gen_arch_str)
## Sort with unnoised meta-surrogate model
y_pred_all = []
self.meta_surrogate_unnoised_model.eval()
self.meta_surrogate_unnoised_model.to(self.device)
with torch.no_grad():
for arch_igraph in gen_arch_igraph:
x, g = self.collect_data(arch_igraph)
y_pred = self.forward(x, g)
y_pred = torch.mean(y_pred)
y_pred_all.append(y_pred.cpu().detach().item())
sorted_arch_lst = self.sort_arch(data_name, torch.tensor(y_pred_all), gen_arch_str)
## Record the information of the architecture generated in sorted order
for _, arch_str in enumerate(sorted_arch_lst):
f_arch_str.write(f'{arch_str}\n')
arch_idx_lst = [self.nasbench201['arch']['str'].index(i) for i in sorted_arch_lst]
arch_str_lst = []
arch_acc_lst = []
## Get the accuracy of the architecture
if 'cifar' in data_name:
sorted_acc_lst = self.get_items(
full_target=self.nasbench201['test-acc'][data_name],
full_source=self.nasbench201['arch']['str'],
source=sorted_arch_lst)
arch_str_lst += sorted_arch_lst
arch_acc_lst += sorted_acc_lst
for arch_idx, acc in zip(arch_idx_lst, sorted_acc_lst):
msg = f'Avg {acc:4f} (%)'
f_arch_acc.write(msg + '\n')
else:
if self.args.multi_proc:
## Run multiple processes in parallel
run_file = os.path.join(os.getcwd(), 'main_exp', 'transfer_nag', 'run_multi_proc.py')
MAX_CAP = 5 # hard-coded for available GPUs
if not len(arch_idx_lst) > MAX_CAP:
arch_idx_lst_ = [arch_idx for arch_idx in arch_idx_lst if not os.path.exists(os.path.join(meta_test_path, str(arch_idx)))]
support_ = ','.join([str(i) for i in arch_idx_lst_])
num_split = int(3 * len(arch_idx_lst_)) # why 3? => running for 3 seeds
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
subprocess.run([cmd], shell=True)
else:
arch_idx_lst_ = []
for j, arch_idx in enumerate(arch_idx_lst):
if not os.path.exists(os.path.join(meta_test_path, str(arch_idx))):
arch_idx_lst_.append(arch_idx)
if (len(arch_idx_lst_) == MAX_CAP) or (j == len(arch_idx_lst) - 1):
support_ = ','.join([str(i) for i in arch_idx_lst_])
num_split = int(3 * len(arch_idx_lst_))
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
subprocess.run([cmd], shell=True)
arch_idx_lst_ = []
while True:
try:
acc_runs_lst = []
epoch = 199
seeds = (777, 888, 999)
for arch_idx in arch_idx_lst:
acc_runs = []
save_path_ = os.path.join(meta_test_path, str(arch_idx))
for seed in seeds:
result = torch.load(os.path.join(save_path_, f'seed-0{seed}.pth'))
acc_runs.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
acc_runs_lst.append(acc_runs)
break
except:
pass
for i in acc_runs_lst:print(np.mean(i))
for arch_idx, acc_runs in zip(arch_idx_lst, acc_runs_lst):
for r, acc in enumerate(acc_runs):
msg = f'run {r+1} {acc:.2f} (%)'
f_arch_acc.write(msg + '\n')
m, h = mean_confidence_interval(acc_runs)
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
f_arch_acc.write(msg + '\n')
arch_acc_lst.append(np.mean(acc_runs))
arch_str_lst.append(all_arch_str[arch_idx])
else:
for arch_idx in arch_idx_lst:
acc_runs = self.train_single_arch(
data_name, self.nasbench201['str'][arch_idx], meta_test_path)
for r, acc in enumerate(acc_runs):
msg = f'run {r+1} {acc:.2f} (%)'
f_arch_acc.write(msg + '\n')
m, h = mean_confidence_interval(acc_runs)
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
f_arch_acc.write(msg + '\n')
arch_acc_lst.append(np.mean(acc_runs))
arch_str_lst.append(all_arch_str[arch_idx])
# Save results
results_path = os.path.join(self.args.exp_name, 'results.pt')
torch.save({
'arch_idx_lst': arch_idx_lst,
'arch_str_lst': arch_str_lst,
'arch_acc_lst': arch_acc_lst
}, results_path)
print(f">>> Save the results at {results_path}...")
def train_single_arch(self, data_name, arch_str, meta_test_path):
save_path = os.path.join(meta_test_path, arch_str)
seeds = (777, 888, 999)
train_single_model(save_dir=save_path,
workers=24,
datasets=[data_name],
xpaths=[f'{self.raw_data_path}/{data_name}'],
splits=[0],
use_less=False,
seeds=seeds,
model_str=arch_str,
arch_config={'channel': 16, 'num_cells': 5})
epoch = 199
test_acc_lst = []
for seed in seeds:
result = torch.load(os.path.join(save_path, f'seed-0{seed}.pth'))
test_acc_lst.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
return test_acc_lst
def sort_arch(self, data_name, y_pred_all, gen_arch_str):
_, sorted_idx = torch.sort(y_pred_all, descending=True)
sotred_gen_arch_str = [gen_arch_str[_] for _ in sorted_idx]
return sotred_gen_arch_str
def collect_data_only(self):
x_batch = []
x_batch.append(self.test_dataset[0])
return torch.stack(x_batch).to(self.device)
def collect_data(self, arch_igraph):
x_batch, g_batch = [], []
for _ in range(10):
x_batch.append(self.test_dataset[0])
g_batch.append(arch_igraph)
return torch.stack(x_batch).to(self.device), g_batch
def get_items(self, full_target, full_source, source):
return [full_target[full_source.index(_)] for _ in source]
def load_diffusion_model(self, args):
self.config = torch.load('./configs/transfer_nag_config.pt')
self.config.device = torch.device('cuda')
self.config.data.label_list = ['meta-acc']
self.config.scorenet_ckpt_path = SCORENET_CKPT_PATH
self.config.sampling.classifier_scale = args.classifier_scale
self.config.eval.batch_size = args.eval_batch_size
self.config.sampling.predictor = args.predictor
self.config.sampling.corrector = args.corrector
self.config.sampling.check_dataname = self.data_name
self.sampling_fn, self.sde = get_sampling_fn_meta(self.config)
self.score_model, self.score_ema, self.score_config = get_score_model(self.config)
def get_gen_arch_str(self):
## Load meta-surrogate model
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
meta_surrogate_model = get_surrogate(meta_surrogate_config)
meta_surrogate_state = dict(model=meta_surrogate_model, step=0, config=meta_surrogate_config)
meta_surrogate_state = restore_checkpoint(
self.meta_surrogate_ckpt_path,
meta_surrogate_state,
device=self.config.device,
resume=True)
## Get dataset embedding, x
with torch.no_grad():
x = self.collect_data_only()
## Generate architectures
generated_arch_str = generate_archs_meta(
config=self.config,
sampling_fn=self.sampling_fn,
score_model=self.score_model,
score_ema=self.score_ema,
meta_surrogate_model=meta_surrogate_model,
num_samples=self.args.n_gen_samples,
args=self.args,
task=x)
## Clean up
meta_surrogate_model = None
gc.collect()
return generated_arch_str

View File

@@ -0,0 +1,301 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import os
import time
import igraph
import random
import numpy as np
import scipy.stats
import torch
import logging
def reset_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def restore_checkpoint(ckpt_dir, state, device, resume=False):
if not resume:
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
return state
elif not os.path.exists(ckpt_dir):
if not os.path.exists(os.path.dirname(ckpt_dir)):
os.makedirs(os.path.dirname(ckpt_dir))
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
return state
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
for k in state:
if k in ['optimizer', 'model', 'ema']:
state[k].load_state_dict(loaded_state[k])
else:
state[k] = loaded_state[k]
return state
def load_graph_config(graph_data_name, nvt, data_path):
if graph_data_name is not 'nasbench201':
raise NotImplementedError(graph_data_name)
g_list = []
max_n = 0 # maximum number of nodes
ms = torch.load(data_path)['arch']['matrix']
for i in range(len(ms)):
g, n = decode_NAS_BENCH_201_8_to_igraph(ms[i])
max_n = max(max_n, n)
g_list.append((g, 0))
# number of different node types including in/out node
graph_config = {}
graph_config['num_vertex_type'] = nvt # original types + start/end types
graph_config['max_n'] = max_n # maximum number of nodes
graph_config['START_TYPE'] = 0 # predefined start vertex type
graph_config['END_TYPE'] = 1 # predefined end vertex type
return graph_config
def decode_NAS_BENCH_201_8_to_igraph(row):
if type(row) == str:
row = eval(row) # convert string to list of lists
n = len(row)
g = igraph.Graph(directed=True)
g.add_vertices(n)
for i, node in enumerate(row):
g.vs[i]['type'] = node[0]
if i < (n - 2) and i > 0:
g.add_edge(i, i + 1) # always connect from last node
for j, edge in enumerate(node[1:]):
if edge == 1:
g.add_edge(j, i)
return g, n
def is_valid_NAS201(g, START_TYPE=0, END_TYPE=1):
# first need to be a valid DAG computation graph
res = is_valid_DAG(g, START_TYPE, END_TYPE)
# in addition, node i must connect to node i+1
res = res and len(g.vs['type']) == 8
res = res and not (0 in g.vs['type'][1:-1])
res = res and not (1 in g.vs['type'][1:-1])
return res
def decode_igraph_to_NAS201_matrix(g):
m = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
xys = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]
for i, xy in enumerate(xys):
m[xy[0]][xy[1]] = float(g.vs[i + 1]['type']) - 2
import numpy
return numpy.array(m)
def decode_igraph_to_NAS_BENCH_201_string(g):
if not is_valid_NAS201(g):
return None
m = decode_igraph_to_NAS201_matrix(g)
types = ['none', 'skip_connect', 'nor_conv_1x1',
'nor_conv_3x3', 'avg_pool_3x3']
return '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
format(types[int(m[1][0])],
types[int(m[2][0])], types[int(m[2][1])],
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
res = g.is_dag()
n_start, n_end = 0, 0
for v in g.vs:
if v['type'] == START_TYPE:
n_start += 1
elif v['type'] == END_TYPE:
n_end += 1
if v.indegree() == 0 and v['type'] != START_TYPE:
return False
if v.outdegree() == 0 and v['type'] != END_TYPE:
return False
return res and n_start == 1 and n_end == 1
class Accumulator():
def __init__(self, *args):
self.args = args
self.argdict = {}
for i, arg in enumerate(args):
self.argdict[arg] = i
self.sums = [0] * len(args)
self.cnt = 0
def accum(self, val):
val = [val] if type(val) is not list else val
val = [v for v in val if v is not None]
assert (len(val) == len(self.args))
for i in range(len(val)):
if torch.is_tensor(val[i]):
val[i] = val[i].item()
self.sums[i] += val[i]
self.cnt += 1
def clear(self):
self.sums = [0] * len(self.args)
self.cnt = 0
def get(self, arg, avg=True):
i = self.argdict.get(arg, -1)
assert (i is not -1)
if avg:
return self.sums[i] / (self.cnt + 1e-8)
else:
return self.sums[i]
def print_(self, header=None, time=None,
logfile=None, do_not_print=[], as_int=[],
avg=True):
msg = '' if header is None else header + ': '
if time is not None:
msg += ('(%.3f secs), ' % time)
args = [arg for arg in self.args if arg not in do_not_print]
arg = []
for arg in args:
val = self.sums[self.argdict[arg]]
if avg:
val /= (self.cnt + 1e-8)
if arg in as_int:
msg += ('%s %d, ' % (arg, int(val)))
else:
msg += ('%s %.4f, ' % (arg, val))
print(msg)
if logfile is not None:
logfile.write(msg + '\n')
logfile.flush()
def add_scalars(self, summary, header=None, tag_scalar=None,
step=None, avg=True, args=None):
for arg in self.args:
val = self.sums[self.argdict[arg]]
if avg:
val /= (self.cnt + 1e-8)
else:
val = val
tag = f'{header}/{arg}' if header is not None else arg
if tag_scalar is not None:
summary.add_scalars(main_tag=tag,
tag_scalar_dict={tag_scalar: val},
global_step=step)
else:
summary.add_scalar(tag=tag,
scalar_value=val,
global_step=step)
class Log:
def __init__(self, args, logf, summary=None):
self.args = args
self.logf = logf
self.summary = summary
self.stime = time.time()
self.ep_sttime = None
def print(self, logger, epoch, tag=None, avg=True):
if tag == 'train':
ct = time.time() - self.ep_sttime
tt = time.time() - self.stime
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
print(msg)
self.logf.write(msg+'\n')
logger.print_(header=tag, logfile=self.logf, avg=avg)
if self.summary is not None:
logger.add_scalars(
self.summary, header=tag, step=epoch, avg=avg)
logger.clear()
def print_args(self):
argdict = vars(self.args)
print(argdict)
for k, v in argdict.items():
self.logf.write(k + ': ' + str(v) + '\n')
self.logf.write('\n')
def set_time(self):
self.stime = time.time()
def save_time_log(self):
ct = time.time() - self.stime
msg = f'({ct:6.2f}s) meta-training phase done'
print(msg)
self.logf.write(msg+'\n')
def print_pred_log(self, loss, corr, tag, epoch=None, max_corr_dict=None):
if tag == 'train':
ct = time.time() - self.ep_sttime
tt = time.time() - self.stime
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
self.logf.write(msg+'\n')
print(msg)
self.logf.flush()
# msg = f'ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} '
# msg += f'time {time.time() - sttime:6.2f} '
if max_corr_dict is not None:
max_corr = max_corr_dict['corr']
max_loss = max_corr_dict['loss']
msg = f'{tag}: loss {loss:.6f} ({max_loss:.6f}) '
msg += f'corr {corr:.4f} ({max_corr:.4f})'
else:
msg = f'{tag}: loss {loss:.6f} corr {corr:.4f}'
self.logf.write(msg+'\n')
print(msg)
self.logf.flush()
def max_corr_log(self, max_corr_dict):
corr = max_corr_dict['corr']
loss = max_corr_dict['loss']
epoch = max_corr_dict['epoch']
msg = f'[epoch {epoch}] max correlation: {corr:.4f}, loss: {loss:.6f}'
self.logf.write(msg+'\n')
print(msg)
self.logf.flush()
def get_log(epoch, loss, y_pred, y, acc_std, acc_mean, tag='train'):
msg = f'[{tag}] Ep {epoch} loss {loss.item()/len(y):0.4f} '
if type(y_pred) == list:
msg += f'pacc {y_pred[0]:0.4f}'
msg += f'({y_pred[0]*100.0*acc_std+acc_mean:0.4f}) '
else:
msg += f'pacc {y_pred:0.4f}'
msg += f'({y_pred*100.0*acc_std+acc_mean:0.4f}) '
msg += f'acc {y[0]:0.4f}({y[0]*100*acc_std+acc_mean:0.4f})'
return msg
def load_model(model, ckpt_path):
model.cpu()
model.load_state_dict(torch.load(ckpt_path))
def save_model(epoch, model, model_path, max_corr=None):
print("==> save current model...")
if max_corr is not None:
torch.save(model.cpu().state_dict(),
os.path.join(model_path, 'ckpt_max_corr.pt'))
else:
torch.save(model.cpu().state_dict(),
os.path.join(model_path, f'ckpt_{epoch}.pt'))
def mean_confidence_interval(data, confidence=0.95):
a = 1.0 * np.array(data)
n = len(a)
m, se = np.mean(a), scipy.stats.sem(a)
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
return m, h

View File

@@ -0,0 +1,6 @@
from pathlib import Path
import sys
dir_path = (Path(__file__).parent).resolve()
if str(dir_path) not in sys.path: sys.path.insert(0, str(dir_path))
from .architecture import train_single_model

View File

@@ -0,0 +1,173 @@
###############################################################
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
###############################################################
from functions import evaluate_for_seed
from nas_bench_201_models import CellStructure, CellArchitectures, get_search_spaces
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
from nas_bench_201_datasets import get_datasets
from procedures import get_machine_info
from procedures import save_checkpoint, copy_checkpoint
from config_utils import load_config
from pathlib import Path
from copy import deepcopy
import os
import sys
import time
import torch
import random
import argparse
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
NASBENCH201_CONFIG_PATH = os.path.join(
os.getcwd(), 'main_exp', 'transfer_nag')
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed,
arch_config, workers, logger):
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
all_infos = {'info': machine_info}
all_dataset_keys = []
# look all the datasets
for dataset, xpath, split in zip(datasets, xpaths, splits):
# train valid data
task = None
train_data, valid_data, xshape, class_num = get_datasets(
dataset, xpath, -1, task)
# load the configuration
if dataset in ['mnist', 'svhn', 'aircraft', 'pets']:
if use_less:
config_path = os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/LESS.config')
else:
config_path = os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}.config'.format(dataset))
p = os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(dataset))
if not os.path.exists(p):
import json
label_list = list(range(len(train_data)))
random.shuffle(label_list)
strlist = [str(label_list[i]) for i in range(len(label_list))]
splited = {'train': ["int", strlist[:len(train_data) // 2]],
'valid': ["int", strlist[len(train_data) // 2:]]}
with open(p, 'w') as f:
f.write(json.dumps(splited))
split_info = load_config(os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(dataset)), None, None)
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
config = load_config(
config_path, {'class_num': class_num, 'xshape': xshape}, logger)
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size,
shuffle=True, num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
shuffle=False, num_workers=workers, pin_memory=True)
splits = load_config(os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}-test-split.txt'.format(dataset)), None, None)
ValLoaders = {'ori-test': valid_loader,
'x-valid': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
splits.xvalid),
num_workers=workers, pin_memory=True),
'x-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
splits.xtest),
num_workers=workers, pin_memory=True)
}
dataset_key = '{:}'.format(dataset)
if bool(split):
dataset_key = dataset_key + '-valid'
logger.log(
'Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.
format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(
dataset_key, config))
for key, value in ValLoaders.items():
logger.log(
'Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
results = evaluate_for_seed(
arch_config, config, arch, train_loader, ValLoaders, seed, logger)
all_infos[dataset_key] = results
all_dataset_keys.append(dataset_key)
all_infos['all_dataset_keys'] = all_dataset_keys
return all_infos
def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less,
seeds, model_str, arch_config):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.set_num_threads(workers)
save_dir = Path(save_dir)
logger = Logger(str(save_dir), 0, False)
if model_str in CellArchitectures:
arch = CellArchitectures[model_str]
logger.log(
'The model string is found in pre-defined architecture dict : {:}'.format(model_str))
else:
try:
arch = CellStructure.str2structure(model_str)
except:
raise ValueError(
'Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
assert arch.check_valid_op(get_search_spaces(
'cell', 'nas-bench-201')), '{:} has the invalid op.'.format(arch)
# assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
logger.log('Start train-evaluate {:}'.format(arch.tostr()))
logger.log('arch_config : {:}'.format(arch_config))
start_time, seed_time = time.time(), AverageMeter()
for _is, seed in enumerate(seeds):
logger.log(
'\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds),
seed))
to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
if to_save_name.exists():
logger.log(
'Find the existing file {:}, directly load!'.format(to_save_name))
checkpoint = torch.load(to_save_name)
else:
logger.log(
'Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less,
seed, arch_config, workers, logger)
torch.save(checkpoint, to_save_name)
# log information
logger.log('{:}'.format(checkpoint['info']))
all_dataset_keys = checkpoint['all_dataset_keys']
for dataset_key in all_dataset_keys:
logger.log('\n{:} dataset : {:} {:}'.format(
'-' * 15, dataset_key, '-' * 15))
dataset_info = checkpoint[dataset_key]
# logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
logger.log('Flops = {:} MB, Params = {:} MB'.format(
dataset_info['flop'], dataset_info['param']))
logger.log('config : {:}'.format(dataset_info['config']))
logger.log('Training State (finish) = {:}'.format(
dataset_info['finish-train']))
last_epoch = dataset_info['total_epoch'] - 1
train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
# measure elapsed time
seed_time.update(time.time() - start_time)
start_time = time.time()
need_time = 'Time Left: {:}'.format(convert_secs2time(
seed_time.avg * (len(seeds) - _is - 1), True))
logger.log(
'\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed,
need_time))
logger.close()

View File

@@ -0,0 +1,13 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .configure_utils import load_config, dict2config#, configure2str
#from .basic_args import obtain_basic_args
#from .attention_args import obtain_attention_args
#from .random_baseline import obtain_RandomSearch_args
#from .cls_kd_args import obtain_cls_kd_args
#from .cls_init_args import obtain_cls_init_args
#from .search_single_args import obtain_search_single_args
#from .search_args import obtain_search_args
# for network pruning
#from .pruning_args import obtain_pruning_args

View File

@@ -0,0 +1,106 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os, json
from os import path as osp
from pathlib import Path
from collections import namedtuple
support_types = ('str', 'int', 'bool', 'float', 'none')
def convert_param(original_lists):
assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists)
ctype, value = original_lists[0], original_lists[1]
assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types)
is_list = isinstance(value, list)
if not is_list: value = [value]
outs = []
for x in value:
if ctype == 'int':
x = int(x)
elif ctype == 'str':
x = str(x)
elif ctype == 'bool':
x = bool(int(x))
elif ctype == 'float':
x = float(x)
elif ctype == 'none':
if x.lower() != 'none':
raise ValueError('For the none type, the value must be none instead of {:}'.format(x))
x = None
else:
raise TypeError('Does not know this type : {:}'.format(ctype))
outs.append(x)
if not is_list: outs = outs[0]
return outs
def load_config(path, extra, logger):
path = str(path)
if hasattr(logger, 'log'): logger.log(path)
assert os.path.exists(path), 'Can not find {:}'.format(path)
# Reading data back
with open(path, 'r') as f:
data = json.load(f)
content = { k: convert_param(v) for k,v in data.items()}
assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra)
if isinstance(extra, dict): content = {**content, **extra}
Arguments = namedtuple('Configure', ' '.join(content.keys()))
content = Arguments(**content)
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
return content
def configure2str(config, xpath=None):
if not isinstance(config, dict):
config = config._asdict()
def cstring(x):
return "\"{:}\"".format(x)
def gtype(x):
if isinstance(x, list): x = x[0]
if isinstance(x, str) : return 'str'
elif isinstance(x, bool) : return 'bool'
elif isinstance(x, int): return 'int'
elif isinstance(x, float): return 'float'
elif x is None : return 'none'
else: raise ValueError('invalid : {:}'.format(x))
def cvalue(x, xtype):
if isinstance(x, list): is_list = True
else:
is_list, x = False, [x]
temps = []
for temp in x:
if xtype == 'bool' : temp = cstring(int(temp))
elif xtype == 'none': temp = cstring('None')
else : temp = cstring(temp)
temps.append( temp )
if is_list:
return "[{:}]".format( ', '.join( temps ) )
else:
return temps[0]
xstrings = []
for key, value in config.items():
xtype = gtype(value)
string = ' {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype))
xstrings.append(string)
Fstring = '{\n' + ',\n'.join(xstrings) + '\n}'
if xpath is not None:
parent = Path(xpath).resolve().parent
parent.mkdir(parents=True, exist_ok=True)
if osp.isfile(xpath): os.remove(xpath)
with open(xpath, "w") as text_file:
text_file.write('{:}'.format(Fstring))
return Fstring
def dict2config(xdict, logger):
assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) )
Arguments = namedtuple('Configure', ' '.join(xdict.keys()))
content = Arguments(**xdict)
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
return content

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "200"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "50"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "200"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "200"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

View File

@@ -0,0 +1,153 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
import time
import torch
from procedures import prepare_seed, get_optim_scheduler
from nasbench_utils import get_model_infos, obtain_accuracy
from config_utils import dict2config
from log_utils import AverageMeter, time_string, convert_secs2time
from nas_bench_201_models import get_cell_based_tiny_net
__all__ = ['evaluate_for_seed', 'pure_evaluate']
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
latencies = []
network.eval()
with torch.no_grad():
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
targets = targets.cuda(non_blocking=True)
inputs = inputs.cuda(non_blocking=True)
data_time.update(time.time() - end)
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
batch_time.update(time.time() - end)
if batch is None or batch == inputs.size(0):
batch = inputs.size(0)
latencies.append(batch_time.val - data_time.val)
# record loss and accuracy
prec1, prec5 = obtain_accuracy(
logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
end = time.time()
if len(latencies) > 2:
latencies = latencies[1:]
return losses.avg, top1.avg, top5.avg, latencies
def procedure(xloader, network, criterion, scheduler, optimizer, mode):
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
if mode == 'train':
network.train()
elif mode == 'valid':
network.eval()
else:
raise ValueError("The mode is not right : {:}".format(mode))
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == 'train':
scheduler.update(None, 1.0 * i / len(xloader))
targets = targets.cuda(non_blocking=True)
if mode == 'train':
optimizer.zero_grad()
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
# backward
if mode == 'train':
loss.backward()
optimizer.step()
# record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
# count time
batch_time.update(time.time() - end)
end = time.time()
return losses.avg, top1.avg, top5.avg, batch_time.sum
def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger):
prepare_seed(seed) # random seed
net = get_cell_based_tiny_net(dict2config({'name': 'infer.tiny',
'C': arch_config['channel'], 'N': arch_config['num_cells'],
'genotype': arch, 'num_classes': config.class_num}, None)
)
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
if 'ckpt_path' in arch_config.keys():
ckpt = torch.load(arch_config['ckpt_path'])
ckpt['classifier.weight'] = net.state_dict()['classifier.weight']
ckpt['classifier.bias'] = net.state_dict()['classifier.bias']
net.load_state_dict(ckpt)
flop, param = get_model_infos(net, config.xshape)
logger.log('Network : {:}'.format(net.get_message()), False)
logger.log(
'{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed))
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
# train and valid
optimizer, scheduler, criterion = get_optim_scheduler(
net.parameters(), config)
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
# network, criterion = torch.nn.DataParallel(net).to(torch.device(f"cuda:{device}")), criterion.to(torch.device(f"cuda:{device}"))
# start training
start_time, epoch_time, total_epoch = time.time(
), AverageMeter(), config.epochs + config.warmup
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {
}, {}, {}, {}, {}, {}
train_times, valid_times = {}, {}
for epoch in range(total_epoch):
scheduler.update(epoch, 0.0)
train_loss, train_acc1, train_acc5, train_tm = procedure(
train_loader, network, criterion, scheduler, optimizer, 'train')
train_losses[epoch] = train_loss
train_acc1es[epoch] = train_acc1
train_acc5es[epoch] = train_acc5
train_times[epoch] = train_tm
with torch.no_grad():
for key, xloder in valid_loaders.items():
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(
xloder, network, criterion, None, None, 'valid')
valid_losses['{:}@{:}'.format(key, epoch)] = valid_loss
valid_acc1es['{:}@{:}'.format(key, epoch)] = valid_acc1
valid_acc5es['{:}@{:}'.format(key, epoch)] = valid_acc5
valid_times['{:}@{:}'.format(key, epoch)] = valid_tm
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = 'Time Left: {:}'.format(convert_secs2time(
epoch_time.avg * (total_epoch-epoch-1), True))
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]'.format(
time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5))
info_seed = {'flop': flop,
'param': param,
'channel': arch_config['channel'],
'num_cells': arch_config['num_cells'],
'config': config._asdict(),
'total_epoch': total_epoch,
'train_losses': train_losses,
'train_acc1es': train_acc1es,
'train_acc5es': train_acc5es,
'train_times': train_times,
'valid_losses': valid_losses,
'valid_acc1es': valid_acc1es,
'valid_acc5es': valid_acc5es,
'valid_times': valid_times,
'net_state_dict': net.state_dict(),
'net_string': '{:}'.format(net),
'finish-train': True
}
return info_seed

View File

@@ -0,0 +1,9 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# every package does not rely on pytorch or tensorflow
# I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib
from .logger import Logger#, PrintLogger
from .meter import AverageMeter
from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_secs2time
from .time_utils import time_string, convert_secs2time

View File

@@ -0,0 +1,150 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from pathlib import Path
import importlib, warnings
import os, sys, time, numpy as np
if sys.version_info.major == 2: # Python 2.x
from StringIO import StringIO as BIO
else: # Python 3.x
from io import BytesIO as BIO
if importlib.util.find_spec('tensorflow'):
import tensorflow as tf
class PrintLogger(object):
def __init__(self):
"""Create a summary writer logging to log_dir."""
self.name = 'PrintLogger'
def log(self, string):
print (string)
def close(self):
print ('-'*30 + ' close printer ' + '-'*30)
class Logger(object):
def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False):
"""Create a summary writer logging to log_dir."""
self.seed = int(seed)
self.log_dir = Path(log_dir)
self.model_dir = Path(log_dir) / 'checkpoint'
self.log_dir.mkdir (parents=True, exist_ok=True)
if create_model_dir:
self.model_dir.mkdir(parents=True, exist_ok=True)
#self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
self.use_tf = bool(use_tf)
self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h', time.gmtime(time.time()) )))
#self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) )))
self.logger_path = self.log_dir / 'seed-{:}-T-{:}.log'.format(self.seed, time.strftime('%d-%h-at-%H-%M-%S', time.gmtime(time.time())))
self.logger_file = open(self.logger_path, 'w')
if self.use_tf:
self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
self.writer = tf.summary.FileWriter(str(self.tensorboard_dir))
else:
self.writer = None
def __repr__(self):
return ('{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})'.format(name=self.__class__.__name__, **self.__dict__))
def path(self, mode):
valids = ('model', 'best', 'info', 'log')
if mode == 'model': return self.model_dir / 'seed-{:}-basic.pth'.format(self.seed)
elif mode == 'best' : return self.model_dir / 'seed-{:}-best.pth'.format(self.seed)
elif mode == 'info' : return self.log_dir / 'seed-{:}-last-info.pth'.format(self.seed)
elif mode == 'log' : return self.log_dir
else: raise TypeError('Unknow mode = {:}, valid modes = {:}'.format(mode, valids))
def extract_log(self):
return self.logger_file
def close(self):
self.logger_file.close()
if self.writer is not None:
self.writer.close()
def log(self, string, save=True, stdout=False):
if stdout:
sys.stdout.write(string); sys.stdout.flush()
else:
print (string)
if save:
self.logger_file.write('{:}\n'.format(string))
self.logger_file.flush()
def scalar_summary(self, tags, values, step):
"""Log a scalar variable."""
if not self.use_tf:
warnings.warn('Do set use-tensorflow installed but call scalar_summary')
else:
assert isinstance(tags, list) == isinstance(values, list), 'Type : {:} vs {:}'.format(type(tags), type(values))
if not isinstance(tags, list):
tags, values = [tags], [values]
for tag, value in zip(tags, values):
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
self.writer.flush()
def image_summary(self, tag, images, step):
"""Log a list of images."""
import scipy
if not self.use_tf:
warnings.warn('Do set use-tensorflow installed but call scalar_summary')
return
img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(img).save(s, format="png")
# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1])
# Create a Summary value
img_summaries.append(tf.Summary.Value(tag='{}/{}'.format(tag, i), image=img_sum))
# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)
self.writer.flush()
def histo_summary(self, tag, values, step, bins=1000):
"""Log a histogram of the tensor of values."""
if not self.use_tf: raise ValueError('Do not have tensorflow')
import tensorflow as tf
# Create a histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)
# Fill the fields of the histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values**2))
# Drop the start of the first bin
bin_edges = bin_edges[1:]
# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)
# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
self.writer.flush()

View File

@@ -0,0 +1,98 @@
import numpy as np
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0.0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __repr__(self):
return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__))
class RecorderMeter(object):
"""Computes and stores the minimum loss value and its epoch index"""
def __init__(self, total_epoch):
self.reset(total_epoch)
def reset(self, total_epoch):
assert total_epoch > 0, 'total_epoch should be greater than 0 vs {:}'.format(total_epoch)
self.total_epoch = total_epoch
self.current_epoch = 0
self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
self.epoch_losses = self.epoch_losses - 1
self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
self.epoch_accuracy= self.epoch_accuracy
def update(self, idx, train_loss, train_acc, val_loss, val_acc):
assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx)
self.epoch_losses [idx, 0] = train_loss
self.epoch_losses [idx, 1] = val_loss
self.epoch_accuracy[idx, 0] = train_acc
self.epoch_accuracy[idx, 1] = val_acc
self.current_epoch = idx + 1
return self.max_accuracy(False) == self.epoch_accuracy[idx, 1]
def max_accuracy(self, istrain):
if self.current_epoch <= 0: return 0
if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max()
else: return self.epoch_accuracy[:self.current_epoch, 1].max()
def plot_curve(self, save_path):
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
title = 'the accuracy/loss curve of train/val'
dpi = 100
width, height = 1600, 1000
legend_fontsize = 10
figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize)
x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
y_axis = np.zeros(self.total_epoch)
plt.xlim(0, self.total_epoch)
plt.ylim(0, 100)
interval_y = 5
interval_x = 5
plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
plt.yticks(np.arange(0, 100 + interval_y, interval_y))
plt.grid()
plt.title(title, fontsize=20)
plt.xlabel('the training epoch', fontsize=16)
plt.ylabel('accuracy', fontsize=16)
y_axis[:] = self.epoch_accuracy[:, 0]
plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
y_axis[:] = self.epoch_accuracy[:, 1]
plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
y_axis[:] = self.epoch_losses[:, 0]
plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
y_axis[:] = self.epoch_losses[:, 1]
plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
if save_path is not None:
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
print ('---- save figure {} into {}'.format(title, save_path))
plt.close(fig)

View File

@@ -0,0 +1,42 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import time, sys
import numpy as np
def time_for_file():
ISOTIMEFORMAT='%d-%h-at-%H-%M-%S'
return '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
def time_string():
ISOTIMEFORMAT='%Y-%m-%d %X'
string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def time_string_short():
ISOTIMEFORMAT='%Y%m%d'
string = '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def time_print(string, is_print=True):
if (is_print):
print('{} : {}'.format(time_string(), string))
def convert_secs2time(epoch_time, return_str=False):
need_hour = int(epoch_time / 3600)
need_mins = int((epoch_time - 3600*need_hour) / 60)
need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
if return_str:
str = '[{:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
return str
else:
return need_hour, need_mins, need_secs
def print_log(print_string, log):
#if isinstance(log, Logger): log.log('{:}'.format(print_string))
if hasattr(log, 'log'): log.log('{:}'.format(print_string))
else:
print("{:}".format(print_string))
if log is not None:
log.write('{:}\n'.format(print_string))
log.flush()

View File

@@ -0,0 +1,4 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .get_dataset_with_transform import get_datasets

View File

@@ -0,0 +1,179 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import torch.utils.data as data
from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader
from PIL import Image
import os
import numpy as np
def make_dataset(dir, image_ids, targets):
assert (len(image_ids) == len(targets))
images = []
dir = os.path.expanduser(dir)
for i in range(len(image_ids)):
item = (os.path.join(dir, 'data', 'images',
'%s.jpg' % image_ids[i]), targets[i])
images.append(item)
return images
def find_classes(classes_file):
# read classes file, separating out image IDs and class names
image_ids = []
targets = []
f = open(classes_file, 'r')
for line in f:
split_line = line.split(' ')
image_ids.append(split_line[0])
targets.append(' '.join(split_line[1:]))
f.close()
# index class names
classes = np.unique(targets)
class_to_idx = {classes[i]: i for i in range(len(classes))}
targets = [class_to_idx[c] for c in targets]
return (image_ids, targets, classes, class_to_idx)
class FGVCAircraft(data.Dataset):
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset.
Args:
root (string): Root directory path to dataset.
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g. ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in the root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
class_types = ('variant', 'family', 'manufacturer')
splits = ('train', 'val', 'trainval', 'test')
def __init__(self, root, class_type='variant', split='train', transform=None,
target_transform=None, loader=default_loader, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
if class_type not in self.class_types:
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
class_type, ', '.join(self.class_types),
))
self.root = os.path.expanduser(root)
self.root = os.path.join(self.root, 'fgvc-aircraft-2013b')
self.class_type = class_type
self.split = split
self.classes_file = os.path.join(self.root, 'data',
'images_%s_%s.txt' % (self.class_type, self.split))
if download:
self.download()
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
samples = make_dataset(self.root, image_ids, targets)
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.samples = samples
self.classes = classes
self.class_to_idx = class_to_idx
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def _check_exists(self):
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
os.path.exists(self.classes_file)
def download(self):
"""Download the FGVC-Aircraft data if it doesn't exist already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
print('Downloading %s ... (may take a few minutes)' % self.url)
parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
tar_name = self.url.rpartition('/')[-1]
tar_path = os.path.join(parent_dir, tar_name)
data = urllib.request.urlopen(self.url)
# download .tar.gz file
with open(tar_path, 'wb') as f:
f.write(data.read())
# extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
data_folder = tar_path.strip('.tar.gz')
print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
tar = tarfile.open(tar_path)
tar.extractall(parent_dir)
# if necessary, rename data folder to self.root
if not os.path.samefile(data_folder, self.root):
print('Renaming %s to %s ...' % (data_folder, self.root))
os.rename(data_folder, self.root)
# delete .tar.gz file
print('Deleting %s ...' % tar_path)
os.remove(tar_path)
print('Done!')
if __name__ == '__main__':
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='train', transform=None,
target_transform=None, loader=default_loader, download=False)
print(len(air))
print(len(air))
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='val', transform=None,
target_transform=None, loader=default_loader, download=False)
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='trainval', transform=None,
target_transform=None, loader=default_loader, download=False)
print(len(air))
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b/', class_type='manufacturer', split='test', transform=None,
target_transform=None, loader=default_loader, download=False)
print(len(air))
import pdb;
pdb.set_trace()
print(len(air))

View File

@@ -0,0 +1,304 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
# Modified by Hayeon Lee, Eunyoung Hyung 2021. 03.
##################################################
import os
import sys
import torch
import os.path as osp
import numpy as np
import torchvision.datasets as dset
import torchvision.transforms as transforms
from copy import deepcopy
# from PIL import Image
import random
import pdb
from .aircraft import FGVCAircraft
from .pets import PetDataset
from config_utils import load_config
Dataset2Class = {'cifar10': 10,
'cifar100': 100,
'mnist': 10,
'svhn': 10,
'aircraft': 30,
'pets': 37}
class CUTOUT(object):
def __init__(self, length):
self.length = length
def __repr__(self):
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
imagenet_pca = {
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
'eigvec': np.asarray([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
class Lighting(object):
def __init__(self, alphastd,
eigval=imagenet_pca['eigval'],
eigvec=imagenet_pca['eigvec']):
self.alphastd = alphastd
assert eigval.shape == (3,)
assert eigvec.shape == (3, 3)
self.eigval = eigval
self.eigvec = eigvec
def __call__(self, img):
if self.alphastd == 0.:
return img
rnd = np.random.randn(3) * self.alphastd
rnd = rnd.astype('float32')
v = rnd
old_dtype = np.asarray(img).dtype
v = v * self.eigval
v = v.reshape((3, 1))
inc = np.dot(self.eigvec, v).reshape((3,))
img = np.add(img, inc)
if old_dtype == np.uint8:
img = np.clip(img, 0, 255)
img = Image.fromarray(img.astype(old_dtype), 'RGB')
return img
def __repr__(self):
return self.__class__.__name__ + '()'
def get_datasets(name, root, cutout, use_num_cls=None):
if name == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif name == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name.startswith('mnist'):
mean, std = [0.1307, 0.1307, 0.1307], [0.3081, 0.3081, 0.3081]
elif name.startswith('svhn'):
mean, std = [0.4376821, 0.4437697, 0.47280442], [
0.19803012, 0.20101562, 0.19703614]
elif name.startswith('aircraft'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name.startswith('pets'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
else:
raise TypeError("Unknow dataset : {:}".format(name))
# Data Argumentation
if name == 'cifar10' or name == 'cifar100':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
transforms.Normalize(mean, std)]
if cutout > 0:
lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 32, 32)
elif name.startswith('cub200'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
xshape = (1, 3, 32, 32)
elif name.startswith('mnist'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean, std),
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean, std)
])
xshape = (1, 3, 32, 32)
elif name.startswith('svhn'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
xshape = (1, 3, 32, 32)
elif name.startswith('aircraft'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
xshape = (1, 3, 32, 32)
elif name.startswith('pets'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
xshape = (1, 3, 32, 32)
else:
raise TypeError("Unknow dataset : {:}".format(name))
if name == 'cifar10':
train_data = dset.CIFAR10(
root, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR10(
root, train=False, transform=test_transform, download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'cifar100':
train_data = dset.CIFAR100(
root, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR100(
root, train=False, transform=test_transform, download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'mnist':
train_data = dset.MNIST(
root, train=True, transform=train_transform, download=True)
test_data = dset.MNIST(
root, train=False, transform=test_transform, download=True)
assert len(train_data) == 60000 and len(test_data) == 10000
elif name == 'svhn':
train_data = dset.SVHN(root, split='train',
transform=train_transform, download=True)
test_data = dset.SVHN(root, split='test',
transform=test_transform, download=True)
assert len(train_data) == 73257 and len(test_data) == 26032
elif name == 'aircraft':
train_data = FGVCAircraft(root, class_type='manufacturer', split='trainval',
transform=train_transform, download=False)
test_data = FGVCAircraft(root, class_type='manufacturer', split='test',
transform=test_transform, download=False)
assert len(train_data) == 6667 and len(test_data) == 3333
elif name == 'pets':
train_data = PetDataset(root, train=True, num_cl=37,
val_split=0.15, transforms=train_transform)
test_data = PetDataset(root, train=False, num_cl=37,
val_split=0.15, transforms=test_transform)
else:
raise TypeError("Unknow dataset : {:}".format(name))
class_num = Dataset2Class[name] if use_num_cls is None else len(
use_num_cls)
return train_data, test_data, xshape, class_num
def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, num_cls=None):
if isinstance(batch_size, (list, tuple)):
batch, test_batch = batch_size
else:
batch, test_batch = batch_size, batch_size
if dataset == 'cifar10':
# split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(
'{:}/cifar-split.txt'.format(config_root), None, None)
# search over the proposed training and validation set
train_split, valid_split = cifar_split.train, cifar_split.valid
# logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
# To split data
xvalid_data = deepcopy(train_data)
if hasattr(xvalid_data, 'transforms'): # to avoid a print issue
xvalid_data.transforms = valid_data.transform
xvalid_data.transform = deepcopy(valid_data.transform)
search_data = SearchDataset(
dataset, train_data, train_split, valid_split)
# data loader
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
train_split),
num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
valid_split),
num_workers=workers, pin_memory=True)
elif dataset == 'cifar100':
cifar100_test_split = load_config(
'{:}/cifar100-test-split.txt'.format(config_root), None, None)
search_train_data = train_data
search_valid_data = deepcopy(valid_data)
search_valid_data.transform = train_data.transform
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
list(range(len(search_train_data))),
cifar100_test_split.xvalid)
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True, num_workers=workers,
pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
cifar100_test_split.xvalid), num_workers=workers, pin_memory=True)
elif dataset in ['mnist', 'svhn', 'aircraft', 'pets']:
if not os.path.exists('{:}/{}-test-split.txt'.format(config_root, dataset)):
import json
label_list = list(range(len(valid_data)))
random.shuffle(label_list)
strlist = [str(label_list[i]) for i in range(len(label_list))]
split = {'xvalid': ["int", strlist[:len(valid_data) // 2]],
'xtest': ["int", strlist[len(valid_data) // 2:]]}
with open('{:}/{}-test-split.txt'.format(config_root, dataset), 'w') as f:
f.write(json.dumps(split))
test_split = load_config(
'{:}/{}-test-split.txt'.format(config_root, dataset), None, None)
search_train_data = train_data
search_valid_data = deepcopy(valid_data)
search_valid_data.transform = train_data.transform
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
list(range(len(search_train_data))), test_split.xvalid)
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True,
num_workers=workers, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True,
num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
test_split.xvalid), num_workers=workers, pin_memory=True)
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
return search_loader, train_loader, valid_loader

View File

@@ -0,0 +1,45 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import torch
from glob import glob
from torch.utils.data.dataset import Dataset
import os
from PIL import Image
def load_image(filename):
img = Image.open(filename)
img = img.convert('RGB')
return img
class PetDataset(Dataset):
def __init__(self, root, train=True, num_cl=37, val_split=0.2, transforms=None):
self.data = torch.load(os.path.join(root,'{}{}.pth'.format('train' if train else 'test',
int(100*(1-val_split)) if train else int(100*val_split))))
self.len = len(self.data)
self.transform = transforms
def __getitem__(self, index):
img, label = self.data[index]
if self.transform:
img = self.transform(img)
return img, label
def __len__(self):
return self.len
if __name__ == '__main__':
# Added
import torchvision.transforms as transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose(
[transforms.Resize(256), transforms.RandomRotation(45), transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
test_transform = transforms.Compose(
[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
root = '/w14/dataset/MetaGen/pets'
train_data, test_data = get_pets(root, num_cl=37, val_split=0.2,
tr_transform=train_transform,
te_transform=test_transform)
import pdb;
pdb.set_trace()

View File

@@ -0,0 +1,34 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
def additive_func(A, B):
assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size())
C = min(A.size(1), B.size(1))
if A.size(1) == B.size(1):
return A + B
elif A.size(1) < B.size(1):
out = B.clone()
out[:,:C] += A
return out
else:
out = A.clone()
out[:,:C] += B
return out
def change_key(key, value):
def func(m):
if hasattr(m, key):
setattr(m, key, value)
return func
def parse_channel_info(xstring):
blocks = xstring.split(' ')
blocks = [x.split('-') for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks]
return blocks

View File

@@ -0,0 +1,45 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from os import path as osp
from typing import List, Text
import torch
__all__ = ['get_cell_based_tiny_net', 'get_search_spaces', \
'CellStructure', 'CellArchitectures'
]
# useful modules
from config_utils import dict2config
from .SharedUtils import change_key
from .cell_searchs import CellStructure, CellArchitectures
# Cell-based NAS Models
def get_cell_based_tiny_net(config):
if config.name == 'infer.tiny':
from .cell_infers import TinyNetwork
if hasattr(config, 'genotype'):
genotype = config.genotype
elif hasattr(config, 'arch_str'):
genotype = CellStructure.str2structure(config.arch_str)
else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
else:
raise ValueError('invalid network name : {:}'.format(config.name))
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
def get_search_spaces(xtype, name) -> List[Text]:
if xtype == 'cell' or xtype == 'tss': # The topology search space.
from .cell_operations import SearchSpaceNames
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
return SearchSpaceNames[name]
elif xtype == 'sss': # The size search space.
if name == 'nas-bench-301':
return {'candidates': [8, 16, 24, 32, 40, 48, 56, 64],
'numbers': 5}
else:
raise ValueError('Invalid name : {:}'.format(name))
else:
raise ValueError('invalid search-space type is {:}'.format(xtype))

View File

@@ -0,0 +1,4 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from .tiny_network import TinyNetwork

View File

@@ -0,0 +1,122 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import OPS
# Cell for NAS-Bench-201
class InferCell(nn.Module):
def __init__(self, genotype, C_in, C_out, stride):
super(InferCell, self).__init__()
self.layers = nn.ModuleList()
self.node_IN = []
self.node_IX = []
self.genotype = deepcopy(genotype)
for i in range(1, len(genotype)):
node_info = genotype[i-1]
cur_index = []
cur_innod = []
for (op_name, op_in) in node_info:
if op_in == 0:
layer = OPS[op_name](C_in , C_out, stride, True, True)
else:
layer = OPS[op_name](C_out, C_out, 1, True, True)
# import pdb; pdb.set_trace()
cur_index.append( len(self.layers) )
cur_innod.append( op_in )
self.layers.append( layer )
self.node_IX.append( cur_index )
self.node_IN.append( cur_innod )
self.nodes = len(genotype)
self.in_dim = C_in
self.out_dim = C_out
def extra_repr(self):
string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
laystr = []
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)]
x = '{:}<-({:})'.format(i+1, ','.join(y))
laystr.append( x )
return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr())
def forward(self, inputs):
nodes = [inputs]
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
nodes.append( node_feature )
return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetInferCell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
super(NASNetInferCell, self).__init__()
self.reduction = reduction
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
if not reduction:
nodes, concats = genotype['normal'], genotype['normal_concat']
else:
nodes, concats = genotype['reduce'], genotype['reduce_concat']
self._multiplier = len(concats)
self._concats = concats
self._steps = len(nodes)
self._nodes = nodes
self.edges = nn.ModuleDict()
for i, node in enumerate(nodes):
for in_node in node:
name, j = in_node[0], in_node[1]
stride = 2 if reduction and j < 2 else 1
node_str = '{:}<-{:}'.format(i+2, j)
self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats)
# [TODO] to support drop_prob in this function..
def forward(self, s0, s1, unused_drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i, node in enumerate(self._nodes):
clist = []
for in_node in node:
name, j = in_node[0], in_node[1]
node_str = '{:}<-{:}'.format(i+2, j)
op = self.edges[ node_str ]
clist.append( op(states[j]) )
states.append( sum(clist) )
return torch.cat([states[x] for x in self._concats], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x

View File

@@ -0,0 +1,66 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
from ..cell_operations import ResNetBasicblock
from .cells import InferCell
# The macro structure for architectures in NAS-Bench-201
class TinyNetwork(nn.Module):
def __init__(self, C, N, genotype, num_classes):
super(TinyNetwork, self).__init__()
self._C = C
self._layerN = N
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
else:
cell = InferCell(genotype, C_prev, C_curr, 1)
self.cells.append( cell )
C_prev = cell.out_dim
self._Layer= len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
'''
out2 = self.lastact(feature)
out = self.global_pooling( out2 )
out = out.view(out.size(0), -1)
out2 = out2.view(out2.size(0), -1)
logits = self.classifier(out)
return out2, logits
'''
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,308 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import torch.nn as nn
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
OPS = {
'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride),
'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats),
'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats),
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats),
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
}
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
'nas-bench-201': NAS_BENCH_201,
'nas-bench-301': NAS_BENCH_201,
'darts' : DARTS_SPACE}
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
)
def forward(self, x):
return self.op(x)
class DualSepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(DualSepConv, self).__init__()
self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats)
self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats)
def forward(self, x):
x = self.op_a(x)
x = self.op_b(x)
return x
class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride, affine=True):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine)
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def extra_repr(self):
string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__)
return string
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
return residual + basicblock
class POOLING(nn.Module):
def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True):
super(POOLING, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1, affine, track_running_stats)
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
def forward(self, inputs):
if self.preprocess: x = self.preprocess(inputs)
else : x = inputs
return self.op(x)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True
def forward(self, x):
if self.C_in == self.C_out:
if self.stride == 1: return x.mul(0.)
else : return x[:,:,::self.stride,::self.stride].mul(0.)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine))
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
elif stride == 1:
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
def forward(self, x):
if self.stride == 2:
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
else:
out = self.conv(x)
out = self.bn(out)
return out
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
# Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019
class PartAwareOp(nn.Module):
def __init__(self, C_in, C_out, stride, part=4):
super().__init__()
self.part = 4
self.hidden = C_in // 3
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.local_conv_list = nn.ModuleList()
for i in range(self.part):
self.local_conv_list.append(
nn.Sequential(nn.ReLU(), nn.Conv2d(C_in, self.hidden, 1), nn.BatchNorm2d(self.hidden, affine=True))
)
self.W_K = nn.Linear(self.hidden, self.hidden)
self.W_Q = nn.Linear(self.hidden, self.hidden)
if stride == 2 : self.last = FactorizedReduce(C_in + self.hidden, C_out, 2)
elif stride == 1: self.last = FactorizedReduce(C_in + self.hidden, C_out, 1)
else: raise ValueError('Invalid Stride : {:}'.format(stride))
def forward(self, x):
batch, C, H, W = x.size()
assert H >= self.part, 'input size too small : {:} vs {:}'.format(x.shape, self.part)
IHs = [0]
for i in range(self.part): IHs.append( min(H, int((i+1)*(float(H)/self.part))) )
local_feat_list = []
for i in range(self.part):
feature = x[:, :, IHs[i]:IHs[i+1], :]
xfeax = self.avg_pool(feature)
xfea = self.local_conv_list[i]( xfeax )
local_feat_list.append( xfea )
part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part)
part_feature = part_feature.transpose(1,2).contiguous()
part_K = self.W_K(part_feature)
part_Q = self.W_Q(part_feature).transpose(1,2).contiguous()
weight_att = torch.bmm(part_K, part_Q)
attention = torch.softmax(weight_att, dim=2)
aggreateF = torch.bmm(attention, part_feature).transpose(1,2).contiguous()
features = []
for i in range(self.part):
feature = aggreateF[:, :, i:i+1].expand(batch, self.hidden, IHs[i+1]-IHs[i])
feature = feature.view(batch, self.hidden, IHs[i+1]-IHs[i], 1)
features.append( feature )
features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W)
final_fea = torch.cat((x,features), dim=1)
outputs = self.last( final_fea )
return outputs
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = x.new_zeros(x.size(0), 1, 1, 1)
mask = mask.bernoulli_(keep_prob)
x = torch.div(x, keep_prob)
x.mul_(mask)
return x
# Searching for A Robust Neural Architecture in Four GPU Hours
class GDAS_Reduction_Cell(nn.Module):
def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier, affine, track_running_stats):
super(GDAS_Reduction_Cell, self).__init__()
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine, track_running_stats)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, 1, affine, track_running_stats)
self.multiplier = multiplier
self.reduction = True
self.ops1 = nn.ModuleList(
[nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
nn.BatchNorm2d(C, affine=True),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C, affine=True)),
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
nn.BatchNorm2d(C, affine=True),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C, affine=True))])
self.ops2 = nn.ModuleList(
[nn.Sequential(
nn.MaxPool2d(3, stride=1, padding=1),
nn.BatchNorm2d(C, affine=True)),
nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
nn.BatchNorm2d(C, affine=True))])
def forward(self, s0, s1, drop_prob = -1):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
X0 = self.ops1[0] (s0)
X1 = self.ops1[1] (s1)
if self.training and drop_prob > 0.:
X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)
#X2 = self.ops2[0] (X0+X1)
X2 = self.ops2[0] (s0)
X3 = self.ops2[1] (s1)
if self.training and drop_prob > 0.:
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
return torch.cat([X0, X1, X2, X3], dim=1)

View File

@@ -0,0 +1,26 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# The macro structure is defined in NAS-Bench-201
# from .search_model_darts import TinyNetworkDarts
# from .search_model_gdas import TinyNetworkGDAS
# from .search_model_setn import TinyNetworkSETN
# from .search_model_enas import TinyNetworkENAS
# from .search_model_random import TinyNetworkRANDOM
# from .generic_model import GenericNAS201Model
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
# NASNet-based macro structure
# from .search_model_gdas_nasnet import NASNetworkGDAS
# from .search_model_darts_nasnet import NASNetworkDARTS
# nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
# "DARTS-V2": TinyNetworkDarts,
# "GDAS": TinyNetworkGDAS,
# "SETN": TinyNetworkSETN,
# "ENAS": TinyNetworkENAS,
# "RANDOM": TinyNetworkRANDOM,
# "generic": GenericNAS201Model}
# nasnet_super_nets = {"GDAS": NASNetworkGDAS,
# "DARTS": NASNetworkDARTS}

View File

@@ -0,0 +1,198 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from copy import deepcopy
def get_combination(space, num):
combs = []
for i in range(num):
if i == 0:
for func in space:
combs.append( [(func, i)] )
else:
new_combs = []
for string in combs:
for func in space:
xstring = string + [(func, i)]
new_combs.append( xstring )
combs = new_combs
return combs
class Structure:
def __init__(self, genotype):
assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype))
self.node_num = len(genotype) + 1
self.nodes = []
self.node_N = []
for idx, node_info in enumerate(genotype):
assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info))
assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info))
for node_in in node_info:
assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in))
assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in)
self.node_N.append( len(node_info) )
self.nodes.append( tuple(deepcopy(node_info)) )
def tolist(self, remove_str):
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
# note that we re-order the input node in this function
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
genotypes = []
for node_info in self.nodes:
node_info = list( node_info )
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
if len(node_info) == 0: return None, False
genotypes.append( node_info )
return genotypes, True
def node(self, index):
assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self))
return self.nodes[index]
def tostr(self):
strings = []
for node_info in self.nodes:
string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info])
string = '|{:}|'.format(string)
strings.append( string )
return '+'.join(strings)
def check_valid(self):
nodes = {0: True}
for i, node_info in enumerate(self.nodes):
sums = []
for op, xin in node_info:
if op == 'none' or nodes[xin] is False: x = False
else: x = True
sums.append( x )
nodes[i+1] = sum(sums) > 0
return nodes[len(self.nodes)]
def to_unique_str(self, consider_zero=False):
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
# two operations are special, i.e., none and skip_connect
nodes = {0: '0'}
for i_node, node_info in enumerate(self.nodes):
cur_node = []
for op, xin in node_info:
if consider_zero is None:
x = '('+nodes[xin]+')' + '@{:}'.format(op)
elif consider_zero:
if op == 'none' or nodes[xin] == '#': x = '#' # zero
elif op == 'skip_connect': x = nodes[xin]
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
else:
if op == 'skip_connect': x = nodes[xin]
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
cur_node.append(x)
nodes[i_node+1] = '+'.join( sorted(cur_node) )
return nodes[ len(self.nodes) ]
def check_valid_op(self, op_names):
for node_info in self.nodes:
for inode_edge in node_info:
#assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
if inode_edge[0] not in op_names: return False
return True
def __repr__(self):
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
def __len__(self):
return len(self.nodes) + 1
def __getitem__(self, index):
return self.nodes[index]
@staticmethod
def str2structure(xstr):
if isinstance(xstr, Structure): return xstr
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
genotypes.append( input_infos )
return Structure( genotypes )
@staticmethod
def str2fullstructure(xstr, default_name='none'):
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
all_in_nodes= list(x[1] for x in input_infos)
for j in range(i):
if j not in all_in_nodes: input_infos.append((default_name, j))
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
genotypes.append( tuple(node_info) )
return Structure( genotypes )
@staticmethod
def gen_all(search_space, num, return_ori):
assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
all_archs = get_combination(search_space, 1)
for i, arch in enumerate(all_archs):
all_archs[i] = [ tuple(arch) ]
for inode in range(2, num):
cur_nodes = get_combination(search_space, inode)
new_all_archs = []
for previous_arch in all_archs:
for cur_node in cur_nodes:
new_all_archs.append( previous_arch + [tuple(cur_node)] )
all_archs = new_all_archs
if return_ori:
return all_archs
else:
return [Structure(x) for x in all_archs]
ResNet_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1
(('nor_conv_3x3', 1), ), # node-2
(('skip_connect', 0), ('skip_connect', 2))] # node-3
)
AllConv3x3_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3
)
AllFull_CODE = Structure(
[(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3
)
AllConv1x1_CODE = Structure(
[(('nor_conv_1x1', 0), ), # node-1
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3
)
AllIdentity_CODE = Structure(
[(('skip_connect', 0), ), # node-1
(('skip_connect', 0), ('skip_connect', 1)), # node-2
(('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3
)
architectures = {'resnet' : ResNet_CODE,
'all_c3x3': AllConv3x3_CODE,
'all_c1x1': AllConv1x1_CODE,
'all_idnt': AllIdentity_CODE,
'all_full': AllFull_CODE}

View File

@@ -0,0 +1,167 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferCifarResNet(nn.Module):
def __init__(self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual):
super(InferCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.xchannels = xchannels
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL+1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,150 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
else:
self.downsample = None
self.out_dim = planes*self.expansion
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferDepthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual):
super(InferDepthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.channels = [16]
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, planes, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.channels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,160 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferWidthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual):
super(InferWidthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.xchannels = xchannels
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,170 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=True, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferImagenetResNet(nn.Module):
def __init__(self, block_name, layers, xblocks, xchannels, deep_stem, num_classes, zero_init_residual):
super(InferImagenetResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'BasicBlock':
block = ResNetBasicblock
elif block_name == 'Bottleneck':
block = ResNetBottleneck
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == len(layers), 'invalid layers : {:} vs xblocks : {:}'.format(layers, xblocks)
self.message = 'InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}'.format(sum(layers)*block.num_conv, sum(xblocks)*block.num_conv, xblocks)
self.num_classes = num_classes
self.xchannels = xchannels
if not deep_stem:
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 7, 2, 3, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
else:
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True)
,ConvBNReLU(xchannels[1], xchannels[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 2
self.layers.append( nn.MaxPool2d(kernel_size=3, stride=2, padding=1) )
for stage, layer_blocks in enumerate(layers):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL+1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
assert last_channel_idx + 1 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,122 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
from torch import nn
from ..initialization import initialize_resnet
from ..SharedUtils import parse_channel_info
class ConvBNReLU(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, groups, has_bn=True, has_relu=True):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False)
if has_bn: self.bn = nn.BatchNorm2d(out_planes)
else : self.bn = None
if has_relu: self.relu = nn.ReLU6(inplace=True)
else : self.relu = None
def forward(self, x):
out = self.conv( x )
if self.bn: out = self.bn ( out )
if self.relu: out = self.relu( out )
return out
class InvertedResidual(nn.Module):
def __init__(self, channels, stride, expand_ratio, additive):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2], 'invalid stride : {:}'.format(stride)
assert len(channels) in [2, 3], 'invalid channels : {:}'.format(channels)
if len(channels) == 2:
layers = []
else:
layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)]
layers.extend([
# dw
ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]),
# pw-linear
ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False),
])
self.conv = nn.Sequential(*layers)
self.additive = additive
if self.additive and channels[0] != channels[-1]:
self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False)
else:
self.shortcut = None
self.out_dim = channels[-1]
def forward(self, x):
out = self.conv(x)
# if self.additive: return additive_func(out, x)
if self.shortcut: return out + self.shortcut(x)
else : return out
class InferMobileNetV2(nn.Module):
def __init__(self, num_classes, xchannels, xblocks, dropout):
super(InferMobileNetV2, self).__init__()
block = InvertedResidual
inverted_residual_setting = [
# t, c, n, s
[1, 16 , 1, 1],
[6, 24 , 2, 2],
[6, 32 , 3, 2],
[6, 64 , 4, 2],
[6, 96 , 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
assert len(inverted_residual_setting) == len(xblocks), 'invalid number of layers : {:} vs {:}'.format(len(inverted_residual_setting), len(xblocks))
for block_num, ir_setting in zip(xblocks, inverted_residual_setting):
assert block_num <= ir_setting[2], '{:} vs {:}'.format(block_num, ir_setting)
xchannels = parse_channel_info(xchannels)
#for i, chs in enumerate(xchannels):
# if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs)
self.xchannels = xchannels
self.message = 'InferMobileNetV2 : xblocks={:}'.format(xblocks)
# building first layer
features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)]
last_channel_idx = 1
# building inverted residual blocks
for stage, (t, c, n, s) in enumerate(inverted_residual_setting):
for i in range(n):
stride = s if i == 0 else 1
additv = True if i > 0 else False
module = block(self.xchannels[last_channel_idx], stride, t, additv)
features.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(stage, i, n, len(features), self.xchannels[last_channel_idx], stride, t, c)
last_channel_idx += 1
if i + 1 == xblocks[stage]:
out_channel = module.out_dim
for iiL in range(i+1, n):
last_channel_idx += 1
self.xchannels[last_channel_idx][0] = module.out_dim
break
# building last several layers
features.append(ConvBNReLU(self.xchannels[last_channel_idx][0], self.xchannels[last_channel_idx][1], 1, 1, 1))
assert last_channel_idx + 2 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(self.xchannels[last_channel_idx][1], num_classes),
)
# weight initialization
self.apply( initialize_resnet )
def get_message(self):
return self.message
def forward(self, inputs):
features = self.features(inputs)
vectors = features.mean([2, 3])
predicts = self.classifier(vectors)
return features, predicts

View File

@@ -0,0 +1,58 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from typing import List, Text, Any
import torch.nn as nn
from models.cell_operations import ResNetBasicblock
from models.cell_infers.cells import InferCell
class DynamicShapeTinyNet(nn.Module):
def __init__(self, channels: List[int], genotype: Any, num_classes: int):
super(DynamicShapeTinyNet, self).__init__()
self._channels = channels
if len(channels) % 3 != 2:
raise ValueError('invalid number of layers : {:}'.format(len(channels)))
self._num_stage = N = len(channels) // 3
self.stem = nn.Sequential(
nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(channels[0]))
# layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
c_prev = channels[0]
self.cells = nn.ModuleList()
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)):
if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True)
else : cell = InferCell(genotype, c_prev, c_curr, 1)
self.cells.append( cell )
c_prev = cell.out_dim
self._num_layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(c_prev, num_classes)
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,9 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from .InferCifarResNet_width import InferWidthCifarResNet
from .InferImagenetResNet import InferImagenetResNet
from .InferCifarResNet_depth import InferDepthCifarResNet
from .InferCifarResNet import InferCifarResNet
from .InferMobileNetV2 import InferMobileNetV2
from .InferTinyCellNet import DynamicShapeTinyNet

View File

@@ -0,0 +1,5 @@
def parse_channel_info(xstring):
blocks = xstring.split(' ')
blocks = [x.split('-') for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks]
return blocks

View File

@@ -0,0 +1,2 @@
from .evaluation_utils import obtain_accuracy
from .flop_benchmark import get_model_infos

View File

@@ -0,0 +1,17 @@
import torch
def obtain_accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
# correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

View File

@@ -0,0 +1,181 @@
import torch
import torch.nn as nn
import numpy as np
def count_parameters_in_MB(model):
if isinstance(model, nn.Module):
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
else:
return np.sum(np.prod(v.size()) for v in model)/1e6
def get_model_infos(model, shape):
#model = copy.deepcopy( model )
model = add_flops_counting_methods(model)
#model = model.cuda()
model.eval()
#cache_inputs = torch.zeros(*shape).cuda()
#cache_inputs = torch.zeros(*shape)
cache_inputs = torch.rand(*shape)
if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda()
#print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
with torch.no_grad():
_____ = model(cache_inputs)
FLOPs = compute_average_flops_cost( model ) / 1e6
Param = count_parameters_in_MB(model)
if hasattr(model, 'auxiliary_param'):
aux_params = count_parameters_in_MB(model.auxiliary_param())
print ('The auxiliary params of this model is : {:}'.format(aux_params))
print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param))
Param = Param - aux_params
#print_log('FLOPs : {:} MB'.format(FLOPs), log)
torch.cuda.empty_cache()
model.apply( remove_hook_function )
return FLOPs, Param
# ---- Public functions
def add_flops_counting_methods( model ):
model.__batch_counter__ = 0
add_batch_counter_hook_function( model )
model.apply( add_flops_counter_variable_or_reset )
model.apply( add_flops_counter_hook_function )
return model
def compute_average_flops_cost(model):
"""
A method that will be available after add_flops_counting_methods() is called on a desired net object.
Returns current mean flops consumption per image.
"""
batches_count = model.__batch_counter__
flops_sum = 0
#or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
for module in model.modules():
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
or isinstance(module, torch.nn.Conv1d) \
or hasattr(module, 'calculate_flop_self'):
flops_sum += module.__flops__
return flops_sum / batches_count
# ---- Internal functions
def pool_flops_counter_hook(pool_module, inputs, output):
batch_size = inputs[0].size(0)
kernel_size = pool_module.kernel_size
out_C, output_height, output_width = output.shape[1:]
assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size())
overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
pool_module.__flops__ += overall_flops
def self_calculate_flops_counter_hook(self_module, inputs, output):
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
self_module.__flops__ += overall_flops
def fc_flops_counter_hook(fc_module, inputs, output):
batch_size = inputs[0].size(0)
xin, xout = fc_module.in_features, fc_module.out_features
assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout)
overall_flops = batch_size * xin * xout
if fc_module.bias is not None:
overall_flops += batch_size * xout
fc_module.__flops__ += overall_flops
def conv1d_flops_counter_hook(conv_module, inputs, outputs):
batch_size = inputs[0].size(0)
outL = outputs.shape[-1]
[kernel] = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
conv_per_position_flops = kernel * in_channels * out_channels / groups
active_elements_count = batch_size * outL
overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops
def conv2d_flops_counter_hook(conv_module, inputs, output):
batch_size = inputs[0].size(0)
output_height, output_width = output.shape[2:]
kernel_height, kernel_width = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups
active_elements_count = batch_size * output_height * output_width
overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops
def batch_counter_hook(module, inputs, output):
# Can have multiple inputs, getting the first one
inputs = inputs[0]
batch_size = inputs.shape[0]
module.__batch_counter__ += batch_size
def add_batch_counter_hook_function(module):
if not hasattr(module, '__batch_counter_handle__'):
handle = module.register_forward_hook(batch_counter_hook)
module.__batch_counter_handle__ = handle
def add_flops_counter_variable_or_reset(module):
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
or isinstance(module, torch.nn.Conv1d) \
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
or hasattr(module, 'calculate_flop_self'):
module.__flops__ = 0
def add_flops_counter_hook_function(module):
if isinstance(module, torch.nn.Conv2d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(conv2d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Conv1d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(conv1d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Linear):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(fc_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(pool_flops_counter_hook)
module.__flops_handle__ = handle
elif hasattr(module, 'calculate_flop_self'): # self-defined module
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(self_calculate_flops_counter_hook)
module.__flops_handle__ = handle
def remove_hook_function(module):
hookers = ['__batch_counter_handle__', '__flops_handle__']
for hooker in hookers:
if hasattr(module, hooker):
handle = getattr(module, hooker)
handle.remove()
keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers
for ckey in keys:
if hasattr(module, ckey): delattr(module, ckey)

View File

@@ -0,0 +1,28 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .starts import get_machine_info, save_checkpoint, copy_checkpoint
from .optimizers import get_optim_scheduler
from .starts import prepare_seed #, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint
'''
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
from .funcs_nasbench import get_nas_bench_loaders
def get_procedures(procedure):
from .basic_main import basic_train, basic_valid
from .search_main import search_train, search_valid
from .search_main_v2 import search_train_v2
from .simple_KD_main import simple_KD_train, simple_KD_valid
train_funcs = {'basic' : basic_train, \
'search': search_train,'Simple-KD': simple_KD_train, \
'search-v2': search_train_v2}
valid_funcs = {'basic' : basic_valid, \
'search': search_valid,'Simple-KD': simple_KD_valid, \
'search-v2': search_valid}
train_func = train_funcs[procedure]
valid_func = valid_funcs[procedure]
return train_func, valid_func
'''

View File

@@ -0,0 +1,204 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import math, torch
import torch.nn as nn
from bisect import bisect_right
from torch.optim import Optimizer
class _LRScheduler(object):
def __init__(self, optimizer, warmup_epochs, epochs):
if not isinstance(optimizer, Optimizer):
raise TypeError('{:} is not an Optimizer'.format(type(optimizer).__name__))
self.optimizer = optimizer
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.max_epochs = epochs
self.warmup_epochs = warmup_epochs
self.current_epoch = 0
self.current_iter = 0
def extra_repr(self):
return ''
def __repr__(self):
return ('{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}'.format(name=self.__class__.__name__, **self.__dict__)
+ ', {:})'.format(self.extra_repr()))
def state_dict(self):
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_lr(self):
raise NotImplementedError
def get_min_info(self):
lrs = self.get_lr()
return '#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#'.format(min(lrs), max(lrs), self.current_epoch, self.current_iter)
def get_min_lr(self):
return min( self.get_lr() )
def update(self, cur_epoch, cur_iter):
if cur_epoch is not None:
assert isinstance(cur_epoch, int) and cur_epoch>=0, 'invalid cur-epoch : {:}'.format(cur_epoch)
self.current_epoch = cur_epoch
if cur_iter is not None:
assert isinstance(cur_iter, float) and cur_iter>=0, 'invalid cur-iter : {:}'.format(cur_iter)
self.current_iter = cur_iter
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class CosineAnnealingLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
self.T_max = T_max
self.eta_min = eta_min
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, T-max={:}, eta-min={:}'.format('cosine', self.T_max, self.eta_min)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
#if last_epoch < self.T_max:
#if last_epoch < self.max_epochs:
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
#else:
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
elif self.current_epoch >= self.max_epochs:
lr = self.eta_min
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class MultiStepLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
assert len(milestones) == len(gammas), 'invalid {:} vs {:}'.format(len(milestones), len(gammas))
self.milestones = milestones
self.gammas = gammas
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, milestones={:}, gammas={:}, base-lrs={:}'.format('multistep', self.milestones, self.gammas, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
idx = bisect_right(self.milestones, last_epoch)
lr = base_lr
for x in self.gammas[:idx]: lr *= x
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class ExponentialLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
self.gamma = gamma
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, gamma={:}, base-lrs={:}'.format('exponential', self.gamma, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
lr = base_lr * (self.gamma ** last_epoch)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class LinearLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
self.max_LR = max_LR
self.min_LR = min_LR
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, max_LR={:}, min_LR={:}, base-lrs={:}'.format('LinearLR', self.max_LR, self.min_LR, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
lr = base_lr * (1-ratio)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def get_optim_scheduler(parameters, config):
assert hasattr(config, 'optim') and hasattr(config, 'scheduler') and hasattr(config, 'criterion'), 'config must have optim / scheduler / criterion keys instead of {:}'.format(config)
if config.optim == 'SGD':
optim = torch.optim.SGD(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov)
elif config.optim == 'RMSprop':
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
else:
raise ValueError('invalid optim : {:}'.format(config.optim))
if config.scheduler == 'cos':
T_max = getattr(config, 'T_max', config.epochs)
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
elif config.scheduler == 'multistep':
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
elif config.scheduler == 'exponential':
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
elif config.scheduler == 'linear':
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
else:
raise ValueError('invalid scheduler : {:}'.format(config.scheduler))
if config.criterion == 'Softmax':
criterion = torch.nn.CrossEntropyLoss()
elif config.criterion == 'SmoothSoftmax':
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
else:
raise ValueError('invalid criterion : {:}'.format(config.criterion))
return optim, scheduler, criterion

View File

@@ -0,0 +1,64 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, torch, random, PIL, copy, numpy as np
from os import path as osp
from shutil import copyfile
def prepare_seed(rand_seed):
random.seed(rand_seed)
np.random.seed(rand_seed)
torch.manual_seed(rand_seed)
torch.cuda.manual_seed(rand_seed)
torch.cuda.manual_seed_all(rand_seed)
def prepare_logger(xargs):
args = copy.deepcopy( xargs )
from log_utils import Logger
logger = Logger(args.save_dir, args.rand_seed)
logger.log('Main Function with logger : {:}'.format(logger))
logger.log('Arguments : -------------------------------')
for name, value in args._get_kwargs():
logger.log('{:16} : {:}'.format(name, value))
logger.log("Python Version : {:}".format(sys.version.replace('\n', ' ')))
logger.log("Pillow Version : {:}".format(PIL.__version__))
logger.log("PyTorch Version : {:}".format(torch.__version__))
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None'))
return logger
def get_machine_info():
info = "Python Version : {:}".format(sys.version.replace('\n', ' '))
info+= "\nPillow Version : {:}".format(PIL.__version__)
info+= "\nPyTorch Version : {:}".format(torch.__version__)
info+= "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
info+= "\nCUDA available : {:}".format(torch.cuda.is_available())
info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
if 'CUDA_VISIBLE_DEVICES' in os.environ:
info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES'])
else:
info+= "\nDoes not set CUDA_VISIBLE_DEVICES"
return info
def save_checkpoint(state, filename, logger):
if osp.isfile(filename):
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename))
os.remove(filename)
torch.save(state, filename)
assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename)
if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename))
return filename
def copy_checkpoint(src, dst, logger):
if osp.isfile(dst):
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst))
os.remove(dst)
copyfile(src, dst)
if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst))

View File

@@ -0,0 +1,83 @@
from torch.multiprocessing import Process
import os
from absl import app, flags
import sys
import torch
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
from nas_bench_201 import train_single_model
from all_path import NASBENCH201
FLAGS = flags.FLAGS
flags.DEFINE_integer("num_split", 15, "The number of splits")
flags.DEFINE_list("arch_idx_lst", None, "arch index list")
flags.DEFINE_list("arch_str_lst", None, "arch str list")
flags.DEFINE_string("meta_test_path", None, "meta test path")
flags.DEFINE_string("data_name", None, "data_name")
flags.DEFINE_string("raw_data_path", None, "raw_data_path")
def run_single_process(rank, seed, arch_idx, meta_test_path, data_name,
raw_data_path, num_split=15, backend="nccl"):
# 8 GPUs
device = ['0', '1', '2', '3', '4', '5', '6', '7', '0', '1', '2', '3', '4', '5', '6', '7',
'0', '1', '2', '3', '4', '5', '6', '7', '0', '1', '2', '3', '4', '5', '6', '7'][rank]
os.environ["CUDA_VISIBLE_DEVICES"] = device
save_path = os.path.join(meta_test_path, str(arch_idx))
if type(seed) == int:
seeds = [seed]
elif type(seed) in [list, tuple]:
seeds = seed
nasbench201 = torch.load(NASBENCH201)
arch_str = nasbench201['arch']['str'][arch_idx]
os.makedirs(save_path, exist_ok=True)
train_single_model(save_dir=save_path,
workers=24,
datasets=[data_name],
xpaths=[f'{raw_data_path}/{data_name}'],
splits=[0],
use_less=False,
seeds=seeds,
model_str=arch_str,
arch_config={'channel': 16, 'num_cells': 5})
def run_multi_process(argv):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "1234"
os.environ["WANDB_SILENT"] = "true"
processes = []
arch_idx_lst = [int(i) for i in FLAGS.arch_idx_lst]
seeds = [777, 888, 999] * len(arch_idx_lst)
arch_idx_lst_ = []
for i in arch_idx_lst:
arch_idx_lst_ += [i] * 3
for arch_idx in arch_idx_lst:
os.makedirs(os.path.join(FLAGS.meta_test_path, str(arch_idx)), exist_ok=True)
for rank in range(FLAGS.num_split):
arch_idx = arch_idx_lst_[rank]
seed = seeds[rank]
p = Process(target=run_single_process, args=(rank,
seed,
arch_idx,
FLAGS.meta_test_path,
FLAGS.data_name,
FLAGS.raw_data_path))
p.start()
processes.append(p)
for p in processes:
p.join()
while any(p.is_alive() for p in processes):
continue
print("All processes have completed.")
if __name__ == "__main__":
app.run(run_multi_process)

View File

@@ -0,0 +1,38 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from set_encoder.setenc_modules import *
class SetPool(nn.Module):
def __init__(self, dim_input, num_outputs, dim_output,
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
super(SetPool, self).__init__()
if 'sab' in mode: # [32, 400, 128]
self.enc = nn.Sequential(
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
else: # [32, 400, 128]
self.enc = nn.Sequential(
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
if 'PF' in mode: # [32, 1, 501]
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
nn.Linear(dim_hidden, dim_output))
elif 'P' in mode:
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
else: # torch.Size([32, 1, 501])
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
nn.Linear(dim_hidden, dim_output))
# "", sm, sab, sabsm
def forward(self, X):
x1 = self.enc(X)
x2 = self.dec(x1)
return x2

View File

@@ -0,0 +1,67 @@
#####################################################################################
# Copyright (c) Juho Lee SetTransformer, ICML 2019 [GitHub set_transformer]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MAB(nn.Module):
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
super(MAB, self).__init__()
self.dim_V = dim_V
self.num_heads = num_heads
self.fc_q = nn.Linear(dim_Q, dim_V)
self.fc_k = nn.Linear(dim_K, dim_V)
self.fc_v = nn.Linear(dim_K, dim_V)
if ln:
self.ln0 = nn.LayerNorm(dim_V)
self.ln1 = nn.LayerNorm(dim_V)
self.fc_o = nn.Linear(dim_V, dim_V)
def forward(self, Q, K):
Q = self.fc_q(Q)
K, V = self.fc_k(K), self.fc_v(K)
dim_split = self.dim_V // self.num_heads
Q_ = torch.cat(Q.split(dim_split, 2), 0)
K_ = torch.cat(K.split(dim_split, 2), 0)
V_ = torch.cat(V.split(dim_split, 2), 0)
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O))
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
return O
class SAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, ln=False):
super(SAB, self).__init__()
self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)
def forward(self, X):
return self.mab(X, X)
class ISAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
super(ISAB, self).__init__()
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
nn.init.xavier_uniform_(self.I)
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
def forward(self, X):
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
return self.mab1(X, H)
class PMA(nn.Module):
def __init__(self, dim, num_heads, num_seeds, ln=False):
super(PMA, self).__init__()
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
nn.init.xavier_uniform_(self.S)
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
def forward(self, X):
return self.mab(self.S.repeat(X.size(0), 1, 1), X)

View File

@@ -0,0 +1,243 @@
######################################################################################
# Copyright (c) muhanzhang, D-VAE, NeurIPS 2019 [GitHub D-VAE]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import torch
from torch import nn
from set_encoder.setenc_models import SetPool
class MetaSurrogateUnnoisedModel(nn.Module):
def __init__(self, args, graph_config):
super(MetaSurrogateUnnoisedModel, self).__init__()
self.max_n = graph_config['max_n'] # maximum number of vertices
self.nvt = args.nvt # number of vertex types
self.START_TYPE = graph_config['START_TYPE']
self.END_TYPE = graph_config['END_TYPE']
self.hs = args.hs # hidden state size of each vertex
self.nz = args.nz # size of latent representation z
self.gs = args.hs # size of graph state
self.bidir = True # whether to use bidirectional encoding
self.vid = True
self.device = None
self.input_type = 'DG'
self.num_sample = args.num_sample
if self.vid:
self.vs = self.hs + self.max_n # vertex state size = hidden state + vid
else:
self.vs = self.hs
# 0. encoding-related
self.grue_forward = nn.GRUCell(self.nvt, self.hs) # encoder GRU
self.grue_backward = nn.GRUCell(
self.nvt, self.hs) # backward encoder GRU
self.fc1 = nn.Linear(self.gs, self.nz) # latent mean
self.fc2 = nn.Linear(self.gs, self.nz) # latent logvar
# 2. gate-related
self.gate_forward = nn.Sequential(
nn.Linear(self.vs, self.hs),
nn.Sigmoid()
)
self.gate_backward = nn.Sequential(
nn.Linear(self.vs, self.hs),
nn.Sigmoid()
)
self.mapper_forward = nn.Sequential(
nn.Linear(self.vs, self.hs, bias=False),
) # disable bias to ensure padded zeros also mapped to zeros
self.mapper_backward = nn.Sequential(
nn.Linear(self.vs, self.hs, bias=False),
)
# 3. bidir-related, to unify sizes
if self.bidir:
self.hv_unify = nn.Sequential(
nn.Linear(self.hs * 2, self.hs),
)
self.hg_unify = nn.Sequential(
nn.Linear(self.gs * 2, self.gs),
)
# 4. other
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.logsoftmax1 = nn.LogSoftmax(1)
# 6. predictor
np = self.gs
self.intra_setpool = SetPool(dim_input=512,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.inter_setpool = SetPool(dim_input=self.nz,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.set_fc = nn.Sequential(
nn.Linear(512, self.nz),
nn.ReLU())
input_dim = 0
if 'D' in self.input_type:
input_dim += self.nz
if 'G' in self.input_type:
input_dim += self.nz
self.pred_fc = nn.Sequential(
nn.Linear(input_dim, self.hs),
nn.Tanh(),
nn.Linear(self.hs, 1)
)
self.mseloss = nn.MSELoss(reduction='sum')
def predict(self, D_mu, G_mu):
input_vec = []
if 'D' in self.input_type:
input_vec.append(D_mu)
if 'G' in self.input_type:
input_vec.append(G_mu)
input_vec = torch.cat(input_vec, dim=1)
return self.pred_fc(input_vec)
def get_device(self):
if self.device is None:
self.device = next(self.parameters()).device
return self.device
def _get_zeros(self, n, length):
# get a zero hidden state
return torch.zeros(n, length).to(self.get_device())
def _get_zero_hidden(self, n=1):
return self._get_zeros(n, self.hs) # get a zero hidden state
def _one_hot(self, idx, length):
if type(idx) in [list, range]:
if idx == []:
return None
idx = torch.LongTensor(idx).unsqueeze(0).t()
x = torch.zeros((len(idx), length)).scatter_(
1, idx, 1).to(self.get_device())
else:
idx = torch.LongTensor([idx]).unsqueeze(0)
x = torch.zeros((1, length)).scatter_(
1, idx, 1).to(self.get_device())
return x
def _gated(self, h, gate, mapper):
return gate(h) * mapper(h)
def _collate_fn(self, G):
return [g.copy() for g in G]
def _propagate_to(self, G, v, propagator, H=None, reverse=False, gate=None, mapper=None):
# propagate messages to vertex index v for all graphs in G
# return the new messages (states) at v
G = [g for g in G if g.vcount() > v]
if len(G) == 0:
return
if H is not None:
idx = [i for i, g in enumerate(G) if g.vcount() > v]
H = H[idx]
v_types = [g.vs[v]['type'] for g in G]
X = self._one_hot(v_types, self.nvt)
if reverse:
H_name = 'H_backward' # name of the hidden states attribute
H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G]
if self.vid:
vids = [self._one_hot(g.successors(v), self.max_n) for g in G]
gate, mapper = self.gate_backward, self.mapper_backward
else:
H_name = 'H_forward' # name of the hidden states attribute
H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
if self.vid:
vids = [self._one_hot(g.predecessors(v), self.max_n)
for g in G]
if gate is None:
gate, mapper = self.gate_forward, self.mapper_forward
if self.vid:
H_pred = [[torch.cat([x[i], y[i:i + 1]], 1)
for i in range(len(x))] for x, y in zip(H_pred, vids)]
# if h is not provided, use gated sum of v's predecessors' states as the input hidden state
if H is None:
# maximum number of predecessors
max_n_pred = max([len(x) for x in H_pred])
if max_n_pred == 0:
H = self._get_zero_hidden(len(G))
else:
H_pred = [torch.cat(h_pred +
[self._get_zeros(max_n_pred - len(h_pred), self.vs)], 0).unsqueeze(0)
for h_pred in H_pred] # pad all to same length
H_pred = torch.cat(H_pred, 0) # batch * max_n_pred * vs
H = self._gated(H_pred, gate, mapper).sum(1) # batch * hs
Hv = propagator(X, H)
for i, g in enumerate(G):
g.vs[v][H_name] = Hv[i:i + 1]
return Hv
def _propagate_from(self, G, v, propagator, H0=None, reverse=False):
# perform a series of propagation_to steps starting from v following a topo order
# assume the original vertex indices are in a topological order
if reverse:
prop_order = range(v, -1, -1)
else:
prop_order = range(v, self.max_n)
Hv = self._propagate_to(G, v, propagator, H0,
reverse=reverse) # the initial vertex
for v_ in prop_order[1:]:
self._propagate_to(G, v_, propagator, reverse=reverse)
return Hv
def _get_graph_state(self, G, decode=False):
# get the graph states
# when decoding, use the last generated vertex's state as the graph state
# when encoding, use the ending vertex state or unify the starting and ending vertex states
Hg = []
for g in G:
hg = g.vs[g.vcount() - 1]['H_forward']
if self.bidir and not decode: # decoding never uses backward propagation
hg_b = g.vs[0]['H_backward']
hg = torch.cat([hg, hg_b], 1)
Hg.append(hg)
Hg = torch.cat(Hg, 0)
if self.bidir and not decode:
Hg = self.hg_unify(Hg)
return Hg
def set_encode(self, X):
proto_batch = []
for x in X:
cls_protos = self.intra_setpool(
x.view(-1, self.num_sample, 512)).squeeze(1)
proto_batch.append(
self.inter_setpool(cls_protos.unsqueeze(0)))
v = torch.stack(proto_batch).squeeze()
return v
def graph_encode(self, G):
# encode graphs G into latent vectors
if type(G) != list:
G = [G]
self._propagate_from(G, 0, self.grue_forward, H0=self._get_zero_hidden(len(G)),
reverse=False)
if self.bidir:
self._propagate_from(G, self.max_n - 1, self.grue_backward,
H0=self._get_zero_hidden(len(G)), reverse=True)
Hg = self._get_graph_state(G)
mu = self.fc1(Hg)
# logvar = self.fc2(Hg)
return mu # , logvar
def reparameterize(self, mu, logvar, eps_scale=0.01):
# return z ~ N(mu, std)
if self.training:
std = logvar.mul(0.5).exp_()
eps = torch.randn_like(std) * eps_scale
return eps.mul(std).add_(mu)
else:
return mu

View File

@@ -0,0 +1,33 @@
import os
import logging
import torch
import numpy as np
import random
def restore_checkpoint(ckpt_dir, state, device, resume=False):
if not resume:
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
return state
elif not os.path.exists(ckpt_dir):
if not os.path.exists(os.path.dirname(ckpt_dir)):
os.makedirs(os.path.dirname(ckpt_dir))
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
return state
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
for k in state:
if k in ['optimizer', 'model', 'ema']:
state[k].load_state_dict(loaded_state[k])
else:
state[k] = loaded_state[k]
return state
def reset_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True