first commit

This commit is contained in:
CownowAn
2024-03-15 14:38:51 +00:00
commit bc2ed1304f
321 changed files with 44802 additions and 0 deletions

View File

@@ -0,0 +1,8 @@
SCORENET_CKPT_PATH="./checkpoints/scorenet/checkpoint.pth.tar"
META_SURROGATE_CKPT_PATH="./checkpoints/meta_surrogate/checkpoint.pth.tar"
META_SURROGATE_UNNOISED_CKPT_PATH = "./checkpoints/meta_surrogate/unnoised_checkpoint.pth.tar"
NASBENCH201="./data/transfer_nag/nasbench201.pt"
NASBENCH201_INFO="./data/transfer_nag/nasbench201_info.pt"
META_TEST_PATH="./data/transfer_nag/test"
RAW_DATA_PATH="./data/raw_data"
DATA_PATH = "./data/transfer_nag"

View File

@@ -0,0 +1,347 @@
import numpy as np
import torch
from all_path import *
class BasicArchMetrics(object):
def __init__(self, train_ds=None, train_arch_str_list=None):
if train_ds is None:
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
else:
self.ops_decoder = train_ds.ops_decoder
self.nasbench201 = torch.load(NASBENCH201_INFO)
self.train_arch_str_list = train_arch_str_list
def compute_validity(self, generated):
START_TYPE = self.ops_decoder.index('input')
END_TYPE = self.ops_decoder.index('output')
valid = []
valid_arch_str = []
all_arch_str = []
for x in generated:
is_valid, error_types = is_valid_NAS201_x(x, START_TYPE, END_TYPE)
if is_valid:
valid.append(x)
arch_str = decode_x_to_NAS_BENCH_201_string(x, self.ops_decoder)
valid_arch_str.append(arch_str)
else:
arch_str = None
all_arch_str.append(arch_str)
validity = 0 if len(generated) == 0 else (len(valid)/len(generated))
return valid, validity, valid_arch_str, all_arch_str
def compute_uniqueness(self, valid_arch_str):
return list(set(valid_arch_str)), len(set(valid_arch_str)) / len(valid_arch_str)
def compute_novelty(self, unique):
num_novel = 0
novel = []
if self.train_arch_str_list is None:
print("Dataset arch_str is None, novelty computation skipped")
return 1, 1
for arch_str in unique:
if arch_str not in self.train_arch_str_list:
novel.append(arch_str)
num_novel += 1
return novel, num_novel / len(unique)
def evaluate(self, generated, check_dataname='cifar10'):
valid, validity, valid_arch_str, all_arch_str = self.compute_validity(generated)
if validity > 0:
unique, uniqueness = self.compute_uniqueness(valid_arch_str)
if self.train_arch_str_list is not None:
_, novelty = self.compute_novelty(unique)
else:
novelty = -1.0
else:
novelty = -1.0
uniqueness = 0.0
unique = []
if uniqueness > 0.:
arch_idx_list, flops_list, params_list, latency_list = list(), list(), list(), list()
for arch in unique:
arch_index, flops, params, latency = \
get_arch_acc_info(self.nasbench201, arch=arch, dataname=check_dataname)
arch_idx_list.append(arch_index)
flops_list.append(flops)
params_list.append(params)
latency_list.append(latency)
else:
arch_idx_list, flops_list, params_list, latency_list = [-1], [0], [0], [0]
return ([validity, uniqueness, novelty],
unique,
dict(arch_idx_list=arch_idx_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
all_arch_str)
class BasicArchMetricsMeta(object):
def __init__(self, train_ds=None, train_arch_str_list=None):
if train_ds is None:
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
else:
self.ops_decoder = train_ds.ops_decoder
self.nasbench201 = torch.load(NASBENCH201_INFO)
self.train_arch_str_list = train_arch_str_list
def compute_validity(self, generated):
START_TYPE = self.ops_decoder.index('input')
END_TYPE = self.ops_decoder.index('output')
valid = []
valid_arch_str = []
all_arch_str = []
error_types = []
for x in generated:
is_valid, error_type = is_valid_NAS201_x(x, START_TYPE, END_TYPE)
if is_valid:
valid.append(x)
arch_str = decode_x_to_NAS_BENCH_201_string(x, self.ops_decoder)
valid_arch_str.append(arch_str)
else:
arch_str = None
error_types.append(error_type)
all_arch_str.append(arch_str)
# exceptional case
validity = 0 if len(generated) == 0 else (len(valid)/len(generated))
if len(valid) == 0:
validity = 0
valid_arch_str = []
return valid, validity, valid_arch_str, all_arch_str
def compute_uniqueness(self, valid_arch_str):
return list(set(valid_arch_str)), len(set(valid_arch_str)) / len(valid_arch_str)
def compute_novelty(self, unique):
num_novel = 0
novel = []
if self.train_arch_str_list is None:
print("Dataset arch_str is None, novelty computation skipped")
return 1, 1
for arch_str in unique:
if arch_str not in self.train_arch_str_list:
novel.append(arch_str)
num_novel += 1
return novel, num_novel / len(unique)
def evaluate(self, generated, check_dataname='cifar10'):
valid, validity, valid_arch_str, all_arch_str = self.compute_validity(generated)
if validity > 0:
unique, uniqueness = self.compute_uniqueness(valid_arch_str)
if self.train_arch_str_list is not None:
_, novelty = self.compute_novelty(unique)
else:
novelty = -1.0
else:
novelty = -1.0
uniqueness = 0.0
unique = []
if uniqueness > 0.:
arch_idx_list, flops_list, params_list, latency_list = list(), list(), list(), list()
for arch in unique:
arch_index, flops, params, latency = \
get_arch_acc_info_meta(self.nasbench201, arch=arch, dataname=check_dataname)
arch_idx_list.append(arch_index)
flops_list.append(flops)
params_list.append(params)
latency_list.append(latency)
else:
arch_idx_list, flops_list, params_list, latency_list = [-1], [0], [0], [0]
return ([validity, uniqueness, novelty],
unique,
dict(arch_idx_list=arch_idx_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
all_arch_str)
def get_arch_acc_info(nasbench201, arch, dataname='cifar10'):
arch_index = nasbench201['str'].index(arch)
flops = nasbench201['flops'][dataname][arch_index]
params = nasbench201['params'][dataname][arch_index]
latency = nasbench201['latency'][dataname][arch_index]
return arch_index, flops, params, latency
def get_arch_acc_info_meta(nasbench201, arch, dataname='cifar10'):
arch_index = nasbench201['str'].index(arch)
flops = nasbench201['flops'][dataname][arch_index]
params = nasbench201['params'][dataname][arch_index]
latency = nasbench201['latency'][dataname][arch_index]
return arch_index, flops, params, latency
def decode_igraph_to_NAS_BENCH_201_string(g):
if not is_valid_NAS201(g):
return None
m = decode_igraph_to_NAS201_matrix(g)
types = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
return '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
format(types[int(m[1][0])],
types[int(m[2][0])], types[int(m[2][1])],
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
def decode_igraph_to_NAS201_matrix(g):
m = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
xys = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]
for i, xy in enumerate(xys):
m[xy[0]][xy[1]] = float(g.vs[i + 1]['type']) - 2
import numpy
return numpy.array(m)
def decode_x_to_NAS_BENCH_201_matrix(x):
m = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
xys = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]
for i, xy in enumerate(xys):
# m[xy[0]][xy[1]] = int(torch.argmax(torch.tensor(x[i+1])).item()) - 2
m[xy[0]][xy[1]] = int(torch.argmax(torch.tensor(x[i+1])).item())
import numpy
return numpy.array(m)
def decode_x_to_NAS_BENCH_201_string(x, ops_decoder):
"""_summary_
Args:
x (torch.Tensor): x_elem [8, 7]
Returns:
arch_str
"""
is_valid, error_type = is_valid_NAS201_x(x)
if not is_valid:
return None
m = decode_x_to_NAS_BENCH_201_matrix(x)
types = ops_decoder
arch_str = '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
format(types[int(m[1][0])],
types[int(m[2][0])], types[int(m[2][1])],
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
return arch_str
def decode_x_to_NAS_BENCH_201_string(x, ops_decoder):
"""_summary_
Args:
x (torch.Tensor): x_elem [8, 7]
Returns:
arch_str
"""
if not is_valid_NAS201_x(x)[0]:
return None
m = decode_x_to_NAS_BENCH_201_matrix(x)
types = ops_decoder
arch_str = '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
format(types[int(m[1][0])],
types[int(m[2][0])], types[int(m[2][1])],
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
return arch_str
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
res = g.is_dag()
n_start, n_end = 0, 0
for v in g.vs:
if v['type'] == START_TYPE:
n_start += 1
elif v['type'] == END_TYPE:
n_end += 1
if v.indegree() == 0 and v['type'] != START_TYPE:
return False
if v.outdegree() == 0 and v['type'] != END_TYPE:
return False
return res and n_start == 1 and n_end == 1
def is_valid_NAS201(g, START_TYPE=0, END_TYPE=1):
# first need to be a valid DAG computation graph
res = is_valid_DAG(g, START_TYPE, END_TYPE)
# in addition, node i must connect to node i+1
res = res and len(g.vs['type']) == 8
res = res and not (START_TYPE in g.vs['type'][1:-1])
res = res and not (END_TYPE in g.vs['type'][1:-1])
return res
def check_single_node_type(x):
for x_elem in x:
if int(np.sum(x_elem)) != 1:
return False
return True
def check_start_end_nodes(x, START_TYPE, END_TYPE):
if x[0][START_TYPE] != 1:
return False
if x[-1][END_TYPE] != 1:
return False
return True
def check_interm_node_types(x, START_TYPE, END_TYPE):
for x_elem in x[1:-1]:
if x_elem[START_TYPE] == 1:
return False
if x_elem[END_TYPE] == 1:
return False
return True
ERORR_NB201 = {
'MULTIPLE_NODE_TYPES': 1,
'No_START_END': 2,
'INTERM_START_END': 3,
'NO_ERROR': -1
}
def is_valid_NAS201_x(x, START_TYPE=0, END_TYPE=1):
# first need to be a valid DAG computation graph
assert len(x.shape) == 2
if not check_single_node_type(x):
return False, ERORR_NB201['MULTIPLE_NODE_TYPES']
if not check_start_end_nodes(x, START_TYPE, END_TYPE):
return False, ERORR_NB201['No_START_END']
if not check_interm_node_types(x, START_TYPE, END_TYPE):
return False, ERORR_NB201['INTERM_START_END']
return True, ERORR_NB201['NO_ERROR']
def compute_arch_metrics(arch_list,
train_arch_str_list,
train_ds,
check_dataname='cifar10'):
metrics = BasicArchMetrics(train_ds, train_arch_str_list)
arch_metrics = metrics.evaluate(arch_list, check_dataname=check_dataname)
all_arch_str = arch_metrics[-1]
return arch_metrics, all_arch_str
def compute_arch_metrics_meta(arch_list,
train_arch_str_list,
train_ds,
check_dataname='cifar10'):
metrics = BasicArchMetricsMeta(train_ds, train_arch_str_list)
arch_metrics = metrics.evaluate(arch_list, check_dataname=check_dataname)
return arch_metrics

View File

@@ -0,0 +1,77 @@
from analysis.arch_functions import compute_arch_metrics, compute_arch_metrics_meta
import torch.nn as nn
class SamplingArchMetrics(nn.Module):
def __init__(self,
config,
train_ds,
exp_name,):
super().__init__()
self.exp_name = exp_name
self.train_ds = train_ds
self.train_arch_str_list = train_ds.arch_str_list_
def forward(self,
arch_list: list,
this_sample_dir,
check_dataname='cifar10'):
arch_metrics, all_arch_str = compute_arch_metrics(arch_list=arch_list,
train_arch_str_list=self.train_arch_str_list,
train_ds=self.train_ds,
check_dataname=check_dataname)
valid_unique_arch = arch_metrics[1] # arch_str
valid_unique_arch_prop_dict = arch_metrics[2] # flops, params, latency
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
for i in range(len(valid_unique_arch)):
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n")
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n\n")
textfile.writelines(valid_unique_arch)
textfile.close()
return arch_metrics
class SamplingArchMetricsMeta(nn.Module):
def __init__(self,
config,
train_ds,
exp_name):
super().__init__()
self.exp_name = exp_name
self.train_ds = train_ds
self.search_space = config.data.name
self.train_arch_str_list = [train_ds.arch_str_list[i] for i in train_ds.idx_lst['train']]
def forward(self,
arch_list: list,
this_sample_dir,
check_dataname='cifar10'):
arch_metrics = compute_arch_metrics_meta(arch_list=arch_list,
train_arch_str_list=self.train_arch_str_list,
train_ds=self.train_ds,
check_dataname=check_dataname)
valid_unique_arch = arch_metrics[1] # arch_str
valid_unique_arch_prop_dict = arch_metrics[2] # flops, params, latency
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
for i in range(len(valid_unique_arch)):
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n")
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n\n")
textfile.writelines(valid_unique_arch)
textfile.close()
return arch_metrics

View File

@@ -0,0 +1,72 @@
"""Evaluate trained score network"""
import ml_collections
import torch
from all_path import SCORENET_CKPT_PATH
def get_config():
config = ml_collections.ConfigDict()
# general
config.folder_name = 'test'
config.model_type = 'scorenet'
config.task = 'eval_scorenet'
config.exp_name = None
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
config.resume = False
config.scorenet_ckpt_path = SCORENET_CKPT_PATH
# training
config.training = training = ml_collections.ConfigDict()
training.sde = 'vesde'
training.continuous = True
training.reduce_mean = True
training.noised = True
# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.method = 'pc'
sampling.predictor = 'euler_maruyama'
sampling.corrector = 'langevin'
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.16
# evaluation
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.batch_size = 256
evaluate.enable_sampling = True
evaluate.num_samples = 256
# data
config.data = data = ml_collections.ConfigDict()
data.centered = True
data.dequantization = False
data.root = '../data/transfer_nag/nasbench201_info.pt'
data.name = 'NASBench201'
data.split_ratio = 1.0
data.dataset_idx = 'random' # 'sorted' | 'random'
data.max_node = 8
data.n_vocab = 7 # number of operations
data.START_TYPE = 0
data.END_TYPE = 1
data.num_graphs = 15625
data.num_channels = 1
data.label_list = ['test-acc']
data.tg_dataset = 'cifar10'
# aug_mask
data.aug_mask_algo = 'floyd' # 'long_range' | 'floyd'
# model
config.model = model = ml_collections.ConfigDict()
model.num_scales = 1000
model.beta_min = 0.1
model.beta_max = 5.0
model.sigma_min = 0.1
model.sigma_max = 5.0
return config

View File

@@ -0,0 +1,125 @@
"""Training PGSN on Community Small Dataset with GraphGDP"""
import ml_collections
import torch
from all_path import SCORENET_CKPT_PATH
from all_path import NASBENCH201_INFO
def get_config():
config = ml_collections.ConfigDict()
# config.search_space = None
# general
config.folder_name = 'test'
config.model_type = 'meta_surrogate'
config.task = 'tr_meta_surrogate'
config.exp_name = None
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
config.resume = False
config.scorenet_ckpt_path = SCORENET_CKPT_PATH
# training
config.training = training = ml_collections.ConfigDict()
training.sde = 'vesde'
training.continuous = True
training.reduce_mean = True
training.noised = True
training.batch_size = 256
training.eval_batch_size = 100
training.n_iters = 10000
training.snapshot_freq = 500
training.log_freq = 100
training.eval_freq = 100
training.snapshot_sampling = True
training.likelihood_weighting = False
# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.method = 'pc'
sampling.predictor = 'euler_maruyama'
sampling.corrector = 'langevin'
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.16
# for conditional sampling
sampling.classifier_scale = 10000.0
sampling.regress = True
sampling.labels = 'max'
sampling.weight_ratio = False
sampling.weight_scheduling = False
sampling.check_dataname = 'cifar10'
# evaluation
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.batch_size = 512
evaluate.enable_sampling = True
evaluate.num_samples = 1024
# data
config.data = data = ml_collections.ConfigDict()
data.centered = True
data.dequantization = False
data.root = NASBENCH201_INFO
data.name = 'NASBench201'
data.max_node = 8
data.n_vocab = 7
data.START_TYPE = 0
data.END_TYPE = 1
data.num_channels = 1
data.label_list = ['meta-acc']
# aug_mask
data.aug_mask_algo = 'floyd' # 'long_range' | 'floyd'
# model
config.model = model = ml_collections.ConfigDict()
model.name = 'MetaNeuralPredictor'
model.ema_rate = 0.9999
model.normalization = 'GroupNorm'
model.nonlinearity = 'swish'
model.nf = 128
model.num_gnn_layers = 4
model.size_cond = False
model.embedding_type = 'positional'
model.rw_depth = 16
model.graph_layer = 'PosTransLayer'
model.edge_th = -1.
model.heads = 8
model.attn_clamp = False
# meta-predictor
model.input_type = 'DA'
model.hs = 32
model.nz = 56
model.num_sample = 20
model.num_scales = 1000
model.beta_min = 0.1
model.beta_max = 5.0
model.sigma_min = 0.1
model.sigma_max = 5.0
model.dropout = 0.1
# graph encoder
config.model.graph_encoder = graph_encoder = ml_collections.ConfigDict()
graph_encoder.initial_hidden = 7
graph_encoder.gcn_hidden = 144
graph_encoder.gcn_layers = 4
graph_encoder.linear_hidden = 128
# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 0.001
optim.beta1 = 0.9
optim.eps = 1e-8
optim.warmup = 1000
optim.grad_clip = 1.
return config

View File

@@ -0,0 +1,113 @@
"""Training Score Network"""
import ml_collections
import torch
def get_config():
config = ml_collections.ConfigDict()
# general
config.folder_name = 'test'
config.model_type = 'scorenet'
config.task = 'tr_scorenet'
config.exp_name = None
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
config.resume = False
config.resume_ckpt_path = ''
# training
config.training = training = ml_collections.ConfigDict()
training.sde = 'vesde'
training.continuous = True
training.reduce_mean = True
training.batch_size = 256
training.eval_batch_size = 1000
training.n_iters = 250000
training.snapshot_freq = 10000
training.log_freq = 200
training.eval_freq = 10000
training.snapshot_sampling = True
training.likelihood_weighting = False
# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.method = 'pc'
sampling.predictor = 'euler_maruyama'
sampling.corrector = 'langevin'
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.16
# evaluation
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.batch_size = 1024
evaluate.enable_sampling = True
evaluate.num_samples = 1024
# data
config.data = data = ml_collections.ConfigDict()
data.centered = True
data.dequantization = False
data.root = '../data/transfer_nag/nasbench201_info.pt'
data.name = 'NASBench201'
data.split_ratio = 1.0
data.dataset_idx = 'random' # 'sorted' | 'random'
data.max_node = 8
data.n_vocab = 7 # number of operations
data.START_TYPE = 0
data.END_TYPE = 1
data.num_graphs = 15625
data.num_channels = 1
data.label_list = None
data.tg_dataset = None
# aug_mask
data.aug_mask_algo = 'floyd' # 'long_range' | 'floyd'
# model
config.model = model = ml_collections.ConfigDict()
model.name = 'CATE'
model.ema_rate = 0.9999
model.normalization = 'GroupNorm'
model.nonlinearity = 'swish'
model.nf = 128
model.num_gnn_layers = 4
model.size_cond = False
model.embedding_type = 'positional'
model.rw_depth = 16
model.graph_layer = 'PosTransLayer'
model.edge_th = -1.
model.heads = 8
model.attn_clamp = False
# for pos emb
model.pos_enc_type = 2
model.num_scales = 1000
model.sigma_min = 0.1
model.sigma_max = 5.0
model.dropout = 0.1
# graph encoder
config.model.graph_encoder = graph_encoder = ml_collections.ConfigDict()
graph_encoder.n_layers = 12
graph_encoder.d_model = 64
graph_encoder.n_head = 8
graph_encoder.d_ff = 128
graph_encoder.dropout = 0.1
graph_encoder.n_vocab = 7
# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 2e-5
optim.beta1 = 0.9
optim.eps = 1e-8
optim.warmup = 1000
optim.grad_clip = 1.
return config

View File

@@ -0,0 +1,469 @@
from __future__ import print_function
import torch
import os
import numpy as np
from collections import defaultdict
from torch.utils.data import DataLoader, Dataset
from analysis.arch_functions import decode_x_to_NAS_BENCH_201_matrix, decode_x_to_NAS_BENCH_201_string
from all_path import *
def get_data_scaler(config):
"""Data normalizer. Assume data are always in [0, 1]."""
if config.data.centered:
# Rescale to [-1, 1]
return lambda x: x * 2. - 1.
else:
return lambda x: x
def get_data_inverse_scaler(config):
"""Inverse data normalizer."""
if config.data.centered:
# Rescale [-1, 1] to [0, 1]
return lambda x: (x + 1.) / 2.
else:
return lambda x: x
def is_triu(mat):
is_triu_ = np.allclose(mat, np.triu(mat))
return is_triu_
def get_dataset(config):
train_dataset = NASBench201Dataset(
data_path=NASBENCH201_INFO,
mode='train')
eval_dataset = NASBench201Dataset(
data_path=NASBENCH201_INFO,
mode='eval')
test_dataset = NASBench201Dataset(
data_path=NASBENCH201_INFO,
mode='test')
return train_dataset, eval_dataset, test_dataset
def get_dataloader(config, train_dataset, eval_dataset, test_dataset):
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.training.batch_size,
shuffle=True,
collate_fn=None)
eval_loader = DataLoader(dataset=eval_dataset,
batch_size=config.training.batch_size,
shuffle=False,
collate_fn=None)
test_loader = DataLoader(dataset=test_dataset,
batch_size=config.training.batch_size,
shuffle=False,
collate_fn=None)
return train_loader, eval_loader, test_loader
class NASBench201Dataset(Dataset):
def __init__(
self,
data_path,
split_ratio=1.0,
mode='train',
label_list=None,
tg_dataset=None):
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
# ---------- entire dataset ---------- #
self.data = torch.load(data_path)
# ---------- igraph ---------- #
self.igraph_list = self.data['g']
# ---------- x ---------- #
self.x_list = self.data['x']
# ---------- adj ---------- #
adj = self.get_adj()
self.adj_list = [adj] * len(self.igraph_list)
# ---------- matrix ---------- #
self.matrix_list = self.data['matrix']
# ---------- arch_str ---------- #
self.arch_str_list = self.data['str']
# ---------- labels ---------- #
self.label_list = label_list
if self.label_list is not None:
self.val_acc_list = self.data['val-acc'][tg_dataset]
self.test_acc_list = self.data['test-acc'][tg_dataset]
self.flops_list = self.data['flops'][tg_dataset]
self.params_list = self.data['params'][tg_dataset]
self.latency_list = self.data['latency'][tg_dataset]
# ----------- split dataset ---------- #
self.ds_idx = list(torch.load(DATA_PATH + '/ridx.pt'))
self.split_ratio = split_ratio
num_train = int(len(self.x_list) * self.split_ratio)
num_test = len(self.x_list) - num_train
# ----------- compute mean and std w/ training dataset ---------- #
if self.label_list is not None:
self.train_idx_list = self.ds_idx[:num_train]
print('>>> Computing mean and std of the training set...')
LABEL_TO_MEAN_STD = defaultdict(dict)
assert type(self.label_list) == list, f"self.label_list is {type(self.label_list)}"
for label in self.label_list:
if label == 'val-acc':
self.val_acc_list_tr = [self.val_acc_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.val_acc_list_tr))
elif label == 'test-acc':
self.test_acc_list_tr = [self.test_acc_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.test_acc_list_tr))
elif label == 'flops':
self.flops_list_tr = [self.flops_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.flops_list_tr))
elif label == 'params':
self.params_list_tr = [self.params_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.params_list_tr))
elif label == 'latency':
self.latency_list_tr = [self.latency_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.latency_list_tr))
else:
raise ValueError
self.mode = mode
if self.mode in ['train']:
self.idx_list = self.ds_idx[:num_train]
elif self.mode in ['eval']:
if num_test == 0:
self.idx_list = self.ds_idx[:100]
else:
self.idx_list = self.ds_idx[:num_test]
elif self.mode in ['test']:
if num_test == 0:
self.idx_list = self.ds_idx[15000:]
else:
self.idx_list = self.ds_idx[num_train:]
self.igraph_list_ = [self.igraph_list[i] for i in self.idx_list]
self.x_list_ = [self.x_list[i] for i in self.idx_list]
self.adj_list_ = [self.adj_list[i] for i in self.idx_list]
self.matrix_list_ = [self.matrix_list[i] for i in self.idx_list]
self.arch_str_list_ = [self.arch_str_list[i] for i in self.idx_list]
if self.label_list is not None:
assert type(self.label_list) == list
for label in self.label_list:
if label == 'val-acc':
self.val_acc_list_ = [self.val_acc_list[i] for i in self.idx_list]
self.val_acc_list_ = self.normalize(self.val_acc_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'test-acc':
self.test_acc_list_ = [self.test_acc_list[i] for i in self.idx_list]
self.test_acc_list_ = self.normalize(self.test_acc_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'flops':
self.flops_list_ = [self.flops_list[i] for i in self.idx_list]
self.flops_list_ = self.normalize(self.flops_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'params':
self.params_list_ = [self.params_list[i] for i in self.idx_list]
self.params_list_ = self.normalize(self.params_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'latency':
self.latency_list_ = [self.latency_list[i] for i in self.idx_list]
self.latency_list_ = self.normalize(self.latency_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
else:
raise ValueError
def normalize(self, original, mean, std):
return [(i-mean)/std for i in original]
# def get_not_connect_prev_adj(self):
def get_adj(self):
adj = np.asarray(
[[0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0]]
)
adj = torch.tensor(adj, dtype=torch.float32, device=torch.device('cpu'))
return adj
@property
def adj(self):
return self.adj_list_[0]
def mask(self, algo='floyd'):
from utils import aug_mask
return aug_mask(self.adj, algo=algo)[0]
def get_unnoramlized_entire_data(self, label, tg_dataset):
entire_val_acc_list = self.data['val-acc'][tg_dataset]
entire_test_acc_list = self.data['test-acc'][tg_dataset]
entire_flops_list = self.data['flops'][tg_dataset]
entire_params_list = self.data['params'][tg_dataset]
entire_latency_list = self.data['latency'][tg_dataset]
if label == 'val-acc':
return entire_val_acc_list
elif label == 'test-acc':
return entire_test_acc_list
elif label == 'flops':
return entire_flops_list
elif label == 'params':
return entire_params_list
elif label == 'latency':
return entire_latency_list
else:
raise ValueError
def get_unnoramlized_data(self, label, tg_dataset):
entire_val_acc_list = self.data['val-acc'][tg_dataset]
entire_test_acc_list = self.data['test-acc'][tg_dataset]
entire_flops_list = self.data['flops'][tg_dataset]
entire_params_list = self.data['params'][tg_dataset]
entire_latency_list = self.data['latency'][tg_dataset]
if label == 'val-acc':
return [entire_val_acc_list[i] for i in self.idx_list]
elif label == 'test-acc':
return [entire_test_acc_list[i] for i in self.idx_list]
elif label == 'flops':
return [entire_flops_list[i] for i in self.idx_list]
elif label == 'params':
return [entire_params_list[i] for i in self.idx_list]
elif label == 'latency':
return [entire_latency_list[i] for i in self.idx_list]
else:
raise ValueError
def __len__(self):
return len(self.x_list_)
def __getitem__(self, index):
label_dict = {}
if self.label_list is not None:
assert type(self.label_list) == list
for label in self.label_list:
if label == 'val-acc':
label_dict[f"{label}"] = self.val_acc_list_[index]
elif label == 'test-acc':
label_dict[f"{label}"] = self.test_acc_list_[index]
elif label == 'flops':
label_dict[f"{label}"] = self.flops_list_[index]
elif label == 'params':
label_dict[f"{label}"] = self.params_list_[index]
elif label == 'latency':
label_dict[f"{label}"] = self.latency_list_[index]
else:
raise ValueError
return self.x_list_[index], self.adj_list_[index], label_dict
# ---------- Meta-Dataset ---------- #
def get_meta_dataset(config):
train_dataset = MetaTrainDatabase(
data_path=DATA_PATH,
num_sample=config.model.num_sample,
label_list=config.data.label_list,
mode='train')
eval_dataset = MetaTrainDatabase(
data_path=DATA_PATH,
num_sample=config.model.num_sample,
label_list=config.data.label_list,
mode='eval')
test_dataset = None
return train_dataset, eval_dataset, test_dataset
def get_meta_dataloader(config ,train_dataset, eval_dataset, test_dataset):
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.training.batch_size,
shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset,
batch_size=config.training.batch_size,
shuffle=False)
test_loader = None
return train_loader, eval_loader, test_loader
class MetaTrainDatabase(Dataset):
def __init__(
self,
data_path,
num_sample,
label_list,
mode='train'):
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
self.mode = mode
self.acc_norm = True
self.num_sample = num_sample
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
mtr_data_path = os.path.join(data_path, 'meta_train_tasks_predictor.pt')
idx_path = os.path.join(data_path, 'meta_train_tasks_predictor_idx.pt')
data = torch.load(mtr_data_path)
self.acc_list = data['acc']
self.task = data['task']
# ---------- igraph ---------- #
self.igraph_list = data['g']
# ---------- x ---------- #
self.x_list = data['x']
# ---------- adj ---------- #
adj = self.get_adj()
self.adj_list = [adj] * len(self.igraph_list)
# ---------- matrix ----------- #
if 'matrix' in data:
self.matrix_list = data['matrix']
else:
self.matrix_list = [decode_x_to_NAS_BENCH_201_matrix(i) for i in self.x_list]
# ---------- arch_str ---------- #
if 'str' in data:
self.arch_str_list = data['str']
else:
self.arch_str_list = [decode_x_to_NAS_BENCH_201_string(i, self.ops_decoder) for i in self.x_list]
# ---------- label ---------- #
self.label_list = label_list
if self.label_list is not None:
self.flops_list = torch.tensor(data['flops'])
self.params_list = torch.tensor(data['params'])
self.latency_list = torch.tensor(data['latency'])
random_idx_lst = torch.load(idx_path)
self.idx_lst = {}
self.idx_lst['eval'] = random_idx_lst[:400]
self.idx_lst['train'] = random_idx_lst[400:]
self.acc_list = torch.tensor(self.acc_list)
self.mean = torch.mean(self.acc_list[self.idx_lst['train']]).item()
self.std = torch.std(self.acc_list[self.idx_lst['train']]).item()
self.task_lst = torch.load(os.path.join(data_path, 'meta_train_task_lst.pt'))
def get_adj(self):
adj = np.asarray(
[[0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0]]
)
adj = torch.tensor(adj, dtype=torch.float32, device=torch.device('cpu'))
return adj
@property
def adj(self):
return self.adj_list[0]
def mask(self, algo='floyd'):
from utils import aug_mask
return aug_mask(self.adj, algo=algo)[0]
def set_mode(self, mode):
self.mode = mode
def __len__(self):
return len(self.idx_lst[self.mode])
def __getitem__(self, index):
data = []
ridx = self.idx_lst[self.mode]
tidx = self.task[ridx[index]]
classes = self.task_lst[tidx]
# ---------- igraph -----------
graph = self.igraph_list[ridx[index]]
# ---------- x -----------
x = self.x_list[ridx[index]]
# ---------- adj ----------
adj = self.adj_list[ridx[index]]
acc = self.acc_list[ridx[index]]
for cls in classes:
cx = self.x[cls-1][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
task = torch.cat(data)
if self.acc_norm:
acc = ((acc- self.mean) / self.std) / 100.0
else:
acc = acc / 100.0
label_dict = {}
if self.label_list is not None:
assert type(self.label_list) == list
for label in self.label_list:
if label == 'meta-acc':
label_dict[f"{label}"] = acc
elif label == 'flops':
label_dict[f"{label}"] = self.flops_list[ridx[index]]
elif label == 'params':
label_dict[f"{label}"] = self.params_list[ridx[index]]
elif label == 'latency':
label_dict[f"{label}"] = self.latency_list[ridx[index]]
else:
raise ValueError
return x, adj, label_dict, task
class MetaTestDataset(Dataset):
def __init__(self, data_path, data_name, num_sample, num_class=None):
self.num_sample = num_sample
self.data_name = data_name
num_class_dict = {
'cifar100': 100,
'cifar10': 10,
'aircraft': 30,
'pets': 37
}
if num_class is not None:
self.num_class = num_class
else:
self.num_class = num_class_dict[data_name]
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
def __len__(self):
return 1000000
def __getitem__(self, index):
data = []
classes = list(range(self.num_class))
for cls in classes:
cx = self.x[cls][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
return x

137
NAS-Bench-201/logger.py Normal file
View File

@@ -0,0 +1,137 @@
import os
import wandb
import torch
import numpy as np
class Logger:
def __init__(
self,
log_dir=None,
write_textfile=True
):
self.log_dir = log_dir
self.write_textfile = write_textfile
self.logs_for_save = {}
self.logs = {}
if self.write_textfile:
self.f = open(os.path.join(log_dir, 'logs.txt'), 'w')
def write_str(self, log_str):
self.f.write(log_str+'\n')
self.f.flush()
def update_config(self, v, is_args=False):
if is_args:
self.logs_for_save.update({'args': v})
else:
self.logs_for_save.update(v)
def write_log(self, element, step, return_log_dict=False):
log_str = f"{step} | "
log_dict = {}
for head, keys in element.items():
for k in keys:
if k in self.logs:
v = self.logs[k].avg
if not k in self.logs_for_save:
self.logs_for_save[k] = []
self.logs_for_save[k].append(v)
log_str += f'{k} {v}| '
log_dict[f'{head}/{k}'] = v
if self.write_textfile:
self.f.write(log_str+'\n')
self.f.flush()
if return_log_dict:
return log_dict
def save_log(self, name=None):
name = 'logs.pt' if name is None else name
torch.save(self.logs_for_save, os.path.join(self.log_dir, name))
def update(self, key, v, n=1):
if not key in self.logs:
self.logs[key] = AverageMeter()
self.logs[key].update(v, n)
def reset(self, keys=None, except_keys=[]):
if keys is not None:
if isinstance(keys, list):
for key in keys:
self.logs[key] = AverageMeter()
else:
self.logs[keys] = AverageMeter()
else:
for key in self.logs.keys():
if not key in except_keys:
self.logs[key] = AverageMeter()
def avg(self, keys=None, except_keys=[]):
if keys is not None:
if isinstance(keys, list):
return {key: self.logs[key].avg for key in keys if key in self.logs.keys()}
else:
return self.logs[keys].avg
else:
avg_dict = {}
for key in self.logs.keys():
if not key in except_keys:
avg_dict[key] = self.logs[key].avg
return avg_dict
class AverageMeter(object):
"""
Computes and stores the average and current value
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_metrics(g_embeds, x_embeds, logit_scale, prefix='train'):
metrics = {}
logits_per_g = (logit_scale * g_embeds @ x_embeds.t()).detach().cpu()
logits_per_x = logits_per_g.t().detach().cpu()
logits = {"g_to_x": logits_per_g, "x_to_g": logits_per_x}
ground_truth = torch.arange(len(x_embeds)).view(-1, 1)
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1]
preds = preds.detach().cpu().numpy()
metrics[f"{prefix}_{name}_mean_rank"] = preds.mean() + 1
metrics[f"{prefix}_{name}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{prefix}_{name}_R@{k}"] = np.mean(preds < k)
return metrics

369
NAS-Bench-201/losses.py Normal file
View File

@@ -0,0 +1,369 @@
"""All functions related to loss computation and optimization."""
import torch
import torch.optim as optim
import numpy as np
from models import utils as mutils
from sde_lib import VPSDE, VESDE
def get_optimizer(config, params):
"""Return a flax optimizer object based on `config`."""
if config.optim.optimizer == 'Adam':
optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
weight_decay=config.optim.weight_decay)
else:
raise NotImplementedError(
f'Optimizer {config.optim.optimizer} not supported yet!'
)
return optimizer
def optimization_manager(config):
"""Return an optimize_fn based on `config`."""
def optimize_fn(optimizer, params, step, lr=config.optim.lr,
warmup=config.optim.warmup,
grad_clip=config.optim.grad_clip):
"""Optimize with warmup and gradient clipping (disabled if negative)."""
if warmup > 0:
for g in optimizer.param_groups:
g['lr'] = lr * np.minimum(step / warmup, 1.0)
if grad_clip >= 0:
torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
optimizer.step()
return optimize_fn
def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5):
"""Create a loss function for training with arbitrary SDEs.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
train: `True` for training loss and `False` for evaluation loss.
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
eps: A `float` number. The smallest time step to sample from.
Returns:
A loss function.
"""
# reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
def loss_fn(model, batch):
"""Compute the loss function.
Args:
model: A score model.
batch: A mini-batch of training data, including adjacency matrices and mask.
Returns:
loss: A scalar that represents the average loss value across the mini-batch.
"""
adj, mask = batch
score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
t = torch.rand(adj.shape[0], device=adj.device) * (sde.T - eps) + eps
z = torch.randn_like(adj) # [B, C, N, N]
z = torch.tril(z, -1)
z = z + z.transpose(2, 3)
mean, std = sde.marginal_prob(adj, t)
mean = torch.tril(mean, -1)
mean = mean + mean.transpose(2, 3)
perturbed_data = mean + std[:, None, None, None] * z
score = score_fn(perturbed_data, t, mask=mask)
mask = torch.tril(mask, -1)
mask = mask + mask.transpose(2, 3)
mask = mask.reshape(mask.shape[0], -1) # low triangular part of adj matrices
if not likelihood_weighting:
losses = torch.square(score * std[:, None, None, None] + z)
losses = losses.reshape(losses.shape[0], -1)
if reduce_mean:
losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
else:
losses = 0.5 * torch.sum(losses * mask, dim=-1)
loss = losses.mean()
else:
g2 = sde.sde(torch.zeros_like(adj), t)[1] ** 2
losses = torch.square(score + z / std[:, None, None, None])
losses = losses.reshape(losses.shape[0], -1)
if reduce_mean:
losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
else:
losses = 0.5 * torch.sum(losses * mask, dim=-1)
loss = (losses * g2).mean()
return loss
return loss_fn
def get_sde_loss_fn_nas(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5):
"""Create a loss function for training with arbitrary SDEs.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
train: `True` for training loss and `False` for evaluation loss.
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
eps: A `float` number. The smallest time step to sample from.
Returns:
A loss function.
"""
def loss_fn(model, batch):
"""Compute the loss function.
Args:
model: A score model.
batch: A mini-batch of training data, including adjacency matrices and mask.
Returns:
loss: A scalar that represents the average loss value across the mini-batch.
"""
x, adj, mask = batch
score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
t = torch.rand(x.shape[0], device=adj.device) * (sde.T - eps) + eps
z = torch.randn_like(x) # [B, C, N, N]
mean, std = sde.marginal_prob(x, t)
perturbed_data = mean + std[:, None, None] * z
score = score_fn(perturbed_data, t, mask)
if not likelihood_weighting:
losses = torch.square(score * std[:, None, None] + z)
losses = losses.reshape(losses.shape[0], -1)
if reduce_mean:
losses = torch.mean(losses, dim=-1)
else:
losses = 0.5 * torch.sum(losses, dim=-1)
loss = losses.mean()
else:
g2 = sde.sde(torch.zeros_like(x), t)[1] ** 2
losses = torch.square(score + z / std[:, None, None])
losses = losses.reshape(losses.shape[0], -1)
if reduce_mean:
losses = torch.mean(losses, dim=-1)
else:
losses = 0.5 * torch.sum(losses, dim=-1)
loss = (losses * g2).mean()
return loss
return loss_fn
def get_step_fn(sde,
train,
optimize_fn=None,
reduce_mean=False,
continuous=True,
likelihood_weighting=False,
data='NASBench201'):
"""Create a one-step training/evaluation function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
Tuple (`sde_lib.SDE`, `sde_lib.SDE`) that represents the forward node SDE and edge SDE.
optimize_fn: An optimization function.
reduce_mean: If `True`, average the loss across data dimensions.
Otherwise, sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according to
https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended by score-sde.
Returns:
A one-step function for training or evaluation.
"""
if continuous:
if data in ['NASBench201', 'ofa']:
loss_fn = get_sde_loss_fn_nas(sde, train, reduce_mean=reduce_mean,
continuous=True, likelihood_weighting=likelihood_weighting)
else:
raise NotImplementedError(f"Data {data} (search space) is not supported yet.")
else:
raise NotImplementedError(f"Discrete training for {sde.__class__.__name__} is not implemented.")
def step_fn(state, batch):
"""Running one step of training or evaluation.
For jax version: This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and
jit-compiled together for faster execution.
Args:
state: A dictionary of training information, containing the score model, optimizer,
EMA status, and number of optimization steps.
batch: A mini-batch of training/evaluation data, including min-batch adjacency matrices and mask.
Returns:
loss: The average loss value of this state.
"""
model = state['model']
if train:
optimizer = state['optimizer']
optimizer.zero_grad()
loss = loss_fn(model, batch)
loss.backward()
optimize_fn(optimizer, model.parameters(), step=state['step'])
state['step'] += 1
state['ema'].update(model.parameters())
else:
with torch.no_grad():
ema = state['ema']
ema.store(model.parameters())
ema.copy_to(model.parameters())
loss = loss_fn(model, batch)
ema.restore(model.parameters())
return loss
return step_fn
# ------------------- predictor -------------------
def get_meta_predictor_loss_fn_nas(sde,
train,
reduce_mean=True,
continuous=True,
likelihood_weighting=True,
eps=1e-5,
label_list=None,
noised=True):
"""Create a loss function for training with arbitrary SDEs.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
train: `True` for training loss and `False` for evaluation loss.
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
eps: A `float` number. The smallest time step to sample from.
Returns:
A loss function.
"""
def loss_fn(model, batch):
"""Compute the loss function.
Args:
model: A score model.
batch: A mini-batch of training data, including adjacency matrices and mask.
Returns:
loss: A scalar that represents the average loss value across the mini-batch.
"""
x, adj, mask, extra, task = batch
predictor_fn = mutils.get_predictor_fn(sde, model, train=train, continuous=continuous)
if noised:
t = torch.rand(x.shape[0], device=adj.device) * (sde.T - eps) + eps
z = torch.randn_like(x) # [B, C, N, N]
mean, std = sde.marginal_prob(x, t)
perturbed_data = mean + std[:, None, None] * z
pred = predictor_fn(perturbed_data, t, mask, task)
else:
t = eps * torch.ones(x.shape[0], device=adj.device)
pred = predictor_fn(x, t, mask, task)
labels = extra[f"{label_list[-1]}"]
labels = labels.to(pred.device).unsqueeze(1).type(pred.dtype)
loss = torch.nn.MSELoss()(pred, labels)
return loss, pred, labels
return loss_fn
def get_step_fn_predictor(sde,
train,
optimize_fn=None,
reduce_mean=False,
continuous=True,
likelihood_weighting=False,
data='NASBench201',
label_list=None,
noised=True):
"""Create a one-step training/evaluation function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
Tuple (`sde_lib.SDE`, `sde_lib.SDE`) that represents the forward node SDE and edge SDE.
optimize_fn: An optimization function.
reduce_mean: If `True`, average the loss across data dimensions.
Otherwise, sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according to
https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended by score-sde.
Returns:
A one-step function for training or evaluation.
"""
if continuous:
if data in ['NASBench201', 'ofa']:
loss_fn = get_meta_predictor_loss_fn_nas(sde,
train,
reduce_mean=reduce_mean,
continuous=True,
likelihood_weighting=likelihood_weighting,
label_list=label_list,
noised=noised)
else:
raise NotImplementedError(f"Data {data} (search space) is not supported yet.")
else:
raise NotImplementedError(f"Discrete training for {sde.__class__.__name__} is not implemented.")
def step_fn(state, batch):
"""Running one step of training or evaluation.
For jax version: This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and
jit-compiled together for faster execution.
Args:
state: A dictionary of training information, containing the score model, optimizer,
EMA status, and number of optimization steps.
batch: A mini-batch of training/evaluation data, including min-batch adjacency matrices and mask.
Returns:
loss: The average loss value of this state.
"""
model = state['model']
if train:
model.train()
optimizer = state['optimizer']
optimizer.zero_grad()
loss, pred, labels = loss_fn(model, batch)
loss.backward()
optimize_fn(optimizer, model.parameters(), step=state['step'])
state['step'] += 1
else:
model.eval()
with torch.no_grad():
loss, pred, labels = loss_fn(model, batch)
return loss, pred, labels
return step_fn

37
NAS-Bench-201/main.py Normal file
View File

@@ -0,0 +1,37 @@
"""Training and evaluation"""
import run_lib
from absl import app, flags
from ml_collections.config_flags import config_flags
import logging
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
'config', None, 'Training configuration.', lock_config=True
)
config_flags.DEFINE_config_file(
'classifier_config_nf', None, 'Training configuration.', lock_config=True
)
flags.DEFINE_enum('mode', None, ['train', 'eval'],
'Running mode: train or eval')
def main(argv):
## Set random seed
run_lib.set_random_seed(FLAGS.config)
if FLAGS.mode == 'train':
logger = logging.getLogger()
logger.setLevel('INFO')
run_lib.train(FLAGS.config)
elif FLAGS.mode == 'eval':
run_lib.evaluate(FLAGS.config)
else:
raise ValueError(f"Mode {FLAGS.mode} not recognized.")
if __name__ == '__main__':
app.run(main)

View File

@@ -0,0 +1,286 @@
import torch
import sys
import numpy as np
import random
sys.path.append('.')
import sampling
import datasets_nas
from models import cate
from models import digcn
from models import digcn_meta
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import sde_lib
from utils import *
from analysis.arch_functions import BasicArchMetricsMeta
from all_path import *
def get_sampling_fn_meta(config):
## Set SDE
if config.training.sde.lower() == 'vpsde':
sde = sde_lib.VPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales)
sampling_eps = 1e-3
elif config.training.sde.lower() == 'subvpsde':
sde = sde_lib.subVPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales)
sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
sde = sde_lib.VESDE(
sigma_min=config.model.sigma_min,
sigma_max=config.model.sigma_max,
N=config.model.num_scales)
sampling_eps = 1e-5
else:
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
## Get data normalizer inverse
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
## Get sampling function
sampling_shape = (config.eval.batch_size, config.data.max_node, config.data.n_vocab)
sampling_fn = sampling.get_sampling_fn(
config=config,
sde=sde,
shape=sampling_shape,
inverse_scaler=inverse_scaler,
eps=sampling_eps,
conditional=True,
data_name=config.sampling.check_dataname,
num_sample=config.model.num_sample)
return sampling_fn, sde
def get_score_model(config):
try:
score_config = torch.load(config.scorenet_ckpt_path)['config']
ckpt_path = config.scorenet_ckpt_path
except:
config.scorenet_ckpt_path = SCORENET_CKPT_PATH
score_config = torch.load(config.scorenet_ckpt_path)['config']
ckpt_path = config.scorenet_ckpt_path
score_model = mutils.create_model(score_config)
score_ema = ExponentialMovingAverage(
score_model.parameters(), decay=score_config.model.ema_rate)
score_state = dict(
model=score_model, ema=score_ema, step=0, config=score_config)
score_state = restore_checkpoint(
ckpt_path, score_state,
device=config.device, resume=True)
score_ema.copy_to(score_model.parameters())
return score_model, score_ema, score_config
def get_surrogate(config):
surrogate_model = mutils.create_model(config)
return surrogate_model
def get_adj(except_inout=False):
_adj = np.asarray(
[[0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0]]
)
_adj = torch.tensor(_adj, dtype=torch.float32, device=torch.device('cpu'))
if except_inout: _adj = _adj[1:-1, 1:-1]
return _adj
def generate_archs_meta(
config,
sampling_fn,
score_model,
score_ema,
meta_surrogate_model,
num_samples,
args=None,
task=None,
patient_factor=20,
batch_size=256,):
metrics = BasicArchMetricsMeta()
## Get the adj and mask
adj_s = get_adj()
mask_s = aug_mask(adj_s)[0]
adj_c = get_adj()
mask_c = aug_mask(adj_c)[0]
assert (adj_s == adj_c).all() and (mask_s == mask_c).all()
adj_s, mask_s, adj_c, mask_c = \
adj_s.to(config.device), mask_s.to(config.device), adj_c.to(config.device), mask_c.to(config.device)
score_ema.copy_to(score_model.parameters())
score_model.eval()
meta_surrogate_model.eval()
c_scale = args.classifier_scale
num_sampling_rounds = int(np.ceil(num_samples / batch_size) * patient_factor) if num_samples > batch_size else int(patient_factor)
round = 0
all_samples = []
while True and round < num_sampling_rounds:
round += 1
sample = sampling_fn(score_model,
mask_s,
meta_surrogate_model,
classifier_scale=c_scale,
task=task)
quantized_sample = quantize(sample)
_, _, valid_arch_str, _ = metrics.compute_validity(quantized_sample)
if len(valid_arch_str) > 0: all_samples += valid_arch_str
# to sample various architectures
c_scale -= args.scale_step
seed = int(random.random() * 10000)
reset_seed(seed)
# stop sampling if we have enough samples
if (len(set(all_samples)) >= num_samples):
break
return list(set(all_samples))
def save_checkpoint(ckpt_dir, state, epoch, is_best):
saved_state = {}
for k in state:
if k in ['optimizer', 'model', 'ema']:
saved_state.update({k: state[k].state_dict()})
else:
saved_state.update({k: state[k]})
os.makedirs(ckpt_dir, exist_ok=True)
torch.save(saved_state, os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'))
if is_best:
shutil.copy(os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'), os.path.join(ckpt_dir, 'model_best.pth.tar'))
# remove the ckpt except is_best state
for ckpt_file in sorted(os.listdir(ckpt_dir)):
if not ckpt_file.startswith('checkpoint'):
continue
if os.path.join(ckpt_dir, ckpt_file) != os.path.join(ckpt_dir, 'model_best.pth.tar'):
os.remove(os.path.join(ckpt_dir, ckpt_file))
def restore_checkpoint(ckpt_dir, state, device, resume=False):
if not resume:
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
return state
elif not os.path.exists(ckpt_dir):
if not os.path.exists(os.path.dirname(ckpt_dir)):
os.makedirs(os.path.dirname(ckpt_dir))
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
return state
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
for k in state:
if k in ['optimizer', 'model', 'ema']:
state[k].load_state_dict(loaded_state[k])
else:
state[k] = loaded_state[k]
return state
def floyed(r):
"""
:param r: a numpy NxN matrix with float 0,1
:return: a numpy NxN matrix with float 0,1
"""
if type(r) == torch.Tensor:
r = r.cpu().numpy()
N = r.shape[0]
for k in range(N):
for i in range(N):
for j in range(N):
if r[i, k] > 0 and r[k, j] > 0:
r[i, j] = 1
return r
def aug_mask(adj, algo='floyed', data='NASBench201'):
if len(adj.shape) == 2:
adj = adj.unsqueeze(0)
if data.lower() in ['nasbench201', 'ofa']:
assert len(adj.shape) == 3
r = adj[0].clone().detach()
if algo == 'long_range':
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
elif algo == 'floyed':
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
else:
mask_i = r
masks = [mask_i] * adj.size(0)
return torch.stack(masks)
else:
masks = []
for r in adj:
if algo == 'long_range':
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
elif algo == 'floyed':
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
else:
mask_i = r
masks.append(mask_i)
return torch.stack(masks)
def long_range(r):
"""
:param r: a numpy NxN matrix with float 0,1
:return: a numpy NxN matrix with float 0,1
"""
# r = np.array(r)
if type(r) == torch.Tensor:
r = r.cpu().numpy()
N = r.shape[0]
for j in range(1, N):
col_j = r[:, j][:j]
in_to_j = [i for i, val in enumerate(col_j) if val > 0]
if len(in_to_j) > 0:
for i in in_to_j:
col_i = r[:, i][:i]
in_to_i = [i for i, val in enumerate(col_i) if val > 0]
if len(in_to_i) > 0:
for k in in_to_i:
r[k, j] = 1
return r
def quantize(x):
"""Covert the PyTorch tensor x, adj matrices to numpy array.
Args:
x: [Batch_size, Max_node, N_vocab]
"""
x_list = []
# discretization
x[x >= 0.5] = 1.
x[x < 0.5] = 0.
for i in range(x.shape[0]):
x_tmp = x[i]
x_tmp = x_tmp.cpu().numpy()
x_list.append(x_tmp)
return x_list
def reset_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True

View File

@@ -0,0 +1,137 @@
import os
import wandb
import torch
import numpy as np
class Logger:
def __init__(
self,
log_dir=None,
write_textfile=True
):
self.log_dir = log_dir
self.write_textfile = write_textfile
self.logs_for_save = {}
self.logs = {}
if self.write_textfile:
self.f = open(os.path.join(log_dir, 'logs.txt'), 'w')
def write_str(self, log_str):
self.f.write(log_str+'\n')
self.f.flush()
def update_config(self, v, is_args=False):
if is_args:
self.logs_for_save.update({'args': v})
else:
self.logs_for_save.update(v)
def write_log(self, element, step, return_log_dict=False):
log_str = f"{step} | "
log_dict = {}
for head, keys in element.items():
for k in keys:
if k in self.logs:
v = self.logs[k].avg
if not k in self.logs_for_save:
self.logs_for_save[k] = []
self.logs_for_save[k].append(v)
log_str += f'{k} {v}| '
log_dict[f'{head}/{k}'] = v
if self.write_textfile:
self.f.write(log_str+'\n')
self.f.flush()
if return_log_dict:
return log_dict
def save_log(self, name=None):
name = 'logs.pt' if name is None else name
torch.save(self.logs_for_save, os.path.join(self.log_dir, name))
def update(self, key, v, n=1):
if not key in self.logs:
self.logs[key] = AverageMeter()
self.logs[key].update(v, n)
def reset(self, keys=None, except_keys=[]):
if keys is not None:
if isinstance(keys, list):
for key in keys:
self.logs[key] = AverageMeter()
else:
self.logs[keys] = AverageMeter()
else:
for key in self.logs.keys():
if not key in except_keys:
self.logs[key] = AverageMeter()
def avg(self, keys=None, except_keys=[]):
if keys is not None:
if isinstance(keys, list):
return {key: self.logs[key].avg for key in keys if key in self.logs.keys()}
else:
return self.logs[keys].avg
else:
avg_dict = {}
for key in self.logs.keys():
if not key in except_keys:
avg_dict[key] = self.logs[key].avg
return avg_dict
class AverageMeter(object):
"""
Computes and stores the average and current value
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_metrics(g_embeds, x_embeds, logit_scale, prefix='train'):
metrics = {}
logits_per_g = (logit_scale * g_embeds @ x_embeds.t()).detach().cpu()
logits_per_x = logits_per_g.t().detach().cpu()
logits = {"g_to_x": logits_per_g, "x_to_g": logits_per_x}
ground_truth = torch.arange(len(x_embeds)).view(-1, 1)
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1]
preds = preds.detach().cpu().numpy()
metrics[f"{prefix}_{name}_mean_rank"] = preds.mean() + 1
metrics[f"{prefix}_{name}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{prefix}_{name}_R@{k}"] = np.mean(preds < k)
return metrics

View File

@@ -0,0 +1,63 @@
"""
@author: Hayeon Lee
2020/02/19
Script for downloading, and reorganizing aircraft
for few shot classification
Run this file as follows:
python get_data.py
"""
import pickle
import os
import numpy as np
from tqdm import tqdm
import requests
import tarfile
from PIL import Image
import glob
import shutil
import pickle
import collections
import sys
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
from all_path import RAW_DATA_PATH
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
dir_path = RAW_DATA_PATH
if not os.path.exists(dir_path):
os.makedirs(dir_path)
file_name = os.path.join(dir_path, 'fgvc-aircraft-2013b.tar.gz')
if not os.path.exists(file_name):
print(f"Downloading {file_name}\n")
download_file(
'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz',
file_name)
print("\nDownloading done.\n")
else:
print("fgvc-aircraft-2013b.tar.gz has already been downloaded. Did not download twice.\n")
untar_file_name = os.path.join(dir_path, 'aircraft')
if not os.path.exists(untar_file_name):
tarname = file_name
print("Untarring: {}".format(tarname))
tar = tarfile.open(tarname)
tar.extractall(untar_file_name)
tar.close()
else:
print(f"{untar_file_name} folder already exists. Did not untarring twice\n")
os.remove(file_name)

View File

@@ -0,0 +1,50 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
import sys
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
from all_path import RAW_DATA_PATH
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm(unit="B", total=int(r.headers['Content-Length']))
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update(len(chunk))
f.write(chunk)
return filename
dir_path = os.path.join(RAW_DATA_PATH, 'pets')
if not os.path.exists(dir_path):
os.makedirs(dir_path)
full_name = os.path.join(dir_path, 'test15.pth')
if not os.path.exists(full_name):
print(f"Downloading {full_name}\n")
download_file(
'https://www.dropbox.com/s/kzmrwyyk5iaugv0/test15.pth?dl=1', full_name)
print("Downloading done.\n")
else:
print(f"{full_name} has already been downloaded. Did not download twice.\n")
full_name = os.path.join(dir_path, 'train85.pth')
if not os.path.exists(full_name):
print(f"Downloading {full_name}\n")
download_file(
'https://www.dropbox.com/s/w7mikpztkamnw9s/train85.pth?dl=1', full_name)
print("Downloading done.\n")
else:
print(f"{full_name} has already been downloaded. Did not download twice.\n")

View File

@@ -0,0 +1,47 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
DATA_PATH = "./data/transfer_nag"
dir_path = DATA_PATH
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
def get_preprocessed_data(file_name, url):
print(f"Downloading {file_name} datasets\n")
full_name = os.path.join(dir_path, file_name)
download_file(url, full_name)
print("Downloading done.\n")
for file_name, url in [
('aircraftbylabel.pt', 'https://www.dropbox.com/s/mb66kitv20ykctp/aircraftbylabel.pt?dl=1'),
('cifar100bylabel.pt', 'https://www.dropbox.com/s/y0xahxgzj29kffk/cifar100bylabel.pt?dl=1'),
('cifar10bylabel.pt', 'https://www.dropbox.com/s/wt1pcwi991xyhwr/cifar10bylabel.pt?dl=1'),
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
('petsbylabel.pt', 'https://www.dropbox.com/s/mxh6qz3grhy7wcn/petsbylabel.pt?dl=1'),
]:
get_preprocessed_data(file_name, url)

View File

@@ -0,0 +1,130 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import os
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=True):
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
print(f'==> The number of tasks for meta-training: {len(dataset)}')
loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=collate_fn)
return loader
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
dataset = MetaTestDataset(data_path, data_name, num_class)
print(f'==> Meta-Test dataset {data_name}')
loader = DataLoader(dataset=dataset,
batch_size=100,
shuffle=False,
num_workers=0)
return loader
class MetaTrainDatabase(Dataset):
def __init__(self, data_path, num_sample, is_pred=True):
self.mode = 'train'
self.acc_norm = True
self.num_sample = num_sample
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
mtr_data_path = os.path.join(
data_path, 'meta_train_tasks_predictor.pt')
idx_path = os.path.join(
data_path, 'meta_train_tasks_predictor_idx.pt')
data = torch.load(mtr_data_path)
self.acc = data['acc']
self.task = data['task']
self.graph = data['g']
random_idx_lst = torch.load(idx_path)
self.idx_lst = {}
self.idx_lst['valid'] = random_idx_lst[:400]
self.idx_lst['train'] = random_idx_lst[400:]
self.acc = torch.tensor(self.acc)
self.mean = torch.mean(self.acc[self.idx_lst['train']]).item()
self.std = torch.std(self.acc[self.idx_lst['train']]).item()
self.task_lst = torch.load(os.path.join(
data_path, 'meta_train_task_lst.pt'))
def set_mode(self, mode):
self.mode = mode
def __len__(self):
return len(self.idx_lst[self.mode])
def __getitem__(self, index):
data = []
ridx = self.idx_lst[self.mode]
tidx = self.task[ridx[index]]
classes = self.task_lst[tidx]
graph = self.graph[ridx[index]]
acc = self.acc[ridx[index]]
for cls in classes:
cx = self.x[cls-1][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
if self.acc_norm:
acc = ((acc - self.mean) / self.std) / 100.0
else:
acc = acc / 100.0
return x, graph, acc
class MetaTestDataset(Dataset):
def __init__(self, data_path, data_name, num_sample, num_class=None):
self.num_sample = num_sample
self.data_name = data_name
num_class_dict = {
'cifar100': 100,
'cifar10': 10,
'mnist': 10,
'svhn': 10,
'aircraft': 30,
'pets': 37
}
if num_class is not None:
self.num_class = num_class
else:
self.num_class = num_class_dict[data_name]
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
def __len__(self):
return 1000000
def __getitem__(self, index):
data = []
classes = list(range(self.num_class))
for cls in classes:
cx = self.x[cls][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
return x
def collate_fn(batch):
x = torch.stack([item[0] for item in batch])
graph = [item[1] for item in batch]
acc = torch.stack([item[2] for item in batch])
return [x, graph, acc]

View File

@@ -0,0 +1,91 @@
import os
import sys
import random
import numpy as np
import argparse
import torch
import os
from nag import NAG
sys.path.append(os.getcwd())
save_path = "results"
data_path = os.path.join('MetaD2A_nas_bench_201', 'data')
def str2bool(v):
return v.lower() in ['t', 'true', True]
def get_parser():
parser = argparse.ArgumentParser()
# general settings
parser.add_argument('--seed', type=int, default=444)
parser.add_argument('--gpu', type=str, default='0', help='set visible gpus')
parser.add_argument('--save-path', type=str, default=save_path, help='the path of save directory')
parser.add_argument('--data-path', type=str, default=data_path, help='the path of save directory')
parser.add_argument('--model-load-path', type=str, default='', help='')
parser.add_argument('--save-epoch', type=int, default=20, help='how many epochs to wait each time to save model states')
parser.add_argument('--max-epoch', type=int, default=1000, help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=1024, help='batch size for generator')
parser.add_argument('--graph-data-name', default='nasbench201', help='graph dataset name')
parser.add_argument('--nvt', type=int, default=7, help='number of different node types, 7: NAS-Bench-201 including in/out node')
# set encoder
parser.add_argument('--num-sample', type=int, default=20, help='the number of images as input for set encoder')
# graph encoder
parser.add_argument('--hs', type=int, default=512, help='hidden size of GRUs')
parser.add_argument('--nz', type=int, default=56, help='the number of dimensions of latent vectors z')
# test
parser.add_argument('--test', action='store_true', default=True, help='turn on test mode')
parser.add_argument('--load-epoch', type=int, default=100, help='checkpoint epoch loaded for meta-test')
parser.add_argument('--data-name', type=str, default='pets', help='meta-test dataset name')
parser.add_argument('--trials', type=int, default=20)
parser.add_argument('--num-class', type=int, default=None, help='the number of class of dataset')
parser.add_argument('--num-gen-arch', type=int, default=500, help='the number of candidate architectures generated by the generator')
parser.add_argument('--train-arch', type=str2bool, default=True, help='whether to train the searched architecture')
parser.add_argument('--n_init', type=int, default=10)
parser.add_argument('--N', type=int, default=1)
# DiffusionNAG
parser.add_argument('--folder_name', type=str, default='debug')
parser.add_argument('--exp_name', type=str, default='')
parser.add_argument('--classifier_scale', type=float, default=10000., help='classifier scale')
parser.add_argument('--scale_step', type=float, default=300.)
parser.add_argument('--eval_batch_size', type=int, default=256)
parser.add_argument('--predictor', type=str, default='euler_maruyama', choices=['euler_maruyama', 'reverse_diffusion', 'none'])
parser.add_argument('--corrector', type=str, default='langevin', choices=['none', 'langevin'])
parser.add_argument('--patient_factor', type=int, default=20)
parser.add_argument('--n_gen_samples', type=int, default=10)
parser.add_argument('--multi_proc', type=str2bool, default=True)
args = parser.parse_args()
return args
def set_exp_name(args):
exp_name = f'./results/transfer_nag/{args.folder_name}/{args.data_name}'
os.makedirs(exp_name, exist_ok=True)
args.exp_name = exp_name
def main():
## Get arguments
args = get_parser()
## Set gpus and seeds
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)
random.seed(args.seed)
## Set experiment name
set_exp_name(args)
## Run
nag = NAG(args)
nag.meta_test()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,305 @@
from __future__ import print_function
import torch
import os
import gc
import sys
import numpy as np
import os
import subprocess
from nag_utils import mean_confidence_interval
from nag_utils import restore_checkpoint
from nag_utils import load_graph_config
from nag_utils import load_model
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
from nas_bench_201 import train_single_model
from unnoised_model import MetaSurrogateUnnoisedModel
from diffusion.run_lib import generate_archs_meta
from diffusion.run_lib import get_sampling_fn_meta
from diffusion.run_lib import get_score_model
from diffusion.run_lib import get_surrogate
from loader import MetaTestDataset
from logger import Logger
from all_path import *
class NAG:
def __init__(self, args):
self.args = args
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
## Target dataset information
self.raw_data_path = RAW_DATA_PATH
self.data_path = DATA_PATH
self.data_name = args.data_name
self.num_class = args.num_class
self.num_sample = args.num_sample
graph_config = load_graph_config(args.graph_data_name, args.nvt, NASBENCH201)
self.meta_surrogate_unnoised_model = MetaSurrogateUnnoisedModel(args, graph_config)
load_model(model=self.meta_surrogate_unnoised_model,
ckpt_path=META_SURROGATE_UNNOISED_CKPT_PATH)
self.meta_surrogate_unnoised_model.to(self.device)
## Load pre-trained meta-surrogate model
self.meta_surrogate_ckpt_path = META_SURROGATE_CKPT_PATH
## Load score network model (base diffusion model)
self.load_diffusion_model(args=args)
## Check config
self.check_config()
## Set logger
self.logger = Logger(
log_dir=args.exp_name,
write_textfile=True
)
self.logger.update_config(args, is_args=True)
self.logger.write_str(str(vars(args)))
self.logger.write_str('-' * 100)
def check_config(self):
"""
Check if the configuration of the pre-trained score network model matches that of the meta surrogate model.
"""
scorenet_config = torch.load(self.config.scorenet_ckpt_path)['config']
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
assert scorenet_config.model.sigma_min == meta_surrogate_config.model.sigma_min
assert scorenet_config.model.sigma_max == meta_surrogate_config.model.sigma_max
assert scorenet_config.training.sde == meta_surrogate_config.training.sde
assert scorenet_config.training.continuous == meta_surrogate_config.training.continuous
assert scorenet_config.data.centered == meta_surrogate_config.data.centered
assert scorenet_config.data.max_node == meta_surrogate_config.data.max_node
assert scorenet_config.data.n_vocab == meta_surrogate_config.data.n_vocab
def forward(self, x, arch):
D_mu = self.meta_surrogate_unnoised_model.set_encode(x.to(self.device))
G_mu = self.meta_surrogate_unnoised_model.graph_encode(arch)
y_pred = self.meta_surrogate_unnoised_model.predict(D_mu, G_mu)
return y_pred
def meta_test(self):
if self.data_name == 'all':
for data_name in ['cifar10', 'cifar100', 'aircraft', 'pets']:
self.meta_test_per_dataset(data_name)
else:
self.meta_test_per_dataset(self.data_name)
def meta_test_per_dataset(self, data_name):
## Load NASBench201
self.nasbench201 = torch.load(NASBENCH201)
all_arch_str = np.array(self.nasbench201['arch']['str'])
## Load meta-test dataset
self.test_dataset = MetaTestDataset(self.data_path, data_name, self.num_sample, self.num_class)
## Set save path
meta_test_path = os.path.join(META_TEST_PATH, data_name)
os.makedirs(meta_test_path, exist_ok=True)
f_arch_str = open(os.path.join(self.args.exp_name, 'architecture.txt'), 'w')
f_arch_acc = open(os.path.join(self.args.exp_name, 'accuracy.txt'), 'w')
## Generate architectures
gen_arch_str = self.get_gen_arch_str()
gen_arch_igraph = self.get_items(
full_target=self.nasbench201['arch']['igraph'],
full_source=self.nasbench201['arch']['str'],
source=gen_arch_str)
## Sort with unnoised meta-surrogate model
y_pred_all = []
self.meta_surrogate_unnoised_model.eval()
self.meta_surrogate_unnoised_model.to(self.device)
with torch.no_grad():
for arch_igraph in gen_arch_igraph:
x, g = self.collect_data(arch_igraph)
y_pred = self.forward(x, g)
y_pred = torch.mean(y_pred)
y_pred_all.append(y_pred.cpu().detach().item())
sorted_arch_lst = self.sort_arch(data_name, torch.tensor(y_pred_all), gen_arch_str)
## Record the information of the architecture generated in sorted order
for _, arch_str in enumerate(sorted_arch_lst):
f_arch_str.write(f'{arch_str}\n')
arch_idx_lst = [self.nasbench201['arch']['str'].index(i) for i in sorted_arch_lst]
arch_str_lst = []
arch_acc_lst = []
## Get the accuracy of the architecture
if 'cifar' in data_name:
sorted_acc_lst = self.get_items(
full_target=self.nasbench201['test-acc'][data_name],
full_source=self.nasbench201['arch']['str'],
source=sorted_arch_lst)
arch_str_lst += sorted_arch_lst
arch_acc_lst += sorted_acc_lst
for arch_idx, acc in zip(arch_idx_lst, sorted_acc_lst):
msg = f'Avg {acc:4f} (%)'
f_arch_acc.write(msg + '\n')
else:
if self.args.multi_proc:
## Run multiple processes in parallel
run_file = os.path.join(os.getcwd(), 'main_exp', 'transfer_nag', 'run_multi_proc.py')
MAX_CAP = 5 # hard-coded for available GPUs
if not len(arch_idx_lst) > MAX_CAP:
arch_idx_lst_ = [arch_idx for arch_idx in arch_idx_lst if not os.path.exists(os.path.join(meta_test_path, str(arch_idx)))]
support_ = ','.join([str(i) for i in arch_idx_lst_])
num_split = int(3 * len(arch_idx_lst_)) # why 3? => running for 3 seeds
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
subprocess.run([cmd], shell=True)
else:
arch_idx_lst_ = []
for j, arch_idx in enumerate(arch_idx_lst):
if not os.path.exists(os.path.join(meta_test_path, str(arch_idx))):
arch_idx_lst_.append(arch_idx)
if (len(arch_idx_lst_) == MAX_CAP) or (j == len(arch_idx_lst) - 1):
support_ = ','.join([str(i) for i in arch_idx_lst_])
num_split = int(3 * len(arch_idx_lst_))
cmd = f"python {run_file} --num_split {num_split} --arch_idx_lst {support_} --meta_test_path {meta_test_path} --data_name {data_name} --raw_data_path {self.raw_data_path}"
subprocess.run([cmd], shell=True)
arch_idx_lst_ = []
while True:
try:
acc_runs_lst = []
epoch = 199
seeds = (777, 888, 999)
for arch_idx in arch_idx_lst:
acc_runs = []
save_path_ = os.path.join(meta_test_path, str(arch_idx))
for seed in seeds:
result = torch.load(os.path.join(save_path_, f'seed-0{seed}.pth'))
acc_runs.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
acc_runs_lst.append(acc_runs)
break
except:
pass
for i in acc_runs_lst:print(np.mean(i))
for arch_idx, acc_runs in zip(arch_idx_lst, acc_runs_lst):
for r, acc in enumerate(acc_runs):
msg = f'run {r+1} {acc:.2f} (%)'
f_arch_acc.write(msg + '\n')
m, h = mean_confidence_interval(acc_runs)
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
f_arch_acc.write(msg + '\n')
arch_acc_lst.append(np.mean(acc_runs))
arch_str_lst.append(all_arch_str[arch_idx])
else:
for arch_idx in arch_idx_lst:
acc_runs = self.train_single_arch(
data_name, self.nasbench201['str'][arch_idx], meta_test_path)
for r, acc in enumerate(acc_runs):
msg = f'run {r+1} {acc:.2f} (%)'
f_arch_acc.write(msg + '\n')
m, h = mean_confidence_interval(acc_runs)
msg = f'Avg {m:.2f}+-{h.item():.2f} (%)'
f_arch_acc.write(msg + '\n')
arch_acc_lst.append(np.mean(acc_runs))
arch_str_lst.append(all_arch_str[arch_idx])
# Save results
results_path = os.path.join(self.args.exp_name, 'results.pt')
torch.save({
'arch_idx_lst': arch_idx_lst,
'arch_str_lst': arch_str_lst,
'arch_acc_lst': arch_acc_lst
}, results_path)
print(f">>> Save the results at {results_path}...")
def train_single_arch(self, data_name, arch_str, meta_test_path):
save_path = os.path.join(meta_test_path, arch_str)
seeds = (777, 888, 999)
train_single_model(save_dir=save_path,
workers=24,
datasets=[data_name],
xpaths=[f'{self.raw_data_path}/{data_name}'],
splits=[0],
use_less=False,
seeds=seeds,
model_str=arch_str,
arch_config={'channel': 16, 'num_cells': 5})
epoch = 199
test_acc_lst = []
for seed in seeds:
result = torch.load(os.path.join(save_path, f'seed-0{seed}.pth'))
test_acc_lst.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
return test_acc_lst
def sort_arch(self, data_name, y_pred_all, gen_arch_str):
_, sorted_idx = torch.sort(y_pred_all, descending=True)
sotred_gen_arch_str = [gen_arch_str[_] for _ in sorted_idx]
return sotred_gen_arch_str
def collect_data_only(self):
x_batch = []
x_batch.append(self.test_dataset[0])
return torch.stack(x_batch).to(self.device)
def collect_data(self, arch_igraph):
x_batch, g_batch = [], []
for _ in range(10):
x_batch.append(self.test_dataset[0])
g_batch.append(arch_igraph)
return torch.stack(x_batch).to(self.device), g_batch
def get_items(self, full_target, full_source, source):
return [full_target[full_source.index(_)] for _ in source]
def load_diffusion_model(self, args):
self.config = torch.load('./configs/transfer_nag_config.pt')
self.config.device = torch.device('cuda')
self.config.data.label_list = ['meta-acc']
self.config.scorenet_ckpt_path = SCORENET_CKPT_PATH
self.config.sampling.classifier_scale = args.classifier_scale
self.config.eval.batch_size = args.eval_batch_size
self.config.sampling.predictor = args.predictor
self.config.sampling.corrector = args.corrector
self.config.sampling.check_dataname = self.data_name
self.sampling_fn, self.sde = get_sampling_fn_meta(self.config)
self.score_model, self.score_ema, self.score_config = get_score_model(self.config)
def get_gen_arch_str(self):
## Load meta-surrogate model
meta_surrogate_config = torch.load(self.meta_surrogate_ckpt_path)['config']
meta_surrogate_model = get_surrogate(meta_surrogate_config)
meta_surrogate_state = dict(model=meta_surrogate_model, step=0, config=meta_surrogate_config)
meta_surrogate_state = restore_checkpoint(
self.meta_surrogate_ckpt_path,
meta_surrogate_state,
device=self.config.device,
resume=True)
## Get dataset embedding, x
with torch.no_grad():
x = self.collect_data_only()
## Generate architectures
generated_arch_str = generate_archs_meta(
config=self.config,
sampling_fn=self.sampling_fn,
score_model=self.score_model,
score_ema=self.score_ema,
meta_surrogate_model=meta_surrogate_model,
num_samples=self.args.n_gen_samples,
args=self.args,
task=x)
## Clean up
meta_surrogate_model = None
gc.collect()
return generated_arch_str

View File

@@ -0,0 +1,301 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import os
import time
import igraph
import random
import numpy as np
import scipy.stats
import torch
import logging
def reset_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def restore_checkpoint(ckpt_dir, state, device, resume=False):
if not resume:
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
return state
elif not os.path.exists(ckpt_dir):
if not os.path.exists(os.path.dirname(ckpt_dir)):
os.makedirs(os.path.dirname(ckpt_dir))
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
return state
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
for k in state:
if k in ['optimizer', 'model', 'ema']:
state[k].load_state_dict(loaded_state[k])
else:
state[k] = loaded_state[k]
return state
def load_graph_config(graph_data_name, nvt, data_path):
if graph_data_name is not 'nasbench201':
raise NotImplementedError(graph_data_name)
g_list = []
max_n = 0 # maximum number of nodes
ms = torch.load(data_path)['arch']['matrix']
for i in range(len(ms)):
g, n = decode_NAS_BENCH_201_8_to_igraph(ms[i])
max_n = max(max_n, n)
g_list.append((g, 0))
# number of different node types including in/out node
graph_config = {}
graph_config['num_vertex_type'] = nvt # original types + start/end types
graph_config['max_n'] = max_n # maximum number of nodes
graph_config['START_TYPE'] = 0 # predefined start vertex type
graph_config['END_TYPE'] = 1 # predefined end vertex type
return graph_config
def decode_NAS_BENCH_201_8_to_igraph(row):
if type(row) == str:
row = eval(row) # convert string to list of lists
n = len(row)
g = igraph.Graph(directed=True)
g.add_vertices(n)
for i, node in enumerate(row):
g.vs[i]['type'] = node[0]
if i < (n - 2) and i > 0:
g.add_edge(i, i + 1) # always connect from last node
for j, edge in enumerate(node[1:]):
if edge == 1:
g.add_edge(j, i)
return g, n
def is_valid_NAS201(g, START_TYPE=0, END_TYPE=1):
# first need to be a valid DAG computation graph
res = is_valid_DAG(g, START_TYPE, END_TYPE)
# in addition, node i must connect to node i+1
res = res and len(g.vs['type']) == 8
res = res and not (0 in g.vs['type'][1:-1])
res = res and not (1 in g.vs['type'][1:-1])
return res
def decode_igraph_to_NAS201_matrix(g):
m = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
xys = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]
for i, xy in enumerate(xys):
m[xy[0]][xy[1]] = float(g.vs[i + 1]['type']) - 2
import numpy
return numpy.array(m)
def decode_igraph_to_NAS_BENCH_201_string(g):
if not is_valid_NAS201(g):
return None
m = decode_igraph_to_NAS201_matrix(g)
types = ['none', 'skip_connect', 'nor_conv_1x1',
'nor_conv_3x3', 'avg_pool_3x3']
return '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
format(types[int(m[1][0])],
types[int(m[2][0])], types[int(m[2][1])],
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
res = g.is_dag()
n_start, n_end = 0, 0
for v in g.vs:
if v['type'] == START_TYPE:
n_start += 1
elif v['type'] == END_TYPE:
n_end += 1
if v.indegree() == 0 and v['type'] != START_TYPE:
return False
if v.outdegree() == 0 and v['type'] != END_TYPE:
return False
return res and n_start == 1 and n_end == 1
class Accumulator():
def __init__(self, *args):
self.args = args
self.argdict = {}
for i, arg in enumerate(args):
self.argdict[arg] = i
self.sums = [0] * len(args)
self.cnt = 0
def accum(self, val):
val = [val] if type(val) is not list else val
val = [v for v in val if v is not None]
assert (len(val) == len(self.args))
for i in range(len(val)):
if torch.is_tensor(val[i]):
val[i] = val[i].item()
self.sums[i] += val[i]
self.cnt += 1
def clear(self):
self.sums = [0] * len(self.args)
self.cnt = 0
def get(self, arg, avg=True):
i = self.argdict.get(arg, -1)
assert (i is not -1)
if avg:
return self.sums[i] / (self.cnt + 1e-8)
else:
return self.sums[i]
def print_(self, header=None, time=None,
logfile=None, do_not_print=[], as_int=[],
avg=True):
msg = '' if header is None else header + ': '
if time is not None:
msg += ('(%.3f secs), ' % time)
args = [arg for arg in self.args if arg not in do_not_print]
arg = []
for arg in args:
val = self.sums[self.argdict[arg]]
if avg:
val /= (self.cnt + 1e-8)
if arg in as_int:
msg += ('%s %d, ' % (arg, int(val)))
else:
msg += ('%s %.4f, ' % (arg, val))
print(msg)
if logfile is not None:
logfile.write(msg + '\n')
logfile.flush()
def add_scalars(self, summary, header=None, tag_scalar=None,
step=None, avg=True, args=None):
for arg in self.args:
val = self.sums[self.argdict[arg]]
if avg:
val /= (self.cnt + 1e-8)
else:
val = val
tag = f'{header}/{arg}' if header is not None else arg
if tag_scalar is not None:
summary.add_scalars(main_tag=tag,
tag_scalar_dict={tag_scalar: val},
global_step=step)
else:
summary.add_scalar(tag=tag,
scalar_value=val,
global_step=step)
class Log:
def __init__(self, args, logf, summary=None):
self.args = args
self.logf = logf
self.summary = summary
self.stime = time.time()
self.ep_sttime = None
def print(self, logger, epoch, tag=None, avg=True):
if tag == 'train':
ct = time.time() - self.ep_sttime
tt = time.time() - self.stime
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
print(msg)
self.logf.write(msg+'\n')
logger.print_(header=tag, logfile=self.logf, avg=avg)
if self.summary is not None:
logger.add_scalars(
self.summary, header=tag, step=epoch, avg=avg)
logger.clear()
def print_args(self):
argdict = vars(self.args)
print(argdict)
for k, v in argdict.items():
self.logf.write(k + ': ' + str(v) + '\n')
self.logf.write('\n')
def set_time(self):
self.stime = time.time()
def save_time_log(self):
ct = time.time() - self.stime
msg = f'({ct:6.2f}s) meta-training phase done'
print(msg)
self.logf.write(msg+'\n')
def print_pred_log(self, loss, corr, tag, epoch=None, max_corr_dict=None):
if tag == 'train':
ct = time.time() - self.ep_sttime
tt = time.time() - self.stime
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
self.logf.write(msg+'\n')
print(msg)
self.logf.flush()
# msg = f'ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} '
# msg += f'time {time.time() - sttime:6.2f} '
if max_corr_dict is not None:
max_corr = max_corr_dict['corr']
max_loss = max_corr_dict['loss']
msg = f'{tag}: loss {loss:.6f} ({max_loss:.6f}) '
msg += f'corr {corr:.4f} ({max_corr:.4f})'
else:
msg = f'{tag}: loss {loss:.6f} corr {corr:.4f}'
self.logf.write(msg+'\n')
print(msg)
self.logf.flush()
def max_corr_log(self, max_corr_dict):
corr = max_corr_dict['corr']
loss = max_corr_dict['loss']
epoch = max_corr_dict['epoch']
msg = f'[epoch {epoch}] max correlation: {corr:.4f}, loss: {loss:.6f}'
self.logf.write(msg+'\n')
print(msg)
self.logf.flush()
def get_log(epoch, loss, y_pred, y, acc_std, acc_mean, tag='train'):
msg = f'[{tag}] Ep {epoch} loss {loss.item()/len(y):0.4f} '
if type(y_pred) == list:
msg += f'pacc {y_pred[0]:0.4f}'
msg += f'({y_pred[0]*100.0*acc_std+acc_mean:0.4f}) '
else:
msg += f'pacc {y_pred:0.4f}'
msg += f'({y_pred*100.0*acc_std+acc_mean:0.4f}) '
msg += f'acc {y[0]:0.4f}({y[0]*100*acc_std+acc_mean:0.4f})'
return msg
def load_model(model, ckpt_path):
model.cpu()
model.load_state_dict(torch.load(ckpt_path))
def save_model(epoch, model, model_path, max_corr=None):
print("==> save current model...")
if max_corr is not None:
torch.save(model.cpu().state_dict(),
os.path.join(model_path, 'ckpt_max_corr.pt'))
else:
torch.save(model.cpu().state_dict(),
os.path.join(model_path, f'ckpt_{epoch}.pt'))
def mean_confidence_interval(data, confidence=0.95):
a = 1.0 * np.array(data)
n = len(a)
m, se = np.mean(a), scipy.stats.sem(a)
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
return m, h

View File

@@ -0,0 +1,6 @@
from pathlib import Path
import sys
dir_path = (Path(__file__).parent).resolve()
if str(dir_path) not in sys.path: sys.path.insert(0, str(dir_path))
from .architecture import train_single_model

View File

@@ -0,0 +1,173 @@
###############################################################
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
###############################################################
from functions import evaluate_for_seed
from nas_bench_201_models import CellStructure, CellArchitectures, get_search_spaces
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
from nas_bench_201_datasets import get_datasets
from procedures import get_machine_info
from procedures import save_checkpoint, copy_checkpoint
from config_utils import load_config
from pathlib import Path
from copy import deepcopy
import os
import sys
import time
import torch
import random
import argparse
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
NASBENCH201_CONFIG_PATH = os.path.join(
os.getcwd(), 'main_exp', 'transfer_nag')
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed,
arch_config, workers, logger):
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
all_infos = {'info': machine_info}
all_dataset_keys = []
# look all the datasets
for dataset, xpath, split in zip(datasets, xpaths, splits):
# train valid data
task = None
train_data, valid_data, xshape, class_num = get_datasets(
dataset, xpath, -1, task)
# load the configuration
if dataset in ['mnist', 'svhn', 'aircraft', 'pets']:
if use_less:
config_path = os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/LESS.config')
else:
config_path = os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}.config'.format(dataset))
p = os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(dataset))
if not os.path.exists(p):
import json
label_list = list(range(len(train_data)))
random.shuffle(label_list)
strlist = [str(label_list[i]) for i in range(len(label_list))]
splited = {'train': ["int", strlist[:len(train_data) // 2]],
'valid': ["int", strlist[len(train_data) // 2:]]}
with open(p, 'w') as f:
f.write(json.dumps(splited))
split_info = load_config(os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(dataset)), None, None)
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
config = load_config(
config_path, {'class_num': class_num, 'xshape': xshape}, logger)
# data loader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size,
shuffle=True, num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
shuffle=False, num_workers=workers, pin_memory=True)
splits = load_config(os.path.join(
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}-test-split.txt'.format(dataset)), None, None)
ValLoaders = {'ori-test': valid_loader,
'x-valid': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
splits.xvalid),
num_workers=workers, pin_memory=True),
'x-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
splits.xtest),
num_workers=workers, pin_memory=True)
}
dataset_key = '{:}'.format(dataset)
if bool(split):
dataset_key = dataset_key + '-valid'
logger.log(
'Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.
format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(
dataset_key, config))
for key, value in ValLoaders.items():
logger.log(
'Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
results = evaluate_for_seed(
arch_config, config, arch, train_loader, ValLoaders, seed, logger)
all_infos[dataset_key] = results
all_dataset_keys.append(dataset_key)
all_infos['all_dataset_keys'] = all_dataset_keys
return all_infos
def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less,
seeds, model_str, arch_config):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.set_num_threads(workers)
save_dir = Path(save_dir)
logger = Logger(str(save_dir), 0, False)
if model_str in CellArchitectures:
arch = CellArchitectures[model_str]
logger.log(
'The model string is found in pre-defined architecture dict : {:}'.format(model_str))
else:
try:
arch = CellStructure.str2structure(model_str)
except:
raise ValueError(
'Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
assert arch.check_valid_op(get_search_spaces(
'cell', 'nas-bench-201')), '{:} has the invalid op.'.format(arch)
# assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
logger.log('Start train-evaluate {:}'.format(arch.tostr()))
logger.log('arch_config : {:}'.format(arch_config))
start_time, seed_time = time.time(), AverageMeter()
for _is, seed in enumerate(seeds):
logger.log(
'\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds),
seed))
to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
if to_save_name.exists():
logger.log(
'Find the existing file {:}, directly load!'.format(to_save_name))
checkpoint = torch.load(to_save_name)
else:
logger.log(
'Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less,
seed, arch_config, workers, logger)
torch.save(checkpoint, to_save_name)
# log information
logger.log('{:}'.format(checkpoint['info']))
all_dataset_keys = checkpoint['all_dataset_keys']
for dataset_key in all_dataset_keys:
logger.log('\n{:} dataset : {:} {:}'.format(
'-' * 15, dataset_key, '-' * 15))
dataset_info = checkpoint[dataset_key]
# logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
logger.log('Flops = {:} MB, Params = {:} MB'.format(
dataset_info['flop'], dataset_info['param']))
logger.log('config : {:}'.format(dataset_info['config']))
logger.log('Training State (finish) = {:}'.format(
dataset_info['finish-train']))
last_epoch = dataset_info['total_epoch'] - 1
train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
# measure elapsed time
seed_time.update(time.time() - start_time)
start_time = time.time()
need_time = 'Time Left: {:}'.format(convert_secs2time(
seed_time.avg * (len(seeds) - _is - 1), True))
logger.log(
'\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed,
need_time))
logger.close()

