first commit
This commit is contained in:
347
NAS-Bench-201/analysis/arch_functions.py
Normal file
347
NAS-Bench-201/analysis/arch_functions.py
Normal file
@@ -0,0 +1,347 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from all_path import *
|
||||
|
||||
|
||||
class BasicArchMetrics(object):
|
||||
def __init__(self, train_ds=None, train_arch_str_list=None):
|
||||
if train_ds is None:
|
||||
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
else:
|
||||
self.ops_decoder = train_ds.ops_decoder
|
||||
self.nasbench201 = torch.load(NASBENCH201_INFO)
|
||||
self.train_arch_str_list = train_arch_str_list
|
||||
|
||||
|
||||
def compute_validity(self, generated):
|
||||
START_TYPE = self.ops_decoder.index('input')
|
||||
END_TYPE = self.ops_decoder.index('output')
|
||||
|
||||
valid = []
|
||||
valid_arch_str = []
|
||||
all_arch_str = []
|
||||
for x in generated:
|
||||
is_valid, error_types = is_valid_NAS201_x(x, START_TYPE, END_TYPE)
|
||||
if is_valid:
|
||||
valid.append(x)
|
||||
arch_str = decode_x_to_NAS_BENCH_201_string(x, self.ops_decoder)
|
||||
valid_arch_str.append(arch_str)
|
||||
else:
|
||||
arch_str = None
|
||||
all_arch_str.append(arch_str)
|
||||
validity = 0 if len(generated) == 0 else (len(valid)/len(generated))
|
||||
return valid, validity, valid_arch_str, all_arch_str
|
||||
|
||||
|
||||
def compute_uniqueness(self, valid_arch_str):
|
||||
return list(set(valid_arch_str)), len(set(valid_arch_str)) / len(valid_arch_str)
|
||||
|
||||
|
||||
def compute_novelty(self, unique):
|
||||
num_novel = 0
|
||||
novel = []
|
||||
if self.train_arch_str_list is None:
|
||||
print("Dataset arch_str is None, novelty computation skipped")
|
||||
return 1, 1
|
||||
for arch_str in unique:
|
||||
if arch_str not in self.train_arch_str_list:
|
||||
novel.append(arch_str)
|
||||
num_novel += 1
|
||||
return novel, num_novel / len(unique)
|
||||
|
||||
|
||||
def evaluate(self, generated, check_dataname='cifar10'):
|
||||
valid, validity, valid_arch_str, all_arch_str = self.compute_validity(generated)
|
||||
|
||||
if validity > 0:
|
||||
unique, uniqueness = self.compute_uniqueness(valid_arch_str)
|
||||
if self.train_arch_str_list is not None:
|
||||
_, novelty = self.compute_novelty(unique)
|
||||
else:
|
||||
novelty = -1.0
|
||||
else:
|
||||
novelty = -1.0
|
||||
uniqueness = 0.0
|
||||
unique = []
|
||||
|
||||
if uniqueness > 0.:
|
||||
arch_idx_list, flops_list, params_list, latency_list = list(), list(), list(), list()
|
||||
for arch in unique:
|
||||
arch_index, flops, params, latency = \
|
||||
get_arch_acc_info(self.nasbench201, arch=arch, dataname=check_dataname)
|
||||
arch_idx_list.append(arch_index)
|
||||
flops_list.append(flops)
|
||||
params_list.append(params)
|
||||
latency_list.append(latency)
|
||||
else:
|
||||
arch_idx_list, flops_list, params_list, latency_list = [-1], [0], [0], [0]
|
||||
|
||||
return ([validity, uniqueness, novelty],
|
||||
unique,
|
||||
dict(arch_idx_list=arch_idx_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
|
||||
all_arch_str)
|
||||
|
||||
|
||||
class BasicArchMetricsMeta(object):
|
||||
def __init__(self, train_ds=None, train_arch_str_list=None):
|
||||
if train_ds is None:
|
||||
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
else:
|
||||
self.ops_decoder = train_ds.ops_decoder
|
||||
self.nasbench201 = torch.load(NASBENCH201_INFO)
|
||||
self.train_arch_str_list = train_arch_str_list
|
||||
|
||||
|
||||
def compute_validity(self, generated):
|
||||
START_TYPE = self.ops_decoder.index('input')
|
||||
END_TYPE = self.ops_decoder.index('output')
|
||||
|
||||
valid = []
|
||||
valid_arch_str = []
|
||||
all_arch_str = []
|
||||
error_types = []
|
||||
|
||||
for x in generated:
|
||||
is_valid, error_type = is_valid_NAS201_x(x, START_TYPE, END_TYPE)
|
||||
if is_valid:
|
||||
valid.append(x)
|
||||
arch_str = decode_x_to_NAS_BENCH_201_string(x, self.ops_decoder)
|
||||
valid_arch_str.append(arch_str)
|
||||
else:
|
||||
arch_str = None
|
||||
error_types.append(error_type)
|
||||
all_arch_str.append(arch_str)
|
||||
|
||||
# exceptional case
|
||||
validity = 0 if len(generated) == 0 else (len(valid)/len(generated))
|
||||
if len(valid) == 0:
|
||||
validity = 0
|
||||
valid_arch_str = []
|
||||
|
||||
return valid, validity, valid_arch_str, all_arch_str
|
||||
|
||||
|
||||
def compute_uniqueness(self, valid_arch_str):
|
||||
return list(set(valid_arch_str)), len(set(valid_arch_str)) / len(valid_arch_str)
|
||||
|
||||
|
||||
def compute_novelty(self, unique):
|
||||
num_novel = 0
|
||||
novel = []
|
||||
if self.train_arch_str_list is None:
|
||||
print("Dataset arch_str is None, novelty computation skipped")
|
||||
return 1, 1
|
||||
for arch_str in unique:
|
||||
if arch_str not in self.train_arch_str_list:
|
||||
novel.append(arch_str)
|
||||
num_novel += 1
|
||||
return novel, num_novel / len(unique)
|
||||
|
||||
|
||||
def evaluate(self, generated, check_dataname='cifar10'):
|
||||
valid, validity, valid_arch_str, all_arch_str = self.compute_validity(generated)
|
||||
|
||||
if validity > 0:
|
||||
unique, uniqueness = self.compute_uniqueness(valid_arch_str)
|
||||
if self.train_arch_str_list is not None:
|
||||
_, novelty = self.compute_novelty(unique)
|
||||
else:
|
||||
novelty = -1.0
|
||||
else:
|
||||
novelty = -1.0
|
||||
uniqueness = 0.0
|
||||
unique = []
|
||||
|
||||
if uniqueness > 0.:
|
||||
arch_idx_list, flops_list, params_list, latency_list = list(), list(), list(), list()
|
||||
for arch in unique:
|
||||
arch_index, flops, params, latency = \
|
||||
get_arch_acc_info_meta(self.nasbench201, arch=arch, dataname=check_dataname)
|
||||
arch_idx_list.append(arch_index)
|
||||
flops_list.append(flops)
|
||||
params_list.append(params)
|
||||
latency_list.append(latency)
|
||||
else:
|
||||
arch_idx_list, flops_list, params_list, latency_list = [-1], [0], [0], [0]
|
||||
|
||||
return ([validity, uniqueness, novelty],
|
||||
unique,
|
||||
dict(arch_idx_list=arch_idx_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
|
||||
all_arch_str)
|
||||
|
||||
|
||||
def get_arch_acc_info(nasbench201, arch, dataname='cifar10'):
|
||||
arch_index = nasbench201['str'].index(arch)
|
||||
flops = nasbench201['flops'][dataname][arch_index]
|
||||
params = nasbench201['params'][dataname][arch_index]
|
||||
latency = nasbench201['latency'][dataname][arch_index]
|
||||
return arch_index, flops, params, latency
|
||||
|
||||
|
||||
def get_arch_acc_info_meta(nasbench201, arch, dataname='cifar10'):
|
||||
arch_index = nasbench201['str'].index(arch)
|
||||
flops = nasbench201['flops'][dataname][arch_index]
|
||||
params = nasbench201['params'][dataname][arch_index]
|
||||
latency = nasbench201['latency'][dataname][arch_index]
|
||||
return arch_index, flops, params, latency
|
||||
|
||||
|
||||
def decode_igraph_to_NAS_BENCH_201_string(g):
|
||||
if not is_valid_NAS201(g):
|
||||
return None
|
||||
m = decode_igraph_to_NAS201_matrix(g)
|
||||
types = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||
return '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
|
||||
format(types[int(m[1][0])],
|
||||
types[int(m[2][0])], types[int(m[2][1])],
|
||||
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
|
||||
|
||||
|
||||
def decode_igraph_to_NAS201_matrix(g):
|
||||
m = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
|
||||
xys = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]
|
||||
for i, xy in enumerate(xys):
|
||||
m[xy[0]][xy[1]] = float(g.vs[i + 1]['type']) - 2
|
||||
import numpy
|
||||
return numpy.array(m)
|
||||
|
||||
|
||||
def decode_x_to_NAS_BENCH_201_matrix(x):
|
||||
m = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
|
||||
xys = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]
|
||||
for i, xy in enumerate(xys):
|
||||
# m[xy[0]][xy[1]] = int(torch.argmax(torch.tensor(x[i+1])).item()) - 2
|
||||
m[xy[0]][xy[1]] = int(torch.argmax(torch.tensor(x[i+1])).item())
|
||||
import numpy
|
||||
return numpy.array(m)
|
||||
|
||||
|
||||
def decode_x_to_NAS_BENCH_201_string(x, ops_decoder):
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): x_elem [8, 7]
|
||||
|
||||
Returns:
|
||||
arch_str
|
||||
"""
|
||||
is_valid, error_type = is_valid_NAS201_x(x)
|
||||
if not is_valid:
|
||||
return None
|
||||
m = decode_x_to_NAS_BENCH_201_matrix(x)
|
||||
types = ops_decoder
|
||||
arch_str = '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
|
||||
format(types[int(m[1][0])],
|
||||
types[int(m[2][0])], types[int(m[2][1])],
|
||||
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
|
||||
return arch_str
|
||||
|
||||
|
||||
def decode_x_to_NAS_BENCH_201_string(x, ops_decoder):
|
||||
"""_summary_
|
||||
Args:
|
||||
x (torch.Tensor): x_elem [8, 7]
|
||||
Returns:
|
||||
arch_str
|
||||
"""
|
||||
|
||||
if not is_valid_NAS201_x(x)[0]:
|
||||
return None
|
||||
m = decode_x_to_NAS_BENCH_201_matrix(x)
|
||||
types = ops_decoder
|
||||
arch_str = '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.\
|
||||
format(types[int(m[1][0])],
|
||||
types[int(m[2][0])], types[int(m[2][1])],
|
||||
types[int(m[3][0])], types[int(m[3][1])], types[int(m[3][2])])
|
||||
return arch_str
|
||||
|
||||
|
||||
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
|
||||
res = g.is_dag()
|
||||
n_start, n_end = 0, 0
|
||||
for v in g.vs:
|
||||
if v['type'] == START_TYPE:
|
||||
n_start += 1
|
||||
elif v['type'] == END_TYPE:
|
||||
n_end += 1
|
||||
if v.indegree() == 0 and v['type'] != START_TYPE:
|
||||
return False
|
||||
if v.outdegree() == 0 and v['type'] != END_TYPE:
|
||||
return False
|
||||
return res and n_start == 1 and n_end == 1
|
||||
|
||||
|
||||
def is_valid_NAS201(g, START_TYPE=0, END_TYPE=1):
|
||||
# first need to be a valid DAG computation graph
|
||||
res = is_valid_DAG(g, START_TYPE, END_TYPE)
|
||||
# in addition, node i must connect to node i+1
|
||||
res = res and len(g.vs['type']) == 8
|
||||
res = res and not (START_TYPE in g.vs['type'][1:-1])
|
||||
res = res and not (END_TYPE in g.vs['type'][1:-1])
|
||||
return res
|
||||
|
||||
|
||||
def check_single_node_type(x):
|
||||
for x_elem in x:
|
||||
if int(np.sum(x_elem)) != 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_start_end_nodes(x, START_TYPE, END_TYPE):
|
||||
if x[0][START_TYPE] != 1:
|
||||
return False
|
||||
if x[-1][END_TYPE] != 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_interm_node_types(x, START_TYPE, END_TYPE):
|
||||
for x_elem in x[1:-1]:
|
||||
if x_elem[START_TYPE] == 1:
|
||||
return False
|
||||
if x_elem[END_TYPE] == 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
ERORR_NB201 = {
|
||||
'MULTIPLE_NODE_TYPES': 1,
|
||||
'No_START_END': 2,
|
||||
'INTERM_START_END': 3,
|
||||
'NO_ERROR': -1
|
||||
}
|
||||
|
||||
|
||||
def is_valid_NAS201_x(x, START_TYPE=0, END_TYPE=1):
|
||||
# first need to be a valid DAG computation graph
|
||||
assert len(x.shape) == 2
|
||||
|
||||
if not check_single_node_type(x):
|
||||
return False, ERORR_NB201['MULTIPLE_NODE_TYPES']
|
||||
|
||||
if not check_start_end_nodes(x, START_TYPE, END_TYPE):
|
||||
return False, ERORR_NB201['No_START_END']
|
||||
|
||||
if not check_interm_node_types(x, START_TYPE, END_TYPE):
|
||||
return False, ERORR_NB201['INTERM_START_END']
|
||||
|
||||
return True, ERORR_NB201['NO_ERROR']
|
||||
|
||||
|
||||
def compute_arch_metrics(arch_list,
|
||||
train_arch_str_list,
|
||||
train_ds,
|
||||
check_dataname='cifar10'):
|
||||
metrics = BasicArchMetrics(train_ds, train_arch_str_list)
|
||||
arch_metrics = metrics.evaluate(arch_list, check_dataname=check_dataname)
|
||||
all_arch_str = arch_metrics[-1]
|
||||
return arch_metrics, all_arch_str
|
||||
|
||||
def compute_arch_metrics_meta(arch_list,
|
||||
train_arch_str_list,
|
||||
train_ds,
|
||||
check_dataname='cifar10'):
|
||||
metrics = BasicArchMetricsMeta(train_ds, train_arch_str_list)
|
||||
arch_metrics = metrics.evaluate(arch_list, check_dataname=check_dataname)
|
||||
return arch_metrics
|
77
NAS-Bench-201/analysis/arch_metrics.py
Normal file
77
NAS-Bench-201/analysis/arch_metrics.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from analysis.arch_functions import compute_arch_metrics, compute_arch_metrics_meta
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SamplingArchMetrics(nn.Module):
|
||||
def __init__(self,
|
||||
config,
|
||||
train_ds,
|
||||
exp_name,):
|
||||
|
||||
super().__init__()
|
||||
self.exp_name = exp_name
|
||||
self.train_ds = train_ds
|
||||
self.train_arch_str_list = train_ds.arch_str_list_
|
||||
|
||||
|
||||
def forward(self,
|
||||
arch_list: list,
|
||||
this_sample_dir,
|
||||
check_dataname='cifar10'):
|
||||
|
||||
arch_metrics, all_arch_str = compute_arch_metrics(arch_list=arch_list,
|
||||
train_arch_str_list=self.train_arch_str_list,
|
||||
train_ds=self.train_ds,
|
||||
check_dataname=check_dataname)
|
||||
|
||||
valid_unique_arch = arch_metrics[1] # arch_str
|
||||
valid_unique_arch_prop_dict = arch_metrics[2] # flops, params, latency
|
||||
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
|
||||
for i in range(len(valid_unique_arch)):
|
||||
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
|
||||
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
|
||||
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n")
|
||||
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
|
||||
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n\n")
|
||||
textfile.writelines(valid_unique_arch)
|
||||
textfile.close()
|
||||
|
||||
return arch_metrics
|
||||
|
||||
|
||||
class SamplingArchMetricsMeta(nn.Module):
|
||||
def __init__(self,
|
||||
config,
|
||||
train_ds,
|
||||
exp_name):
|
||||
|
||||
super().__init__()
|
||||
self.exp_name = exp_name
|
||||
self.train_ds = train_ds
|
||||
self.search_space = config.data.name
|
||||
self.train_arch_str_list = [train_ds.arch_str_list[i] for i in train_ds.idx_lst['train']]
|
||||
|
||||
|
||||
def forward(self,
|
||||
arch_list: list,
|
||||
this_sample_dir,
|
||||
check_dataname='cifar10'):
|
||||
|
||||
arch_metrics = compute_arch_metrics_meta(arch_list=arch_list,
|
||||
train_arch_str_list=self.train_arch_str_list,
|
||||
train_ds=self.train_ds,
|
||||
check_dataname=check_dataname)
|
||||
|
||||
valid_unique_arch = arch_metrics[1] # arch_str
|
||||
valid_unique_arch_prop_dict = arch_metrics[2] # flops, params, latency
|
||||
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
|
||||
for i in range(len(valid_unique_arch)):
|
||||
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
|
||||
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
|
||||
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n")
|
||||
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
|
||||
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n\n")
|
||||
textfile.writelines(valid_unique_arch)
|
||||
textfile.close()
|
||||
|
||||
return arch_metrics
|
Reference in New Issue
Block a user