first commit
This commit is contained in:
475
MobileNetV3/analysis/arch_functions.py
Normal file
475
MobileNetV3/analysis/arch_functions.py
Normal file
@@ -0,0 +1,475 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import wandb
|
||||
import igraph
|
||||
from torch.nn.functional import one_hot
|
||||
|
||||
|
||||
KS_LIST = [3, 5, 7]
|
||||
EXPAND_LIST = [3, 4, 6]
|
||||
DEPTH_LIST = [2, 3, 4]
|
||||
NUM_STAGE = 5
|
||||
MAX_LAYER_PER_STAGE = 4
|
||||
MAX_N_BLOCK= NUM_STAGE * MAX_LAYER_PER_STAGE # 20
|
||||
OPS = {
|
||||
'3-3': 0, '3-4': 1, '3-6': 2,
|
||||
'5-3': 3, '5-4': 4, '5-6': 5,
|
||||
'7-3': 6, '7-4': 7, '7-6': 8,
|
||||
}
|
||||
|
||||
OPS2STR = {
|
||||
0: '3-3', 1: '3-4', 2: '3-6',
|
||||
3: '5-3', 4: '5-4', 5: '5-6',
|
||||
6: '7-3', 7: '7-4', 8: '7-6',
|
||||
}
|
||||
NUM_OPS = len(OPS)
|
||||
LONGEST_PATH_LENGTH = 20
|
||||
|
||||
|
||||
class BasicArchMetricsOFA(object):
|
||||
def __init__(self, train_ds=None, train_arch_str_list=None, except_inout=False, data_root=None):
|
||||
if data_root is not None:
|
||||
self.ofa = torch.load(data_root)
|
||||
self.train_arch_list = self.ofa['x']
|
||||
else:
|
||||
self.ofa = None
|
||||
self.train_arch_list = None
|
||||
# self.ofa = torch.load(data_root)
|
||||
self.ops_decoder = OPS
|
||||
self.except_inout = except_inout
|
||||
|
||||
def get_string_from_onehot_x(self, x):
|
||||
# node_types = torch.nonzero(torch.tensor(x).long(), as_tuple=True)[1]
|
||||
x = torch.tensor(x)
|
||||
ds = torch.sum(x.view(NUM_STAGE, -1), dim=1)
|
||||
string = ''
|
||||
for i, _ in enumerate(x):
|
||||
if sum(_) == 0:
|
||||
string += '0-0-0_'
|
||||
else:
|
||||
string += f'{int(ds[int(i/MAX_LAYER_PER_STAGE)])}-' + OPS2STR[torch.nonzero(torch.tensor(_)).item()] + '_'
|
||||
return string[:-1]
|
||||
|
||||
|
||||
def compute_validity(self, generated, adj=None, mask=None):
|
||||
""" generated: list of couples (positions, node_types)"""
|
||||
valid = []
|
||||
error_types = []
|
||||
valid_str = []
|
||||
for x in generated:
|
||||
is_valid, error_type = is_valid_OFA_x(x)
|
||||
if is_valid:
|
||||
valid.append(torch.tensor(x).long())
|
||||
valid_str.append(self.get_string_from_onehot_x(x))
|
||||
else:
|
||||
error_types.append(error_type)
|
||||
|
||||
return valid, len(valid) / len(generated), valid_str, None, error_types
|
||||
|
||||
def compute_uniqueness(self, valid_arch):
|
||||
unique = []
|
||||
for x in valid_arch:
|
||||
if not any([torch.equal(x, tr_m) for tr_m in unique]):
|
||||
unique.append(x)
|
||||
return unique, len(unique) / len(valid_arch)
|
||||
|
||||
def compute_novelty(self, unique):
|
||||
num_novel = 0
|
||||
novel = []
|
||||
if self.train_arch_list is None:
|
||||
print("Dataset arch_str is None, novelty computation skipped")
|
||||
return 1, 1
|
||||
for arch in unique:
|
||||
if not any([torch.equal(arch, tr_m) for tr_m in self.train_arch_list]):
|
||||
# if arch not in self.train_arch_list[1:]:
|
||||
novel.append(arch)
|
||||
num_novel += 1
|
||||
return novel, num_novel / len(unique)
|
||||
|
||||
def evaluate(self, generated, adj, mask, check_dataname='cifar10'):
|
||||
""" generated: list of pairs """
|
||||
valid_arch, validity, _, _, error_types = self.compute_validity(generated, adj, mask)
|
||||
|
||||
print(f"Validity over {len(generated)} archs: {validity * 100 :.2f}%")
|
||||
error_1 = torch.sum(torch.tensor(error_types) == 1) / len(generated)
|
||||
error_2 = torch.sum(torch.tensor(error_types) == 2) / len(generated)
|
||||
error_3 = torch.sum(torch.tensor(error_types) == 3) / len(generated)
|
||||
print(f"Unvalid-Multi_Node_Type over {len(generated)} archs: {error_1 * 100 :.2f}%")
|
||||
print(f"INVALID_1OR2 over {len(generated)} archs: {error_2 * 100 :.2f}%")
|
||||
print(f"INVALID_3AND4 over {len(generated)} archs: {error_3 * 100 :.2f}%")
|
||||
# print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}")
|
||||
|
||||
if validity > 0:
|
||||
unique, uniqueness = self.compute_uniqueness(valid_arch)
|
||||
print(f"Uniqueness over {len(valid_arch)} valid archs: {uniqueness * 100 :.2f}%")
|
||||
|
||||
if self.train_arch_list is not None:
|
||||
_, novelty = self.compute_novelty(unique)
|
||||
print(f"Novelty over {len(unique)} unique valid archs: {novelty * 100 :.2f}%")
|
||||
else:
|
||||
novelty = -1.0
|
||||
|
||||
else:
|
||||
novelty = -1.0
|
||||
uniqueness = 0.0
|
||||
unique = []
|
||||
|
||||
test_acc_list, flops_list, params_list, latency_list = [0], [0], [0], [0]
|
||||
all_arch_str = None
|
||||
return ([validity, uniqueness, novelty, error_1, error_2, error_3],
|
||||
unique,
|
||||
dict(test_acc_list=test_acc_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
|
||||
all_arch_str)
|
||||
|
||||
|
||||
class BasicArchMetricsMetaOFA(object):
|
||||
def __init__(self, train_ds=None, train_arch_str_list=None, except_inout=False, data_root=None):
|
||||
if data_root is not None:
|
||||
self.ofa = torch.load(data_root)
|
||||
self.train_arch_list = self.ofa['x']
|
||||
else:
|
||||
self.ofa = None
|
||||
self.train_arch_list = None
|
||||
self.ops_decoder = OPS
|
||||
|
||||
def get_string_from_onehot_x(self, x):
|
||||
x = torch.tensor(x)
|
||||
ds = torch.sum(x.view(NUM_STAGE, -1), dim=1)
|
||||
string = ''
|
||||
for i, _ in enumerate(x):
|
||||
if sum(_) == 0:
|
||||
string += '0-0-0_'
|
||||
else:
|
||||
string += f'{int(ds[int(i/MAX_LAYER_PER_STAGE)])}-' + OPS2STR[torch.nonzero(torch.tensor(_)).item()] + '_'
|
||||
return string[:-1]
|
||||
|
||||
def compute_validity(self, generated, adj=None, mask=None):
|
||||
""" generated: list of couples (positions, node_types)"""
|
||||
valid = []
|
||||
valid_arch_str = []
|
||||
all_arch_str = []
|
||||
error_types = []
|
||||
for x in generated:
|
||||
is_valid, error_type = is_valid_OFA_x(x)
|
||||
if is_valid:
|
||||
valid.append(torch.tensor(x).long())
|
||||
arch_str = self.get_string_from_onehot_x(x)
|
||||
valid_arch_str.append(arch_str)
|
||||
else:
|
||||
arch_str = None
|
||||
error_types.append(error_type)
|
||||
all_arch_str.append(arch_str)
|
||||
validity = 0 if len(generated) == 0 else (len(valid)/len(generated))
|
||||
return valid, validity, valid_arch_str, all_arch_str, error_types
|
||||
|
||||
def compute_uniqueness(self, valid_arch):
|
||||
unique = []
|
||||
for x in valid_arch:
|
||||
if not any([torch.equal(x, tr_m) for tr_m in unique]):
|
||||
unique.append(x)
|
||||
return unique, len(unique) / len(valid_arch)
|
||||
|
||||
def compute_novelty(self, unique):
|
||||
num_novel = 0
|
||||
novel = []
|
||||
if self.train_arch_list is None:
|
||||
print("Dataset arch_str is None, novelty computation skipped")
|
||||
return 1, 1
|
||||
for arch in unique:
|
||||
if not any([torch.equal(arch, tr_m) for tr_m in self.train_arch_list]):
|
||||
novel.append(arch)
|
||||
num_novel += 1
|
||||
return novel, num_novel / len(unique)
|
||||
|
||||
def evaluate(self, generated, adj, mask, check_dataname='imagenet1k'):
|
||||
""" generated: list of pairs """
|
||||
valid_arch, validity, _, _, error_types = self.compute_validity(generated, adj, mask)
|
||||
|
||||
print(f"Validity over {len(generated)} archs: {validity * 100 :.2f}%")
|
||||
error_1 = torch.sum(torch.tensor(error_types) == 1) / len(generated)
|
||||
error_2 = torch.sum(torch.tensor(error_types) == 2) / len(generated)
|
||||
error_3 = torch.sum(torch.tensor(error_types) == 3) / len(generated)
|
||||
print(f"Unvalid-Multi_Node_Type over {len(generated)} archs: {error_1 * 100 :.2f}%")
|
||||
print(f"INVALID_1OR2 over {len(generated)} archs: {error_2 * 100 :.2f}%")
|
||||
print(f"INVALID_3AND4 over {len(generated)} archs: {error_3 * 100 :.2f}%")
|
||||
|
||||
if validity > 0:
|
||||
unique, uniqueness = self.compute_uniqueness(valid_arch)
|
||||
print(f"Uniqueness over {len(valid_arch)} valid archs: {uniqueness * 100 :.2f}%")
|
||||
|
||||
if self.train_arch_list is not None:
|
||||
_, novelty = self.compute_novelty(unique)
|
||||
print(f"Novelty over {len(unique)} unique valid archs: {novelty * 100 :.2f}%")
|
||||
else:
|
||||
novelty = -1.0
|
||||
|
||||
else:
|
||||
novelty = -1.0
|
||||
uniqueness = 0.0
|
||||
unique = []
|
||||
|
||||
test_acc_list, flops_list, params_list, latency_list = [0], [0], [0], [0]
|
||||
all_arch_str = None
|
||||
return ([validity, uniqueness, novelty, error_1, error_2, error_3],
|
||||
unique,
|
||||
dict(test_acc_list=test_acc_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
|
||||
all_arch_str)
|
||||
|
||||
|
||||
def get_arch_acc_info(nasbench201, arch, dataname='cifar10'):
|
||||
arch_index = nasbench201['str'].index(arch)
|
||||
test_acc = nasbench201['test-acc'][dataname][arch_index]
|
||||
flops = nasbench201['flops'][dataname][arch_index]
|
||||
params = nasbench201['params'][dataname][arch_index]
|
||||
latency = nasbench201['latency'][dataname][arch_index]
|
||||
return test_acc, flops, params, latency
|
||||
|
||||
|
||||
def get_arch_acc_info_meta(nasbench201, arch, dataname='cifar10'):
|
||||
arch_index = nasbench201['str'].index(arch)
|
||||
flops = nasbench201['flops'][dataname][arch_index]
|
||||
params = nasbench201['params'][dataname][arch_index]
|
||||
latency = nasbench201['latency'][dataname][arch_index]
|
||||
if 'cifar' in dataname:
|
||||
test_acc = nasbench201['test-acc'][dataname][arch_index]
|
||||
else:
|
||||
# TODO
|
||||
test_acc = None
|
||||
return arch_index, test_acc, flops, params, latency
|
||||
|
||||
|
||||
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
|
||||
res = g.is_dag()
|
||||
n_start, n_end = 0, 0
|
||||
for v in g.vs:
|
||||
if v['type'] == START_TYPE:
|
||||
n_start += 1
|
||||
elif v['type'] == END_TYPE:
|
||||
n_end += 1
|
||||
if v.indegree() == 0 and v['type'] != START_TYPE:
|
||||
return False
|
||||
if v.outdegree() == 0 and v['type'] != END_TYPE:
|
||||
return False
|
||||
return res and n_start == 1 and n_end == 1
|
||||
|
||||
def check_single_node_type(x):
|
||||
for x_elem in x:
|
||||
if int(np.sum(x_elem)) != 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_start_end_nodes(x, START_TYPE, END_TYPE):
|
||||
if x[0][START_TYPE] != 1:
|
||||
return False
|
||||
if x[-1][END_TYPE] != 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_interm_node_types(x, START_TYPE, END_TYPE):
|
||||
for x_elem in x[1:-1]:
|
||||
if x_elem[START_TYPE] == 1:
|
||||
return False
|
||||
if x_elem[END_TYPE] == 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def construct_igraph(node_type, edge_type, ops_decoder, except_inout=True):
|
||||
assert node_type.shape[0] == edge_type.shape[0]
|
||||
|
||||
START_TYPE = ops_decoder.index('input')
|
||||
END_TYPE = ops_decoder.index('output')
|
||||
|
||||
g = igraph.Graph(directed=True)
|
||||
for i, node in enumerate(node_type):
|
||||
new_type = node.item()
|
||||
g.add_vertex(type=new_type)
|
||||
if new_type == END_TYPE:
|
||||
end_vertices = set([v.index for v in g.vs.select(_outdegree_eq=0) if v.index != g.vcount()-1])
|
||||
for v in end_vertices:
|
||||
g.add_edge(v, i)
|
||||
elif i > 0:
|
||||
for ek in range(i):
|
||||
ek_score = edge_type[ek][i].item()
|
||||
if ek_score >= 0.5:
|
||||
g.add_edge(ek, i)
|
||||
|
||||
return g
|
||||
|
||||
|
||||
def compute_arch_metrics(arch_list, adj, mask, train_arch_str_list,
|
||||
train_ds, timestep=None, name=None, except_inout=False, data_root=None):
|
||||
""" arch_list: (dict) """
|
||||
metrics = BasicArchMetricsOFA(data_root=data_root)
|
||||
arch_metrics = metrics.evaluate(arch_list, adj, mask, check_dataname='cifar10')
|
||||
all_arch_str = arch_metrics[-1]
|
||||
|
||||
if wandb.run:
|
||||
arch_prop = arch_metrics[2]
|
||||
test_acc_list = arch_prop['test_acc_list']
|
||||
flops_list = arch_prop['flops_list']
|
||||
params_list = arch_prop['params_list']
|
||||
latency_list = arch_prop['latency_list']
|
||||
if arch_metrics[0][1] > 0.: # uniquness > 0.
|
||||
dic = {
|
||||
'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
|
||||
'test_acc_max': np.max(test_acc_list), 'test_acc_min': np.min(test_acc_list), 'test_acc_mean': np.mean(test_acc_list), 'test_acc_std': np.std(test_acc_list),
|
||||
'flops_max': np.max(flops_list), 'flops_min': np.min(flops_list), 'flops_mean': np.mean(flops_list), 'flops_std': np.std(flops_list),
|
||||
'params_max': np.max(params_list), 'params_min': np.min(params_list), 'params_mean': np.mean(params_list), 'params_std': np.std(params_list),
|
||||
'latency_max': np.max(latency_list), 'latency_min': np.min(latency_list), 'latency_mean': np.mean(latency_list), 'latency_std': np.std(latency_list),
|
||||
}
|
||||
else:
|
||||
dic = {
|
||||
'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
|
||||
'test_acc_max': -1, 'test_acc_min': -1, 'test_acc_mean': -1, 'test_acc_std': 0,
|
||||
'flops_max': -1, 'flops_min': -1, 'flops_mean': -1, 'flops_std': 0,
|
||||
'params_max': -1, 'params_min': -1, 'params_mean': -1, 'params_std': 0,
|
||||
'latency_max': -1, 'latency_min': -1, 'latency_mean': -1, 'latency_std': 0,
|
||||
}
|
||||
if timestep is not None:
|
||||
dic.update({'step': timestep})
|
||||
|
||||
wandb.log(dic)
|
||||
|
||||
return arch_metrics, all_arch_str
|
||||
|
||||
def compute_arch_metrics_meta(
|
||||
arch_list, adj, mask, train_arch_str_list, train_ds,
|
||||
timestep=None, check_dataname='cifar10', name=None):
|
||||
""" arch_list: (dict) """
|
||||
|
||||
metrics = BasicArchMetricsMetaOFA(train_ds, train_arch_str_list)
|
||||
arch_metrics = metrics.evaluate(arch_list, adj, mask, check_dataname=check_dataname)
|
||||
if wandb.run:
|
||||
arch_prop = arch_metrics[2]
|
||||
if name != 'ofa':
|
||||
arch_idx_list = arch_prop['arch_idx_list']
|
||||
test_acc_list = arch_prop['test_acc_list']
|
||||
flops_list = arch_prop['flops_list']
|
||||
params_list = arch_prop['params_list']
|
||||
latency_list = arch_prop['latency_list']
|
||||
if arch_metrics[0][1] > 0.: # uniquness > 0.
|
||||
dic = {
|
||||
'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
|
||||
'test_acc_max': np.max(test_acc_list), 'test_acc_min': np.min(test_acc_list), 'test_acc_mean': np.mean(test_acc_list), 'test_acc_std': np.std(test_acc_list),
|
||||
'flops_max': np.max(flops_list), 'flops_min': np.min(flops_list), 'flops_mean': np.mean(flops_list), 'flops_std': np.std(flops_list),
|
||||
'params_max': np.max(params_list), 'params_min': np.min(params_list), 'params_mean': np.mean(params_list), 'params_std': np.std(params_list),
|
||||
'latency_max': np.max(latency_list), 'latency_min': np.min(latency_list), 'latency_mean': np.mean(latency_list), 'latency_std': np.std(latency_list),
|
||||
}
|
||||
else:
|
||||
dic = {
|
||||
'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
|
||||
'test_acc_max': -1, 'test_acc_min': -1, 'test_acc_mean': -1, 'test_acc_std': 0,
|
||||
'flops_max': -1, 'flops_min': -1, 'flops_mean': -1, 'flops_std': 0,
|
||||
'params_max': -1, 'params_min': -1, 'params_mean': -1, 'params_std': 0,
|
||||
'latency_max': -1, 'latency_min': -1, 'latency_mean': -1, 'latency_std': 0,
|
||||
}
|
||||
if timestep is not None:
|
||||
dic.update({'step': timestep})
|
||||
|
||||
return arch_metrics
|
||||
|
||||
|
||||
def check_multiple_nodes(x):
|
||||
assert len(x.shape) == 2
|
||||
for x_elem in x:
|
||||
x_elem = np.array(x_elem)
|
||||
if int(np.sum(x_elem)) > 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_inout_node(x, START_TYPE=0, END_TYPE=1):
|
||||
assert len(x.shape) == 2
|
||||
return x[0][START_TYPE] == 1 and x[-1][END_TYPE] == 1
|
||||
|
||||
def check_none_in_1_and_2_layers(x, NONE_TYPE=None):
|
||||
assert len(x.shape) == 2
|
||||
first_and_second_layers = [0, 1, 4, 5, 8, 9, 12, 13, 16, 17]
|
||||
for layer in first_and_second_layers:
|
||||
if int(np.sum(x[layer])) == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_none_in_3_and_4_layers(x, NONE_TYPE=None):
|
||||
assert len(x.shape) == 2
|
||||
third_layers = [2, 6, 10, 14, 18]
|
||||
|
||||
for layer in third_layers:
|
||||
if int(np.sum(x[layer])) == 0:
|
||||
if int(np.sum(x[layer+1])) != 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_interm_inout_node(x, START_TYPE, END_TYPE):
|
||||
for x_elem in x[1:-1]:
|
||||
if x_elem[START_TYPE] == 1:
|
||||
return False
|
||||
if x_elem[END_TYPE] == 1:
|
||||
return False
|
||||
|
||||
|
||||
def is_valid_OFA_x(x):
|
||||
ERORR = {
|
||||
'MULIPLE_NODES': 1,
|
||||
'INVALID_1OR2_LAYERS': 2,
|
||||
'INVALID_3AND4_LAYERS': 3,
|
||||
'NO_ERROR': -1
|
||||
}
|
||||
if not check_multiple_nodes(x):
|
||||
return False, ERORR['MULIPLE_NODES']
|
||||
|
||||
if not check_none_in_1_and_2_layers(x):
|
||||
return False, ERORR['INVALID_1OR2_LAYERS']
|
||||
|
||||
if not check_none_in_3_and_4_layers(x):
|
||||
return False, ERORR['INVALID_3AND4_LAYERS']
|
||||
|
||||
return True, ERORR['NO_ERROR']
|
||||
|
||||
|
||||
def get_x_adj_from_opsdict_ofa(ops):
|
||||
node_types = torch.zeros(NUM_STAGE * MAX_LAYER_PER_STAGE).long() # w/o in / out
|
||||
num_vertices = len(OPS.values())
|
||||
num_nodes = NUM_STAGE * MAX_LAYER_PER_STAGE
|
||||
d_matrix = []
|
||||
|
||||
for i in range(NUM_STAGE):
|
||||
ds = ops['d'][i]
|
||||
for j in range(ds):
|
||||
d_matrix.append(ds)
|
||||
|
||||
for j in range(MAX_LAYER_PER_STAGE - ds):
|
||||
d_matrix.append('none')
|
||||
|
||||
for i, (ks, e, d) in enumerate(zip(
|
||||
ops['ks'], ops['e'], d_matrix)):
|
||||
if d == 'none':
|
||||
pass
|
||||
else:
|
||||
node_types[i] = OPS[f'{ks}-{e}']
|
||||
|
||||
x = one_hot(node_types, num_vertices).float()
|
||||
|
||||
def get_adj():
|
||||
adj = torch.zeros(num_nodes, num_nodes)
|
||||
for i in range(num_nodes-1):
|
||||
adj[i, i+1] = 1
|
||||
adj = np.array(adj)
|
||||
return adj
|
||||
|
||||
adj = get_adj()
|
||||
return x, adj
|
||||
|
||||
|
||||
def get_string_from_onehot_x(x):
|
||||
x = torch.tensor(x)
|
||||
ds = torch.sum(x.view(NUM_STAGE, -1), dim=1)
|
||||
string = ''
|
||||
for i, _ in enumerate(x):
|
||||
if sum(_) == 0:
|
||||
string += '0-0-0_'
|
||||
else:
|
||||
string += f'{int(ds[int(i/MAX_LAYER_PER_STAGE)])}-' + OPS2STR[torch.nonzero(torch.tensor(_)).item()] + '_'
|
||||
return string[:-1]
|
114
MobileNetV3/analysis/arch_metrics.py
Normal file
114
MobileNetV3/analysis/arch_metrics.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from analysis.arch_functions import compute_arch_metrics, compute_arch_metrics_meta
|
||||
from torch import Tensor
|
||||
import wandb
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SamplingArchMetrics(nn.Module):
|
||||
def __init__(self, config, train_ds, exp_name):
|
||||
super().__init__()
|
||||
|
||||
self.exp_name = exp_name
|
||||
self.train_ds = train_ds
|
||||
if config.data.name == 'ofa':
|
||||
self.train_arch_str_list = train_ds.x_list_
|
||||
else:
|
||||
self.train_arch_str_list = train_ds.arch_str_list_
|
||||
self.name = config.data.name
|
||||
self.except_inout = config.data.except_inout
|
||||
self.data_root = config.data.root
|
||||
|
||||
|
||||
def forward(self, arch_list: list, adj, mask, this_sample_dir, test=False, timestep=None):
|
||||
"""_summary_
|
||||
:params arch_list: list of archs
|
||||
:params adj: [batch_size, num_nodes, num_nodes]
|
||||
:params mask: [batch_size, num_nodes, num_nodes]
|
||||
"""
|
||||
arch_metrics, all_arch_str = compute_arch_metrics(
|
||||
arch_list, adj, mask, self.train_arch_str_list, self.train_ds, timestep=timestep,
|
||||
name=self.name, except_inout=self.except_inout, data_root=self.data_root)
|
||||
# arch_metrics
|
||||
# ([validity, uniqueness, novelty],
|
||||
# unique,
|
||||
# dict(test_acc_list=test_acc_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
|
||||
# all_arch_str)
|
||||
|
||||
if test and self.name != 'ofa':
|
||||
with open(r'final_.txt', 'w') as fp:
|
||||
for arch_str in all_arch_str:
|
||||
# write each item on a new line
|
||||
fp.write("%s\n" % arch_str)
|
||||
print('All archs saved')
|
||||
|
||||
if self.name != 'ofa':
|
||||
valid_unique_arch = arch_metrics[1]
|
||||
valid_unique_arch_prop_dict = arch_metrics[2] # test_acc, flops, params, latency
|
||||
# textfile = open(f'{this_sample_dir}/archs/{name}/valid_unique_arch_step-{current_step}.txt', "w")
|
||||
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
|
||||
for i in range(len(valid_unique_arch)):
|
||||
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
|
||||
textfile.write(f"Test Acc: {valid_unique_arch_prop_dict['test_acc_list'][i]} \n")
|
||||
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n ")
|
||||
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
|
||||
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n \n")
|
||||
textfile.writelines(valid_unique_arch)
|
||||
textfile.close()
|
||||
|
||||
# res_dic = {
|
||||
# 'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
|
||||
# 'test_acc_max': -1, 'test_acc_min':-1, 'test_acc_mean': -1, 'test_acc_std': 0,
|
||||
# 'flops_max': -1, 'flops_min':-1, 'flops_mean': -1, 'flops_std': 0,
|
||||
# 'params_max': -1, 'params_min':-1, 'params_mean': -1, 'params_std': 0,
|
||||
# 'latency_max': -1, 'latency_min':-1, 'latency_mean': -1, 'latency_std': 0,
|
||||
# }
|
||||
|
||||
return arch_metrics
|
||||
|
||||
class SamplingArchMetricsMeta(nn.Module):
|
||||
def __init__(self, config, train_ds, exp_name, train_index=None, nasbench=None):
|
||||
super().__init__()
|
||||
|
||||
self.exp_name = exp_name
|
||||
self.train_ds = train_ds
|
||||
self.search_space = config.data.name
|
||||
if self.search_space == 'ofa':
|
||||
self.train_arch_str_list = None
|
||||
else:
|
||||
self.train_arch_str_list = [train_ds.arch_str_list[i] for i in train_ds.idx_lst['train']]
|
||||
|
||||
def forward(self, arch_list: list, adj, mask, this_sample_dir, test=False,
|
||||
timestep=None, check_dataname='cifar10'):
|
||||
"""_summary_
|
||||
:params arch_list: list of archs
|
||||
:params adj: [batch_size, num_nodes, num_nodes]
|
||||
:params mask: [batch_size, num_nodes, num_nodes]
|
||||
"""
|
||||
arch_metrics = compute_arch_metrics_meta(arch_list, adj, mask, self.train_arch_str_list,
|
||||
self.train_ds, timestep=timestep, check_dataname=check_dataname,
|
||||
name=self.search_space)
|
||||
all_arch_str = arch_metrics[-1]
|
||||
|
||||
if test:
|
||||
with open(r'final_.txt', 'w') as fp:
|
||||
for arch_str in all_arch_str:
|
||||
# write each item on a new line
|
||||
fp.write("%s\n" % arch_str)
|
||||
print('All archs saved')
|
||||
|
||||
valid_unique_arch = arch_metrics[1] # arch_str
|
||||
valid_unique_arch_prop_dict = arch_metrics[2] # test_acc, flops, params, latency
|
||||
# textfile = open(f'{this_sample_dir}/archs/{name}/valid_unique_arch_step-{current_step}.txt', "w")
|
||||
if self.search_space != 'ofa':
|
||||
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
|
||||
for i in range(len(valid_unique_arch)):
|
||||
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
|
||||
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
|
||||
textfile.write(f"Test Acc: {valid_unique_arch_prop_dict['test_acc_list'][i]} \n")
|
||||
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n ")
|
||||
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
|
||||
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n \n")
|
||||
textfile.writelines(valid_unique_arch)
|
||||
textfile.close()
|
||||
|
||||
return arch_metrics
|
547
MobileNetV3/analysis/visualization.py
Normal file
547
MobileNetV3/analysis/visualization.py
Normal file
@@ -0,0 +1,547 @@
|
||||
import os
|
||||
import torch
|
||||
import imageio
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
# import rdkit.Chem
|
||||
import wandb
|
||||
import matplotlib.pyplot as plt
|
||||
# import igraph
|
||||
# import pygraphviz as pgv
|
||||
import datasets_nas
|
||||
from configs.ckpt import DATAROOT_NB201
|
||||
|
||||
|
||||
class ArchVisualization:
|
||||
def __init__(self, config, remove_none=False, exp_name=None):
|
||||
self.config = config
|
||||
self.remove_none = remove_none
|
||||
self.exp_name = exp_name
|
||||
self.num_graphs_to_visualize = config.log.num_graphs_to_visualize
|
||||
self.nasbench201 = torch.load(DATAROOT_NB201)
|
||||
|
||||
self.labels = {
|
||||
0: 'input',
|
||||
1: 'output',
|
||||
2: 'conv3',
|
||||
3: 'sep3',
|
||||
4: 'conv5',
|
||||
5: 'sep5',
|
||||
6: 'avg3',
|
||||
7: 'max3',
|
||||
}
|
||||
|
||||
self.colors = ['skyblue', 'pink', 'yellow', 'orange', 'greenyellow', 'green', 'azure', 'beige']
|
||||
|
||||
|
||||
def to_networkx_directed(self, node_list, adjacency_matrix):
|
||||
"""
|
||||
Convert graphs to neural architectures
|
||||
node_list: the nodes of a batch of nodes (bs x n)
|
||||
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
|
||||
"""
|
||||
|
||||
|
||||
graph = nx.DiGraph()
|
||||
# add nodes to the graph
|
||||
for i in range(len(node_list)):
|
||||
if node_list[i] == -1:
|
||||
continue
|
||||
graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])
|
||||
|
||||
rows, cols = np.where(torch.triu(torch.tensor(adjacency_matrix), diagonal=1).numpy() >= 1)
|
||||
edges = zip(rows.tolist(), cols.tolist())
|
||||
for edge in edges:
|
||||
edge_type = adjacency_matrix[edge[0]][edge[1]]
|
||||
graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type)
|
||||
|
||||
return graph
|
||||
|
||||
def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=1200, largest_component=False):
|
||||
if largest_component:
|
||||
CGs = [graph.subgraph(c) for c in nx.connected_components(graph)]
|
||||
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
|
||||
graph = CGs[0]
|
||||
|
||||
# Plot the graph structure with colors
|
||||
if pos is None:
|
||||
pos = nx.nx_pydot.graphviz_layout(graph, prog="dot")
|
||||
# pos = nx.multipartite_layout(graph, subset_key='number')
|
||||
# pos = nx.spring_layout(graph, iterations=iterations)
|
||||
|
||||
# Set node colors based on the operations
|
||||
|
||||
plt.figure()
|
||||
nx.draw(graph, pos=pos, labels=self.labels, arrows=True, node_shape="s",
|
||||
node_size=node_size, node_color=self.colors, edge_color='grey', with_labels=True)
|
||||
# nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1],
|
||||
# cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey')
|
||||
# import pdb; pdb.set_trace()
|
||||
# plt.tight_layout()
|
||||
|
||||
plt.savefig(path)
|
||||
plt.close("all")
|
||||
|
||||
def visualize(self, path: str, graphs: list, log='graph', adj=None):
|
||||
# define path to save figures
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
# visualize the final molecules
|
||||
for i in range(self.num_graphs_to_visualize):
|
||||
file_path = os.path.join(path, 'graph_{}.png'.format(i))
|
||||
graph = self.to_networkx_directed(graphs[i], adj[0].detach().cpu().numpy())
|
||||
self.visualize_non_molecule(graph, pos=None, path=file_path)
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run and log is not None:
|
||||
wandb.log({log: [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
def visualize_chain(self, path, sample_list, adjacency_matrix,
|
||||
r_valid_chain, r_uniqueness_chain, r_novel_chain):
|
||||
import pdb; pdb.set_trace()
|
||||
# convert graphs to networkx
|
||||
graphs = [self.to_networkx_directed(sample_list[i], adjacency_matrix[i]) for i in range(sample_list.shape[0])]
|
||||
# find the coordinates of atoms in the final molecule
|
||||
final_graph = graphs[-1]
|
||||
final_pos = nx.nx_pydot.graphviz_layout(final_graph, prog="dot")
|
||||
# final_pos = None
|
||||
|
||||
# draw gif
|
||||
save_paths = []
|
||||
num_frams = sample_list
|
||||
|
||||
for frame in range(num_frams):
|
||||
file_name = os.path.join(path, 'frame_{}.png'.format(frame))
|
||||
self.visualize_non_molecule(graphs[frame], pos=final_pos, path=file_name)
|
||||
save_paths.append(file_name)
|
||||
|
||||
imgs = [imageio.imread(fn) for fn in save_paths]
|
||||
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
|
||||
print(f'==> Save gif at {gif_path}')
|
||||
imgs.extend([imgs[-1]] * 10)
|
||||
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
|
||||
if wandb.run:
|
||||
wandb.log({'chain': [wandb.Video(gif_path, caption=gif_path, format="gif")]})
|
||||
|
||||
|
||||
def visualize_chain_vun(self, path, r_valid_chain, r_unique_chain, r_novel_chain, sde, sampling_eps, number_chain_steps=None):
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
# timesteps = torch.linspace(sampling_eps, sde.T, sde.N)
|
||||
timesteps = torch.linspace(sde.T, sampling_eps, sde.N)
|
||||
|
||||
if number_chain_steps is not None:
|
||||
timesteps_ = []
|
||||
n = int(sde.N / number_chain_steps)
|
||||
for i, t in enumerate(timesteps):
|
||||
if i % n == n - 1:
|
||||
timesteps_.append(t.item())
|
||||
# timesteps_ = [t for i, t in enumerate(timesteps) if i % n == n-1]
|
||||
assert len(timesteps_) == number_chain_steps
|
||||
timesteps_ = timesteps_[::-1]
|
||||
|
||||
else:
|
||||
timesteps_ = list(timesteps.numpy())[::-1]
|
||||
|
||||
# validity
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(timesteps_, r_valid_chain, color='red')
|
||||
ax.set_title(f'Validity')
|
||||
ax.set_xlabel('time')
|
||||
ax.set_ylabel('Validity')
|
||||
plt.show()
|
||||
file_path = os.path.join(path, 'validity.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {file_path}')
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run:
|
||||
wandb.log({'r_valid_chains': [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
# Uniqueness
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(timesteps_, r_unique_chain, color='green')
|
||||
ax.set_title(f'Uniqueness')
|
||||
ax.set_xlabel('time')
|
||||
ax.set_ylabel('Uniqueness')
|
||||
plt.show()
|
||||
file_path = os.path.join(path, 'uniquness.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {file_path}')
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run:
|
||||
wandb.log({'r_uniqueness_chains': [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
# Novelty
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(timesteps_, r_novel_chain, color='blue')
|
||||
ax.set_title(f'Novelty')
|
||||
ax.set_xlabel('time')
|
||||
ax.set_ylabel('Novelty')
|
||||
file_path = os.path.join(path, 'novelty.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {file_path}')
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run:
|
||||
wandb.log({'r_novelty_chains': [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
|
||||
def visualize_grad_norm(self, path, score_grad_norm_p, classifier_grad_norm_p,
|
||||
score_grad_norm_c, classifier_grad_norm_c, sde, sampling_eps,
|
||||
number_chain_steps=None):
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
# timesteps = torch.linspace(sampling_eps, sde.T, sde.N)
|
||||
timesteps = torch.linspace(sde.T, sampling_eps, sde.N)
|
||||
timesteps_ = list(timesteps.numpy())[::-1]
|
||||
|
||||
if len(score_grad_norm_c) == 0:
|
||||
score_grad_norm_c = [-1] * len(score_grad_norm_p)
|
||||
if len(classifier_grad_norm_c) == 0:
|
||||
classifier_grad_norm_c = [-1] * len(classifier_grad_norm_p)
|
||||
|
||||
plt.clf()
|
||||
fig, ax1 = plt.subplots()
|
||||
|
||||
color_1 = 'red'
|
||||
ax1.set_title(f'grad_norm (predictor)')
|
||||
ax1.set_xlabel('time')
|
||||
ax1.set_ylabel('score_grad_norm (predictor)', color=color_1)
|
||||
ax1.plot(timesteps_, score_grad_norm_p, color=color_1)
|
||||
ax1.tick_params(axis='y', labelcolor=color_1)
|
||||
|
||||
ax2 = ax1.twinx()
|
||||
color_2 = 'blue'
|
||||
ax2.set_ylabel('classifier_grad_norm (predictor)', color=color_2)
|
||||
ax2.plot(timesteps_, classifier_grad_norm_p, color=color_2)
|
||||
ax2.tick_params(axis='y', labelcolor=color_2)
|
||||
fig.tight_layout()
|
||||
plt.show()
|
||||
|
||||
file_path = os.path.join(path, 'grad_norm_p.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {file_path}')
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run:
|
||||
wandb.log({'grad_norm_p': [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
|
||||
plt.clf()
|
||||
fig, ax1 = plt.subplots()
|
||||
|
||||
color_1 = 'green'
|
||||
ax1.set_title(f'grad_norm (corrector)')
|
||||
ax1.set_xlabel('time')
|
||||
ax1.set_ylabel('score_grad_norm (corrector)', color=color_1)
|
||||
ax1.plot(timesteps_, score_grad_norm_c, color=color_1)
|
||||
ax1.tick_params(axis='y', labelcolor=color_1)
|
||||
|
||||
ax2 = ax1.twinx()
|
||||
color_2 = 'yellow'
|
||||
ax2.set_ylabel('classifier_grad_norm (corrector)', color=color_2)
|
||||
ax2.plot(timesteps_, classifier_grad_norm_c, color=color_2)
|
||||
ax2.tick_params(axis='y', labelcolor=color_2)
|
||||
fig.tight_layout()
|
||||
plt.show()
|
||||
|
||||
file_path = os.path.join(path, 'grad_norm_c.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {file_path}')
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run:
|
||||
wandb.log({'grad_norm_c': [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
|
||||
def visualize_scatter(self, path,
|
||||
score_config, classifier_config,
|
||||
sampled_arch_metric, plot_textstr=True,
|
||||
x_axis='latency', y_axis='test-acc', x_label='Latency (ms)', y_label='Accuracy (%)',
|
||||
log='scatter', check_dataname='cifar10-valid',
|
||||
selected_arch_idx_list_topN=None, selected_arch_idx_list=None,
|
||||
train_idx_list=None, return_file_path=False):
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
tg_dataset = classifier_config.data.tg_dataset
|
||||
|
||||
train_ds_s, eval_ds_s, test_ds_s = datasets_nas.get_dataset(score_config)
|
||||
if selected_arch_idx_list is None:
|
||||
train_ds_c, eval_ds_c, test_ds_c = datasets_nas.get_dataset(classifier_config)
|
||||
else:
|
||||
train_ds_c, eval_ds_c, test_ds_c = datasets_nas.get_dataset_iter(classifier_config)
|
||||
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
# entire architectures
|
||||
entire_ds_x = train_ds_s.get_unnoramlized_entire_data(x_axis, tg_dataset)
|
||||
entire_ds_y = train_ds_s.get_unnoramlized_entire_data(y_axis, tg_dataset)
|
||||
ax.scatter(entire_ds_x, entire_ds_y, color = 'lightgray', alpha = 0.5, label='Entire', marker=',')
|
||||
|
||||
# architectures trained by the score_model
|
||||
# train_ds_s_x = train_ds_s.get_unnoramlized_data(x_axis, tg_dataset)
|
||||
# train_ds_s_y = train_ds_s.get_unnoramlized_data(y_axis, tg_dataset)
|
||||
# ax.scatter(train_ds_s_x, train_ds_s_y, color = 'gray', alpha = 0.8, label='Trained by Score Model')
|
||||
|
||||
# architectures trained by the classifier
|
||||
train_ds_c_x = train_ds_c.get_unnoramlized_data(x_axis, tg_dataset)
|
||||
train_ds_c_y = train_ds_c.get_unnoramlized_data(y_axis, tg_dataset)
|
||||
ax.scatter(train_ds_c_x, train_ds_c_y, color = 'black', alpha = 0.8, label='Trained by Predictor Model')
|
||||
|
||||
# oracle
|
||||
oracle_idx = torch.argmax(torch.tensor(entire_ds_y)).item()
|
||||
# oracle_idx = torch.argmax(torch.tensor(train_ds_s.get_unnoramlized_entire_data('val-acc', tg_dataset))).item()
|
||||
oracle_item_x = entire_ds_x[oracle_idx]
|
||||
oracle_item_y = entire_ds_y[oracle_idx]
|
||||
ax.scatter(oracle_item_x, oracle_item_y, color = 'red', alpha = 1.0, label='Oracle', marker='*', s=150)
|
||||
|
||||
# architectures sampled by the score_model & classifier
|
||||
AXIS_TO_PROP = {
|
||||
'val-acc': 'val_acc_list',
|
||||
'test-acc': 'test_acc_list',
|
||||
'latency': 'latency_list',
|
||||
'flops': 'flops_list',
|
||||
'params': 'params_list',
|
||||
}
|
||||
sampled_ds_c_x = sampled_arch_metric[2][AXIS_TO_PROP[x_axis]]
|
||||
sampled_ds_c_y = sampled_arch_metric[2][AXIS_TO_PROP[y_axis]]
|
||||
ax.scatter(sampled_ds_c_x, sampled_ds_c_y, color = 'limegreen', alpha = 0.8, label='Sampled', marker='x')
|
||||
|
||||
ax.set_title(f'{tg_dataset.upper()} Dataset')
|
||||
ax.set_xlabel(x_label)
|
||||
ax.set_ylabel(y_label)
|
||||
|
||||
|
||||
if selected_arch_idx_list_topN is not None:
|
||||
selected_arch_topN_info_dict = get_arch_acc_info_dict(
|
||||
self.nasbench201, dataname=check_dataname, arch_index_list=selected_arch_idx_list_topN)
|
||||
selected_topN_ds_x = selected_arch_topN_info_dict[AXIS_TO_PROP[x_axis]]
|
||||
selected_topN_ds_y = selected_arch_topN_info_dict[AXIS_TO_PROP[y_axis]]
|
||||
ax.scatter(selected_topN_ds_x, selected_topN_ds_y, color = 'pink', alpha = 0.8, label='Selected_topN', marker='x')
|
||||
|
||||
# architectures selected by the prdictor
|
||||
selected_ds_x, selected_ds_y = None, None
|
||||
if selected_arch_idx_list is not None:
|
||||
selected_arch_info_dict = get_arch_acc_info_dict(
|
||||
self.nasbench201, dataname=check_dataname, arch_index_list=selected_arch_idx_list)
|
||||
selected_ds_x = selected_arch_info_dict[AXIS_TO_PROP[x_axis]]
|
||||
selected_ds_y = selected_arch_info_dict[AXIS_TO_PROP[y_axis]]
|
||||
ax.scatter(selected_ds_x, selected_ds_y, color = 'blue', alpha = 0.8, label='Selected', marker='x')
|
||||
|
||||
if plot_textstr:
|
||||
textstr = self.get_textstr(sampled_arch_metric=sampled_arch_metric,
|
||||
sampled_ds_c_x=sampled_ds_c_x, sampled_ds_c_y=sampled_ds_c_y,
|
||||
x_axis=x_axis, y_axis=y_axis,
|
||||
classifier_config=classifier_config,
|
||||
selected_ds_x=selected_ds_x, selected_ds_y=selected_ds_y,
|
||||
selected_topN_ds_x=selected_topN_ds_x, selected_topN_ds_y=selected_topN_ds_y,
|
||||
oracle_idx=oracle_idx, train_idx_list=train_idx_list
|
||||
)
|
||||
|
||||
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
|
||||
ax.text(0.6, 0.4, textstr, transform=ax.transAxes, verticalalignment='bottom', bbox=props, fontsize='x-small')
|
||||
# ax.text(textstr, transform=ax.transAxes, verticalalignment='bottom', bbox=props)
|
||||
ax.legend(loc="lower right")
|
||||
|
||||
plt.subplots_adjust(left=0, bottom=0, right=1, top=1)
|
||||
plt.show()
|
||||
plt.tight_layout()
|
||||
|
||||
file_path = os.path.join(path, 'scatter.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {path}')
|
||||
|
||||
if return_file_path:
|
||||
return file_path
|
||||
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run and log is not None:
|
||||
wandb.log({log: [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
# if return_selected_arch_info_dict:
|
||||
# return selected_arch_info_dict, selected_arch_topN_info_dict
|
||||
|
||||
def visualize_scatter_chain(self, path, score_config, classifier_config, sampled_arch_metric_chain, plot_textstr=True,
|
||||
x_axis='latency', y_axis='test-acc', x_label='Latency (ms)', y_label='Accuracy (%)',
|
||||
log='scatter_chain'):
|
||||
|
||||
# draw gif
|
||||
os.makedirs(path, exist_ok=True)
|
||||
save_paths = []
|
||||
num_frames = len(sampled_arch_metric_chain)
|
||||
|
||||
tg_dataset = classifier_config.data.tg_dataset
|
||||
|
||||
train_ds_s, eval_ds_s, test_ds_s = datasets_nas.get_dataset(score_config)
|
||||
train_ds_c, eval_ds_c, test_ds_c = datasets_nas.get_dataset(classifier_config)
|
||||
|
||||
# entire architectures
|
||||
entire_ds_x = train_ds_s.get_unnoramlized_entire_data(x_axis, tg_dataset)
|
||||
entire_ds_y = train_ds_s.get_unnoramlized_entire_data(y_axis, tg_dataset)
|
||||
|
||||
# architectures trained by the score_model
|
||||
train_ds_s_x = train_ds_s.get_unnoramlized_data(x_axis, tg_dataset)
|
||||
train_ds_s_y = train_ds_s.get_unnoramlized_data(y_axis, tg_dataset)
|
||||
|
||||
# architectures trained by the classifier
|
||||
train_ds_c_x = train_ds_c.get_unnoramlized_data(x_axis, tg_dataset)
|
||||
train_ds_c_y = train_ds_c.get_unnoramlized_data(y_axis, tg_dataset)
|
||||
|
||||
# oracle
|
||||
# oracle_idx = torch.argmax(torch.tensor(entire_ds_y)).item()
|
||||
oracle_idx = torch.argmax(torch.tensor(train_ds_s.get_unnoramlized_entire_data('val-acc', tg_dataset))).item()
|
||||
oracle_item_x = entire_ds_x[oracle_idx]
|
||||
oracle_item_y = entire_ds_y[oracle_idx]
|
||||
|
||||
for frame in range(num_frames):
|
||||
sampled_arch_metric = sampled_arch_metric_chain[frame]
|
||||
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
# entire architectures
|
||||
ax.scatter(entire_ds_x, entire_ds_y, color = 'lightgray', alpha = 0.5, label='Entire', marker=',')
|
||||
# architectures trained by the score_model
|
||||
ax.scatter(train_ds_s_x, train_ds_s_y, color = 'gray', alpha = 0.8, label='Trained by Score Model')
|
||||
# architectures trained by the classifier
|
||||
ax.scatter(train_ds_c_x, train_ds_c_y, color = 'black', alpha = 0.8, label='Trained by Predictor Model')
|
||||
# oracle
|
||||
ax.scatter(oracle_item_x, oracle_item_y, color = 'red', alpha = 1.0, label='Oracle', marker='*', s=150)
|
||||
# architectures sampled by the score_model & classifier
|
||||
AXIS_TO_PROP = {
|
||||
'test-acc': 'test_acc_list',
|
||||
'latency': 'latency_list',
|
||||
'flops': 'flops_list',
|
||||
'params': 'params_list',
|
||||
}
|
||||
sampled_ds_c_x = sampled_arch_metric[2][AXIS_TO_PROP[x_axis]]
|
||||
sampled_ds_c_y = sampled_arch_metric[2][AXIS_TO_PROP[y_axis]]
|
||||
ax.scatter(sampled_ds_c_x, sampled_ds_c_y, color = 'limegreen', alpha = 0.8, label='Sampled', marker='x')
|
||||
|
||||
ax.set_title(f'{tg_dataset.upper()} Dataset')
|
||||
ax.set_xlabel(x_label)
|
||||
ax.set_ylabel(y_label)
|
||||
|
||||
if plot_textstr:
|
||||
textstr = self.get_textstr(sampled_arch_metric, sampled_ds_c_x, sampled_ds_c_y,
|
||||
x_axis, y_axis, classifier_config)
|
||||
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
|
||||
ax.text(0.6, 0.3, textstr, transform=ax.transAxes, verticalalignment='bottom', bbox=props)
|
||||
# ax.text(textstr, transform=ax.transAxes, verticalalignment='bottom', bbox=props)
|
||||
ax.legend(loc="lower right")
|
||||
|
||||
plt.subplots_adjust(left=0, bottom=0, right=1, top=1)
|
||||
plt.show()
|
||||
# plt.tight_layout()
|
||||
|
||||
file_path = os.path.join(path, f'frame_{frame}.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {file_path}')
|
||||
save_paths.append(file_path)
|
||||
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run and log is not None:
|
||||
wandb.log({log: [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
# draw gif
|
||||
imgs = [imageio.imread(fn) for fn in save_paths[::-1]]
|
||||
# gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
|
||||
gif_path = os.path.join(path, f'scatter.gif')
|
||||
print(f'==> Save gif at {gif_path}')
|
||||
imgs.extend([imgs[-1]] * 10)
|
||||
# imgs.extend([imgs[0]] * 10)
|
||||
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
|
||||
if wandb.run:
|
||||
wandb.log({'chain_gif': [wandb.Video(gif_path, caption=gif_path, format="gif")]})
|
||||
|
||||
def get_textstr(self,
|
||||
sampled_arch_metric,
|
||||
sampled_ds_c_x, sampled_ds_c_y,
|
||||
x_axis='latency', y_axis='test-acc',
|
||||
classifier_config=None,
|
||||
selected_ds_x=None, selected_ds_y=None,
|
||||
selected_topN_ds_x=None, selected_topN_ds_y=None,
|
||||
oracle_idx=None, train_idx_list=None):
|
||||
mean_v_x = round(np.mean(np.array(sampled_ds_c_x)), 4)
|
||||
std_v_x = round(np.std(np.array(sampled_ds_c_x)), 4)
|
||||
max_v_x = round(np.max(np.array(sampled_ds_c_x)), 4)
|
||||
min_v_x = round(np.min(np.array(sampled_ds_c_x)), 4)
|
||||
|
||||
mean_v_y = round(np.mean(np.array(sampled_ds_c_y)), 4)
|
||||
std_v_y = round(np.std(np.array(sampled_ds_c_y)), 4)
|
||||
max_v_y = round(np.max(np.array(sampled_ds_c_y)), 4)
|
||||
min_v_y = round(np.min(np.array(sampled_ds_c_y)), 4)
|
||||
|
||||
if selected_ds_x is not None:
|
||||
mean_v_x_s = round(np.mean(np.array(selected_ds_x)), 4)
|
||||
std_v_x_s = round(np.std(np.array(selected_ds_x)), 4)
|
||||
max_v_x_s = round(np.max(np.array(selected_ds_x)), 4)
|
||||
min_v_x_s = round(np.min(np.array(selected_ds_x)), 4)
|
||||
|
||||
if selected_ds_y is not None:
|
||||
mean_v_y_s = round(np.mean(np.array(selected_ds_y)), 4)
|
||||
std_v_y_s = round(np.std(np.array(selected_ds_y)), 4)
|
||||
max_v_y_s = round(np.max(np.array(selected_ds_y)), 4)
|
||||
min_v_y_s = round(np.min(np.array(selected_ds_y)), 4)
|
||||
|
||||
textstr = ''
|
||||
r_valid, r_unique, r_novel = round(sampled_arch_metric[0][0], 4), round(sampled_arch_metric[0][1], 4), round(sampled_arch_metric[0][2], 4)
|
||||
textstr += f'V-{r_valid} | U-{r_unique} | N-{r_novel} \n'
|
||||
textstr += f'Predictor (Noise-aware-{str(classifier_config.training.noised)[0]}, k={self.config.sampling.classifier_scale}) \n'
|
||||
textstr += f'=> Sampled {x_axis} \n'
|
||||
textstr += f'Mean-{mean_v_x} | Std-{std_v_x} \n'
|
||||
textstr += f'Max-{max_v_x} | Min-{min_v_x} \n'
|
||||
textstr += f'=> Sampled {y_axis} \n'
|
||||
textstr += f'Mean-{mean_v_y} | Std-{std_v_y} \n'
|
||||
textstr += f'Max-{max_v_y} | Min-{min_v_y} \n'
|
||||
if selected_ds_x is not None:
|
||||
textstr += f'==> Selected {x_axis} \n'
|
||||
textstr += f'Mean-{mean_v_x_s} | Std-{std_v_x_s} \n'
|
||||
textstr += f'Max-{max_v_x_s} | Min-{min_v_x_s} \n'
|
||||
if selected_ds_y is not None:
|
||||
textstr += f'==> Selected {y_axis} \n'
|
||||
textstr += f'Mean-{mean_v_y_s} | Std-{std_v_y_s} \n'
|
||||
textstr += f'Max-{max_v_y_s} | Min-{min_v_y_s} \n'
|
||||
if selected_topN_ds_y is not None:
|
||||
textstr += f'==> Predicted TopN (10) -{str(round(max(selected_topN_ds_y[:10]), 4))} \n'
|
||||
|
||||
if train_idx_list is not None and oracle_idx in train_idx_list:
|
||||
textstr += f'==> Hit Oracle ({oracle_idx}) !'
|
||||
|
||||
return textstr
|
||||
|
||||
|
||||
def get_arch_acc_info_dict(nasbench201, dataname='cifar10-valid', arch_index_list=None):
|
||||
val_acc_list = []
|
||||
test_acc_list = []
|
||||
flops_list = []
|
||||
params_list = []
|
||||
latency_list = []
|
||||
|
||||
for arch_index in arch_index_list:
|
||||
val_acc = nasbench201['val-acc'][dataname][arch_index]
|
||||
val_acc_list.append(val_acc)
|
||||
test_acc = nasbench201['test-acc'][dataname][arch_index]
|
||||
test_acc_list.append(test_acc)
|
||||
flops = nasbench201['flops'][dataname][arch_index]
|
||||
flops_list.append(flops)
|
||||
params = nasbench201['params'][dataname][arch_index]
|
||||
params_list.append(params)
|
||||
latency = nasbench201['latency'][dataname][arch_index]
|
||||
latency_list.append(latency)
|
||||
|
||||
return {
|
||||
'val_acc_list': val_acc_list,
|
||||
'test_acc_list': test_acc_list,
|
||||
'flops_list': flops_list,
|
||||
'params_list': params_list,
|
||||
'latency_list': latency_list
|
||||
}
|
Reference in New Issue
Block a user