View File

@@ -0,0 +1,13 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .configure_utils import load_config, dict2config#, configure2str
#from .basic_args import obtain_basic_args
#from .attention_args import obtain_attention_args
#from .random_baseline import obtain_RandomSearch_args
#from .cls_kd_args import obtain_cls_kd_args
#from .cls_init_args import obtain_cls_init_args
#from .search_single_args import obtain_search_single_args
#from .search_args import obtain_search_args
# for network pruning
#from .pruning_args import obtain_pruning_args

View File

@@ -0,0 +1,106 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os, json
from os import path as osp
from pathlib import Path
from collections import namedtuple
support_types = ('str', 'int', 'bool', 'float', 'none')
def convert_param(original_lists):
assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists)
ctype, value = original_lists[0], original_lists[1]
assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types)
is_list = isinstance(value, list)
if not is_list: value = [value]
outs = []
for x in value:
if ctype == 'int':
x = int(x)
elif ctype == 'str':
x = str(x)
elif ctype == 'bool':
x = bool(int(x))
elif ctype == 'float':
x = float(x)
elif ctype == 'none':
if x.lower() != 'none':
raise ValueError('For the none type, the value must be none instead of {:}'.format(x))
x = None
else:
raise TypeError('Does not know this type : {:}'.format(ctype))
outs.append(x)
if not is_list: outs = outs[0]
return outs
def load_config(path, extra, logger):
path = str(path)
if hasattr(logger, 'log'): logger.log(path)
assert os.path.exists(path), 'Can not find {:}'.format(path)
# Reading data back
with open(path, 'r') as f:
data = json.load(f)
content = { k: convert_param(v) for k,v in data.items()}
assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra)
if isinstance(extra, dict): content = {**content, **extra}
Arguments = namedtuple('Configure', ' '.join(content.keys()))
content = Arguments(**content)
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
return content
def configure2str(config, xpath=None):
if not isinstance(config, dict):
config = config._asdict()
def cstring(x):
return "\"{:}\"".format(x)
def gtype(x):
if isinstance(x, list): x = x[0]
if isinstance(x, str) : return 'str'
elif isinstance(x, bool) : return 'bool'
elif isinstance(x, int): return 'int'
elif isinstance(x, float): return 'float'
elif x is None : return 'none'
else: raise ValueError('invalid : {:}'.format(x))
def cvalue(x, xtype):
if isinstance(x, list): is_list = True
else:
is_list, x = False, [x]
temps = []
for temp in x:
if xtype == 'bool' : temp = cstring(int(temp))
elif xtype == 'none': temp = cstring('None')
else : temp = cstring(temp)
temps.append( temp )
if is_list:
return "[{:}]".format( ', '.join( temps ) )
else:
return temps[0]
xstrings = []
for key, value in config.items():
xtype = gtype(value)
string = ' {:20s} : [{:8s}, {:}]'.format(cstring(key), cstring(xtype), cvalue(value, xtype))
xstrings.append(string)
Fstring = '{\n' + ',\n'.join(xstrings) + '\n}'
if xpath is not None:
parent = Path(xpath).resolve().parent
parent.mkdir(parents=True, exist_ok=True)
if osp.isfile(xpath): os.remove(xpath)
with open(xpath, "w") as text_file:
text_file.write('{:}'.format(Fstring))
return Fstring
def dict2config(xdict, logger):
assert isinstance(xdict, dict), 'invalid type : {:}'.format( type(xdict) )
Arguments = namedtuple('Configure', ' '.join(xdict.keys()))
content = Arguments(**xdict)
if hasattr(logger, 'log'): logger.log('{:}'.format(content))
return content

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "200"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "50"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "200"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
{
"scheduler": ["str", "cos"],
"eta_min" : ["float", "0.0"],
"epochs" : ["int", "200"],
"warmup" : ["int", "0"],
"optim" : ["str", "SGD"],
"LR" : ["float", "0.1"],
"decay" : ["float", "0.0005"],
"momentum" : ["float", "0.9"],
"nesterov" : ["bool", "1"],
"criterion": ["str", "Softmax"],
"batch_size": ["int", "256"]
}

View File

@@ -0,0 +1,153 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
import time
import torch
from procedures import prepare_seed, get_optim_scheduler
from nasbench_utils import get_model_infos, obtain_accuracy
from config_utils import dict2config
from log_utils import AverageMeter, time_string, convert_secs2time
from nas_bench_201_models import get_cell_based_tiny_net
__all__ = ['evaluate_for_seed', 'pure_evaluate']
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
latencies = []
network.eval()
with torch.no_grad():
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
targets = targets.cuda(non_blocking=True)
inputs = inputs.cuda(non_blocking=True)
data_time.update(time.time() - end)
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
batch_time.update(time.time() - end)
if batch is None or batch == inputs.size(0):
batch = inputs.size(0)
latencies.append(batch_time.val - data_time.val)
# record loss and accuracy
prec1, prec5 = obtain_accuracy(
logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
end = time.time()
if len(latencies) > 2:
latencies = latencies[1:]
return losses.avg, top1.avg, top5.avg, latencies
def procedure(xloader, network, criterion, scheduler, optimizer, mode):
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
if mode == 'train':
network.train()
elif mode == 'valid':
network.eval()
else:
raise ValueError("The mode is not right : {:}".format(mode))
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
for i, (inputs, targets) in enumerate(xloader):
if mode == 'train':
scheduler.update(None, 1.0 * i / len(xloader))
targets = targets.cuda(non_blocking=True)
if mode == 'train':
optimizer.zero_grad()
# forward
features, logits = network(inputs)
loss = criterion(logits, targets)
# backward
if mode == 'train':
loss.backward()
optimizer.step()
# record loss and accuracy
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
# count time
batch_time.update(time.time() - end)
end = time.time()
return losses.avg, top1.avg, top5.avg, batch_time.sum
def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger):
prepare_seed(seed) # random seed
net = get_cell_based_tiny_net(dict2config({'name': 'infer.tiny',
'C': arch_config['channel'], 'N': arch_config['num_cells'],
'genotype': arch, 'num_classes': config.class_num}, None)
)
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
if 'ckpt_path' in arch_config.keys():
ckpt = torch.load(arch_config['ckpt_path'])
ckpt['classifier.weight'] = net.state_dict()['classifier.weight']
ckpt['classifier.bias'] = net.state_dict()['classifier.bias']
net.load_state_dict(ckpt)
flop, param = get_model_infos(net, config.xshape)
logger.log('Network : {:}'.format(net.get_message()), False)
logger.log(
'{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed))
logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
# train and valid
optimizer, scheduler, criterion = get_optim_scheduler(
net.parameters(), config)
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
# network, criterion = torch.nn.DataParallel(net).to(torch.device(f"cuda:{device}")), criterion.to(torch.device(f"cuda:{device}"))
# start training
start_time, epoch_time, total_epoch = time.time(
), AverageMeter(), config.epochs + config.warmup
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {
}, {}, {}, {}, {}, {}
train_times, valid_times = {}, {}
for epoch in range(total_epoch):
scheduler.update(epoch, 0.0)
train_loss, train_acc1, train_acc5, train_tm = procedure(
train_loader, network, criterion, scheduler, optimizer, 'train')
train_losses[epoch] = train_loss
train_acc1es[epoch] = train_acc1
train_acc5es[epoch] = train_acc5
train_times[epoch] = train_tm
with torch.no_grad():
for key, xloder in valid_loaders.items():
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(
xloder, network, criterion, None, None, 'valid')
valid_losses['{:}@{:}'.format(key, epoch)] = valid_loss
valid_acc1es['{:}@{:}'.format(key, epoch)] = valid_acc1
valid_acc5es['{:}@{:}'.format(key, epoch)] = valid_acc5
valid_times['{:}@{:}'.format(key, epoch)] = valid_tm
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = 'Time Left: {:}'.format(convert_secs2time(
epoch_time.avg * (total_epoch-epoch-1), True))
logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]'.format(
time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5))
info_seed = {'flop': flop,
'param': param,
'channel': arch_config['channel'],
'num_cells': arch_config['num_cells'],
'config': config._asdict(),
'total_epoch': total_epoch,
'train_losses': train_losses,
'train_acc1es': train_acc1es,
'train_acc5es': train_acc5es,
'train_times': train_times,
'valid_losses': valid_losses,
'valid_acc1es': valid_acc1es,
'valid_acc5es': valid_acc5es,
'valid_times': valid_times,
'net_state_dict': net.state_dict(),
'net_string': '{:}'.format(net),
'finish-train': True
}
return info_seed

View File

@@ -0,0 +1,9 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# every package does not rely on pytorch or tensorflow
# I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib
from .logger import Logger#, PrintLogger
from .meter import AverageMeter
from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_secs2time
from .time_utils import time_string, convert_secs2time

View File

@@ -0,0 +1,150 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from pathlib import Path
import importlib, warnings
import os, sys, time, numpy as np
if sys.version_info.major == 2: # Python 2.x
from StringIO import StringIO as BIO
else: # Python 3.x
from io import BytesIO as BIO
if importlib.util.find_spec('tensorflow'):
import tensorflow as tf
class PrintLogger(object):
def __init__(self):
"""Create a summary writer logging to log_dir."""
self.name = 'PrintLogger'
def log(self, string):
print (string)
def close(self):
print ('-'*30 + ' close printer ' + '-'*30)
class Logger(object):
def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False):
"""Create a summary writer logging to log_dir."""
self.seed = int(seed)
self.log_dir = Path(log_dir)
self.model_dir = Path(log_dir) / 'checkpoint'
self.log_dir.mkdir (parents=True, exist_ok=True)
if create_model_dir:
self.model_dir.mkdir(parents=True, exist_ok=True)
#self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
self.use_tf = bool(use_tf)
self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h', time.gmtime(time.time()) )))
#self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) )))
self.logger_path = self.log_dir / 'seed-{:}-T-{:}.log'.format(self.seed, time.strftime('%d-%h-at-%H-%M-%S', time.gmtime(time.time())))
self.logger_file = open(self.logger_path, 'w')
if self.use_tf:
self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
self.writer = tf.summary.FileWriter(str(self.tensorboard_dir))
else:
self.writer = None
def __repr__(self):
return ('{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})'.format(name=self.__class__.__name__, **self.__dict__))
def path(self, mode):
valids = ('model', 'best', 'info', 'log')
if mode == 'model': return self.model_dir / 'seed-{:}-basic.pth'.format(self.seed)
elif mode == 'best' : return self.model_dir / 'seed-{:}-best.pth'.format(self.seed)
elif mode == 'info' : return self.log_dir / 'seed-{:}-last-info.pth'.format(self.seed)
elif mode == 'log' : return self.log_dir
else: raise TypeError('Unknow mode = {:}, valid modes = {:}'.format(mode, valids))
def extract_log(self):
return self.logger_file
def close(self):
self.logger_file.close()
if self.writer is not None:
self.writer.close()
def log(self, string, save=True, stdout=False):
if stdout:
sys.stdout.write(string); sys.stdout.flush()
else:
print (string)
if save:
self.logger_file.write('{:}\n'.format(string))
self.logger_file.flush()
def scalar_summary(self, tags, values, step):
"""Log a scalar variable."""
if not self.use_tf:
warnings.warn('Do set use-tensorflow installed but call scalar_summary')
else:
assert isinstance(tags, list) == isinstance(values, list), 'Type : {:} vs {:}'.format(type(tags), type(values))
if not isinstance(tags, list):
tags, values = [tags], [values]
for tag, value in zip(tags, values):
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
self.writer.flush()
def image_summary(self, tag, images, step):
"""Log a list of images."""
import scipy
if not self.use_tf:
warnings.warn('Do set use-tensorflow installed but call scalar_summary')
return
img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(img).save(s, format="png")
# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1])
# Create a Summary value
img_summaries.append(tf.Summary.Value(tag='{}/{}'.format(tag, i), image=img_sum))
# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)
self.writer.flush()
def histo_summary(self, tag, values, step, bins=1000):
"""Log a histogram of the tensor of values."""
if not self.use_tf: raise ValueError('Do not have tensorflow')
import tensorflow as tf
# Create a histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)
# Fill the fields of the histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values**2))
# Drop the start of the first bin
bin_edges = bin_edges[1:]
# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)
# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
self.writer.flush()

View File

@@ -0,0 +1,98 @@
import numpy as np
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0.0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __repr__(self):
return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__))
class RecorderMeter(object):
"""Computes and stores the minimum loss value and its epoch index"""
def __init__(self, total_epoch):
self.reset(total_epoch)
def reset(self, total_epoch):
assert total_epoch > 0, 'total_epoch should be greater than 0 vs {:}'.format(total_epoch)
self.total_epoch = total_epoch
self.current_epoch = 0
self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
self.epoch_losses = self.epoch_losses - 1
self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
self.epoch_accuracy= self.epoch_accuracy
def update(self, idx, train_loss, train_acc, val_loss, val_acc):
assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx)
self.epoch_losses [idx, 0] = train_loss
self.epoch_losses [idx, 1] = val_loss
self.epoch_accuracy[idx, 0] = train_acc
self.epoch_accuracy[idx, 1] = val_acc
self.current_epoch = idx + 1
return self.max_accuracy(False) == self.epoch_accuracy[idx, 1]
def max_accuracy(self, istrain):
if self.current_epoch <= 0: return 0
if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max()
else: return self.epoch_accuracy[:self.current_epoch, 1].max()
def plot_curve(self, save_path):
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
title = 'the accuracy/loss curve of train/val'
dpi = 100
width, height = 1600, 1000
legend_fontsize = 10
figsize = width / float(dpi), height / float(dpi)
fig = plt.figure(figsize=figsize)
x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
y_axis = np.zeros(self.total_epoch)
plt.xlim(0, self.total_epoch)
plt.ylim(0, 100)
interval_y = 5
interval_x = 5
plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
plt.yticks(np.arange(0, 100 + interval_y, interval_y))
plt.grid()
plt.title(title, fontsize=20)
plt.xlabel('the training epoch', fontsize=16)
plt.ylabel('accuracy', fontsize=16)
y_axis[:] = self.epoch_accuracy[:, 0]
plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
y_axis[:] = self.epoch_accuracy[:, 1]
plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
y_axis[:] = self.epoch_losses[:, 0]
plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
y_axis[:] = self.epoch_losses[:, 1]
plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2)
plt.legend(loc=4, fontsize=legend_fontsize)
if save_path is not None:
fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
print ('---- save figure {} into {}'.format(title, save_path))
plt.close(fig)

View File

@@ -0,0 +1,42 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import time, sys
import numpy as np
def time_for_file():
ISOTIMEFORMAT='%d-%h-at-%H-%M-%S'
return '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
def time_string():
ISOTIMEFORMAT='%Y-%m-%d %X'
string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def time_string_short():
ISOTIMEFORMAT='%Y%m%d'
string = '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string
def time_print(string, is_print=True):
if (is_print):
print('{} : {}'.format(time_string(), string))
def convert_secs2time(epoch_time, return_str=False):
need_hour = int(epoch_time / 3600)
need_mins = int((epoch_time - 3600*need_hour) / 60)
need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
if return_str:
str = '[{:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
return str
else:
return need_hour, need_mins, need_secs
def print_log(print_string, log):
#if isinstance(log, Logger): log.log('{:}'.format(print_string))
if hasattr(log, 'log'): log.log('{:}'.format(print_string))
else:
print("{:}".format(print_string))
if log is not None:
log.write('{:}\n'.format(print_string))
log.flush()

View File

@@ -0,0 +1,4 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .get_dataset_with_transform import get_datasets

View File

@@ -0,0 +1,179 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import torch.utils.data as data
from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader
from PIL import Image
import os
import numpy as np
def make_dataset(dir, image_ids, targets):
assert (len(image_ids) == len(targets))
images = []
dir = os.path.expanduser(dir)
for i in range(len(image_ids)):
item = (os.path.join(dir, 'data', 'images',
'%s.jpg' % image_ids[i]), targets[i])
images.append(item)
return images
def find_classes(classes_file):
# read classes file, separating out image IDs and class names
image_ids = []
targets = []
f = open(classes_file, 'r')
for line in f:
split_line = line.split(' ')
image_ids.append(split_line[0])
targets.append(' '.join(split_line[1:]))
f.close()
# index class names
classes = np.unique(targets)
class_to_idx = {classes[i]: i for i in range(len(classes))}
targets = [class_to_idx[c] for c in targets]
return (image_ids, targets, classes, class_to_idx)
class FGVCAircraft(data.Dataset):
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset.
Args:
root (string): Root directory path to dataset.
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g. ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in the root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
class_types = ('variant', 'family', 'manufacturer')
splits = ('train', 'val', 'trainval', 'test')
def __init__(self, root, class_type='variant', split='train', transform=None,
target_transform=None, loader=default_loader, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
if class_type not in self.class_types:
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
class_type, ', '.join(self.class_types),
))
self.root = os.path.expanduser(root)
self.root = os.path.join(self.root, 'fgvc-aircraft-2013b')
self.class_type = class_type
self.split = split
self.classes_file = os.path.join(self.root, 'data',
'images_%s_%s.txt' % (self.class_type, self.split))
if download:
self.download()
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
samples = make_dataset(self.root, image_ids, targets)
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.samples = samples
self.classes = classes
self.class_to_idx = class_to_idx
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def _check_exists(self):
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
os.path.exists(self.classes_file)
def download(self):
"""Download the FGVC-Aircraft data if it doesn't exist already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
print('Downloading %s ... (may take a few minutes)' % self.url)
parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
tar_name = self.url.rpartition('/')[-1]
tar_path = os.path.join(parent_dir, tar_name)
data = urllib.request.urlopen(self.url)
# download .tar.gz file
with open(tar_path, 'wb') as f:
f.write(data.read())
# extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
data_folder = tar_path.strip('.tar.gz')
print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
tar = tarfile.open(tar_path)
tar.extractall(parent_dir)
# if necessary, rename data folder to self.root
if not os.path.samefile(data_folder, self.root):
print('Renaming %s to %s ...' % (data_folder, self.root))
os.rename(data_folder, self.root)
# delete .tar.gz file
print('Deleting %s ...' % tar_path)
os.remove(tar_path)
print('Done!')
if __name__ == '__main__':
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='train', transform=None,
target_transform=None, loader=default_loader, download=False)
print(len(air))
print(len(air))
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='val', transform=None,
target_transform=None, loader=default_loader, download=False)
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b', class_type='manufacturer', split='trainval', transform=None,
target_transform=None, loader=default_loader, download=False)
print(len(air))
air = FGVCAircraft('/w14/dataset/fgvc-aircraft-2013b/', class_type='manufacturer', split='test', transform=None,
target_transform=None, loader=default_loader, download=False)
print(len(air))
import pdb;
pdb.set_trace()
print(len(air))

View File

@@ -0,0 +1,304 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
# Modified by Hayeon Lee, Eunyoung Hyung 2021. 03.
##################################################
import os
import sys
import torch
import os.path as osp
import numpy as np
import torchvision.datasets as dset
import torchvision.transforms as transforms
from copy import deepcopy
# from PIL import Image
import random
import pdb
from .aircraft import FGVCAircraft
from .pets import PetDataset
from config_utils import load_config
Dataset2Class = {'cifar10': 10,
'cifar100': 100,
'mnist': 10,
'svhn': 10,
'aircraft': 30,
'pets': 37}
class CUTOUT(object):
def __init__(self, length):
self.length = length
def __repr__(self):
return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
imagenet_pca = {
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
'eigvec': np.asarray([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
class Lighting(object):
def __init__(self, alphastd,
eigval=imagenet_pca['eigval'],
eigvec=imagenet_pca['eigvec']):
self.alphastd = alphastd
assert eigval.shape == (3,)
assert eigvec.shape == (3, 3)
self.eigval = eigval
self.eigvec = eigvec
def __call__(self, img):
if self.alphastd == 0.:
return img
rnd = np.random.randn(3) * self.alphastd
rnd = rnd.astype('float32')
v = rnd
old_dtype = np.asarray(img).dtype
v = v * self.eigval
v = v.reshape((3, 1))
inc = np.dot(self.eigvec, v).reshape((3,))
img = np.add(img, inc)
if old_dtype == np.uint8:
img = np.clip(img, 0, 255)
img = Image.fromarray(img.astype(old_dtype), 'RGB')
return img
def __repr__(self):
return self.__class__.__name__ + '()'
def get_datasets(name, root, cutout, use_num_cls=None):
if name == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif name == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif name.startswith('mnist'):
mean, std = [0.1307, 0.1307, 0.1307], [0.3081, 0.3081, 0.3081]
elif name.startswith('svhn'):
mean, std = [0.4376821, 0.4437697, 0.47280442], [
0.19803012, 0.20101562, 0.19703614]
elif name.startswith('aircraft'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
elif name.startswith('pets'):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
else:
raise TypeError("Unknow dataset : {:}".format(name))
# Data Argumentation
if name == 'cifar10' or name == 'cifar100':
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
transforms.Normalize(mean, std)]
if cutout > 0:
lists += [CUTOUT(cutout)]
train_transform = transforms.Compose(lists)
test_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean, std)])
xshape = (1, 3, 32, 32)
elif name.startswith('cub200'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
xshape = (1, 3, 32, 32)
elif name.startswith('mnist'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean, std),
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean, std)
])
xshape = (1, 3, 32, 32)
elif name.startswith('svhn'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
xshape = (1, 3, 32, 32)
elif name.startswith('aircraft'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
xshape = (1, 3, 32, 32)
elif name.startswith('pets'):
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
xshape = (1, 3, 32, 32)
else:
raise TypeError("Unknow dataset : {:}".format(name))
if name == 'cifar10':
train_data = dset.CIFAR10(
root, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR10(
root, train=False, transform=test_transform, download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'cifar100':
train_data = dset.CIFAR100(
root, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR100(
root, train=False, transform=test_transform, download=True)
assert len(train_data) == 50000 and len(test_data) == 10000
elif name == 'mnist':
train_data = dset.MNIST(
root, train=True, transform=train_transform, download=True)
test_data = dset.MNIST(
root, train=False, transform=test_transform, download=True)
assert len(train_data) == 60000 and len(test_data) == 10000
elif name == 'svhn':
train_data = dset.SVHN(root, split='train',
transform=train_transform, download=True)
test_data = dset.SVHN(root, split='test',
transform=test_transform, download=True)
assert len(train_data) == 73257 and len(test_data) == 26032
elif name == 'aircraft':
train_data = FGVCAircraft(root, class_type='manufacturer', split='trainval',
transform=train_transform, download=False)
test_data = FGVCAircraft(root, class_type='manufacturer', split='test',
transform=test_transform, download=False)
assert len(train_data) == 6667 and len(test_data) == 3333
elif name == 'pets':
train_data = PetDataset(root, train=True, num_cl=37,
val_split=0.15, transforms=train_transform)
test_data = PetDataset(root, train=False, num_cl=37,
val_split=0.15, transforms=test_transform)
else:
raise TypeError("Unknow dataset : {:}".format(name))
class_num = Dataset2Class[name] if use_num_cls is None else len(
use_num_cls)
return train_data, test_data, xshape, class_num
def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, num_cls=None):
if isinstance(batch_size, (list, tuple)):
batch, test_batch = batch_size
else:
batch, test_batch = batch_size, batch_size
if dataset == 'cifar10':
# split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
cifar_split = load_config(
'{:}/cifar-split.txt'.format(config_root), None, None)
# search over the proposed training and validation set
train_split, valid_split = cifar_split.train, cifar_split.valid
# logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
# To split data
xvalid_data = deepcopy(train_data)
if hasattr(xvalid_data, 'transforms'): # to avoid a print issue
xvalid_data.transforms = valid_data.transform
xvalid_data.transform = deepcopy(valid_data.transform)
search_data = SearchDataset(
dataset, train_data, train_split, valid_split)
# data loader
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
train_split),
num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
valid_split),
num_workers=workers, pin_memory=True)
elif dataset == 'cifar100':
cifar100_test_split = load_config(
'{:}/cifar100-test-split.txt'.format(config_root), None, None)
search_train_data = train_data
search_valid_data = deepcopy(valid_data)
search_valid_data.transform = train_data.transform
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
list(range(len(search_train_data))),
cifar100_test_split.xvalid)
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True, num_workers=workers,
pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True, num_workers=workers,
pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
cifar100_test_split.xvalid), num_workers=workers, pin_memory=True)
elif dataset in ['mnist', 'svhn', 'aircraft', 'pets']:
if not os.path.exists('{:}/{}-test-split.txt'.format(config_root, dataset)):
import json
label_list = list(range(len(valid_data)))
random.shuffle(label_list)
strlist = [str(label_list[i]) for i in range(len(label_list))]
split = {'xvalid': ["int", strlist[:len(valid_data) // 2]],
'xtest': ["int", strlist[len(valid_data) // 2:]]}
with open('{:}/{}-test-split.txt'.format(config_root, dataset), 'w') as f:
f.write(json.dumps(split))
test_split = load_config(
'{:}/{}-test-split.txt'.format(config_root, dataset), None, None)
search_train_data = train_data
search_valid_data = deepcopy(valid_data)
search_valid_data.transform = train_data.transform
search_data = SearchDataset(dataset, [search_train_data, search_valid_data],
list(range(len(search_train_data))), test_split.xvalid)
search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True,
num_workers=workers, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True,
num_workers=workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=test_batch,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
test_split.xvalid), num_workers=workers, pin_memory=True)
else:
raise ValueError('invalid dataset : {:}'.format(dataset))
return search_loader, train_loader, valid_loader

View File

@@ -0,0 +1,45 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import torch
from glob import glob
from torch.utils.data.dataset import Dataset
import os
from PIL import Image
def load_image(filename):
img = Image.open(filename)
img = img.convert('RGB')
return img
class PetDataset(Dataset):
def __init__(self, root, train=True, num_cl=37, val_split=0.2, transforms=None):
self.data = torch.load(os.path.join(root,'{}{}.pth'.format('train' if train else 'test',
int(100*(1-val_split)) if train else int(100*val_split))))
self.len = len(self.data)
self.transform = transforms
def __getitem__(self, index):
img, label = self.data[index]
if self.transform:
img = self.transform(img)
return img, label
def __len__(self):
return self.len
if __name__ == '__main__':
# Added
import torchvision.transforms as transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose(
[transforms.Resize(256), transforms.RandomRotation(45), transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
test_transform = transforms.Compose(
[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
root = '/w14/dataset/MetaGen/pets'
train_data, test_data = get_pets(root, num_cl=37, val_split=0.2,
tr_transform=train_transform,
te_transform=test_transform)
import pdb;
pdb.set_trace()

View File

@@ -0,0 +1,34 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
def additive_func(A, B):
assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size())
C = min(A.size(1), B.size(1))
if A.size(1) == B.size(1):
return A + B
elif A.size(1) < B.size(1):
out = B.clone()
out[:,:C] += A
return out
else:
out = A.clone()
out[:,:C] += B
return out
def change_key(key, value):
def func(m):
if hasattr(m, key):
setattr(m, key, value)
return func
def parse_channel_info(xstring):
blocks = xstring.split(' ')
blocks = [x.split('-') for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks]
return blocks

View File

@@ -0,0 +1,45 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from os import path as osp
from typing import List, Text
import torch
__all__ = ['get_cell_based_tiny_net', 'get_search_spaces', \
'CellStructure', 'CellArchitectures'
]
# useful modules
from config_utils import dict2config
from .SharedUtils import change_key
from .cell_searchs import CellStructure, CellArchitectures
# Cell-based NAS Models
def get_cell_based_tiny_net(config):
if config.name == 'infer.tiny':
from .cell_infers import TinyNetwork
if hasattr(config, 'genotype'):
genotype = config.genotype
elif hasattr(config, 'arch_str'):
genotype = CellStructure.str2structure(config.arch_str)
else: raise ValueError('Can not find genotype from this config : {:}'.format(config))
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
else:
raise ValueError('invalid network name : {:}'.format(config.name))
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
def get_search_spaces(xtype, name) -> List[Text]:
if xtype == 'cell' or xtype == 'tss': # The topology search space.
from .cell_operations import SearchSpaceNames
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
return SearchSpaceNames[name]
elif xtype == 'sss': # The size search space.
if name == 'nas-bench-301':
return {'candidates': [8, 16, 24, 32, 40, 48, 56, 64],
'numbers': 5}
else:
raise ValueError('Invalid name : {:}'.format(name))
else:
raise ValueError('invalid search-space type is {:}'.format(xtype))

View File

@@ -0,0 +1,4 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from .tiny_network import TinyNetwork

View File

@@ -0,0 +1,122 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import OPS
# Cell for NAS-Bench-201
class InferCell(nn.Module):
def __init__(self, genotype, C_in, C_out, stride):
super(InferCell, self).__init__()
self.layers = nn.ModuleList()
self.node_IN = []
self.node_IX = []
self.genotype = deepcopy(genotype)
for i in range(1, len(genotype)):
node_info = genotype[i-1]
cur_index = []
cur_innod = []
for (op_name, op_in) in node_info:
if op_in == 0:
layer = OPS[op_name](C_in , C_out, stride, True, True)
else:
layer = OPS[op_name](C_out, C_out, 1, True, True)
# import pdb; pdb.set_trace()
cur_index.append( len(self.layers) )
cur_innod.append( op_in )
self.layers.append( layer )
self.node_IX.append( cur_index )
self.node_IN.append( cur_innod )
self.nodes = len(genotype)
self.in_dim = C_in
self.out_dim = C_out
def extra_repr(self):
string = 'info :: nodes={nodes}, inC={in_dim}, outC={out_dim}'.format(**self.__dict__)
laystr = []
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
y = ['I{:}-L{:}'.format(_ii, _il) for _il, _ii in zip(node_layers, node_innods)]
x = '{:}<-({:})'.format(i+1, ','.join(y))
laystr.append( x )
return string + ', [{:}]'.format( ' | '.join(laystr) ) + ', {:}'.format(self.genotype.tostr())
def forward(self, inputs):
nodes = [inputs]
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX,self.node_IN)):
node_feature = sum( self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) )
nodes.append( node_feature )
return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetInferCell(nn.Module):
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev, affine, track_running_stats):
super(NASNetInferCell, self).__init__()
self.reduction = reduction
if reduction_prev: self.preprocess0 = OPS['skip_connect'](C_prev_prev, C, 2, affine, track_running_stats)
else : self.preprocess0 = OPS['nor_conv_1x1'](C_prev_prev, C, 1, affine, track_running_stats)
self.preprocess1 = OPS['nor_conv_1x1'](C_prev, C, 1, affine, track_running_stats)
if not reduction:
nodes, concats = genotype['normal'], genotype['normal_concat']
else:
nodes, concats = genotype['reduce'], genotype['reduce_concat']
self._multiplier = len(concats)
self._concats = concats
self._steps = len(nodes)
self._nodes = nodes
self.edges = nn.ModuleDict()
for i, node in enumerate(nodes):
for in_node in node:
name, j = in_node[0], in_node[1]
stride = 2 if reduction and j < 2 else 1
node_str = '{:}<-{:}'.format(i+2, j)
self.edges[node_str] = OPS[name](C, C, stride, affine, track_running_stats)
# [TODO] to support drop_prob in this function..
def forward(self, s0, s1, unused_drop_prob):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
for i, node in enumerate(self._nodes):
clist = []
for in_node in node:
name, j = in_node[0], in_node[1]
node_str = '{:}<-{:}'.format(i+2, j)
op = self.edges[ node_str ]
clist.append( op(states[j]) )
states.append( sum(clist) )
return torch.cat([states[x] for x in self._concats], dim=1)
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, num_classes):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0),-1))
return x

View File

@@ -0,0 +1,66 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
from ..cell_operations import ResNetBasicblock
from .cells import InferCell
# The macro structure for architectures in NAS-Bench-201
class TinyNetwork(nn.Module):
def __init__(self, C, N, genotype, num_classes):
super(TinyNetwork, self).__init__()
self._C = C
self._layerN = N
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C))
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C
self.cells = nn.ModuleList()
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
else:
cell = InferCell(genotype, C_prev, C_curr, 1)
self.cells.append( cell )
C_prev = cell.out_dim
self._Layer= len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, num_classes)
def get_message(self):
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_C}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
'''
out2 = self.lastact(feature)
out = self.global_pooling( out2 )
out = out.view(out.size(0), -1)
out2 = out2.view(out2.size(0), -1)
logits = self.classifier(out)
return out2, logits
'''
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,308 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import torch.nn as nn
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
OPS = {
'none' : lambda C_in, C_out, stride, affine, track_running_stats: Zero(C_in, C_out, stride),
'avg_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'avg', affine, track_running_stats),
'max_pool_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: POOLING(C_in, C_out, stride, 'max', affine, track_running_stats),
'nor_conv_7x7' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine, track_running_stats),
'nor_conv_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
'nor_conv_1x1' : lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine, track_running_stats),
'dua_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine, track_running_stats),
'dua_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(C_in, C_out, (5,5), (stride,stride), (2,2), (1,1), affine, track_running_stats),
'dil_sepc_3x3' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (3,3), (stride,stride), (2,2), (2,2), affine, track_running_stats),
'dil_sepc_5x5' : lambda C_in, C_out, stride, affine, track_running_stats: SepConv(C_in, C_out, (5,5), (stride,stride), (4,4), (2,2), affine, track_running_stats),
'skip_connect' : lambda C_in, C_out, stride, affine, track_running_stats: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
}
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
NAS_BENCH_201 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
DARTS_SPACE = ['none', 'skip_connect', 'dua_sepc_3x3', 'dua_sepc_5x5', 'dil_sepc_3x3', 'dil_sepc_5x5', 'avg_pool_3x3', 'max_pool_3x3']
SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
'nas-bench-201': NAS_BENCH_201,
'nas-bench-301': NAS_BENCH_201,
'darts' : DARTS_SPACE}
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=not affine),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine),
nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats),
)
def forward(self, x):
return self.op(x)
class DualSepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine, track_running_stats=True):
super(DualSepConv, self).__init__()
self.op_a = SepConv(C_in, C_in , kernel_size, stride, padding, dilation, affine, track_running_stats)
self.op_b = SepConv(C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats)
def forward(self, x):
x = self.op_a(x)
x = self.op_b(x)
return x
class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride, affine=True):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine)
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def extra_repr(self):
string = '{name}(inC={in_dim}, outC={out_dim}, stride={stride})'.format(name=self.__class__.__name__, **self.__dict__)
return string
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
return residual + basicblock
class POOLING(nn.Module):
def __init__(self, C_in, C_out, stride, mode, affine=True, track_running_stats=True):
super(POOLING, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1, affine, track_running_stats)
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
def forward(self, inputs):
if self.preprocess: x = self.preprocess(inputs)
else : x = inputs
return self.op(x)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True
def forward(self, x):
if self.C_in == self.C_out:
if self.stride == 1: return x.mul(0.)
else : return x[:,:,::self.stride,::self.stride].mul(0.)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine))
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
elif stride == 1:
self.conv = nn.Conv2d(C_in, C_out, 1, stride=stride, padding=0, bias=False)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats)
def forward(self, x):
if self.stride == 2:
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
else:
out = self.conv(x)
out = self.bn(out)
return out
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
# Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019
class PartAwareOp(nn.Module):
def __init__(self, C_in, C_out, stride, part=4):
super().__init__()
self.part = 4
self.hidden = C_in // 3
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.local_conv_list = nn.ModuleList()
for i in range(self.part):
self.local_conv_list.append(
nn.Sequential(nn.ReLU(), nn.Conv2d(C_in, self.hidden, 1), nn.BatchNorm2d(self.hidden, affine=True))
)
self.W_K = nn.Linear(self.hidden, self.hidden)
self.W_Q = nn.Linear(self.hidden, self.hidden)
if stride == 2 : self.last = FactorizedReduce(C_in + self.hidden, C_out, 2)
elif stride == 1: self.last = FactorizedReduce(C_in + self.hidden, C_out, 1)
else: raise ValueError('Invalid Stride : {:}'.format(stride))
def forward(self, x):
batch, C, H, W = x.size()
assert H >= self.part, 'input size too small : {:} vs {:}'.format(x.shape, self.part)
IHs = [0]
for i in range(self.part): IHs.append( min(H, int((i+1)*(float(H)/self.part))) )
local_feat_list = []
for i in range(self.part):
feature = x[:, :, IHs[i]:IHs[i+1], :]
xfeax = self.avg_pool(feature)
xfea = self.local_conv_list[i]( xfeax )
local_feat_list.append( xfea )
part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part)
part_feature = part_feature.transpose(1,2).contiguous()
part_K = self.W_K(part_feature)
part_Q = self.W_Q(part_feature).transpose(1,2).contiguous()
weight_att = torch.bmm(part_K, part_Q)
attention = torch.softmax(weight_att, dim=2)
aggreateF = torch.bmm(attention, part_feature).transpose(1,2).contiguous()
features = []
for i in range(self.part):
feature = aggreateF[:, :, i:i+1].expand(batch, self.hidden, IHs[i+1]-IHs[i])
feature = feature.view(batch, self.hidden, IHs[i+1]-IHs[i], 1)
features.append( feature )
features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W)
final_fea = torch.cat((x,features), dim=1)
outputs = self.last( final_fea )
return outputs
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1. - drop_prob
mask = x.new_zeros(x.size(0), 1, 1, 1)
mask = mask.bernoulli_(keep_prob)
x = torch.div(x, keep_prob)
x.mul_(mask)
return x
# Searching for A Robust Neural Architecture in Four GPU Hours
class GDAS_Reduction_Cell(nn.Module):
def __init__(self, C_prev_prev, C_prev, C, reduction_prev, multiplier, affine, track_running_stats):
super(GDAS_Reduction_Cell, self).__init__()
if reduction_prev:
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine, track_running_stats)
else:
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats)
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, 1, affine, track_running_stats)
self.multiplier = multiplier
self.reduction = True
self.ops1 = nn.ModuleList(
[nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
nn.BatchNorm2d(C, affine=True),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C, affine=True)),
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, 2), padding=(0, 1), groups=8, bias=False),
nn.Conv2d(C, C, (3, 1), stride=(2, 1), padding=(1, 0), groups=8, bias=False),
nn.BatchNorm2d(C, affine=True),
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C, affine=True))])
self.ops2 = nn.ModuleList(
[nn.Sequential(
nn.MaxPool2d(3, stride=1, padding=1),
nn.BatchNorm2d(C, affine=True)),
nn.Sequential(
nn.MaxPool2d(3, stride=2, padding=1),
nn.BatchNorm2d(C, affine=True))])
def forward(self, s0, s1, drop_prob = -1):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
X0 = self.ops1[0] (s0)
X1 = self.ops1[1] (s1)
if self.training and drop_prob > 0.:
X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)
#X2 = self.ops2[0] (X0+X1)
X2 = self.ops2[0] (s0)
X3 = self.ops2[1] (s1)
if self.training and drop_prob > 0.:
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
return torch.cat([X0, X1, X2, X3], dim=1)

View File

@@ -0,0 +1,26 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
# The macro structure is defined in NAS-Bench-201
# from .search_model_darts import TinyNetworkDarts
# from .search_model_gdas import TinyNetworkGDAS
# from .search_model_setn import TinyNetworkSETN
# from .search_model_enas import TinyNetworkENAS
# from .search_model_random import TinyNetworkRANDOM
# from .generic_model import GenericNAS201Model
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
# NASNet-based macro structure
# from .search_model_gdas_nasnet import NASNetworkGDAS
# from .search_model_darts_nasnet import NASNetworkDARTS
# nas201_super_nets = {'DARTS-V1': TinyNetworkDarts,
# "DARTS-V2": TinyNetworkDarts,
# "GDAS": TinyNetworkGDAS,
# "SETN": TinyNetworkSETN,
# "ENAS": TinyNetworkENAS,
# "RANDOM": TinyNetworkRANDOM,
# "generic": GenericNAS201Model}
# nasnet_super_nets = {"GDAS": NASNetworkGDAS,
# "DARTS": NASNetworkDARTS}

View File

@@ -0,0 +1,198 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from copy import deepcopy
def get_combination(space, num):
combs = []
for i in range(num):
if i == 0:
for func in space:
combs.append( [(func, i)] )
else:
new_combs = []
for string in combs:
for func in space:
xstring = string + [(func, i)]
new_combs.append( xstring )
combs = new_combs
return combs
class Structure:
def __init__(self, genotype):
assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype))
self.node_num = len(genotype) + 1
self.nodes = []
self.node_N = []
for idx, node_info in enumerate(genotype):
assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info))
assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info))
for node_in in node_info:
assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in))
assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in)
self.node_N.append( len(node_info) )
self.nodes.append( tuple(deepcopy(node_info)) )
def tolist(self, remove_str):
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
# note that we re-order the input node in this function
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
genotypes = []
for node_info in self.nodes:
node_info = list( node_info )
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
if len(node_info) == 0: return None, False
genotypes.append( node_info )
return genotypes, True
def node(self, index):
assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self))
return self.nodes[index]
def tostr(self):
strings = []
for node_info in self.nodes:
string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info])
string = '|{:}|'.format(string)
strings.append( string )
return '+'.join(strings)
def check_valid(self):
nodes = {0: True}
for i, node_info in enumerate(self.nodes):
sums = []
for op, xin in node_info:
if op == 'none' or nodes[xin] is False: x = False
else: x = True
sums.append( x )
nodes[i+1] = sum(sums) > 0
return nodes[len(self.nodes)]
def to_unique_str(self, consider_zero=False):
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
# two operations are special, i.e., none and skip_connect
nodes = {0: '0'}
for i_node, node_info in enumerate(self.nodes):
cur_node = []
for op, xin in node_info:
if consider_zero is None:
x = '('+nodes[xin]+')' + '@{:}'.format(op)
elif consider_zero:
if op == 'none' or nodes[xin] == '#': x = '#' # zero
elif op == 'skip_connect': x = nodes[xin]
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
else:
if op == 'skip_connect': x = nodes[xin]
else: x = '('+nodes[xin]+')' + '@{:}'.format(op)
cur_node.append(x)
nodes[i_node+1] = '+'.join( sorted(cur_node) )
return nodes[ len(self.nodes) ]
def check_valid_op(self, op_names):
for node_info in self.nodes:
for inode_edge in node_info:
#assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
if inode_edge[0] not in op_names: return False
return True
def __repr__(self):
return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__))
def __len__(self):
return len(self.nodes) + 1
def __getitem__(self, index):
return self.nodes[index]
@staticmethod
def str2structure(xstr):
if isinstance(xstr, Structure): return xstr
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs)
genotypes.append( input_infos )
return Structure( genotypes )
@staticmethod
def str2fullstructure(xstr, default_name='none'):
assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr))
nodestrs = xstr.split('+')
genotypes = []
for i, node_str in enumerate(nodestrs):
inputs = list(filter(lambda x: x != '', node_str.split('|')))
for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput)
inputs = ( xi.split('~') for xi in inputs )
input_infos = list( (op, int(IDX)) for (op, IDX) in inputs)
all_in_nodes= list(x[1] for x in input_infos)
for j in range(i):
if j not in all_in_nodes: input_infos.append((default_name, j))
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
genotypes.append( tuple(node_info) )
return Structure( genotypes )
@staticmethod
def gen_all(search_space, num, return_ori):
assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space))
assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num)
all_archs = get_combination(search_space, 1)
for i, arch in enumerate(all_archs):
all_archs[i] = [ tuple(arch) ]
for inode in range(2, num):
cur_nodes = get_combination(search_space, inode)
new_all_archs = []
for previous_arch in all_archs:
for cur_node in cur_nodes:
new_all_archs.append( previous_arch + [tuple(cur_node)] )
all_archs = new_all_archs
if return_ori:
return all_archs
else:
return [Structure(x) for x in all_archs]
ResNet_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1
(('nor_conv_3x3', 1), ), # node-2
(('skip_connect', 0), ('skip_connect', 2))] # node-3
)
AllConv3x3_CODE = Structure(
[(('nor_conv_3x3', 0), ), # node-1
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2
(('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3
)
AllFull_CODE = Structure(
[(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2
(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3
)
AllConv1x1_CODE = Structure(
[(('nor_conv_1x1', 0), ), # node-1
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2
(('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3
)
AllIdentity_CODE = Structure(
[(('skip_connect', 0), ), # node-1
(('skip_connect', 0), ('skip_connect', 1)), # node-2
(('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3
)
architectures = {'resnet' : ResNet_CODE,
'all_c3x3': AllConv3x3_CODE,
'all_c1x1': AllConv1x1_CODE,
'all_idnt': AllIdentity_CODE,
'all_full': AllFull_CODE}

View File

@@ -0,0 +1,167 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferCifarResNet(nn.Module):
def __init__(self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual):
super(InferCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.xchannels = xchannels
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL+1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,150 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU( planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, inplanes, planes, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU( planes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
elif inplanes != planes*self.expansion:
self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
else:
self.downsample = None
self.out_dim = planes*self.expansion
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferDepthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual):
super(InferDepthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == 3, 'invalid xblocks : {:}'.format(xblocks)
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.channels = [16]
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, planes, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
break
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.channels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,160 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=False, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=False, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferWidthCifarResNet(nn.Module):
def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual):
super(InferWidthCifarResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'ResNetBasicblock':
block = ResNetBasicblock
assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
layer_blocks = (depth - 2) // 6
elif block_name == 'ResNetBottleneck':
block = ResNetBottleneck
assert (depth - 2) % 9 == 0, 'depth should be one of 164'
layer_blocks = (depth - 2) // 9
else:
raise ValueError('invalid block : {:}'.format(block_name))
self.message = 'InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.xchannels = xchannels
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
for stage in range(3):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,170 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch.nn as nn
import torch.nn.functional as F
from ..initialization import initialize_resnet
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
if has_bn : self.bn = nn.BatchNorm2d(nOut)
else : self.bn = None
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
def forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.bn : out = self.bn( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
return out
class ResNetBasicblock(nn.Module):
num_conv = 2
expansion = 1
def __init__(self, iCs, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 3,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_a = ConvBNReLU(iCs[0], iCs[1], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_b = ConvBNReLU(iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=True, has_bn=True, has_relu=False)
residual_in = iCs[2]
elif iCs[0] != iCs[2]:
self.downsample = ConvBNReLU(iCs[0], iCs[2], 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[2])
self.out_dim = iCs[2]
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + basicblock
return F.relu(out, inplace=True)
class ResNetBottleneck(nn.Module):
expansion = 4
num_conv = 3
def __init__(self, iCs, stride):
super(ResNetBottleneck, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
assert isinstance(iCs, tuple) or isinstance(iCs, list), 'invalid type of iCs : {:}'.format( iCs )
assert len(iCs) == 4,'invalid lengths of iCs : {:}'.format(iCs)
self.conv_1x1 = ConvBNReLU(iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_3x3 = ConvBNReLU(iCs[1], iCs[2], 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
self.conv_1x4 = ConvBNReLU(iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[0]
if stride == 2:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=True , has_bn=True, has_relu=False)
residual_in = iCs[3]
elif iCs[0] != iCs[3]:
self.downsample = ConvBNReLU(iCs[0], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False)
residual_in = iCs[3]
else:
self.downsample = None
#self.out_dim = max(residual_in, iCs[3])
self.out_dim = iCs[3]
def forward(self, inputs):
bottleneck = self.conv_1x1(inputs)
bottleneck = self.conv_3x3(bottleneck)
bottleneck = self.conv_1x4(bottleneck)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
out = residual + bottleneck
return F.relu(out, inplace=True)
class InferImagenetResNet(nn.Module):
def __init__(self, block_name, layers, xblocks, xchannels, deep_stem, num_classes, zero_init_residual):
super(InferImagenetResNet, self).__init__()
#Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
if block_name == 'BasicBlock':
block = ResNetBasicblock
elif block_name == 'Bottleneck':
block = ResNetBottleneck
else:
raise ValueError('invalid block : {:}'.format(block_name))
assert len(xblocks) == len(layers), 'invalid layers : {:} vs xblocks : {:}'.format(layers, xblocks)
self.message = 'InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}'.format(sum(layers)*block.num_conv, sum(xblocks)*block.num_conv, xblocks)
self.num_classes = num_classes
self.xchannels = xchannels
if not deep_stem:
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 7, 2, 3, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 1
else:
self.layers = nn.ModuleList( [ ConvBNReLU(xchannels[0], xchannels[1], 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True)
,ConvBNReLU(xchannels[1], xchannels[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
last_channel_idx = 2
self.layers.append( nn.MaxPool2d(kernel_size=3, stride=2, padding=1) )
for stage, layer_blocks in enumerate(layers):
for iL in range(layer_blocks):
num_conv = block.num_conv
iCs = self.xchannels[last_channel_idx:last_channel_idx+num_conv+1]
stride = 2 if stage > 0 and iL == 0 else 1
module = block(iCs, stride)
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iCs, module.out_dim, stride)
if iL + 1 == xblocks[stage]: # reach the maximum depth
out_channel = module.out_dim
for iiL in range(iL+1, layer_blocks):
last_channel_idx += num_conv
self.xchannels[last_channel_idx] = module.out_dim
break
assert last_channel_idx + 1 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
self.apply(initialize_resnet)
if zero_init_residual:
for m in self.modules():
if isinstance(m, ResNetBasicblock):
nn.init.constant_(m.conv_b.bn.weight, 0)
elif isinstance(m, ResNetBottleneck):
nn.init.constant_(m.conv_1x4.bn.weight, 0)
def get_message(self):
return self.message
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@@ -0,0 +1,122 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
from torch import nn
from ..initialization import initialize_resnet
from ..SharedUtils import parse_channel_info
class ConvBNReLU(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, groups, has_bn=True, has_relu=True):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False)
if has_bn: self.bn = nn.BatchNorm2d(out_planes)
else : self.bn = None
if has_relu: self.relu = nn.ReLU6(inplace=True)
else : self.relu = None
def forward(self, x):
out = self.conv( x )
if self.bn: out = self.bn ( out )
if self.relu: out = self.relu( out )
return out
class InvertedResidual(nn.Module):
def __init__(self, channels, stride, expand_ratio, additive):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2], 'invalid stride : {:}'.format(stride)
assert len(channels) in [2, 3], 'invalid channels : {:}'.format(channels)
if len(channels) == 2:
layers = []
else:
layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)]
layers.extend([
# dw
ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]),
# pw-linear
ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False),
])
self.conv = nn.Sequential(*layers)
self.additive = additive
if self.additive and channels[0] != channels[-1]:
self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False)
else:
self.shortcut = None
self.out_dim = channels[-1]
def forward(self, x):
out = self.conv(x)
# if self.additive: return additive_func(out, x)
if self.shortcut: return out + self.shortcut(x)
else : return out
class InferMobileNetV2(nn.Module):
def __init__(self, num_classes, xchannels, xblocks, dropout):
super(InferMobileNetV2, self).__init__()
block = InvertedResidual
inverted_residual_setting = [
# t, c, n, s
[1, 16 , 1, 1],
[6, 24 , 2, 2],
[6, 32 , 3, 2],
[6, 64 , 4, 2],
[6, 96 , 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
assert len(inverted_residual_setting) == len(xblocks), 'invalid number of layers : {:} vs {:}'.format(len(inverted_residual_setting), len(xblocks))
for block_num, ir_setting in zip(xblocks, inverted_residual_setting):
assert block_num <= ir_setting[2], '{:} vs {:}'.format(block_num, ir_setting)
xchannels = parse_channel_info(xchannels)
#for i, chs in enumerate(xchannels):
# if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs)
self.xchannels = xchannels
self.message = 'InferMobileNetV2 : xblocks={:}'.format(xblocks)
# building first layer
features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)]
last_channel_idx = 1
# building inverted residual blocks
for stage, (t, c, n, s) in enumerate(inverted_residual_setting):
for i in range(n):
stride = s if i == 0 else 1
additv = True if i > 0 else False
module = block(self.xchannels[last_channel_idx], stride, t, additv)
features.append(module)
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(stage, i, n, len(features), self.xchannels[last_channel_idx], stride, t, c)
last_channel_idx += 1
if i + 1 == xblocks[stage]:
out_channel = module.out_dim
for iiL in range(i+1, n):
last_channel_idx += 1
self.xchannels[last_channel_idx][0] = module.out_dim
break
# building last several layers
features.append(ConvBNReLU(self.xchannels[last_channel_idx][0], self.xchannels[last_channel_idx][1], 1, 1, 1))
assert last_channel_idx + 2 == len(self.xchannels), '{:} vs {:}'.format(last_channel_idx, len(self.xchannels))
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(self.xchannels[last_channel_idx][1], num_classes),
)
# weight initialization
self.apply( initialize_resnet )
def get_message(self):
return self.message
def forward(self, inputs):
features = self.features(inputs)
vectors = features.mean([2, 3])
predicts = self.classifier(vectors)
return features, predicts

View File

@@ -0,0 +1,58 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from typing import List, Text, Any
import torch.nn as nn
from models.cell_operations import ResNetBasicblock
from models.cell_infers.cells import InferCell
class DynamicShapeTinyNet(nn.Module):
def __init__(self, channels: List[int], genotype: Any, num_classes: int):
super(DynamicShapeTinyNet, self).__init__()
self._channels = channels
if len(channels) % 3 != 2:
raise ValueError('invalid number of layers : {:}'.format(len(channels)))
self._num_stage = N = len(channels) // 3
self.stem = nn.Sequential(
nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(channels[0]))
# layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
c_prev = channels[0]
self.cells = nn.ModuleList()
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)):
if reduction : cell = ResNetBasicblock(c_prev, c_curr, 2, True)
else : cell = InferCell(genotype, c_prev, c_curr, 1)
self.cells.append( cell )
c_prev = cell.out_dim
self._num_layer = len(self.cells)
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True))
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(c_prev, num_classes)
def get_message(self) -> Text:
string = self.extra_repr()
for i, cell in enumerate(self.cells):
string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr())
return string
def extra_repr(self):
return ('{name}(C={_channels}, N={_num_stage}, L={_num_layer})'.format(name=self.__class__.__name__, **self.__dict__))
def forward(self, inputs):
feature = self.stem(inputs)
for i, cell in enumerate(self.cells):
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling( out )
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return out, logits

View File

@@ -0,0 +1,9 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
from .InferCifarResNet_width import InferWidthCifarResNet
from .InferImagenetResNet import InferImagenetResNet
from .InferCifarResNet_depth import InferDepthCifarResNet
from .InferCifarResNet import InferCifarResNet
from .InferMobileNetV2 import InferMobileNetV2
from .InferTinyCellNet import DynamicShapeTinyNet

View File

@@ -0,0 +1,5 @@
def parse_channel_info(xstring):
blocks = xstring.split(' ')
blocks = [x.split('-') for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks]
return blocks

View File

@@ -0,0 +1,2 @@
from .evaluation_utils import obtain_accuracy
from .flop_benchmark import get_model_infos

View File

@@ -0,0 +1,17 @@
import torch
def obtain_accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
# correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

View File

@@ -0,0 +1,181 @@
import torch
import torch.nn as nn
import numpy as np
def count_parameters_in_MB(model):
if isinstance(model, nn.Module):
return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
else:
return np.sum(np.prod(v.size()) for v in model)/1e6
def get_model_infos(model, shape):
#model = copy.deepcopy( model )
model = add_flops_counting_methods(model)
#model = model.cuda()
model.eval()
#cache_inputs = torch.zeros(*shape).cuda()
#cache_inputs = torch.zeros(*shape)
cache_inputs = torch.rand(*shape)
if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda()
#print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
with torch.no_grad():
_____ = model(cache_inputs)
FLOPs = compute_average_flops_cost( model ) / 1e6
Param = count_parameters_in_MB(model)
if hasattr(model, 'auxiliary_param'):
aux_params = count_parameters_in_MB(model.auxiliary_param())
print ('The auxiliary params of this model is : {:}'.format(aux_params))
print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param))
Param = Param - aux_params
#print_log('FLOPs : {:} MB'.format(FLOPs), log)
torch.cuda.empty_cache()
model.apply( remove_hook_function )
return FLOPs, Param
# ---- Public functions
def add_flops_counting_methods( model ):
model.__batch_counter__ = 0
add_batch_counter_hook_function( model )
model.apply( add_flops_counter_variable_or_reset )
model.apply( add_flops_counter_hook_function )
return model
def compute_average_flops_cost(model):
"""
A method that will be available after add_flops_counting_methods() is called on a desired net object.
Returns current mean flops consumption per image.
"""
batches_count = model.__batch_counter__
flops_sum = 0
#or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
for module in model.modules():
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
or isinstance(module, torch.nn.Conv1d) \
or hasattr(module, 'calculate_flop_self'):
flops_sum += module.__flops__
return flops_sum / batches_count
# ---- Internal functions
def pool_flops_counter_hook(pool_module, inputs, output):
batch_size = inputs[0].size(0)
kernel_size = pool_module.kernel_size
out_C, output_height, output_width = output.shape[1:]
assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size())
overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
pool_module.__flops__ += overall_flops
def self_calculate_flops_counter_hook(self_module, inputs, output):
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
self_module.__flops__ += overall_flops
def fc_flops_counter_hook(fc_module, inputs, output):
batch_size = inputs[0].size(0)
xin, xout = fc_module.in_features, fc_module.out_features
assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout)
overall_flops = batch_size * xin * xout
if fc_module.bias is not None:
overall_flops += batch_size * xout
fc_module.__flops__ += overall_flops
def conv1d_flops_counter_hook(conv_module, inputs, outputs):
batch_size = inputs[0].size(0)
outL = outputs.shape[-1]
[kernel] = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
conv_per_position_flops = kernel * in_channels * out_channels / groups
active_elements_count = batch_size * outL
overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops
def conv2d_flops_counter_hook(conv_module, inputs, output):
batch_size = inputs[0].size(0)
output_height, output_width = output.shape[2:]
kernel_height, kernel_width = conv_module.kernel_size
in_channels = conv_module.in_channels
out_channels = conv_module.out_channels
groups = conv_module.groups
conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups
active_elements_count = batch_size * output_height * output_width
overall_flops = conv_per_position_flops * active_elements_count
if conv_module.bias is not None:
overall_flops += out_channels * active_elements_count
conv_module.__flops__ += overall_flops
def batch_counter_hook(module, inputs, output):
# Can have multiple inputs, getting the first one
inputs = inputs[0]
batch_size = inputs.shape[0]
module.__batch_counter__ += batch_size
def add_batch_counter_hook_function(module):
if not hasattr(module, '__batch_counter_handle__'):
handle = module.register_forward_hook(batch_counter_hook)
module.__batch_counter_handle__ = handle
def add_flops_counter_variable_or_reset(module):
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
or isinstance(module, torch.nn.Conv1d) \
or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
or hasattr(module, 'calculate_flop_self'):
module.__flops__ = 0
def add_flops_counter_hook_function(module):
if isinstance(module, torch.nn.Conv2d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(conv2d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Conv1d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(conv1d_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.Linear):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(fc_flops_counter_hook)
module.__flops_handle__ = handle
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(pool_flops_counter_hook)
module.__flops_handle__ = handle
elif hasattr(module, 'calculate_flop_self'): # self-defined module
if not hasattr(module, '__flops_handle__'):
handle = module.register_forward_hook(self_calculate_flops_counter_hook)
module.__flops_handle__ = handle
def remove_hook_function(module):
hookers = ['__batch_counter_handle__', '__flops_handle__']
for hooker in hookers:
if hasattr(module, hooker):
handle = getattr(module, hooker)
handle.remove()
keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers
for ckey in keys:
if hasattr(module, ckey): delattr(module, ckey)

View File

@@ -0,0 +1,28 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .starts import get_machine_info, save_checkpoint, copy_checkpoint
from .optimizers import get_optim_scheduler
from .starts import prepare_seed #, prepare_logger, get_machine_info, save_checkpoint, copy_checkpoint
'''
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
from .funcs_nasbench import get_nas_bench_loaders
def get_procedures(procedure):
from .basic_main import basic_train, basic_valid
from .search_main import search_train, search_valid
from .search_main_v2 import search_train_v2
from .simple_KD_main import simple_KD_train, simple_KD_valid
train_funcs = {'basic' : basic_train, \
'search': search_train,'Simple-KD': simple_KD_train, \
'search-v2': search_train_v2}
valid_funcs = {'basic' : basic_valid, \
'search': search_valid,'Simple-KD': simple_KD_valid, \
'search-v2': search_valid}
train_func = train_funcs[procedure]
valid_func = valid_funcs[procedure]
return train_func, valid_func
'''

View File

@@ -0,0 +1,204 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import math, torch
import torch.nn as nn
from bisect import bisect_right
from torch.optim import Optimizer
class _LRScheduler(object):
def __init__(self, optimizer, warmup_epochs, epochs):
if not isinstance(optimizer, Optimizer):
raise TypeError('{:} is not an Optimizer'.format(type(optimizer).__name__))
self.optimizer = optimizer
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.max_epochs = epochs
self.warmup_epochs = warmup_epochs
self.current_epoch = 0
self.current_iter = 0
def extra_repr(self):
return ''
def __repr__(self):
return ('{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}'.format(name=self.__class__.__name__, **self.__dict__)
+ ', {:})'.format(self.extra_repr()))
def state_dict(self):
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_lr(self):
raise NotImplementedError
def get_min_info(self):
lrs = self.get_lr()
return '#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#'.format(min(lrs), max(lrs), self.current_epoch, self.current_iter)
def get_min_lr(self):
return min( self.get_lr() )
def update(self, cur_epoch, cur_iter):
if cur_epoch is not None:
assert isinstance(cur_epoch, int) and cur_epoch>=0, 'invalid cur-epoch : {:}'.format(cur_epoch)
self.current_epoch = cur_epoch
if cur_iter is not None:
assert isinstance(cur_iter, float) and cur_iter>=0, 'invalid cur-iter : {:}'.format(cur_iter)
self.current_iter = cur_iter
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class CosineAnnealingLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
self.T_max = T_max
self.eta_min = eta_min
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, T-max={:}, eta-min={:}'.format('cosine', self.T_max, self.eta_min)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
#if last_epoch < self.T_max:
#if last_epoch < self.max_epochs:
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
#else:
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
elif self.current_epoch >= self.max_epochs:
lr = self.eta_min
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class MultiStepLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
assert len(milestones) == len(gammas), 'invalid {:} vs {:}'.format(len(milestones), len(gammas))
self.milestones = milestones
self.gammas = gammas
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, milestones={:}, gammas={:}, base-lrs={:}'.format('multistep', self.milestones, self.gammas, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
idx = bisect_right(self.milestones, last_epoch)
lr = base_lr
for x in self.gammas[:idx]: lr *= x
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class ExponentialLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
self.gamma = gamma
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, gamma={:}, base-lrs={:}'.format('exponential', self.gamma, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
lr = base_lr * (self.gamma ** last_epoch)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class LinearLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
self.max_LR = max_LR
self.min_LR = min_LR
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return 'type={:}, max_LR={:}, min_LR={:}, base-lrs={:}'.format('LinearLR', self.max_LR, self.min_LR, self.base_lrs)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, 'invalid last_epoch : {:}'.format(last_epoch)
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
lr = base_lr * (1-ratio)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lrs.append( lr )
return lrs
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def get_optim_scheduler(parameters, config):
assert hasattr(config, 'optim') and hasattr(config, 'scheduler') and hasattr(config, 'criterion'), 'config must have optim / scheduler / criterion keys instead of {:}'.format(config)
if config.optim == 'SGD':
optim = torch.optim.SGD(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov)
elif config.optim == 'RMSprop':
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
else:
raise ValueError('invalid optim : {:}'.format(config.optim))
if config.scheduler == 'cos':
T_max = getattr(config, 'T_max', config.epochs)
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
elif config.scheduler == 'multistep':
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
elif config.scheduler == 'exponential':
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
elif config.scheduler == 'linear':
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
else:
raise ValueError('invalid scheduler : {:}'.format(config.scheduler))
if config.criterion == 'Softmax':
criterion = torch.nn.CrossEntropyLoss()
elif config.criterion == 'SmoothSoftmax':
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
else:
raise ValueError('invalid criterion : {:}'.format(config.criterion))
return optim, scheduler, criterion

View File

@@ -0,0 +1,64 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, torch, random, PIL, copy, numpy as np
from os import path as osp
from shutil import copyfile
def prepare_seed(rand_seed):
random.seed(rand_seed)
np.random.seed(rand_seed)
torch.manual_seed(rand_seed)
torch.cuda.manual_seed(rand_seed)
torch.cuda.manual_seed_all(rand_seed)
def prepare_logger(xargs):
args = copy.deepcopy( xargs )
from log_utils import Logger
logger = Logger(args.save_dir, args.rand_seed)
logger.log('Main Function with logger : {:}'.format(logger))
logger.log('Arguments : -------------------------------')
for name, value in args._get_kwargs():
logger.log('{:16} : {:}'.format(name, value))
logger.log("Python Version : {:}".format(sys.version.replace('\n', ' ')))
logger.log("Pillow Version : {:}".format(PIL.__version__))
logger.log("PyTorch Version : {:}".format(torch.__version__))
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
logger.log("CUDA_VISIBLE_DEVICES : {:}".format(os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ else 'None'))
return logger
def get_machine_info():
info = "Python Version : {:}".format(sys.version.replace('\n', ' '))
info+= "\nPillow Version : {:}".format(PIL.__version__)
info+= "\nPyTorch Version : {:}".format(torch.__version__)
info+= "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
info+= "\nCUDA available : {:}".format(torch.cuda.is_available())
info+= "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
if 'CUDA_VISIBLE_DEVICES' in os.environ:
info+= "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ['CUDA_VISIBLE_DEVICES'])
else:
info+= "\nDoes not set CUDA_VISIBLE_DEVICES"
return info
def save_checkpoint(state, filename, logger):
if osp.isfile(filename):
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(filename))
os.remove(filename)
torch.save(state, filename)
assert osp.isfile(filename), 'save filename : {:} failed, which is not found.'.format(filename)
if hasattr(logger, 'log'): logger.log('save checkpoint into {:}'.format(filename))
return filename
def copy_checkpoint(src, dst, logger):
if osp.isfile(dst):
if hasattr(logger, 'log'): logger.log('Find {:} exist, delete is at first before saving'.format(dst))
os.remove(dst)
copyfile(src, dst)
if hasattr(logger, 'log'): logger.log('copy the file from {:} into {:}'.format(src, dst))

View File

@@ -0,0 +1,83 @@
from torch.multiprocessing import Process
import os
from absl import app, flags
import sys
import torch
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
from nas_bench_201 import train_single_model
from all_path import NASBENCH201
FLAGS = flags.FLAGS
flags.DEFINE_integer("num_split", 15, "The number of splits")
flags.DEFINE_list("arch_idx_lst", None, "arch index list")
flags.DEFINE_list("arch_str_lst", None, "arch str list")
flags.DEFINE_string("meta_test_path", None, "meta test path")
flags.DEFINE_string("data_name", None, "data_name")
flags.DEFINE_string("raw_data_path", None, "raw_data_path")
def run_single_process(rank, seed, arch_idx, meta_test_path, data_name,
raw_data_path, num_split=15, backend="nccl"):
# 8 GPUs
device = ['0', '1', '2', '3', '4', '5', '6', '7', '0', '1', '2', '3', '4', '5', '6', '7',
'0', '1', '2', '3', '4', '5', '6', '7', '0', '1', '2', '3', '4', '5', '6', '7'][rank]
os.environ["CUDA_VISIBLE_DEVICES"] = device
save_path = os.path.join(meta_test_path, str(arch_idx))
if type(seed) == int:
seeds = [seed]
elif type(seed) in [list, tuple]:
seeds = seed
nasbench201 = torch.load(NASBENCH201)
arch_str = nasbench201['arch']['str'][arch_idx]
os.makedirs(save_path, exist_ok=True)
train_single_model(save_dir=save_path,
workers=24,
datasets=[data_name],
xpaths=[f'{raw_data_path}/{data_name}'],
splits=[0],
use_less=False,
seeds=seeds,
model_str=arch_str,
arch_config={'channel': 16, 'num_cells': 5})
def run_multi_process(argv):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "1234"
os.environ["WANDB_SILENT"] = "true"
processes = []
arch_idx_lst = [int(i) for i in FLAGS.arch_idx_lst]
seeds = [777, 888, 999] * len(arch_idx_lst)
arch_idx_lst_ = []
for i in arch_idx_lst:
arch_idx_lst_ += [i] * 3
for arch_idx in arch_idx_lst:
os.makedirs(os.path.join(FLAGS.meta_test_path, str(arch_idx)), exist_ok=True)
for rank in range(FLAGS.num_split):
arch_idx = arch_idx_lst_[rank]
seed = seeds[rank]
p = Process(target=run_single_process, args=(rank,
seed,
arch_idx,
FLAGS.meta_test_path,
FLAGS.data_name,
FLAGS.raw_data_path))
p.start()
processes.append(p)
for p in processes:
p.join()
while any(p.is_alive() for p in processes):
continue
print("All processes have completed.")
if __name__ == "__main__":
app.run(run_multi_process)

View File

@@ -0,0 +1,38 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from set_encoder.setenc_modules import *
class SetPool(nn.Module):
def __init__(self, dim_input, num_outputs, dim_output,
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
super(SetPool, self).__init__()
if 'sab' in mode: # [32, 400, 128]
self.enc = nn.Sequential(
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
else: # [32, 400, 128]
self.enc = nn.Sequential(
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
if 'PF' in mode: # [32, 1, 501]
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
nn.Linear(dim_hidden, dim_output))
elif 'P' in mode:
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
else: # torch.Size([32, 1, 501])
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
nn.Linear(dim_hidden, dim_output))
# "", sm, sab, sabsm
def forward(self, X):
x1 = self.enc(X)
x2 = self.dec(x1)
return x2

View File

@@ -0,0 +1,67 @@
#####################################################################################
# Copyright (c) Juho Lee SetTransformer, ICML 2019 [GitHub set_transformer]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MAB(nn.Module):
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
super(MAB, self).__init__()
self.dim_V = dim_V
self.num_heads = num_heads
self.fc_q = nn.Linear(dim_Q, dim_V)
self.fc_k = nn.Linear(dim_K, dim_V)
self.fc_v = nn.Linear(dim_K, dim_V)
if ln:
self.ln0 = nn.LayerNorm(dim_V)
self.ln1 = nn.LayerNorm(dim_V)
self.fc_o = nn.Linear(dim_V, dim_V)
def forward(self, Q, K):
Q = self.fc_q(Q)
K, V = self.fc_k(K), self.fc_v(K)
dim_split = self.dim_V // self.num_heads
Q_ = torch.cat(Q.split(dim_split, 2), 0)
K_ = torch.cat(K.split(dim_split, 2), 0)
V_ = torch.cat(V.split(dim_split, 2), 0)
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O))
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
return O
class SAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, ln=False):
super(SAB, self).__init__()
self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)
def forward(self, X):
return self.mab(X, X)
class ISAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
super(ISAB, self).__init__()
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
nn.init.xavier_uniform_(self.I)
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
def forward(self, X):
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
return self.mab1(X, H)
class PMA(nn.Module):
def __init__(self, dim, num_heads, num_seeds, ln=False):
super(PMA, self).__init__()
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
nn.init.xavier_uniform_(self.S)
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
def forward(self, X):
return self.mab(self.S.repeat(X.size(0), 1, 1), X)

View File

@@ -0,0 +1,243 @@
######################################################################################
# Copyright (c) muhanzhang, D-VAE, NeurIPS 2019 [GitHub D-VAE]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import torch
from torch import nn
from set_encoder.setenc_models import SetPool
class MetaSurrogateUnnoisedModel(nn.Module):
def __init__(self, args, graph_config):
super(MetaSurrogateUnnoisedModel, self).__init__()
self.max_n = graph_config['max_n'] # maximum number of vertices
self.nvt = args.nvt # number of vertex types
self.START_TYPE = graph_config['START_TYPE']
self.END_TYPE = graph_config['END_TYPE']
self.hs = args.hs # hidden state size of each vertex
self.nz = args.nz # size of latent representation z
self.gs = args.hs # size of graph state
self.bidir = True # whether to use bidirectional encoding
self.vid = True
self.device = None
self.input_type = 'DG'
self.num_sample = args.num_sample
if self.vid:
self.vs = self.hs + self.max_n # vertex state size = hidden state + vid
else:
self.vs = self.hs
# 0. encoding-related
self.grue_forward = nn.GRUCell(self.nvt, self.hs) # encoder GRU
self.grue_backward = nn.GRUCell(
self.nvt, self.hs) # backward encoder GRU
self.fc1 = nn.Linear(self.gs, self.nz) # latent mean
self.fc2 = nn.Linear(self.gs, self.nz) # latent logvar
# 2. gate-related
self.gate_forward = nn.Sequential(
nn.Linear(self.vs, self.hs),
nn.Sigmoid()
)
self.gate_backward = nn.Sequential(
nn.Linear(self.vs, self.hs),
nn.Sigmoid()
)
self.mapper_forward = nn.Sequential(
nn.Linear(self.vs, self.hs, bias=False),
) # disable bias to ensure padded zeros also mapped to zeros
self.mapper_backward = nn.Sequential(
nn.Linear(self.vs, self.hs, bias=False),
)
# 3. bidir-related, to unify sizes
if self.bidir:
self.hv_unify = nn.Sequential(
nn.Linear(self.hs * 2, self.hs),
)
self.hg_unify = nn.Sequential(
nn.Linear(self.gs * 2, self.gs),
)
# 4. other
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.logsoftmax1 = nn.LogSoftmax(1)
# 6. predictor
np = self.gs
self.intra_setpool = SetPool(dim_input=512,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.inter_setpool = SetPool(dim_input=self.nz,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.set_fc = nn.Sequential(
nn.Linear(512, self.nz),
nn.ReLU())
input_dim = 0
if 'D' in self.input_type:
input_dim += self.nz
if 'G' in self.input_type:
input_dim += self.nz
self.pred_fc = nn.Sequential(
nn.Linear(input_dim, self.hs),
nn.Tanh(),
nn.Linear(self.hs, 1)
)
self.mseloss = nn.MSELoss(reduction='sum')
def predict(self, D_mu, G_mu):
input_vec = []
if 'D' in self.input_type:
input_vec.append(D_mu)
if 'G' in self.input_type:
input_vec.append(G_mu)
input_vec = torch.cat(input_vec, dim=1)
return self.pred_fc(input_vec)
def get_device(self):
if self.device is None:
self.device = next(self.parameters()).device
return self.device
def _get_zeros(self, n, length):
# get a zero hidden state
return torch.zeros(n, length).to(self.get_device())
def _get_zero_hidden(self, n=1):
return self._get_zeros(n, self.hs) # get a zero hidden state
def _one_hot(self, idx, length):
if type(idx) in [list, range]:
if idx == []:
return None
idx = torch.LongTensor(idx).unsqueeze(0).t()
x = torch.zeros((len(idx), length)).scatter_(
1, idx, 1).to(self.get_device())
else:
idx = torch.LongTensor([idx]).unsqueeze(0)
x = torch.zeros((1, length)).scatter_(
1, idx, 1).to(self.get_device())
return x
def _gated(self, h, gate, mapper):
return gate(h) * mapper(h)
def _collate_fn(self, G):
return [g.copy() for g in G]
def _propagate_to(self, G, v, propagator, H=None, reverse=False, gate=None, mapper=None):
# propagate messages to vertex index v for all graphs in G
# return the new messages (states) at v
G = [g for g in G if g.vcount() > v]
if len(G) == 0:
return
if H is not None:
idx = [i for i, g in enumerate(G) if g.vcount() > v]
H = H[idx]
v_types = [g.vs[v]['type'] for g in G]
X = self._one_hot(v_types, self.nvt)
if reverse:
H_name = 'H_backward' # name of the hidden states attribute
H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G]
if self.vid:
vids = [self._one_hot(g.successors(v), self.max_n) for g in G]
gate, mapper = self.gate_backward, self.mapper_backward
else:
H_name = 'H_forward' # name of the hidden states attribute
H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
if self.vid:
vids = [self._one_hot(g.predecessors(v), self.max_n)
for g in G]
if gate is None:
gate, mapper = self.gate_forward, self.mapper_forward
if self.vid:
H_pred = [[torch.cat([x[i], y[i:i + 1]], 1)
for i in range(len(x))] for x, y in zip(H_pred, vids)]
# if h is not provided, use gated sum of v's predecessors' states as the input hidden state
if H is None:
# maximum number of predecessors
max_n_pred = max([len(x) for x in H_pred])
if max_n_pred == 0:
H = self._get_zero_hidden(len(G))
else:
H_pred = [torch.cat(h_pred +
[self._get_zeros(max_n_pred - len(h_pred), self.vs)], 0).unsqueeze(0)
for h_pred in H_pred] # pad all to same length
H_pred = torch.cat(H_pred, 0) # batch * max_n_pred * vs
H = self._gated(H_pred, gate, mapper).sum(1) # batch * hs
Hv = propagator(X, H)
for i, g in enumerate(G):
g.vs[v][H_name] = Hv[i:i + 1]
return Hv
def _propagate_from(self, G, v, propagator, H0=None, reverse=False):
# perform a series of propagation_to steps starting from v following a topo order
# assume the original vertex indices are in a topological order
if reverse:
prop_order = range(v, -1, -1)
else:
prop_order = range(v, self.max_n)
Hv = self._propagate_to(G, v, propagator, H0,
reverse=reverse) # the initial vertex
for v_ in prop_order[1:]:
self._propagate_to(G, v_, propagator, reverse=reverse)
return Hv
def _get_graph_state(self, G, decode=False):
# get the graph states
# when decoding, use the last generated vertex's state as the graph state
# when encoding, use the ending vertex state or unify the starting and ending vertex states
Hg = []
for g in G:
hg = g.vs[g.vcount() - 1]['H_forward']
if self.bidir and not decode: # decoding never uses backward propagation
hg_b = g.vs[0]['H_backward']
hg = torch.cat([hg, hg_b], 1)
Hg.append(hg)
Hg = torch.cat(Hg, 0)
if self.bidir and not decode:
Hg = self.hg_unify(Hg)
return Hg
def set_encode(self, X):
proto_batch = []
for x in X:
cls_protos = self.intra_setpool(
x.view(-1, self.num_sample, 512)).squeeze(1)
proto_batch.append(
self.inter_setpool(cls_protos.unsqueeze(0)))
v = torch.stack(proto_batch).squeeze()
return v
def graph_encode(self, G):
# encode graphs G into latent vectors
if type(G) != list:
G = [G]
self._propagate_from(G, 0, self.grue_forward, H0=self._get_zero_hidden(len(G)),
reverse=False)
if self.bidir:
self._propagate_from(G, self.max_n - 1, self.grue_backward,
H0=self._get_zero_hidden(len(G)), reverse=True)
Hg = self._get_graph_state(G)
mu = self.fc1(Hg)
# logvar = self.fc2(Hg)
return mu # , logvar
def reparameterize(self, mu, logvar, eps_scale=0.01):
# return z ~ N(mu, std)
if self.training:
std = logvar.mul(0.5).exp_()
eps = torch.randn_like(std) * eps_scale
return eps.mul(std).add_(mu)
else:
return mu

View File

@@ -0,0 +1,33 @@
import os
import logging
import torch
import numpy as np
import random
def restore_checkpoint(ckpt_dir, state, device, resume=False):
if not resume:
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
return state
elif not os.path.exists(ckpt_dir):
if not os.path.exists(os.path.dirname(ckpt_dir)):
os.makedirs(os.path.dirname(ckpt_dir))
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
return state
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
for k in state:
if k in ['optimizer', 'model', 'ema']:
state[k].load_state_dict(loaded_state[k])
else:
state[k] = loaded_state[k]
return state
def reset_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True

View File

View File

@@ -0,0 +1,391 @@
# Most of this code is from https://github.com/AIoT-MLSys-Lab/CATE.git
# which was authored by Shen Yan, Kaiqiang Song, Fei Liu, Mi Zhang, 2021
import torch.nn as nn
import torch
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import utils
from .transformer import Encoder, SemanticEmbedding
from .set_encoder.setenc_models import SetPool
class MLP(torch.nn.Module):
def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False, activate_func=F.relu):
"""
num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
input_dim: dimensionality of input features
hidden_dim: dimensionality of hidden units at ALL layers
output_dim: number of classes for prediction
num_classes: the number of classes of input, to be treated with different gains and biases,
(see the definition of class `ConditionalLayer1d`)
"""
super(MLP, self).__init__()
self.linear_or_not = True # default is linear model
self.num_layers = num_layers
self.use_bn = use_bn
self.activate_func = activate_func
if num_layers < 1:
raise ValueError("number of layers should be positive!")
elif num_layers == 1:
# Linear model
self.linear = torch.nn.Linear(input_dim, output_dim)
else:
# Multi-layer model
self.linear_or_not = False
self.linears = torch.nn.ModuleList()
self.linears.append(torch.nn.Linear(input_dim, hidden_dim))
for layer in range(num_layers - 2):
self.linears.append(torch.nn.Linear(hidden_dim, hidden_dim))
self.linears.append(torch.nn.Linear(hidden_dim, output_dim))
if self.use_bn:
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layers - 1):
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))
def forward(self, x):
"""
:param x: [num_classes * batch_size, N, F_i], batch of node features
note that in self.cond_layers[layer],
`x` is splited into `num_classes` groups in dim=0,
and then treated with different gains and biases
"""
if self.linear_or_not:
# If linear model
return self.linear(x)
else:
# If MLP
h = x
for layer in range(self.num_layers - 1):
h = self.linears[layer](h)
if self.use_bn:
h = self.batch_norms[layer](h)
h = self.activate_func(h)
return self.linears[self.num_layers - 1](h)
""" Transformer Encoder """
class GraphEncoder(nn.Module):
def __init__(self, config):
super(GraphEncoder, self).__init__()
# Forward Transformers
self.encoder_f = Encoder(config)
def forward(self, x, mask):
h_f, hs_f, attns_f = self.encoder_f(x, mask)
h = torch.cat(hs_f, dim=-1)
return h
@staticmethod
def get_embeddings(h_x):
h_x = h_x.cpu()
return h_x[:, -1]
class CLSHead(nn.Module):
def __init__(self, config, init_weights=None):
super(CLSHead, self).__init__()
self.layer_1 = nn.Linear(config.d_model, config.d_model)
self.dropout = nn.Dropout(p=config.dropout)
self.layer_2 = nn.Linear(config.d_model, config.n_vocab)
if init_weights is not None:
self.layer_2.weight = init_weights
def forward(self, x):
x = self.dropout(torch.tanh(self.layer_1(x)))
return F.log_softmax(self.layer_2(x), dim=-1)
@utils.register_model(name='CATE')
class CATE(nn.Module):
def __init__(self, config):
super(CATE, self).__init__()
# Shared Embedding Layer
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
self.dropout_op = nn.Dropout(p=config.model.dropout)
self.d_model = config.model.graph_encoder.d_model
self.act = act = get_act(config)
# Time
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
# 2 GraphEncoder for X and Y
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
use_bn=False, activate_func=F.elu)
if 'pos_enc_type' in config.model:
self.pos_enc_type = config.model.pos_enc_type
if self.pos_enc_type == 1:
raise NotImplementedError
elif self.pos_enc_type == 2:
if config.data.name == 'NASBench201':
self.pos_encoder = PositionalEncoding_Cell(d_model=self.d_model, max_len=config.data.max_node)
else:
self.pos_encoder = PositionalEncoding_StageWise(d_model=self.d_model, max_len=config.data.max_node)
elif self.pos_enc_type == 3:
raise NotImplementedError
else:
self.pos_encoder = None
else:
self.pos_encoder = None
def forward(self, X, time_cond, maskX):
emb_x = self.dropout_op(self.opEmb(X))
if self.pos_encoder is not None:
emb_p = self.pos_encoder(emb_x)
emb_x = emb_x + emb_p
# Time embedding
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)
emb_t = self.timeEmb1(emb_t) # [32, 512]
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
emb_t = emb_t.unsqueeze(1)
emb = emb_x + emb_t
h_x = self.graph_encoder(emb, maskX)
h_x = self.final(h_x)
return h_x
@utils.register_model(name='PredictorCATE')
class PredictorCATE(nn.Module):
def __init__(self, config):
super(PredictorCATE, self).__init__()
# Shared Embedding Layer
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
self.dropout_op = nn.Dropout(p=config.model.dropout)
self.d_model = config.model.graph_encoder.d_model
self.act = act = get_act(config)
# Time
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
# 2 GraphEncoder for X and Y
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
use_bn=False, activate_func=F.elu)
self.rdim = int(config.data.max_node * config.data.n_vocab)
self.regeress = MLP(num_layers=2, input_dim=self.rdim, hidden_dim=2*self.rdim, output_dim=1,
use_bn=False, activate_func=F.elu)
def forward(self, X, time_cond, maskX):
emb_x = self.dropout_op(self.opEmb(X))
# Time embedding
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)
emb_t = self.timeEmb1(emb_t)
emb_t = self.timeEmb2(self.act(emb_t))
emb_t = emb_t.unsqueeze(1)
emb = emb_x + emb_t
h_x = self.graph_encoder(emb, maskX)
h_x = self.final(h_x)
h_x = h_x.reshape(h_x.size(0), -1)
h_x = self.regeress(h_x)
return h_x
class PositionalEncoding_StageWise(nn.Module):
def __init__(self, d_model, max_len):
super(PositionalEncoding_StageWise, self).__init__()
NUM_STAGE = 5
max_len = int(max_len / NUM_STAGE)
self.encoding = torch.zeros(max_len, d_model)
self.encoding.requires_grad = False
pos = torch.arange(0, max_len)
pos = pos.float().unsqueeze(dim=1)
_2i = torch.arange(0, d_model, step=2).float()
self.encoding[:, ::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
self.encoding = torch.cat([self.encoding] * NUM_STAGE, dim=0)
def forward(self, x):
batch_size, seq_len, _ = x.size()
return self.encoding[:seq_len, :].to(x.device)
class PositionalEncoding_Cell(nn.Module):
def __init__(self, d_model, max_len):
super(PositionalEncoding_Cell, self).__init__()
NUM_STAGE = 1
max_len = int(max_len / NUM_STAGE)
self.encoding = torch.zeros(max_len, d_model)
self.encoding.requires_grad = False
pos = torch.arange(0, max_len)
pos = pos.float().unsqueeze(dim=1)
_2i = torch.arange(0, d_model, step=2).float()
self.encoding[:, ::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
self.encoding = torch.cat([self.encoding] * NUM_STAGE, dim=0)
def forward(self, x):
batch_size, seq_len, _ = x.size()
return self.encoding[:seq_len, :].to(x.device)
@utils.register_model(name='MetaPredictorCATE')
class MetaPredictorCATE(nn.Module):
def __init__(self, config):
super(MetaPredictorCATE, self).__init__()
self.input_type= config.model.input_type
self.hs = config.model.hs
self.opEmb = SemanticEmbedding(config.model.graph_encoder)
self.dropout_op = nn.Dropout(p=config.model.dropout)
self.d_model = config.model.graph_encoder.d_model
self.act = act = get_act(config)
# Time
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
self.graph_encoder = GraphEncoder(config.model.graph_encoder)
self.fdim = int(config.model.graph_encoder.n_layers * config.model.graph_encoder.d_model)
self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=config.data.n_vocab,
use_bn=False, activate_func=F.elu)
self.rdim = int(config.data.max_node * config.data.n_vocab)
self.regeress = MLP(num_layers=2, input_dim=self.rdim, hidden_dim=2*self.rdim, output_dim=2*self.rdim,
use_bn=False, activate_func=F.elu)
# Set
self.nz = config.model.nz
self.num_sample = config.model.num_sample
self.intra_setpool = SetPool(dim_input=512,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.inter_setpool = SetPool(dim_input=self.nz,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.set_fc = nn.Sequential(
nn.Linear(512, self.nz),
nn.ReLU())
input_dim = 0
if 'D' in self.input_type:
input_dim += self.nz
if 'A' in self.input_type:
input_dim += 2*self.rdim
self.pred_fc = nn.Sequential(
nn.Linear(input_dim, self.hs),
nn.Tanh(),
nn.Linear(self.hs, 1)
)
self.sample_state = False
self.D_mu = None
def arch_encode(self, X, time_cond, maskX):
emb_x = self.dropout_op(self.opEmb(X))
# Time embedding
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
emb_t = self.timeEmb1(emb_t) # [32, 512]
emb_t = self.timeEmb2(self.act(emb_t)) # [32, 64]
emb_t = emb_t.unsqueeze(1)
emb = emb_x + emb_t
h_x = self.graph_encoder(emb, maskX)
h_x = self.final(h_x)
h_x = h_x.reshape(h_x.size(0), -1)
h_x = self.regeress(h_x)
return h_x
def set_encode(self, task):
proto_batch = []
for x in task:
cls_protos = self.intra_setpool(
x.view(-1, self.num_sample, 512)).squeeze(1)
proto_batch.append(
self.inter_setpool(cls_protos.unsqueeze(0)))
v = torch.stack(proto_batch).squeeze()
return v
def predict(self, D_mu, A_mu):
input_vec = []
if 'D' in self.input_type:
input_vec.append(D_mu)
if 'A' in self.input_type:
input_vec.append(A_mu)
input_vec = torch.cat(input_vec, dim=1)
return self.pred_fc(input_vec)
def forward(self, X, time_cond, maskX, task):
if self.sample_state:
if self.D_mu is None:
self.D_mu = self.set_encode(task)
D_mu = self.D_mu
else:
D_mu = self.set_encode(task)
A_mu = self.arch_encode(X, time_cond, maskX)
y_pred = self.predict(D_mu, A_mu)
return y_pred
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
def get_act(config):
"""Get actiuvation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
elif config.model.nonlinearity.lower() == 'tanh':
return nn.Tanh()
else:
raise NotImplementedError('activation function does not exist!')

View File

@@ -0,0 +1,125 @@
# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch
# which was authored by Yuge Zhang, 2020
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from . import utils
from models.cate import PositionalEncoding_StageWise
def normalize_adj(adj):
# Row-normalize matrix
last_dim = adj.size(-1)
rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
return torch.div(adj, rowsum)
def graph_pooling(inputs, num_vertices):
num_vertices = num_vertices.to(inputs.device)
out = inputs.sum(1)
return torch.div(out, num_vertices.unsqueeze(-1).expand_as(out))
class DirectedGraphConvolution(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight1 = nn.Parameter(torch.zeros((in_features, out_features)))
self.weight2 = nn.Parameter(torch.zeros((in_features, out_features)))
self.dropout = nn.Dropout(0.1)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight1.data)
nn.init.xavier_uniform_(self.weight2.data)
def forward(self, inputs, adj):
inputs = inputs.to(self.weight1.device)
adj = adj.to(self.weight1.device)
norm_adj = normalize_adj(adj)
output1 = F.relu(torch.matmul(norm_adj, torch.matmul(inputs, self.weight1)))
inv_norm_adj = normalize_adj(adj.transpose(1, 2))
output2 = F.relu(torch.matmul(inv_norm_adj, torch.matmul(inputs, self.weight2)))
out = (output1 + output2) / 2
out = self.dropout(out)
return out
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
@utils.register_model(name='NeuralPredictor')
class NeuralPredictor(nn.Module):
def __init__(self, config):
super().__init__()
self.gcn = [DirectedGraphConvolution(config.model.graph_encoder.initial_hidden if i == 0 else config.model.graph_encoder.gcn_hidden,
config.model.graph_encoder.gcn_hidden)
for i in range(config.model.graph_encoder.gcn_layers)]
self.gcn = nn.ModuleList(self.gcn)
self.dropout = nn.Dropout(0.1)
self.fc1 = nn.Linear(config.model.graph_encoder.gcn_hidden, config.model.graph_encoder.linear_hidden, bias=False)
self.fc2 = nn.Linear(config.model.graph_encoder.linear_hidden, 1, bias=False)
# Time
self.d_model = config.model.graph_encoder.gcn_hidden
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
self.act = act = get_act(config)
def forward(self, X, time_cond, maskX):
out = X
adj = maskX
numv = torch.tensor([adj.size(1)] * adj.size(0)).to(out.device) # 20
gs = adj.size(1) # graph node number
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)# time embedding
emb_t = self.timeEmb1(emb_t)
emb_t = self.timeEmb2(self.act(emb_t)) # (5, 144)
adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device)) # assuming diagonal is not 1
for layer in self.gcn:
out = layer(out, adj_with_diag)
out = graph_pooling(out, numv)
# time
out = out + emb_t
out = self.fc1(out)
out = self.dropout(out)
out = self.fc2(out)
return out
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
def get_act(config):
"""Get actiuvation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
elif config.model.nonlinearity.lower() == 'tanh':
return nn.Tanh()
else:
raise NotImplementedError('activation function does not exist!')

View File

@@ -0,0 +1,190 @@
# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch
# which was authored by Yuge Zhang, 2020
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from . import utils
from .set_encoder.setenc_models import SetPool
def normalize_adj(adj):
# Row-normalize matrix
last_dim = adj.size(-1)
rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
return torch.div(adj, rowsum)
def graph_pooling(inputs, num_vertices):
num_vertices = num_vertices.to(inputs.device)
out = inputs.sum(1)
return torch.div(out, num_vertices.unsqueeze(-1).expand_as(out))
class DirectedGraphConvolution(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight1 = nn.Parameter(torch.zeros((in_features, out_features)))
self.weight2 = nn.Parameter(torch.zeros((in_features, out_features)))
self.dropout = nn.Dropout(0.1)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight1.data)
nn.init.xavier_uniform_(self.weight2.data)
def forward(self, inputs, adj):
inputs = inputs.to(self.weight1.device)
adj = adj.to(self.weight1.device)
norm_adj = normalize_adj(adj)
output1 = F.relu(torch.matmul(norm_adj, torch.matmul(inputs, self.weight1)))
inv_norm_adj = normalize_adj(adj.transpose(1, 2))
output2 = F.relu(torch.matmul(inv_norm_adj, torch.matmul(inputs, self.weight2)))
out = (output1 + output2) / 2
out = self.dropout(out)
return out
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
@utils.register_model(name='MetaNeuralPredictor')
class MetaeuralPredictor(nn.Module):
def __init__(self, config):
super().__init__()
# Arch
self.gcn = [DirectedGraphConvolution(config.model.graph_encoder.initial_hidden if i == 0 else config.model.graph_encoder.gcn_hidden,
config.model.graph_encoder.gcn_hidden)
for i in range(config.model.graph_encoder.gcn_layers)]
self.gcn = nn.ModuleList(self.gcn)
self.dropout = nn.Dropout(0.1)
self.fc1 = nn.Linear(config.model.graph_encoder.gcn_hidden, config.model.graph_encoder.linear_hidden, bias=False)
# Time
self.d_model = config.model.graph_encoder.gcn_hidden
self.timeEmb1 = nn.Linear(self.d_model, self.d_model * 4)
self.timeEmb2 = nn.Linear(self.d_model * 4, self.d_model)
self.act = act = get_act(config)
self.input_type = config.model.input_type
self.hs = config.model.hs
# Set
self.nz = config.model.nz
self.num_sample = config.model.num_sample
self.intra_setpool = SetPool(dim_input=512,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.inter_setpool = SetPool(dim_input=self.nz,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.set_fc = nn.Sequential(
nn.Linear(512, self.nz),
nn.ReLU())
input_dim = 0
if 'D' in self.input_type:
input_dim += self.nz
if 'A' in self.input_type:
input_dim += config.model.graph_encoder.linear_hidden
self.pred_fc = nn.Sequential(
nn.Linear(input_dim, self.hs),
nn.Tanh(),
nn.Linear(self.hs, 1)
)
self.sample_state = False
self.D_mu = None
def arch_encode(self, X, time_cond, maskX):
out = X
adj = maskX
numv = torch.tensor([adj.size(1)] * adj.size(0)).to(out.device)
gs = adj.size(1) # graph node number
timesteps = time_cond
emb_t = get_timestep_embedding(timesteps, self.d_model)
emb_t = self.timeEmb1(emb_t)
emb_t = self.timeEmb2(self.act(emb_t))
adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device))
for layer in self.gcn:
out = layer(out, adj_with_diag)
out = graph_pooling(out, numv)
# time
out = out + emb_t
out = self.fc1(out)
out = self.dropout(out)
return out
def set_encode(self, task):
proto_batch = []
for x in task:
cls_protos = self.intra_setpool(
x.view(-1, self.num_sample, 512)).squeeze(1)
proto_batch.append(
self.inter_setpool(cls_protos.unsqueeze(0)))
v = torch.stack(proto_batch).squeeze()
return v
def predict(self, D_mu, A_mu):
input_vec = []
if 'D' in self.input_type:
input_vec.append(D_mu)
if 'A' in self.input_type:
input_vec.append(A_mu)
input_vec = torch.cat(input_vec, dim=1)
return self.pred_fc(input_vec)
def forward(self, X, time_cond, maskX, task):
if self.sample_state:
if self.D_mu is None:
self.D_mu = self.set_encode(task)
D_mu = self.D_mu
else:
D_mu = self.set_encode(task)
A_mu = self.arch_encode(X, time_cond, maskX)
y_pred = self.predict(D_mu, A_mu)
return y_pred
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
def get_act(config):
"""Get actiuvation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
elif config.model.nonlinearity.lower() == 'tanh':
return nn.Tanh()
else:
raise NotImplementedError('activation function does not exist!')

View File

@@ -0,0 +1,85 @@
import torch
class ExponentialMovingAverage:
"""
Maintains (exponential) moving average of a set of parameters.
"""
def __init__(self, parameters, decay, use_num_updates=True):
"""
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the result of `model.parameters()`.
decay: The exponential decay.
use_num_updates: Whether to use number of updates when computing averages.
"""
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
self.num_updates = 0 if use_num_updates else None
self.shadow_params = [p.clone().detach()
for p in parameters if p.requires_grad]
self.collected_params = []
def update(self, parameters):
"""
Update currently maintained parameters.
Call this every time the parameters are updated, such as the result of the `optimizer.step()` call.
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to
initialize this object.
"""
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
s_param.sub_(one_minus_decay * (s_param - param))
def copy_to(self, parameters):
"""
Copy current parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages.
"""
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
param.data.copy_(s_param.data)
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the original optimization process.
Store the parameters before the `copy_to` method.
After validation (or model saving), use this to restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
def state_dict(self):
return dict(decay=self.decay, num_updates=self.num_updates, shadow_params=self.shadow_params)
def load_state_dict(self, state_dict):
self.decay = state_dict['decay']
self.num_updates = state_dict['num_updates']
self.shadow_params = state_dict['shadow_params']

View File

@@ -0,0 +1,82 @@
import torch.nn as nn
import torch
from .trans_layers import *
class pos_gnn(nn.Module):
def __init__(self, act, x_ch, pos_ch, out_ch, max_node, graph_layer, n_layers=3, edge_dim=None, heads=4,
temb_dim=None, dropout=0.1, attn_clamp=False):
super().__init__()
self.out_ch = out_ch
self.Dropout_0 = nn.Dropout(dropout)
self.act = act
self.max_node = max_node
self.n_layers = n_layers
if temb_dim is not None:
self.Dense_node0 = nn.Linear(temb_dim, x_ch)
self.Dense_node1 = nn.Linear(temb_dim, pos_ch)
self.Dense_edge0 = nn.Linear(temb_dim, edge_dim)
self.Dense_edge1 = nn.Linear(temb_dim, edge_dim)
self.convs = nn.ModuleList()
self.edge_convs = nn.ModuleList()
self.edge_layer = nn.Linear(edge_dim * 2 + self.out_ch, edge_dim)
for i in range(n_layers):
if i == 0:
self.convs.append(eval(graph_layer)(x_ch, pos_ch, self.out_ch//heads, heads, edge_dim=edge_dim*2,
act=act, attn_clamp=attn_clamp))
else:
self.convs.append(eval(graph_layer)
(self.out_ch, pos_ch, self.out_ch//heads, heads, edge_dim=edge_dim*2, act=act,
attn_clamp=attn_clamp))
self.edge_convs.append(nn.Linear(self.out_ch, edge_dim*2))
def forward(self, x_degree, x_pos, edge_index, dense_ori, dense_spd, dense_index, temb=None):
"""
Args:
x_degree: node degree feature [B*N, x_ch]
x_pos: node rwpe feature [B*N, pos_ch]
edge_index: [2, edge_length]
dense_ori: edge feature [B, N, N, nf//2]
dense_spd: edge shortest path distance feature [B, N, N, nf//2] # Do we need this part? # TODO
dense_index
temb: [B, temb_dim]
"""
B, N, _, _ = dense_ori.shape
if temb is not None:
dense_ori = dense_ori + self.Dense_edge0(self.act(temb))[:, None, None, :]
dense_spd = dense_spd + self.Dense_edge1(self.act(temb))[:, None, None, :]
temb = temb.unsqueeze(1).repeat(1, self.max_node, 1)
temb = temb.reshape(-1, temb.shape[-1])
x_degree = x_degree + self.Dense_node0(self.act(temb))
x_pos = x_pos + self.Dense_node1(self.act(temb))
dense_edge = torch.cat([dense_ori, dense_spd], dim=-1)
ori_edge_attr = dense_edge
h = x_degree
h_pos = x_pos
for i_layer in range(self.n_layers):
h_edge = dense_edge[dense_index]
# update node feature
h, h_pos = self.convs[i_layer](h, h_pos, edge_index, h_edge)
h = self.Dropout_0(h)
h_pos = self.Dropout_0(h_pos)
# update dense edge feature
h_dense_node = h.reshape(B, N, -1)
cur_edge_attr = h_dense_node.unsqueeze(1) + h_dense_node.unsqueeze(2) # [B, N, N, nf]
dense_edge = (dense_edge + self.act(self.edge_convs[i_layer](cur_edge_attr))) / math.sqrt(2.)
dense_edge = self.Dropout_0(dense_edge)
# Concat edge attribute
h_dense_edge = torch.cat([ori_edge_attr, dense_edge], dim=-1)
h_dense_edge = self.edge_layer(h_dense_edge).permute(0, 3, 1, 2)
return h_dense_edge

View File

@@ -0,0 +1,44 @@
"""Common layers"""
import torch.nn as nn
import torch
import torch.nn.functional as F
import math
def get_act(config):
"""Get actiuvation functions from the config file."""
if config.model.nonlinearity.lower() == 'elu':
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
return nn.SiLU()
elif config.model.nonlinearity.lower() == 'tanh':
return nn.Tanh()
else:
raise NotImplementedError('activation function does not exist!')
def conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, padding=0):
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
padding=padding)
return conv
# from DDPM
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb

View File

@@ -0,0 +1,38 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from .setenc_modules import *
class SetPool(nn.Module):
def __init__(self, dim_input, num_outputs, dim_output,
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
super(SetPool, self).__init__()
if 'sab' in mode: # [32, 400, 128]
self.enc = nn.Sequential(
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
else: # [32, 400, 128]
self.enc = nn.Sequential(
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
if 'PF' in mode: # [32, 1, 501]
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
nn.Linear(dim_hidden, dim_output))
elif 'P' in mode:
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
else: # torch.Size([32, 1, 501])
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
nn.Linear(dim_hidden, dim_output))
# "", sm, sab, sabsm
def forward(self, X):
x1 = self.enc(X)
x2 = self.dec(x1)
return x2

View File

@@ -0,0 +1,67 @@
#####################################################################################
# Copyright (c) Juho Lee SetTransformer, ICML 2019 [GitHub set_transformer]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MAB(nn.Module):
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
super(MAB, self).__init__()
self.dim_V = dim_V
self.num_heads = num_heads
self.fc_q = nn.Linear(dim_Q, dim_V)
self.fc_k = nn.Linear(dim_K, dim_V)
self.fc_v = nn.Linear(dim_K, dim_V)
if ln:
self.ln0 = nn.LayerNorm(dim_V)
self.ln1 = nn.LayerNorm(dim_V)
self.fc_o = nn.Linear(dim_V, dim_V)
def forward(self, Q, K):
Q = self.fc_q(Q)
K, V = self.fc_k(K), self.fc_v(K)
dim_split = self.dim_V // self.num_heads
Q_ = torch.cat(Q.split(dim_split, 2), 0)
K_ = torch.cat(K.split(dim_split, 2), 0)
V_ = torch.cat(V.split(dim_split, 2), 0)
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O))
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
return O
class SAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, ln=False):
super(SAB, self).__init__()
self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)
def forward(self, X):
return self.mab(X, X)
class ISAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
super(ISAB, self).__init__()
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
nn.init.xavier_uniform_(self.I)
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
def forward(self, X):
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
return self.mab1(X, H)
class PMA(nn.Module):
def __init__(self, dim, num_heads, num_seeds, ln=False):
super(PMA, self).__init__()
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
nn.init.xavier_uniform_(self.S)
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
def forward(self, X):
return self.mab(self.S.repeat(X.size(0), 1, 1), X)

View File

@@ -0,0 +1,144 @@
import math
from typing import Union, Tuple, Optional
from torch_geometric.typing import PairTensor, Adj, OptTensor
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Linear
from torch_scatter import scatter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
import numpy as np
class PosTransLayer(MessagePassing):
"""Involving the edge feature and updating position feature. Multiply Msg."""
_alpha: OptTensor
def __init__(self, x_channels: int, pos_channels: int, out_channels: int,
heads: int = 1, dropout: float = 0., edge_dim: Optional[int] = None,
bias: bool = True, act=None, attn_clamp: bool = False, **kwargs):
kwargs.setdefault('aggr', 'add')
super(PosTransLayer, self).__init__(node_dim=0, **kwargs)
self.x_channels = x_channels
self.pos_channels = pos_channels
self.in_channels = in_channels = x_channels + pos_channels
self.out_channels = out_channels
self.heads = heads
self.dropout = dropout
self.edge_dim = edge_dim
self.attn_clamp = attn_clamp
if act is None:
self.act = nn.LeakyReLU(negative_slope=0.2)
else:
self.act = act
self.lin_key = Linear(in_channels, heads * out_channels)
self.lin_query = Linear(in_channels, heads * out_channels)
self.lin_value = Linear(in_channels, heads * out_channels)
self.lin_edge0 = Linear(edge_dim, heads * out_channels, bias=False)
self.lin_edge1 = Linear(edge_dim, heads * out_channels, bias=False)
self.lin_pos = Linear(heads * out_channels, pos_channels, bias=False)
self.lin_skip = Linear(x_channels, heads * out_channels, bias=bias)
self.norm1 = nn.GroupNorm(num_groups=min(heads * out_channels // 4, 32),
num_channels=heads * out_channels, eps=1e-6)
self.norm2 = nn.GroupNorm(num_groups=min(heads * out_channels // 4, 32),
num_channels=heads * out_channels, eps=1e-6)
# FFN
self.FFN = nn.Sequential(Linear(heads * out_channels, heads * out_channels),
self.act,
Linear(heads * out_channels, heads * out_channels))
self.reset_parameters()
def reset_parameters(self):
self.lin_key.reset_parameters()
self.lin_query.reset_parameters()
self.lin_value.reset_parameters()
self.lin_skip.reset_parameters()
self.lin_edge0.reset_parameters()
self.lin_edge1.reset_parameters()
self.lin_pos.reset_parameters()
def forward(self, x: OptTensor,
pos: Tensor,
edge_index: Adj,
edge_attr: OptTensor = None
) -> Tuple[Tensor, Tensor]:
""""""
H, C = self.heads, self.out_channels
x_feat = torch.cat([x, pos], -1)
query = self.lin_query(x_feat).view(-1, H, C)
key = self.lin_key(x_feat).view(-1, H, C)
value = self.lin_value(x_feat).view(-1, H, C)
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
out_x, out_pos = self.propagate(edge_index, query=query, key=key, value=value, pos=pos, edge_attr=edge_attr,
size=None)
out_x = out_x.view(-1, self.heads * self.out_channels)
# skip connection for x
x_r = self.lin_skip(x)
out_x = (out_x + x_r) / math.sqrt(2)
out_x = self.norm1(out_x)
# FFN
out_x = (out_x + self.FFN(out_x)) / math.sqrt(2)
out_x = self.norm2(out_x)
# skip connection for pos
out_pos = pos + torch.tanh(pos + out_pos)
return out_x, out_pos
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
pos_j: Tensor,
edge_attr: OptTensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tuple[Tensor, Tensor]:
edge_attn = self.lin_edge0(edge_attr).view(-1, self.heads, self.out_channels)
alpha = (query_i * key_j * edge_attn).sum(dim=-1) / math.sqrt(self.out_channels)
if self.attn_clamp:
alpha = alpha.clamp(min=-5., max=5.)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# node feature message
msg = value_j
msg = msg * self.lin_edge1(edge_attr).view(-1, self.heads, self.out_channels)
msg = msg * alpha.view(-1, self.heads, 1)
# node position message
pos_msg = pos_j * self.lin_pos(msg.reshape(-1, self.heads * self.out_channels))
return msg, pos_msg
def aggregate(self, inputs: Tuple[Tensor, Tensor], index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[Tensor, Tensor]:
if ptr is not None:
raise NotImplementedError("Not implement Ptr in aggregate")
else:
return (scatter(inputs[0], index, 0, dim_size=dim_size, reduce=self.aggr),
scatter(inputs[1], index, 0, dim_size=dim_size, reduce="mean"))
def update(self, inputs: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
return inputs
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)

View File

@@ -0,0 +1,255 @@
from copy import deepcopy as cp
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def clones(module, N):
return nn.ModuleList([cp(module) for _ in range(N)])
def attention(query, key, value, mask = None, dropout = None):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim = -1)
if dropout is not None:
attn = dropout(attn)
return torch.matmul(attn, value), attn
class MultiHeadAttention(nn.Module):
def __init__(self, config):
super(MultiHeadAttention, self).__init__()
self.d_model = config.d_model
self.n_head = config.n_head
self.d_k = config.d_model // config.n_head
self.linears = clones(nn.Linear(self.d_model, self.d_model), 4)
self.dropout = nn.Dropout(p=config.dropout)
def forward(self, query, key, value, mask = None):
if mask is not None:
mask = mask.unsqueeze(1)
batch_size = query.size(0)
query, key , value = [l(x).view(batch_size, -1, self.n_head, self.d_k).transpose(1,2) for l, x in zip(self.linears, (query, key, value))]
x, attn = attention(query, key, value, mask = mask, dropout = self.dropout)
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_head * self.d_k)
return self.linears[3](x), attn
class PositionwiseFeedForward(nn.Module):
def __init__(self, config):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(config.d_model, config.d_ff)
self.w_2 = nn.Linear(config.d_ff, config.d_model)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class PositionwiseFeedForwardLast(nn.Module):
def __init__(self, config):
super(PositionwiseFeedForwardLast, self).__init__()
self.w_1 = nn.Linear(config.d_model, config.d_ff)
self.w_2 = nn.Linear(config.d_ff, config.n_vocab)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class SelfAttentionBlock(nn.Module):
def __init__(self, config):
super(SelfAttentionBlock, self).__init__()
self.norm = nn.LayerNorm(config.d_model)
self.attn = MultiHeadAttention(config)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x, mask):
x_ = self.norm(x)
x_ , attn = self.attn(x_, x_, x_, mask)
return self.dropout(x_) + x, attn
class SourceAttentionBlock(nn.Module):
def __init__(self, config):
super(SourceAttentionBlock, self).__init__()
self.norm = nn.LayerNorm(config.d_model)
self.attn = MultiHeadAttention(config)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x, m, mask):
x_ = self.norm(x)
x_, attn = self.attn(x_, m, m, mask)
return self.dropout(x_) + x, attn
class FeedForwardBlock(nn.Module):
def __init__(self, config):
super(FeedForwardBlock, self).__init__()
self.norm = nn.LayerNorm(config.d_model)
self.feed_forward = PositionwiseFeedForward(config)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, x):
x_ = self.norm(x)
x_ = self.feed_forward(x_)
return self.dropout(x_) + x
class FeedForwardBlockLast(nn.Module):
def __init__(self, config):
super(FeedForwardBlockLast, self).__init__()
self.norm = nn.LayerNorm(config.d_model)
self.feed_forward = PositionwiseFeedForwardLast(config)
self.dropout = nn.Dropout(p = config.dropout)
# Only for the last layer
self.proj_fc = nn.Linear(config.d_model, config.n_vocab)
def forward(self, x):
x_ = self.norm(x)
x_ = self.feed_forward(x_)
return self.dropout(x_) + self.proj_fc(x)
class EncoderBlock(nn.Module):
def __init__(self, config):
super(EncoderBlock, self).__init__()
self.self_attn = SelfAttentionBlock(config)
self.feed_forward = FeedForwardBlock(config)
def forward(self, x, mask):
x, attn = self.self_attn(x, mask)
x = self.feed_forward(x)
return x, attn
class EncoderBlockLast(nn.Module):
def __init__(self, config):
super(EncoderBlockLast, self).__init__()
self.self_attn = SelfAttentionBlock(config)
self.feed_forward = FeedForwardBlockLast(config)
def forward(self, x, mask):
x, attn = self.self_attn(x, mask)
x = self.feed_forward(x)
return x, attn
class DecoderBlock(nn.Module):
def __init__(self, config):
super(DecoderBlock, self).__init__()
self.self_attn = SelfAttentionBlock(config)
self.src_attn = SourceAttentionBlock(config)
self.feed_forward = FeedForwardBlock(config)
def forward(self, x, m, src_mask, tgt_mask):
x, attn_tgt = self.self_attn(x, tgt_mask)
x, attn_src = self.src_attn(x, m, src_mask)
x = self.feed_forward(x)
return x, attn_src, attn_tgt
class Encoder(nn.Module):
def __init__(self, config):
super(Encoder, self).__init__()
self.layers = clones(EncoderBlock(config), config.n_layers)
self.norms = clones(nn.LayerNorm(config.d_model), config.n_layers)
def forward(self, x, mask):
outputs = []
attns = []
for layer, norm in zip(self.layers, self.norms):
x, attn = layer(x, mask)
outputs.append(norm(x))
attns.append(attn)
return outputs[-1], outputs, attns
class PositionalEmbedding(nn.Module):
def __init__(self, config):
super(PositionalEmbedding, self).__init__()
p2e = torch.zeros(config.max_len, config.d_model)
position = torch.arange(0.0, config.max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0.0, config.d_model, 2) * (- math.log(10000.0) / config.d_model))
p2e[:, 0::2] = torch.sin(position * div_term)
p2e[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('p2e', p2e)
def forward(self, x):
shp = x.size()
with torch.no_grad():
emb = torch.index_select(self.p2e, 0, x.view(-1)).view(shp + (-1,))
return emb
class Transformer(nn.Module):
def __init__(self, config):
super(Transformer, self).__init__()
self.p2e = PositionalEmbedding(config)
self.encoder = Encoder(config)
def forward(self, input_emb, position_ids, attention_mask):
# position embedding projection
projection = self.p2e(position_ids) + input_emb
return self.encoder(projection, attention_mask)
class TokenTypeEmbedding(nn.Module):
def __init__(self, config):
super(TokenTypeEmbedding, self).__init__()
self.t2e = nn.Embedding(config.n_token_type, config.d_model)
self.d_model = config.d_model
def forward(self, x):
return self.t2e(x) * math.sqrt(self.d_model)
class SemanticEmbedding(nn.Module):
def __init__(self, config):
super(SemanticEmbedding, self).__init__()
self.d_model = config.d_model
self.fc = nn.Linear(config.n_vocab, config.d_model)
def forward(self, x):
return self.fc(x) * math.sqrt(self.d_model)
class Embeddings(nn.Module):
def __init__(self, config):
super(Embeddings, self).__init__()
self.w2e = SemanticEmbedding(config)
self.p2e = PositionalEmbedding(config)
self.t2e = TokenTypeEmbedding(config)
self.dropout = nn.Dropout(p = config.dropout)
def forward(self, input_ids, position_ids = None, token_type_ids = None):
if position_ids is None:
batch_size, length = input_ids.size()
with torch.no_grad():
position_ids = torch.arange(0, length).repeat(batch_size, 1)
if torch.cuda.is_available():
position_ids = position_ids.cuda(device=input_ids.device)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
embeddings = self.w2e(input_ids) + self.p2e(position_ids) + self.t2e(token_type_ids)
return self.dropout(embeddings)

View File

@@ -0,0 +1,289 @@
import torch
import torch.nn.functional as F
import sde_lib
_MODELS = {}
def register_model(cls=None, *, name=None):
"""A decorator for registering model classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _MODELS:
raise ValueError(
f'Already registered model with name: {local_name}')
_MODELS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def get_model(name):
return _MODELS[name]
def create_model(config):
"""Create the model."""
model_name = config.model.name
model = get_model(model_name)(config)
model = model.to(config.device)
return model
def get_model_fn(model, train=False):
"""Create a function to give the output of the score-based model.
Args:
model: The score model.
train: `True` for training and `False` for evaluation.
Returns:
A model function.
"""
def model_fn(x, labels, *args, **kwargs):
"""Compute the output of the score-based model.
Args:
x: A mini-batch of input data (Adjacency matrices).
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
for different models.
mask: Mask for adjacency matrices.
Returns:
A tuple of (model output, new mutable states)
"""
if not train:
model.eval()
return model(x, labels, *args, **kwargs)
else:
model.train()
return model(x, labels, *args, **kwargs)
return model_fn
def get_score_fn(sde, model, train=False, continuous=False):
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A score model.
train: `True` for training and `False` for evaluation.
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
Returns:
A score function.
"""
model_fn = get_model_fn(model, train=train)
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
def score_fn(x, t, *args, **kwargs):
# Scale neural network output by standard deviation and flip sign
if continuous or isinstance(sde, sde_lib.subVPSDE):
# For VP-trained models, t=0 corresponds to the lowest noise level
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
labels = t * 999
score = model_fn(x, labels, *args, **kwargs)
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
score = model_fn(x, labels, *args, **kwargs)
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
labels.long()]
score = -score / std[:, None, None]
return score
elif isinstance(sde, sde_lib.VESDE):
def score_fn(x, t, *args, **kwargs):
if continuous:
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VE-trained models, t=0 corresponds to the highest noise level
labels = sde.T - t
labels *= sde.N - 1
labels = torch.round(labels).long()
score = model_fn(x, labels, *args, **kwargs)
return score
else:
raise NotImplementedError(
f"SDE class {sde.__class__.__name__} not yet supported.")
return score_fn
def get_classifier_grad_fn(sde, classifier, train=False, continuous=False,
regress=True, labels='max'):
logit_fn = get_logit_fn(sde, classifier, train, continuous)
def classifier_grad_fn(x, t, *args, **kwargs):
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
if regress:
assert labels in ['max', 'min']
logit = logit_fn(x_in, t, *args, **kwargs)
if labels == 'max':
prob = logit.sum()
elif labels == 'min':
prob = -logit.sum()
else:
logit = logit_fn(x_in, t, *args, **kwargs)
log_prob = F.log_softmax(logit, dim=-1)
prob = log_prob[range(len(logit)), labels.view(-1)].sum()
classifier_grad = torch.autograd.grad(prob, x_in)[0]
return classifier_grad
return classifier_grad_fn
def get_logit_fn(sde, classifier, train=False, continuous=False):
classifier_fn = get_model_fn(classifier, train=train)
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
def logit_fn(x, t, *args, **kwargs):
# Scale neural network output by standard deviation and flip sign
if continuous or isinstance(sde, sde_lib.subVPSDE):
# For VP-trained models, t=0 corresponds to the lowest noise level
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
labels = t * 999
logit = classifier_fn(x, labels, *args, **kwargs)
else:
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
logit = classifier_fn(x, labels, *args, **kwargs)
return logit
elif isinstance(sde, sde_lib.VESDE):
def logit_fn(x, t, *args, **kwargs):
if continuous:
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VE-trained models, t=0 corresponds to the highest noise level
labels = sde.T - t
labels *= sde.N - 1
labels = torch.round(labels).long()
logit = classifier_fn(x, labels, *args, **kwargs)
return logit
return logit_fn
def get_predictor_fn(sde, model, train=False, continuous=False):
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A predictor model.
train: `True` for training and `False` for evaluation.
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
Returns:
A score function.
"""
model_fn = get_model_fn(model, train=train)
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
def predictor_fn(x, t, *args, **kwargs):
# Scale neural network output by standard deviation and flip sign
if continuous or isinstance(sde, sde_lib.subVPSDE):
# For VP-trained models, t=0 corresponds to the lowest noise level
# The maximum value of time embedding is assumed to 999 for continuously-trained models.
labels = t * 999
pred = model_fn(x, labels, *args, **kwargs)
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
pred = model_fn(x, labels, *args, **kwargs)
std = sde.sqrt_1m_alpha_cumprod.to(labels.device)[
labels.long()]
return pred
elif isinstance(sde, sde_lib.VESDE):
def predictor_fn(x, t, *args, **kwargs):
if continuous:
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VE-trained models, t=0 corresponds to the highest noise level
labels = sde.T - t
labels *= sde.N - 1
labels = torch.round(labels).long()
pred = model_fn(x, labels, *args, **kwargs)
return pred
else:
raise NotImplementedError(
f"SDE class {sde.__class__.__name__} not yet supported.")
return predictor_fn
def to_flattened_numpy(x):
"""Flatten a torch tensor `x` and convert it to numpy."""
return x.detach().cpu().numpy().reshape((-1,))
def from_flattened_numpy(x, shape):
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
return torch.from_numpy(x.reshape(shape))
@torch.no_grad()
def mask_adj2node(adj_mask):
"""Convert batched adjacency mask matrices to batched node mask matrices.
Args:
adj_mask: [B, N, N] Batched adjacency mask matrices without self-loop edge.
Output:
node_mask: [B, N] Batched node mask matrices indicating the valid nodes.
"""
batch_size, max_num_nodes, _ = adj_mask.shape
node_mask = adj_mask[:, 0, :].clone()
node_mask[:, 0] = 1
return node_mask
@torch.no_grad()
def get_rw_feat(k_step, dense_adj):
"""Compute k_step Random Walk for given dense adjacency matrix."""
rw_list = []
deg = dense_adj.sum(-1, keepdims=True)
AD = dense_adj / (deg + 1e-8)
rw_list.append(AD)
for _ in range(k_step):
rw = torch.bmm(rw_list[-1], AD)
rw_list.append(rw)
rw_map = torch.stack(rw_list[1:], dim=1) # [B, k_step, N, N]
rw_landing = torch.diagonal(
rw_map, offset=0, dim1=2, dim2=3) # [B, k_step, N]
rw_landing = rw_landing.permute(0, 2, 1) # [B, N, rw_depth]
# get the shortest path distance indices
tmp_rw = rw_map.sort(dim=1)[0]
spd_ind = (tmp_rw <= 0).sum(dim=1) # [B, N, N]
spd_onehot = torch.nn.functional.one_hot(
spd_ind, num_classes=k_step+1).to(torch.float)
spd_onehot = spd_onehot.permute(0, 3, 1, 2) # [B, kstep, N, N]
return rw_landing, spd_onehot

520
NAS-Bench-201/run_lib.py Normal file
View File

@@ -0,0 +1,520 @@
import os
import torch
import numpy as np
import random
import logging
from absl import flags
from scipy.stats import pearsonr, spearmanr
import torch
from models import cate
from models import digcn
from models import digcn_meta
import losses
import sampling
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import datasets_nas
import sde_lib
from utils import *
from logger import Logger
from analysis.arch_metrics import SamplingArchMetrics, SamplingArchMetricsMeta
FLAGS = flags.FLAGS
def set_exp_name(config):
if config.task == 'tr_scorenet':
exp_name = f'./results/{config.task}/{config.folder_name}'
data = config.data
elif config.task == 'tr_meta_surrogate':
exp_name = f'./results/{config.task}/{config.folder_name}'
os.makedirs(exp_name, exist_ok=True)
config.exp_name = exp_name
set_random_seed(config)
return exp_name
def set_random_seed(config):
seed = config.seed
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def scorenet_train(config):
"""Runs the score network training pipeline.
Args:
config: Configuration to use.
"""
## Set logger
exp_name = set_exp_name(config)
logger = Logger(
log_dir=exp_name,
write_textfile=True)
logger.update_config(config, is_args=True)
logger.write_str(str(vars(config)))
logger.write_str('-' * 100)
## Create directories for experimental logs
sample_dir = os.path.join(exp_name, "samples")
os.makedirs(sample_dir, exist_ok=True)
## Initialize model and optimizer
score_model = mutils.create_model(config)
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
optimizer = losses.get_optimizer(config, score_model.parameters())
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0, config=config)
## Create checkpoints directory
checkpoint_dir = os.path.join(exp_name, "checkpoints")
## Intermediate checkpoints to resume training
checkpoint_meta_dir = os.path.join(exp_name, "checkpoints-meta", "checkpoint.pth")
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)
## Resume training when intermediate checkpoints are detected
if config.resume:
state = restore_checkpoint(config.resume_ckpt_path, state, config.device, resume=config.resume)
initial_step = int(state['step'])
## Build dataloader and iterators
train_ds, eval_ds, test_ds = datasets_nas.get_dataset(config)
train_loader, eval_loader, test_loader = datasets_nas.get_dataloader(config, train_ds, eval_ds, test_ds)
train_iter = iter(train_loader)
# Create data normalizer and its inverse
scaler = datasets_nas.get_data_scaler(config)
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
## Setup SDEs
if config.training.sde.lower() == 'vpsde':
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
sampling_eps = 1e-5
else:
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
# Build one-step training and evaluation functions
optimize_fn = losses.optimization_manager(config)
continuous = config.training.continuous
reduce_mean = config.training.reduce_mean
likelihood_weighting = config.training.likelihood_weighting
train_step_fn = losses.get_step_fn(sde=sde,
train=True,
optimize_fn=optimize_fn,
reduce_mean=reduce_mean,
continuous=continuous,
likelihood_weighting=likelihood_weighting,
data=config.data.name)
eval_step_fn = losses.get_step_fn(sde=sde,
train=False,
optimize_fn=optimize_fn,
reduce_mean=reduce_mean,
continuous=continuous,
likelihood_weighting=likelihood_weighting,
data=config.data.name)
## Build sampling functions
if config.training.snapshot_sampling:
sampling_shape = (config.training.eval_batch_size, config.data.max_node, config.data.n_vocab)
sampling_fn = sampling.get_sampling_fn(config=config,
sde=sde,
shape=sampling_shape,
inverse_scaler=inverse_scaler,
eps=sampling_eps)
## Build analysis tools
sampling_metrics = SamplingArchMetrics(config, train_ds, exp_name)
## Start training the score network
logging.info("Starting training loop at step %d." % (initial_step,))
element = {'train': ['training_loss'],
'eval': ['eval_loss'],
'test': ['test_loss'],
'sample': ['r_valid', 'r_unique', 'r_novel']}
num_train_steps = config.training.n_iters
is_best = False
min_test_loss = 1e05
for step in range(initial_step, num_train_steps+1):
try:
x, adj, extra = next(train_iter)
except StopIteration:
train_iter = train_loader.__iter__()
x, adj, extra = next(train_iter)
mask = aug_mask(adj, algo=config.data.aug_mask_algo, data=config.data.name)
x, adj, mask = scaler(x.to(config.device)), adj.to(config.device), mask.to(config.device)
batch = (x, adj, mask)
## Execute one training step
loss = train_step_fn(state, batch)
logger.update(key="training_loss", v=loss.item())
if step % config.training.log_freq == 0:
logging.info("step: %d, training_loss: %.5e" % (step, loss.item()))
## Report the loss on evaluation dataset periodically
if step % config.training.eval_freq == 0:
for eval_x, eval_adj, eval_extra in eval_loader:
eval_mask = aug_mask(eval_adj, algo=config.data.aug_mask_algo, data=config.data.name)
eval_x, eval_adj, eval_mask = scaler(eval_x.to(config.device)), eval_adj.to(config.device), eval_mask.to(config.device)
eval_batch = (eval_x, eval_adj, eval_mask)
eval_loss = eval_step_fn(state, eval_batch)
logging.info("step: %d, eval_loss: %.5e" % (step, eval_loss.item()))
logger.update(key="eval_loss", v=eval_loss.item())
for test_x, test_adj, test_extra in test_loader:
test_mask = aug_mask(test_adj, algo=config.data.aug_mask_algo, data=config.data.name)
test_x, test_adj, test_mask = scaler(test_x.to(config.device)), test_adj.to(config.device), test_mask.to(config.device)
test_batch = (test_x, test_adj, test_mask)
test_loss = eval_step_fn(state, test_batch)
logging.info("step: %d, test_loss: %.5e" % (step, test_loss.item()))
logger.update(key="test_loss", v=test_loss.item())
if logger.logs['test_loss'].avg < min_test_loss:
is_best = True
## Save the checkpoint
if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
save_step = step // config.training.snapshot_freq
save_checkpoint(checkpoint_dir, state, step, save_step, is_best)
## Generate samples
if config.training.snapshot_sampling:
ema.store(score_model.parameters())
ema.copy_to(score_model.parameters())
sample, sample_steps, _ = sampling_fn(score_model, mask)
quantized_sample = quantize(sample)
this_sample_dir = os.path.join(sample_dir, "iter_{}".format(step))
os.makedirs(this_sample_dir, exist_ok=True)
## Evaluate samples
arch_metric = sampling_metrics(arch_list=quantized_sample, this_sample_dir=this_sample_dir)
r_valid, r_unique, r_novel = arch_metric[0][0], arch_metric[0][1], arch_metric[0][2]
logger.update(key="r_valid", v=r_valid)
logger.update(key="r_unique", v=r_unique)
logger.update(key="r_novel", v=r_novel)
logging.info("r_valid: %.5e" % (r_valid))
logging.info("r_unique: %.5e" % (r_unique))
logging.info("r_novel: %.5e" % (r_novel))
if step % config.training.eval_freq == 0:
logger.write_log(element=element, step=step)
else:
logger.write_log(element={'train': ['training_loss']}, step=step)
logger.reset()
logger.save_log()
def scorenet_evaluate(config):
"""Evaluate trained score network.
Args:
config: Configuration to use.
"""
## Set logger
exp_name = set_exp_name(config)
logger = Logger(
log_dir=exp_name,
write_textfile=True)
logger.update_config(config, is_args=True)
logger.write_str(str(vars(config)))
logger.write_str('-' * 100)
## Load the config of pre-trained score network
score_config = torch.load(config.scorenet_ckpt_path)['config']
## Setup SDEs
if score_config.training.sde.lower() == 'vpsde':
sde = sde_lib.VPSDE(beta_min=score_config.model.beta_min, beta_max=score_config.model.beta_max, N=score_config.model.num_scales)
sampling_eps = 1e-3
elif score_config.training.sde.lower() == 'vesde':
sde = sde_lib.VESDE(sigma_min=score_config.model.sigma_min, sigma_max=score_config.model.sigma_max, N=score_config.model.num_scales)
sampling_eps = 1e-5
else:
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
## Creat data normalizer and its inverse
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
# Build the sampling function
sampling_shape = (config.eval.batch_size, score_config.data.max_node, score_config.data.n_vocab)
sampling_fn = sampling.get_sampling_fn(config=config,
sde=sde,
shape=sampling_shape,
inverse_scaler=inverse_scaler,
eps=sampling_eps)
## Load pre-trained score network
score_model = mutils.create_model(score_config)
ema = ExponentialMovingAverage(score_model.parameters(), decay=score_config.model.ema_rate)
state = dict(model=score_model, ema=ema, step=0, config=score_config)
state = restore_checkpoint(config.scorenet_ckpt_path, state, device=config.device, resume=True)
ema.store(score_model.parameters())
ema.copy_to(score_model.parameters())
## Build dataset
train_ds, eval_ds, test_ds = datasets_nas.get_dataset(score_config)
## Build analysis tools
sampling_metrics = SamplingArchMetrics(config, train_ds, exp_name)
## Create directories for experimental logs
sample_dir = os.path.join(exp_name, "samples")
os.makedirs(sample_dir, exist_ok=True)
## Start sampling
logging.info("Starting sampling")
element = {'sample': ['r_valid', 'r_unique', 'r_novel']}
num_sampling_rounds = int(np.ceil(config.eval.num_samples / config.eval.batch_size))
print(f'>>> Sampling for {num_sampling_rounds} rounds...')
all_samples = []
adj = train_ds.adj.to(config.device)
mask = train_ds.mask(algo=score_config.data.aug_mask_algo).to(config.device)
if len(adj.shape) == 2: adj = adj.unsqueeze(0)
if len(mask.shape) == 2: mask = mask.unsqueeze(0)
for _ in range(num_sampling_rounds):
sample, sample_steps, _ = sampling_fn(score_model, mask)
quantized_sample = quantize(sample)
all_samples += quantized_sample
## Evaluate samples
all_samples = all_samples[:config.eval.num_samples]
arch_metric = sampling_metrics(arch_list=all_samples, this_sample_dir=sample_dir)
r_valid, r_unique, r_novel = arch_metric[0][0], arch_metric[0][1], arch_metric[0][2]
logger.update(key="r_valid", v=r_valid)
logger.update(key="r_unique", v=r_unique)
logger.update(key="r_novel", v=r_novel)
logger.write_log(element=element, step=1)
logger.save_log()
def meta_surrogate_train(config):
"""Runs the meta-predictor model training pipeline.
Args:
config: Configuration to use.
"""
## Set logger
exp_name = set_exp_name(config)
logger = Logger(
log_dir=exp_name,
write_textfile=True)
logger.update_config(config, is_args=True)
logger.write_str(str(vars(config)))
logger.write_str('-' * 100)
## Create directories for experimental logs
sample_dir = os.path.join(exp_name, "samples")
os.makedirs(sample_dir, exist_ok=True)
## Initialize model and optimizer
surrogate_model = mutils.create_model(config)
optimizer = losses.get_optimizer(config, surrogate_model.parameters())
state = dict(optimizer=optimizer, model=surrogate_model, step=0, config=config)
## Create checkpoints directory
checkpoint_dir = os.path.join(exp_name, "checkpoints")
## Intermediate checkpoints to resume training
checkpoint_meta_dir = os.path.join(exp_name, "checkpoints-meta", "checkpoint.pth")
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)
## Resume training when intermediate checkpoints are detected and resume=True
state = restore_checkpoint(checkpoint_meta_dir, state, config.device, resume=config.resume)
initial_step = int(state['step'])
## Build dataloader and iterators
train_ds, eval_ds, test_ds = datasets_nas.get_meta_dataset(config)
train_loader, eval_loader, _ = datasets_nas.get_dataloader(config, train_ds, eval_ds, test_ds)
train_iter = iter(train_loader)
## Create data normalizer and its inverse
scaler = datasets_nas.get_data_scaler(config)
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
## Setup SDEs
if config.training.sde.lower() == 'vpsde':
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
sampling_eps = 1e-5
else:
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
## Build one-step training and evaluation functions
optimize_fn = losses.optimization_manager(config)
continuous = config.training.continuous
reduce_mean = config.training.reduce_mean
likelihood_weighting = config.training.likelihood_weighting
train_step_fn = losses.get_step_fn_predictor(sde=sde,
train=True,
optimize_fn=optimize_fn,
reduce_mean=reduce_mean,
continuous=continuous,
likelihood_weighting=likelihood_weighting,
data=config.data.name,
label_list=config.data.label_list,
noised=config.training.noised)
eval_step_fn = losses.get_step_fn_predictor(sde,
train=False,
optimize_fn=optimize_fn,
reduce_mean=reduce_mean,
continuous=continuous,
likelihood_weighting=likelihood_weighting,
data=config.data.name,
label_list=config.data.label_list,
noised=config.training.noised)
## Build sampling functions
if config.training.snapshot_sampling:
sampling_shape = (config.training.eval_batch_size, config.data.max_node, config.data.n_vocab)
sampling_fn = sampling.get_sampling_fn(config=config,
sde=sde,
shape=sampling_shape,
inverse_scaler=inverse_scaler,
eps=sampling_eps,
conditional=True,
data_name=config.sampling.check_dataname, # for sanity check
num_sample=config.model.num_sample)
## Load pre-trained score network
score_config = torch.load(config.scorenet_ckpt_path)['config']
check_config(score_config, config)
score_model = mutils.create_model(score_config)
score_ema = ExponentialMovingAverage(score_model.parameters(), decay=score_config.model.ema_rate)
score_state = dict(model=score_model, ema=score_ema, step=0, config=score_config)
score_state = restore_checkpoint(config.scorenet_ckpt_path, score_state, device=config.device, resume=True)
score_ema.copy_to(score_model.parameters())
## Build analysis tools
sampling_metrics = SamplingArchMetricsMeta(config, train_ds, exp_name)
## Start training
logging.info("Starting training loop at step %d." % (initial_step,))
element = {'train': ['training_loss'],
'eval': ['eval_loss', 'eval_p_corr', 'eval_s_corr'],
'sample': ['r_valid', 'r_unique', 'r_novel']}
num_train_steps = config.training.n_iters
is_best = False
max_eval_p_corr = -1
for step in range(initial_step, num_train_steps + 1):
try:
x, adj, extra, task = next(train_iter)
except StopIteration:
train_iter = train_loader.__iter__()
x, adj, extra, task = next(train_iter)
mask = aug_mask(adj, algo=config.data.aug_mask_algo, data=config.data.name)
x, adj, mask, task = scaler(x.to(config.device)), adj.to(config.device), mask.to(config.device), task.to(config.device)
batch = (x, adj, mask, extra, task)
## Execute one training step
loss, pred, labels = train_step_fn(state, batch)
logger.update(key="training_loss", v=loss.item())
if step % config.training.log_freq == 0:
logging.info("step: %d, training_loss: %.5e" % (step, loss.item()))
## Report the loss on evaluation dataset periodically
if step % config.training.eval_freq == 0:
eval_pred_list, eval_labels_list = list(), list()
for eval_x, eval_adj, eval_extra, eval_task in eval_loader:
eval_mask = aug_mask(eval_adj, algo=config.data.aug_mask_algo, data=config.data.name)
eval_x, eval_adj, eval_mask, eval_task = scaler(eval_x.to(config.device)), eval_adj.to(config.device), eval_mask.to(config.device), eval_task.to(config.device)
eval_batch = (eval_x, eval_adj, eval_mask, eval_extra, eval_task)
eval_loss, eval_pred, eval_labels = eval_step_fn(state, eval_batch)
eval_pred_list += [v.detach().item() for v in eval_pred.squeeze()]
eval_labels_list += [v.detach().item() for v in eval_labels.squeeze()]
logging.info("step: %d, eval_loss: %.5e" % (step, eval_loss.item()))
logger.update(key="eval_loss", v=eval_loss.item())
eval_p_corr = pearsonr(np.array(eval_pred_list), np.array(eval_labels_list))[0]
eval_s_corr = spearmanr(np.array(eval_pred_list), np.array(eval_labels_list))[0]
logging.info("step: %d, eval_p_corr: %.5e" % (step, eval_p_corr))
logging.info("step: %d, eval_s_corr: %.5e" % (step, eval_s_corr))
logger.update(key="eval_p_corr", v=eval_p_corr)
logger.update(key="eval_s_corr", v=eval_s_corr)
if eval_p_corr > max_eval_p_corr:
is_best = True
max_eval_p_corr = eval_p_corr
## Save a checkpoint periodically and generate samples
if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
## Save the checkpoint.
save_step = step // config.training.snapshot_freq
save_checkpoint(checkpoint_dir, state, step, save_step, is_best)
## Generate and save samples
if config.training.snapshot_sampling:
score_ema.store(score_model.parameters())
score_ema.copy_to(score_model.parameters())
sample = sampling_fn(score_model=score_model,
mask=mask,
classifier=surrogate_model,
classifier_scale=config.sampling.classifier_scale)
quantized_sample = quantize(sample) # quantization
this_sample_dir = os.path.join(sample_dir, "iter_{}".format(step))
os.makedirs(this_sample_dir, exist_ok=True)
## Evaluate samples
arch_metric = sampling_metrics(arch_list=quantized_sample,
this_sample_dir=this_sample_dir,
check_dataname=config.sampling.check_dataname)
r_valid, r_unique, r_novel = arch_metric[0][0], arch_metric[0][1], arch_metric[0][2]
logging.info("step: %d, r_valid: %.5e" % (step, r_valid))
logging.info("step: %d, r_unique: %.5e" % (step, r_unique))
logging.info("step: %d, r_novel: %.5e" % (step, r_novel))
logger.update(key="r_valid", v=r_valid)
logger.update(key="r_unique", v=r_unique)
logger.update(key="r_novel", v=r_novel)
if step % config.training.eval_freq == 0:
logger.write_log(element=element, step=step)
else:
logger.write_log(element={'train': ['training_loss']}, step=step)
logger.reset()
def check_config(config1, config2):
assert config1.model.sigma_min == config2.model.sigma_min
assert config1.model.sigma_max == config2.model.sigma_max
assert config1.training.sde == config2.training.sde
assert config1.training.continuous == config2.training.continuous
assert config1.data.centered == config2.data.centered
assert config1.data.max_node == config2.data.max_node
assert config1.data.n_vocab == config2.data.n_vocab
run_train_dict = {
'scorenet': scorenet_train,
'meta_surrogate': meta_surrogate_train
}
run_eval_dict = {
'scorenet': scorenet_evaluate,
}
def train(config):
run_train_dict[config.model_type](config)
def evaluate(config):
run_eval_dict[config.model_type](config)

579
NAS-Bench-201/sampling.py Normal file
View File

@@ -0,0 +1,579 @@
"""Various sampling methods."""
import functools
import torch
import numpy as np
import abc
from tqdm import trange
import sde_lib
from models import utils as mutils
from datasets_nas import MetaTestDataset
from all_path import DATA_PATH
_CORRECTORS = {}
_PREDICTORS = {}
def register_predictor(cls=None, *, name=None):
"""A decorator for registering predictor classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _PREDICTORS:
raise ValueError(f'Already registered predictor with name: {local_name}')
_PREDICTORS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def register_corrector(cls=None, *, name=None):
"""A decorator for registering corrector classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _CORRECTORS:
raise ValueError(f'Already registered corrector with name: {local_name}')
_CORRECTORS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def get_predictor(name):
return _PREDICTORS[name]
def get_corrector(name):
return _CORRECTORS[name]
def get_sampling_fn(
config,
sde,
shape,
inverse_scaler,
eps,
conditional=False,
data_name='cifar10',
num_sample=20):
"""Create a sampling function.
Args:
config: A `ml_collections.ConfigDict` object that contains all configuration information.
sde: A `sde_lib.SDE` object that represents the forward SDE.
shape: A sequence of integers representing the expected shape of a single sample.
inverse_scaler: The inverse data normalizer function.
eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability.
conditional: If `True`, the sampling function is conditional
data_name: A `str` name of the dataset.
num_sample: An `int` number of samples for each class of the dataset.
Returns:
A function that takes random states and a replicated training state and outputs samples with the
trailing dimensions matching `shape`.
"""
sampler_name = config.sampling.method
# Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases.
if sampler_name.lower() == 'pc':
predictor = get_predictor(config.sampling.predictor.lower())
corrector = get_corrector(config.sampling.corrector.lower())
if not conditional:
print('>>> Get pc_sampler...')
sampling_fn = get_pc_sampler_nas(sde=sde,
shape=shape,
predictor=predictor,
corrector=corrector,
inverse_scaler=inverse_scaler,
snr=config.sampling.snr,
n_steps=config.sampling.n_steps_each,
probability_flow=config.sampling.probability_flow,
continuous=config.training.continuous,
denoise=config.sampling.noise_removal,
eps=eps,
device=config.device)
else:
print('>>> Get pc_conditional_sampler...')
sampling_fn = get_pc_conditional_sampler_meta_nas(sde=sde,
shape=shape,
predictor=predictor,
corrector=corrector,
inverse_scaler=inverse_scaler,
snr=config.sampling.snr,
n_steps=config.sampling.n_steps_each,
probability_flow=config.sampling.probability_flow,
continuous=config.training.continuous,
denoise=config.sampling.noise_removal,
eps=eps,
device=config.device,
regress=config.sampling.regress,
labels=config.sampling.labels,
data_name=data_name,
num_sample=num_sample)
else:
raise NotImplementedError(f"Sampler name {sampler_name} unknown.")
return sampling_fn
class Predictor(abc.ABC):
"""The abstract class for a predictor algorithm."""
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__()
self.sde = sde
# Compute the reverse SDE/ODE
if isinstance(sde, tuple):
self.rsde = (sde[0].reverse(score_fn, probability_flow), sde[1].reverse(score_fn, probability_flow))
else:
self.rsde = sde.reverse(score_fn, probability_flow)
self.score_fn = score_fn
@abc.abstractmethod
def update_fn(self, x, t, *args, **kwargs):
"""One update of the predictor.
Args:
x: A PyTorch tensor representing the current state.
t: A PyTorch tensor representing the current time step.
Returns:
x: A PyTorch tensor of the next state.
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
"""
pass
class Corrector(abc.ABC):
"""The abstract class for a corrector algorithm."""
def __init__(self, sde, score_fn, snr, n_steps):
super().__init__()
self.sde = sde
self.score_fn = score_fn
self.snr = snr
self.n_steps = n_steps
@abc.abstractmethod
def update_fn(self, x, t, *args, **kwargs):
"""One update of the corrector.
Args:
x: A PyTorch tensor representing the current state.
t: A PyTorch tensor representing the current time step.
Returns:
x: A PyTorch tensor of the next state.
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
"""
pass
@register_predictor(name='euler_maruyama')
class EulerMaruyamaPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
def update_fn(self, x, t, *args, **kwargs):
dt = -1. / self.rsde.N
z = torch.randn_like(x)
drift, diffusion = self.rsde.sde(x, t, *args, **kwargs)
x_mean = x + drift * dt
x = x_mean + diffusion[:, None, None] * np.sqrt(-dt) * z
return x, x_mean
@register_predictor(name='reverse_diffusion')
class ReverseDiffusionPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
def update_fn(self, x, t, *args, **kwargs):
f, G = self.rsde.discretize(x, t, *args, **kwargs)
z = torch.randn_like(x)
x_mean = x - f
x = x_mean + G[:, None, None] * z
return x, x_mean
@register_predictor(name='none')
class NonePredictor(Predictor):
"""An empty predictor that does nothing."""
def __init__(self, sde, score_fn, probability_flow=False):
pass
def update_fn(self, x, t, *args, **kwargs):
return x, x
@register_corrector(name='langevin')
class LangevinCorrector(Corrector):
def __init__(self, sde, score_fn, snr, n_steps):
super().__init__(sde, score_fn, snr, n_steps)
def update_fn(self, x, t, *args, **kwargs):
sde = self.sde
score_fn = self.score_fn
n_steps = self.n_steps
target_snr = self.snr
if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
timestep = (t * (sde.N - 1) / sde.T).long()
# Note: it seems that subVPSDE doesn't set alphas
alpha = sde.alphas.to(t.device)[timestep]
else:
alpha = torch.ones_like(t)
for i in range(n_steps):
grad = score_fn(x, t, *args, **kwargs)
noise = torch.randn_like(x)
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
x_mean = x + step_size[:, None, None] * grad
x = x_mean + torch.sqrt(step_size * 2)[:, None, None] * noise
return x, x_mean
@register_corrector(name='none')
class NoneCorrector(Corrector):
"""An empty corrector that does nothing."""
def __init__(self, sde, score_fn, snr, n_steps):
pass
def update_fn(self, x, t, *args, **kwargs):
return x, x
def shared_predictor_update_fn(x, t, sde, model,
predictor, probability_flow, continuous, *args, **kwargs):
"""A wrapper that configures and returns the update function of predictors."""
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
if predictor is None:
# Corrector-only sampler
predictor_obj = NonePredictor(sde, score_fn, probability_flow)
else:
predictor_obj = predictor(sde, score_fn, probability_flow)
return predictor_obj.update_fn(x, t, *args, **kwargs)
def shared_corrector_update_fn(x, t, sde, model,
corrector, continuous, snr, n_steps, *args, **kwargs):
"""A wrapper that configures and returns the update function of correctors."""
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)
if corrector is None:
# Predictor-only sampler
corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
else:
corrector_obj = corrector(sde, score_fn, snr, n_steps)
return corrector_obj.update_fn(x, t, *args, **kwargs)
def get_pc_sampler(sde,
shape,
predictor,
corrector,
inverse_scaler,
snr,
n_steps=1,
probability_flow=False,
continuous=False,
denoise=True,
eps=1e-3,
device='cuda'):
"""Create a Predictor-Corrector (PC) sampler.
Args:
sde: An `sde_lib.SDE` object representing the forward SDE.
shape: A sequence of integers. The expected shape of a single sample.
predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.
corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.
inverse_scaler: The inverse data normalizer.
snr: A `float` number. The signal-to-noise ratio for configuring correctors.
n_steps: An integer. The number of corrector steps per predictor update.
probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
continuous: `True` indicates that the score model was continuously trained.
denoise: If `True`, add one-step denoising to the final samples.
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
device: PyTorch device.
Returns:
A sampling function that returns samples and the number of function evaluations during sampling.
"""
# Create predictor & corrector update functions
predictor_update_fn = functools.partial(shared_predictor_update_fn,
sde=sde,
predictor=predictor,
probability_flow=probability_flow,
continuous=continuous)
corrector_update_fn = functools.partial(shared_corrector_update_fn,
sde=sde,
corrector=corrector,
continuous=continuous,
snr=snr,
n_steps=n_steps)
def pc_sampler(model, n_nodes_pmf):
"""The PC sampler function.
Args:
model: A score model.
n_nodes_pmf: Probability mass function of graph nodes.
Returns:
Samples, number of function evaluations.
"""
with torch.no_grad():
# Initial sample
x = sde.prior_sampling(shape).to(device)
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
# Sample the number of nodes
n_nodes = torch.multinomial(n_nodes_pmf, shape[0], replacement=True)
mask = torch.zeros((shape[0], shape[-1]), device=device)
for i in range(shape[0]):
mask[i][:n_nodes[i]] = 1.
mask = (mask[:, None, :] * mask[:, :, None]).unsqueeze(1)
mask = torch.tril(mask, -1)
mask = mask + mask.transpose(-1, -2)
x = x * mask
for i in range(sde.N):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = corrector_update_fn(x, vec_t, model=model, mask=mask)
x = x * mask
x, x_mean = predictor_update_fn(x, vec_t, model=model, mask=mask)
x = x * mask
return inverse_scaler(x_mean if denoise else x) * mask, sde.N * (n_steps + 1), n_nodes
return pc_sampler
def get_pc_sampler_nas(sde,
shape,
predictor,
corrector,
inverse_scaler,
snr,
n_steps=1,
probability_flow=False,
continuous=False,
denoise=True,
eps=1e-3,
device='cuda'):
"""Create a Predictor-Corrector (PC) sampler.
Args:
sde: An `sde_lib.SDE` object representing the forward SDE.
shape: A sequence of integers. The expected shape of a single sample.
predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.
corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.
inverse_scaler: The inverse data normalizer.
snr: A `float` number. The signal-to-noise ratio for configuring correctors.
n_steps: An integer. The number of corrector steps per predictor update.
probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
continuous: `True` indicates that the score model was continuously trained.
denoise: If `True`, add one-step denoising to the final samples.
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
device: PyTorch device.
Returns:
A sampling function that returns samples and the number of function evaluations during sampling.
"""
# Create predictor & corrector update functions
predictor_update_fn = functools.partial(shared_predictor_update_fn,
sde=sde,
predictor=predictor,
probability_flow=probability_flow,
continuous=continuous)
corrector_update_fn = functools.partial(shared_corrector_update_fn,
sde=sde,
corrector=corrector,
continuous=continuous,
snr=snr,
n_steps=n_steps)
def pc_sampler(model, mask):
"""The PC sampler function.
Args:
model: A score model.
n_nodes_pmf: Probability mass function of graph nodes.
Returns:
Samples, number of function evaluations.
"""
with torch.no_grad():
# Initial sample
x = sde.prior_sampling(shape).to(device)
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
mask = mask[0].unsqueeze(0).repeat(x.size(0), 1, 1)
for i in trange(sde.N, desc='[PC sampling]', position=1, leave=False):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = corrector_update_fn(x, vec_t, model=model, maskX=mask)
x, x_mean = predictor_update_fn(x, vec_t, model=model, maskX=mask)
return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1), None
return pc_sampler
def get_pc_conditional_sampler_meta_nas(
sde,
shape,
predictor,
corrector,
inverse_scaler,
snr,
n_steps=1,
probability_flow=False,
continuous=False,
denoise=True,
eps=1e-5,
device='cuda',
regress=True,
labels='max',
data_name='cifar10',
num_sample=20):
"""Class-conditional sampling with Predictor-Corrector (PC) samplers.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
score_model: A `torch.nn.Module` object that represents the architecture of the score-based model.
classifier: A `torch.nn.Module` object that represents the architecture of the noise-dependent classifier.
# classifier_params: A dictionary that contains the weights of the classifier.
shape: A sequence of integers. The expected shape of a single sample.
predictor: A subclass of `sampling.predictor` that represents a predictor algorithm.
corrector: A subclass of `sampling.corrector` that represents a corrector algorithm.
inverse_scaler: The inverse data normalizer.
snr: A `float` number. The signal-to-noise ratio for correctors.
n_steps: An integer. The number of corrector steps per update of the predictor.
probability_flow: If `True`, solve the probability flow ODE for sampling with the predictor.
continuous: `True` indicates the score-based model was trained with continuous time.
denoise: If `True`, add one-step denoising to final samples.
eps: A `float` number. The SDE/ODE will be integrated to `eps` to avoid numerical issues.
Returns: A pmapped class-conditional image sampler.
"""
# --------- Meta-NAS ---------- #
test_dataset = MetaTestDataset(
data_path=DATA_PATH,
data_name=data_name,
num_sample=num_sample)
def conditional_predictor_update_fn(score_model, classifier, x, t, labels, maskX, classifier_scale, *args, **kwargs):
"""The predictor update function for class-conditional sampling."""
score_fn = mutils.get_score_fn(sde, score_model, train=False, continuous=continuous)
classifier_grad_fn = mutils.get_classifier_grad_fn(sde, classifier, train=False, continuous=continuous,
regress=regress, labels=labels)
def total_grad_fn(x, t, *args, **kwargs):
score = score_fn(x, t, maskX)
classifier_grad = classifier_grad_fn(x, t, maskX, *args, **kwargs)
return score + classifier_scale * classifier_grad
if predictor is None:
predictor_obj = NonePredictor(sde, total_grad_fn, probability_flow)
else:
predictor_obj = predictor(sde, total_grad_fn, probability_flow)
return predictor_obj.update_fn(x, t, *args, **kwargs)
def conditional_corrector_update_fn(score_model, classifier, x, t, labels, maskX, classifier_scale, *args, **kwargs):
"""The corrector update function for class-conditional sampling."""
score_fn = mutils.get_score_fn(sde, score_model, train=False, continuous=continuous)
classifier_grad_fn = mutils.get_classifier_grad_fn(sde, classifier, train=False, continuous=continuous,
regress=regress, labels=labels)
def total_grad_fn(x, t, *args, **kwargs):
score = score_fn(x, t, maskX)
classifier_grad = classifier_grad_fn(x, t, maskX, *args, **kwargs)
return score + classifier_scale * classifier_grad
if corrector is None:
corrector_obj = NoneCorrector(sde, total_grad_fn, snr, n_steps)
else:
corrector_obj = corrector(sde, total_grad_fn, snr, n_steps)
return corrector_obj.update_fn(x, t, *args, **kwargs)
def pc_conditional_sampler(
score_model,
mask,
classifier,
classifier_scale=None,
task=None):
"""Generate class-conditional samples with Predictor-Corrector (PC) samplers.
Args:
score_model: A `torch.nn.Module` object that represents the training state
of the score-based model.
labels: A JAX array of integers that represent the target label of each sample.
Returns:
Class-conditional samples.
"""
# to accerlerating sampling
with torch.no_grad():
if task is None:
task = test_dataset[0]
task = task.repeat(shape[0], 1, 1)
task = task.to(device)
else:
task = task.repeat(shape[0], 1, 1)
task = task.to(device)
classifier.sample_state = True
classifier.D_mu = None
# initial sample
x = sde.prior_sampling(shape).to(device)
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
if len(mask.shape) == 3: mask = mask[0]
mask = mask.unsqueeze(0).repeat(x.size(0), 1, 1) # adj
for i in trange(sde.N, desc='[PC conditional sampling]', position=1, leave=False):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = conditional_corrector_update_fn(score_model, classifier, x, vec_t, labels=labels, maskX=mask, task=task, classifier_scale=classifier_scale)
x, x_mean = conditional_predictor_update_fn(score_model, classifier, x, vec_t, labels=labels, maskX=mask, task=task, classifier_scale=classifier_scale)
classifier.sample_state = False
return inverse_scaler(x_mean if denoise else x)
return pc_conditional_sampler

View File

@@ -0,0 +1,4 @@
export LD_LIBRARY_PATH=/opt/conda/envs/gtctnz_2/lib/python3.7/site-packages/nvidia/cublas/lib/
echo '[Downloading processed]'
python main_exp/transfer_nag/get_files/get_preprocessed_data.py

View File

@@ -0,0 +1,15 @@
export LD_LIBRARY_PATH=/opt/conda/envs/gtctnz_2/lib/python3.7/site-packages/nvidia/cublas/lib/
DATANAME=$1
if [[ $DATANAME = 'aircraft' ]]; then
echo '[Downloading aircraft]'
python main_exp/transfer_nag/get_files/get_aircraft.py
elif [[ $DATANAME = 'pets' ]]; then
echo '[Downloading pets]'
python main_exp/transfer_nag/get_files/get_pets.py
else
echo 'Not Implemeted'
fi

View File

@@ -0,0 +1,6 @@
FOLDER_NAME='tr_meta_surrogate_nb201'
CUDA_VISIBLE_DEVICES=$1 python main.py --config configs/tr_meta_surrogate.py \
--mode train \
--config.folder_name $FOLDER_NAME

View File

@@ -0,0 +1,5 @@
FOLDER_NAME='tr_scorenet_nb201'
CUDA_VISIBLE_DEVICES=$1 python main.py --config configs/tr_scorenet.py \
--mode train \
--config.folder_name $FOLDER_NAME

View File

@@ -0,0 +1,10 @@
FOLDER_NAME='transfer_nag_nb201'
GPU=$1
DATANAME=$2
CUDA_VISIBLE_DEVICES=$GPU python main_exp/transfer_nag/main.py \
--gpu $GPU \
--test \
--folder_name $FOLDER_NAME \
--data-name $DATANAME

300
NAS-Bench-201/sde_lib.py Normal file
View File

@@ -0,0 +1,300 @@
"""Abstract SDE classes, Reverse SDE, and VP SDEs."""
import abc
import torch
import numpy as np
class SDE(abc.ABC):
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""
def __init__(self, N):
"""Construct an SDE.
Args:
N: number of discretization time steps.
"""
super().__init__()
self.N = N
@property
@abc.abstractmethod
def T(self):
"""End time of the SDE."""
pass
@abc.abstractmethod
def sde(self, x, t):
pass
@abc.abstractmethod
def marginal_prob(self, x, t):
"""Parameters to determine the marginal distribution of the SDE, $p_t(x)$"""
pass
@abc.abstractmethod
def prior_sampling(self, shape):
"""Generate one sample from the prior distribution, $p_T(x)$."""
pass
@abc.abstractmethod
def prior_logp(self, z, mask):
"""Compute log-density of the prior distribution.
Useful for computing the log-likelihood via probability flow ODE.
Args:
z: latent code
Returns:
log probability density
"""
pass
def discretize(self, x, t):
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probability flow sampling.
Defaults to Euler-Maruyama discretization.
Args:
x: a torch tensor
t: a torch float representing the time step (from 0 to `self.T`)
Returns:
f, G
"""
dt = 1 / self.N
drift, diffusion = self.sde(x, t)
f = drift * dt
G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
return f, G
def reverse(self, score_fn, probability_flow=False):
"""Create the reverse-time SDE/ODE.
Args:
score_fn: A time-dependent score-based model that takes x and t and returns the score.
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
"""
N = self.N
T = self.T
sde_fn = self.sde
discretize_fn = self.discretize
# Build the class for reverse-time SDE.
class RSDE(self.__class__):
def __init__(self):
self.N = N
self.probability_flow = probability_flow
@property
def T(self):
return T
def sde(self, x, t, *args, **kwargs):
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
drift, diffusion = sde_fn(x, t)
score = score_fn(x, t, *args, **kwargs)
drift = drift - diffusion[:, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
# Set the diffusion function to zero for ODEs.
diffusion = 0. if self.probability_flow else diffusion
return drift, diffusion
'''
def sde_score(self, x, t, score):
"""Create the drift and diffusion functions for the reverse SDE/ODE, given score values."""
drift, diffusion = sde_fn(x, t)
if len(score.shape) == 4:
drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
elif len(score.shape) == 3:
drift = drift - diffusion[:, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
else:
raise ValueError
diffusion = 0. if self.probability_flow else diffusion
return drift, diffusion
'''
def discretize(self, x, t, *args, **kwargs):
"""Create discretized iteration rules for the reverse diffusion sampler."""
f, G = discretize_fn(x, t)
rev_f = f - G[:, None, None] ** 2 * score_fn(x, t, *args, **kwargs) * \
(0.5 if self.probability_flow else 1.)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
'''
def discretize_score(self, x, t, score):
"""Create discretized iteration rules for the reverse diffusion sampler, given score values."""
f, G = discretize_fn(x, t)
if len(score.shape) == 4:
rev_f = f - G[:, None, None, None] ** 2 * score * \
(0.5 if self.probability_flow else 1.)
elif len(score.shape) == 3:
rev_f = f - G[:, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
else:
raise ValueError
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
'''
return RSDE()
class VPSDE(SDE):
def __init__(self, beta_min=0.1, beta_max=20, N=1000):
"""Construct a Variance Preserving SDE.
Args:
beta_min: value of beta(0)
beta_max: value of beta(1)
N: number of discretization steps
"""
super().__init__(N)
self.beta_0 = beta_min
self.beta_1 = beta_max
self.N = N
self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)
self.alphas = 1. - self.discrete_betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
@property
def T(self):
return 1
def sde(self, x, t):
beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
if len(x.shape) == 4:
drift = -0.5 * beta_t[:, None, None, None] * x
elif len(x.shape) == 3:
drift = -0.5 * beta_t[:, None, None] * x
else:
raise NotImplementedError
diffusion = torch.sqrt(beta_t)
return drift, diffusion
def marginal_prob(self, x, t):
log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
if len(x.shape) == 4:
mean = torch.exp(log_mean_coeff[:, None, None, None]) * x
elif len(x.shape) == 3:
mean = torch.exp(log_mean_coeff[:, None, None]) * x
else:
raise ValueError("The shape of x in marginal_prob is not correct.")
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
return mean, std
def prior_sampling(self, shape):
return torch.randn(*shape)
def prior_logp(self, z, mask):
N = torch.sum(mask, dim=tuple(range(1, len(mask.shape))))
logps = -N / 2. * np.log(2 * np.pi) - torch.sum((z * mask) ** 2, dim=(1, 2, 3)) / 2.
return logps
def discretize(self, x, t):
"""DDPM discretization."""
timestep = (t * (self.N - 1) / self.T).long()
beta = self.discrete_betas.to(x.device)[timestep]
alpha = self.alphas.to(x.device)[timestep]
sqrt_beta = torch.sqrt(beta)
if len(x.shape) == 4:
f = torch.sqrt(alpha)[:, None, None, None] * x - x
elif len(x.shape) == 3:
f = torch.sqrt(alpha)[:, None, None] * x - x
else:
NotImplementedError
G = sqrt_beta
return f, G
class subVPSDE(SDE):
def __init__(self, beta_min=0.1, beta_max=20, N=1000):
"""Construct the sub-VP SDE that excels at likelihoods.
Args:
beta_min: value of beta(0)
beta_max: value of beta(1)
N: number of discretization steps
"""
super().__init__(N)
self.beta_0 = beta_min
self.beta_1 = beta_max
self.N = N
@property
def T(self):
return 1
def sde(self, x, t):
beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
drift = -0.5 * beta_t[:, None, None] * x
discount = 1. - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2)
diffusion = torch.sqrt(beta_t * discount)
return drift, diffusion
def marginal_prob(self, x, t):
log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
mean = torch.exp(log_mean_coeff)[:, None, None] * x
std = 1 - torch.exp(2. * log_mean_coeff)
return mean, std
def prior_sampling(self, shape):
return torch.randn(*shape)
def prior_logp(self, z):
shape = z.shape
N = np.prod(shape[1:])
return -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2.
class VESDE(SDE):
def __init__(self, sigma_min=0.01, sigma_max=50, N=1000):
"""Construct a Variance Exploding SDE.
Args:
sigma_min: smallest sigma.
sigma_max: largest sigma.
N: number of discretization steps
"""
super().__init__(N)
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.N = N
@property
def T(self):
return 1
def sde(self, x, t):
sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
drift = torch.zeros_like(x)
diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),
device=t.device))
return drift, diffusion
def marginal_prob(self, x, t):
std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
mean = x
return mean, std
def prior_sampling(self, shape):
return torch.randn(*shape) * self.sigma_max
def prior_logp(self, z):
shape = z.shape
N = np.prod(shape[1:])
return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2)
def discretize(self, x, t):
"""SMLD(NCSN) discretization."""
timestep = (t * (self.N - 1) / self.T).long()
sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
self.discrete_sigmas[timestep.cpu() - 1].to(t.device))
f = torch.zeros_like(x)
G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
return f, G

262
NAS-Bench-201/utils.py Normal file
View File

@@ -0,0 +1,262 @@
import os
import logging
import torch
from torch_scatter import scatter
import shutil
@torch.no_grad()
def to_dense_adj(edge_index, batch=None, edge_attr=None, max_num_nodes=None):
"""Converts batched sparse adjacency matrices given by edge indices and
edge attributes to a single dense batched adjacency matrix.
Args:
edge_index (LongTensor): The edge indices.
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
edge_attr (Tensor, optional): Edge weights or multi-dimensional edge
features. (default: :obj:`None`)
max_num_nodes (int, optional): The size of the output node dimension.
(default: :obj:`None`)
Returns:
adj: [batch_size, max_num_nodes, max_num_nodes] Dense adjacency matrices.
mask: Mask for dense adjacency matrices.
"""
if batch is None:
batch = edge_index.new_zeros(edge_index.max().item() + 1)
batch_size = batch.max().item() + 1
one = batch.new_ones(batch.size(0))
num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='add')
cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
idx0 = batch[edge_index[0]]
idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]
if max_num_nodes is None:
max_num_nodes = num_nodes.max().item()
elif idx1.max() >= max_num_nodes or idx2.max() >= max_num_nodes:
mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes)
idx0 = idx0[mask]
idx1 = idx1[mask]
idx2 = idx2[mask]
edge_attr = None if edge_attr is None else edge_attr[mask]
if edge_attr is None:
edge_attr = torch.ones(idx0.numel(), device=edge_index.device)
size = [batch_size, max_num_nodes, max_num_nodes]
size += list(edge_attr.size())[1:]
adj = torch.zeros(size, dtype=edge_attr.dtype, device=edge_index.device)
flattened_size = batch_size * max_num_nodes * max_num_nodes
adj = adj.view([flattened_size] + list(adj.size())[3:])
idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2
scatter(edge_attr, idx, dim=0, out=adj, reduce='add')
adj = adj.view(size)
node_idx = torch.arange(batch.size(0), dtype=torch.long, device=edge_index.device)
node_idx = (node_idx - cum_nodes[batch]) + (batch * max_num_nodes)
mask = torch.zeros(batch_size * max_num_nodes, dtype=adj.dtype, device=adj.device)
mask[node_idx] = 1
mask = mask.view(batch_size, max_num_nodes)
mask = mask[:, None, :] * mask[:, :, None]
return adj, mask
def restore_checkpoint_partial(model, pretrained_stdict):
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_stdict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
return model
def restore_checkpoint(ckpt_dir, state, device, resume=False):
if not resume:
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
return state
elif not os.path.exists(ckpt_dir):
if not os.path.exists(os.path.dirname(ckpt_dir)):
os.makedirs(os.path.dirname(ckpt_dir))
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
return state
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
for k in state:
if k in ['optimizer', 'model', 'ema']:
state[k].load_state_dict(loaded_state[k])
else:
state[k] = loaded_state[k]
return state
def save_checkpoint(ckpt_dir, state, step, save_step, is_best, remove_except_best=False):
saved_state = {}
for k in state:
if k in ['optimizer', 'model', 'ema']:
saved_state.update({k: state[k].state_dict()})
else:
saved_state.update({k: state[k]})
os.makedirs(ckpt_dir, exist_ok=True)
torch.save(saved_state, os.path.join(ckpt_dir, f'checkpoint_{step}_{save_step}.pth.tar'))
if is_best:
shutil.copy(os.path.join(ckpt_dir, f'checkpoint_{step}_{save_step}.pth.tar'), os.path.join(ckpt_dir, 'model_best.pth.tar'))
# remove the ckpt except is_best state
if remove_except_best:
for ckpt_file in sorted(os.listdir(ckpt_dir)):
if not ckpt_file.startswith('checkpoint'):
continue
if os.path.join(ckpt_dir, ckpt_file) != os.path.join(ckpt_dir, 'model_best.pth.tar'):
os.remove(os.path.join(ckpt_dir, ckpt_file))
def floyed(r):
"""
:param r: a numpy NxN matrix with float 0,1
:return: a numpy NxN matrix with float 0,1
"""
# r = np.array(r)
if type(r) == torch.Tensor:
r = r.cpu().numpy()
N = r.shape[0]
# import pdb; pdb.set_trace()
for k in range(N):
for i in range(N):
for j in range(N):
if r[i, k] > 0 and r[k, j] > 0:
r[i, j] = 1
return r
def aug_mask(adj, algo='floyed', data='NASBench201'):
if len(adj.shape) == 2:
adj = adj.unsqueeze(0)
if data.lower() in ['nasbench201', 'ofa']:
assert len(adj.shape) == 3
r = adj[0].clone().detach()
if algo == 'long_range':
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
elif algo == 'floyed':
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
else:
mask_i = r
masks = [mask_i] * adj.size(0)
return torch.stack(masks)
else:
masks = []
for r in adj:
if algo == 'long_range':
mask_i = torch.from_numpy(long_range(r)).float().to(adj.device)
elif algo == 'floyed':
mask_i = torch.from_numpy(floyed(r)).float().to(adj.device)
else:
mask_i = r
masks.append(mask_i)
return torch.stack(masks)
def long_range(r):
"""
:param r: a numpy NxN matrix with float 0,1
:return: a numpy NxN matrix with float 0,1
"""
# r = np.array(r)
if type(r) == torch.Tensor:
r = r.cpu().numpy()
N = r.shape[0]
for j in range(1, N):
col_j = r[:, j][:j]
in_to_j = [i for i, val in enumerate(col_j) if val > 0]
if len(in_to_j) > 0:
for i in in_to_j:
col_i = r[:, i][:i]
in_to_i = [i for i, val in enumerate(col_i) if val > 0]
if len(in_to_i) > 0:
for k in in_to_i:
r[k, j] = 1
return r
def dense_adj(graph_data, max_num_nodes, scaler=None, dequantization=False):
"""Convert PyG DataBatch to dense adjacency matrices.
Args:
graph_data: DataBatch object.
max_num_nodes: The size of the output node dimension.
scaler: Data normalizer.
dequantization: uniform dequantization.
Returns:
adj: Dense adjacency matrices.
mask: Mask for adjacency matrices.
"""
adj, adj_mask = to_dense_adj(graph_data.edge_index, graph_data.batch, max_num_nodes=max_num_nodes) # [B, N, N]
# adj: [32, 20, 20] / adj_mask: [32, 20, 20]
if dequantization:
noise = torch.rand_like(adj)
noise = torch.tril(noise, -1)
noise = noise + noise.transpose(1, 2)
adj = (noise + adj) / 2.
adj = scaler(adj[:, None, :, :]) # [32, 1, 20, 20]
# set diag = 0 in adj_mask
adj_mask = torch.tril(adj_mask, -1) # [32, 20, 20]
adj_mask = adj_mask + adj_mask.transpose(1, 2)
return adj, adj_mask[:, None, :, :]
def adj2graph(adj, sample_nodes):
"""Covert the PyTorch tensor adjacency matrices to numpy array.
Args:
adj: [Batch_size, channel, Max_node, Max_node], assume channel=1
sample_nodes: [Batch_size]
"""
adj_list = []
# discretization
adj[adj >= 0.5] = 1.
adj[adj < 0.5] = 0.
for i in range(adj.shape[0]):
adj_tmp = adj[i, 0]
# symmetric
adj_tmp = torch.tril(adj_tmp, -1)
adj_tmp = adj_tmp + adj_tmp.transpose(0, 1)
# truncate
adj_tmp = adj_tmp.cpu().numpy()[:sample_nodes[i], :sample_nodes[i]]
adj_list.append(adj_tmp)
return adj_list
def quantize(x):
"""Covert the PyTorch tensor x, adj matrices to numpy array.
Args:
x: [Batch_size, Max_node, N_vocab]
adj: [Batch_size, Max_node, Max_node]
"""
x_list = []
# discretization
x[x >= 0.5] = 1.
x[x < 0.5] = 0.
for i in range(x.shape[0]):
x_tmp = x[i]
x_tmp = x_tmp.cpu().numpy()
x_list.append(x_tmp)
return x_list