first commit

This commit is contained in:
CownowAn
2024-03-15 14:38:51 +00:00
commit bc2ed1304f
321 changed files with 44802 additions and 0 deletions

9
.gitignore vendored Normal file
View File

@@ -0,0 +1,9 @@
__pycache__
checkpoints/
*.pt
data/
exp/
vis/
results/
.empty/
.prev/

9
MobileNetV3/all_path.py Normal file
View File

@@ -0,0 +1,9 @@
RAW_DATA_PATH="./data/ofa/raw_data"
PROCESSED_DATA_PATH = "./data/ofa/data_transfer_nag"
SCORE_MODEL_DATA_PATH="./data/ofa/data_score_model/ofa_database_500000.pt"
SCORE_MODEL_DATA_IDX_PATH="./data/ofa/data_score_model/ridx-500000.pt"
NOISE_META_PREDICTOR_CKPT_PATH = "./checkpoints/ofa/noise_aware_meta_surrogate/model_best.pth.tar"
SCORE_MODEL_CKPT_PATH="./checkpoints/ofa/score_model/model_best.pth.tar"
UNNOISE_META_PREDICTOR_CKPT_PATH="./checkpoints/ofa/unnoised_meta_surrogate_from_metad2a"
CONFIG_PATH='./configs/transfer_nag_ofa.pt'

View File

@@ -0,0 +1,475 @@
import numpy as np
import torch
import wandb
import igraph
from torch.nn.functional import one_hot
KS_LIST = [3, 5, 7]
EXPAND_LIST = [3, 4, 6]
DEPTH_LIST = [2, 3, 4]
NUM_STAGE = 5
MAX_LAYER_PER_STAGE = 4
MAX_N_BLOCK= NUM_STAGE * MAX_LAYER_PER_STAGE # 20
OPS = {
'3-3': 0, '3-4': 1, '3-6': 2,
'5-3': 3, '5-4': 4, '5-6': 5,
'7-3': 6, '7-4': 7, '7-6': 8,
}
OPS2STR = {
0: '3-3', 1: '3-4', 2: '3-6',
3: '5-3', 4: '5-4', 5: '5-6',
6: '7-3', 7: '7-4', 8: '7-6',
}
NUM_OPS = len(OPS)
LONGEST_PATH_LENGTH = 20
class BasicArchMetricsOFA(object):
def __init__(self, train_ds=None, train_arch_str_list=None, except_inout=False, data_root=None):
if data_root is not None:
self.ofa = torch.load(data_root)
self.train_arch_list = self.ofa['x']
else:
self.ofa = None
self.train_arch_list = None
# self.ofa = torch.load(data_root)
self.ops_decoder = OPS
self.except_inout = except_inout
def get_string_from_onehot_x(self, x):
# node_types = torch.nonzero(torch.tensor(x).long(), as_tuple=True)[1]
x = torch.tensor(x)
ds = torch.sum(x.view(NUM_STAGE, -1), dim=1)
string = ''
for i, _ in enumerate(x):
if sum(_) == 0:
string += '0-0-0_'
else:
string += f'{int(ds[int(i/MAX_LAYER_PER_STAGE)])}-' + OPS2STR[torch.nonzero(torch.tensor(_)).item()] + '_'
return string[:-1]
def compute_validity(self, generated, adj=None, mask=None):
""" generated: list of couples (positions, node_types)"""
valid = []
error_types = []
valid_str = []
for x in generated:
is_valid, error_type = is_valid_OFA_x(x)
if is_valid:
valid.append(torch.tensor(x).long())
valid_str.append(self.get_string_from_onehot_x(x))
else:
error_types.append(error_type)
return valid, len(valid) / len(generated), valid_str, None, error_types
def compute_uniqueness(self, valid_arch):
unique = []
for x in valid_arch:
if not any([torch.equal(x, tr_m) for tr_m in unique]):
unique.append(x)
return unique, len(unique) / len(valid_arch)
def compute_novelty(self, unique):
num_novel = 0
novel = []
if self.train_arch_list is None:
print("Dataset arch_str is None, novelty computation skipped")
return 1, 1
for arch in unique:
if not any([torch.equal(arch, tr_m) for tr_m in self.train_arch_list]):
# if arch not in self.train_arch_list[1:]:
novel.append(arch)
num_novel += 1
return novel, num_novel / len(unique)
def evaluate(self, generated, adj, mask, check_dataname='cifar10'):
""" generated: list of pairs """
valid_arch, validity, _, _, error_types = self.compute_validity(generated, adj, mask)
print(f"Validity over {len(generated)} archs: {validity * 100 :.2f}%")
error_1 = torch.sum(torch.tensor(error_types) == 1) / len(generated)
error_2 = torch.sum(torch.tensor(error_types) == 2) / len(generated)
error_3 = torch.sum(torch.tensor(error_types) == 3) / len(generated)
print(f"Unvalid-Multi_Node_Type over {len(generated)} archs: {error_1 * 100 :.2f}%")
print(f"INVALID_1OR2 over {len(generated)} archs: {error_2 * 100 :.2f}%")
print(f"INVALID_3AND4 over {len(generated)} archs: {error_3 * 100 :.2f}%")
# print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}")
if validity > 0:
unique, uniqueness = self.compute_uniqueness(valid_arch)
print(f"Uniqueness over {len(valid_arch)} valid archs: {uniqueness * 100 :.2f}%")
if self.train_arch_list is not None:
_, novelty = self.compute_novelty(unique)
print(f"Novelty over {len(unique)} unique valid archs: {novelty * 100 :.2f}%")
else:
novelty = -1.0
else:
novelty = -1.0
uniqueness = 0.0
unique = []
test_acc_list, flops_list, params_list, latency_list = [0], [0], [0], [0]
all_arch_str = None
return ([validity, uniqueness, novelty, error_1, error_2, error_3],
unique,
dict(test_acc_list=test_acc_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
all_arch_str)
class BasicArchMetricsMetaOFA(object):
def __init__(self, train_ds=None, train_arch_str_list=None, except_inout=False, data_root=None):
if data_root is not None:
self.ofa = torch.load(data_root)
self.train_arch_list = self.ofa['x']
else:
self.ofa = None
self.train_arch_list = None
self.ops_decoder = OPS
def get_string_from_onehot_x(self, x):
x = torch.tensor(x)
ds = torch.sum(x.view(NUM_STAGE, -1), dim=1)
string = ''
for i, _ in enumerate(x):
if sum(_) == 0:
string += '0-0-0_'
else:
string += f'{int(ds[int(i/MAX_LAYER_PER_STAGE)])}-' + OPS2STR[torch.nonzero(torch.tensor(_)).item()] + '_'
return string[:-1]
def compute_validity(self, generated, adj=None, mask=None):
""" generated: list of couples (positions, node_types)"""
valid = []
valid_arch_str = []
all_arch_str = []
error_types = []
for x in generated:
is_valid, error_type = is_valid_OFA_x(x)
if is_valid:
valid.append(torch.tensor(x).long())
arch_str = self.get_string_from_onehot_x(x)
valid_arch_str.append(arch_str)
else:
arch_str = None
error_types.append(error_type)
all_arch_str.append(arch_str)
validity = 0 if len(generated) == 0 else (len(valid)/len(generated))
return valid, validity, valid_arch_str, all_arch_str, error_types
def compute_uniqueness(self, valid_arch):
unique = []
for x in valid_arch:
if not any([torch.equal(x, tr_m) for tr_m in unique]):
unique.append(x)
return unique, len(unique) / len(valid_arch)
def compute_novelty(self, unique):
num_novel = 0
novel = []
if self.train_arch_list is None:
print("Dataset arch_str is None, novelty computation skipped")
return 1, 1
for arch in unique:
if not any([torch.equal(arch, tr_m) for tr_m in self.train_arch_list]):
novel.append(arch)
num_novel += 1
return novel, num_novel / len(unique)
def evaluate(self, generated, adj, mask, check_dataname='imagenet1k'):
""" generated: list of pairs """
valid_arch, validity, _, _, error_types = self.compute_validity(generated, adj, mask)
print(f"Validity over {len(generated)} archs: {validity * 100 :.2f}%")
error_1 = torch.sum(torch.tensor(error_types) == 1) / len(generated)
error_2 = torch.sum(torch.tensor(error_types) == 2) / len(generated)
error_3 = torch.sum(torch.tensor(error_types) == 3) / len(generated)
print(f"Unvalid-Multi_Node_Type over {len(generated)} archs: {error_1 * 100 :.2f}%")
print(f"INVALID_1OR2 over {len(generated)} archs: {error_2 * 100 :.2f}%")
print(f"INVALID_3AND4 over {len(generated)} archs: {error_3 * 100 :.2f}%")
if validity > 0:
unique, uniqueness = self.compute_uniqueness(valid_arch)
print(f"Uniqueness over {len(valid_arch)} valid archs: {uniqueness * 100 :.2f}%")
if self.train_arch_list is not None:
_, novelty = self.compute_novelty(unique)
print(f"Novelty over {len(unique)} unique valid archs: {novelty * 100 :.2f}%")
else:
novelty = -1.0
else:
novelty = -1.0
uniqueness = 0.0
unique = []
test_acc_list, flops_list, params_list, latency_list = [0], [0], [0], [0]
all_arch_str = None
return ([validity, uniqueness, novelty, error_1, error_2, error_3],
unique,
dict(test_acc_list=test_acc_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
all_arch_str)
def get_arch_acc_info(nasbench201, arch, dataname='cifar10'):
arch_index = nasbench201['str'].index(arch)
test_acc = nasbench201['test-acc'][dataname][arch_index]
flops = nasbench201['flops'][dataname][arch_index]
params = nasbench201['params'][dataname][arch_index]
latency = nasbench201['latency'][dataname][arch_index]
return test_acc, flops, params, latency
def get_arch_acc_info_meta(nasbench201, arch, dataname='cifar10'):
arch_index = nasbench201['str'].index(arch)
flops = nasbench201['flops'][dataname][arch_index]
params = nasbench201['params'][dataname][arch_index]
latency = nasbench201['latency'][dataname][arch_index]
if 'cifar' in dataname:
test_acc = nasbench201['test-acc'][dataname][arch_index]
else:
# TODO
test_acc = None
return arch_index, test_acc, flops, params, latency
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
res = g.is_dag()
n_start, n_end = 0, 0
for v in g.vs:
if v['type'] == START_TYPE:
n_start += 1
elif v['type'] == END_TYPE:
n_end += 1
if v.indegree() == 0 and v['type'] != START_TYPE:
return False
if v.outdegree() == 0 and v['type'] != END_TYPE:
return False
return res and n_start == 1 and n_end == 1
def check_single_node_type(x):
for x_elem in x:
if int(np.sum(x_elem)) != 1:
return False
return True
def check_start_end_nodes(x, START_TYPE, END_TYPE):
if x[0][START_TYPE] != 1:
return False
if x[-1][END_TYPE] != 1:
return False
return True
def check_interm_node_types(x, START_TYPE, END_TYPE):
for x_elem in x[1:-1]:
if x_elem[START_TYPE] == 1:
return False
if x_elem[END_TYPE] == 1:
return False
return True
def construct_igraph(node_type, edge_type, ops_decoder, except_inout=True):
assert node_type.shape[0] == edge_type.shape[0]
START_TYPE = ops_decoder.index('input')
END_TYPE = ops_decoder.index('output')
g = igraph.Graph(directed=True)
for i, node in enumerate(node_type):
new_type = node.item()
g.add_vertex(type=new_type)
if new_type == END_TYPE:
end_vertices = set([v.index for v in g.vs.select(_outdegree_eq=0) if v.index != g.vcount()-1])
for v in end_vertices:
g.add_edge(v, i)
elif i > 0:
for ek in range(i):
ek_score = edge_type[ek][i].item()
if ek_score >= 0.5:
g.add_edge(ek, i)
return g
def compute_arch_metrics(arch_list, adj, mask, train_arch_str_list,
train_ds, timestep=None, name=None, except_inout=False, data_root=None):
""" arch_list: (dict) """
metrics = BasicArchMetricsOFA(data_root=data_root)
arch_metrics = metrics.evaluate(arch_list, adj, mask, check_dataname='cifar10')
all_arch_str = arch_metrics[-1]
if wandb.run:
arch_prop = arch_metrics[2]
test_acc_list = arch_prop['test_acc_list']
flops_list = arch_prop['flops_list']
params_list = arch_prop['params_list']
latency_list = arch_prop['latency_list']
if arch_metrics[0][1] > 0.: # uniquness > 0.
dic = {
'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
'test_acc_max': np.max(test_acc_list), 'test_acc_min': np.min(test_acc_list), 'test_acc_mean': np.mean(test_acc_list), 'test_acc_std': np.std(test_acc_list),
'flops_max': np.max(flops_list), 'flops_min': np.min(flops_list), 'flops_mean': np.mean(flops_list), 'flops_std': np.std(flops_list),
'params_max': np.max(params_list), 'params_min': np.min(params_list), 'params_mean': np.mean(params_list), 'params_std': np.std(params_list),
'latency_max': np.max(latency_list), 'latency_min': np.min(latency_list), 'latency_mean': np.mean(latency_list), 'latency_std': np.std(latency_list),
}
else:
dic = {
'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
'test_acc_max': -1, 'test_acc_min': -1, 'test_acc_mean': -1, 'test_acc_std': 0,
'flops_max': -1, 'flops_min': -1, 'flops_mean': -1, 'flops_std': 0,
'params_max': -1, 'params_min': -1, 'params_mean': -1, 'params_std': 0,
'latency_max': -1, 'latency_min': -1, 'latency_mean': -1, 'latency_std': 0,
}
if timestep is not None:
dic.update({'step': timestep})
wandb.log(dic)
return arch_metrics, all_arch_str
def compute_arch_metrics_meta(
arch_list, adj, mask, train_arch_str_list, train_ds,
timestep=None, check_dataname='cifar10', name=None):
""" arch_list: (dict) """
metrics = BasicArchMetricsMetaOFA(train_ds, train_arch_str_list)
arch_metrics = metrics.evaluate(arch_list, adj, mask, check_dataname=check_dataname)
if wandb.run:
arch_prop = arch_metrics[2]
if name != 'ofa':
arch_idx_list = arch_prop['arch_idx_list']
test_acc_list = arch_prop['test_acc_list']
flops_list = arch_prop['flops_list']
params_list = arch_prop['params_list']
latency_list = arch_prop['latency_list']
if arch_metrics[0][1] > 0.: # uniquness > 0.
dic = {
'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
'test_acc_max': np.max(test_acc_list), 'test_acc_min': np.min(test_acc_list), 'test_acc_mean': np.mean(test_acc_list), 'test_acc_std': np.std(test_acc_list),
'flops_max': np.max(flops_list), 'flops_min': np.min(flops_list), 'flops_mean': np.mean(flops_list), 'flops_std': np.std(flops_list),
'params_max': np.max(params_list), 'params_min': np.min(params_list), 'params_mean': np.mean(params_list), 'params_std': np.std(params_list),
'latency_max': np.max(latency_list), 'latency_min': np.min(latency_list), 'latency_mean': np.mean(latency_list), 'latency_std': np.std(latency_list),
}
else:
dic = {
'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
'test_acc_max': -1, 'test_acc_min': -1, 'test_acc_mean': -1, 'test_acc_std': 0,
'flops_max': -1, 'flops_min': -1, 'flops_mean': -1, 'flops_std': 0,
'params_max': -1, 'params_min': -1, 'params_mean': -1, 'params_std': 0,
'latency_max': -1, 'latency_min': -1, 'latency_mean': -1, 'latency_std': 0,
}
if timestep is not None:
dic.update({'step': timestep})
return arch_metrics
def check_multiple_nodes(x):
assert len(x.shape) == 2
for x_elem in x:
x_elem = np.array(x_elem)
if int(np.sum(x_elem)) > 1:
return False
return True
def check_inout_node(x, START_TYPE=0, END_TYPE=1):
assert len(x.shape) == 2
return x[0][START_TYPE] == 1 and x[-1][END_TYPE] == 1
def check_none_in_1_and_2_layers(x, NONE_TYPE=None):
assert len(x.shape) == 2
first_and_second_layers = [0, 1, 4, 5, 8, 9, 12, 13, 16, 17]
for layer in first_and_second_layers:
if int(np.sum(x[layer])) == 0:
return False
return True
def check_none_in_3_and_4_layers(x, NONE_TYPE=None):
assert len(x.shape) == 2
third_layers = [2, 6, 10, 14, 18]
for layer in third_layers:
if int(np.sum(x[layer])) == 0:
if int(np.sum(x[layer+1])) != 0:
return False
return True
def check_interm_inout_node(x, START_TYPE, END_TYPE):
for x_elem in x[1:-1]:
if x_elem[START_TYPE] == 1:
return False
if x_elem[END_TYPE] == 1:
return False
def is_valid_OFA_x(x):
ERORR = {
'MULIPLE_NODES': 1,
'INVALID_1OR2_LAYERS': 2,
'INVALID_3AND4_LAYERS': 3,
'NO_ERROR': -1
}
if not check_multiple_nodes(x):
return False, ERORR['MULIPLE_NODES']
if not check_none_in_1_and_2_layers(x):
return False, ERORR['INVALID_1OR2_LAYERS']
if not check_none_in_3_and_4_layers(x):
return False, ERORR['INVALID_3AND4_LAYERS']
return True, ERORR['NO_ERROR']
def get_x_adj_from_opsdict_ofa(ops):
node_types = torch.zeros(NUM_STAGE * MAX_LAYER_PER_STAGE).long() # w/o in / out
num_vertices = len(OPS.values())
num_nodes = NUM_STAGE * MAX_LAYER_PER_STAGE
d_matrix = []
for i in range(NUM_STAGE):
ds = ops['d'][i]
for j in range(ds):
d_matrix.append(ds)
for j in range(MAX_LAYER_PER_STAGE - ds):
d_matrix.append('none')
for i, (ks, e, d) in enumerate(zip(
ops['ks'], ops['e'], d_matrix)):
if d == 'none':
pass
else:
node_types[i] = OPS[f'{ks}-{e}']
x = one_hot(node_types, num_vertices).float()
def get_adj():
adj = torch.zeros(num_nodes, num_nodes)
for i in range(num_nodes-1):
adj[i, i+1] = 1
adj = np.array(adj)
return adj
adj = get_adj()
return x, adj
def get_string_from_onehot_x(x):
x = torch.tensor(x)
ds = torch.sum(x.view(NUM_STAGE, -1), dim=1)
string = ''
for i, _ in enumerate(x):
if sum(_) == 0:
string += '0-0-0_'
else:
string += f'{int(ds[int(i/MAX_LAYER_PER_STAGE)])}-' + OPS2STR[torch.nonzero(torch.tensor(_)).item()] + '_'
return string[:-1]

View File

@@ -0,0 +1,114 @@
from analysis.arch_functions import compute_arch_metrics, compute_arch_metrics_meta
from torch import Tensor
import wandb
import torch.nn as nn
class SamplingArchMetrics(nn.Module):
def __init__(self, config, train_ds, exp_name):
super().__init__()
self.exp_name = exp_name
self.train_ds = train_ds
if config.data.name == 'ofa':
self.train_arch_str_list = train_ds.x_list_
else:
self.train_arch_str_list = train_ds.arch_str_list_
self.name = config.data.name
self.except_inout = config.data.except_inout
self.data_root = config.data.root
def forward(self, arch_list: list, adj, mask, this_sample_dir, test=False, timestep=None):
"""_summary_
:params arch_list: list of archs
:params adj: [batch_size, num_nodes, num_nodes]
:params mask: [batch_size, num_nodes, num_nodes]
"""
arch_metrics, all_arch_str = compute_arch_metrics(
arch_list, adj, mask, self.train_arch_str_list, self.train_ds, timestep=timestep,
name=self.name, except_inout=self.except_inout, data_root=self.data_root)
# arch_metrics
# ([validity, uniqueness, novelty],
# unique,
# dict(test_acc_list=test_acc_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
# all_arch_str)
if test and self.name != 'ofa':
with open(r'final_.txt', 'w') as fp:
for arch_str in all_arch_str:
# write each item on a new line
fp.write("%s\n" % arch_str)
print('All archs saved')
if self.name != 'ofa':
valid_unique_arch = arch_metrics[1]
valid_unique_arch_prop_dict = arch_metrics[2] # test_acc, flops, params, latency
# textfile = open(f'{this_sample_dir}/archs/{name}/valid_unique_arch_step-{current_step}.txt', "w")
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
for i in range(len(valid_unique_arch)):
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
textfile.write(f"Test Acc: {valid_unique_arch_prop_dict['test_acc_list'][i]} \n")
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n ")
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n \n")
textfile.writelines(valid_unique_arch)
textfile.close()
# res_dic = {
# 'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
# 'test_acc_max': -1, 'test_acc_min':-1, 'test_acc_mean': -1, 'test_acc_std': 0,
# 'flops_max': -1, 'flops_min':-1, 'flops_mean': -1, 'flops_std': 0,
# 'params_max': -1, 'params_min':-1, 'params_mean': -1, 'params_std': 0,
# 'latency_max': -1, 'latency_min':-1, 'latency_mean': -1, 'latency_std': 0,
# }
return arch_metrics
class SamplingArchMetricsMeta(nn.Module):
def __init__(self, config, train_ds, exp_name, train_index=None, nasbench=None):
super().__init__()
self.exp_name = exp_name
self.train_ds = train_ds
self.search_space = config.data.name
if self.search_space == 'ofa':
self.train_arch_str_list = None
else:
self.train_arch_str_list = [train_ds.arch_str_list[i] for i in train_ds.idx_lst['train']]
def forward(self, arch_list: list, adj, mask, this_sample_dir, test=False,
timestep=None, check_dataname='cifar10'):
"""_summary_
:params arch_list: list of archs
:params adj: [batch_size, num_nodes, num_nodes]
:params mask: [batch_size, num_nodes, num_nodes]
"""
arch_metrics = compute_arch_metrics_meta(arch_list, adj, mask, self.train_arch_str_list,
self.train_ds, timestep=timestep, check_dataname=check_dataname,
name=self.search_space)
all_arch_str = arch_metrics[-1]
if test:
with open(r'final_.txt', 'w') as fp:
for arch_str in all_arch_str:
# write each item on a new line
fp.write("%s\n" % arch_str)
print('All archs saved')
valid_unique_arch = arch_metrics[1] # arch_str
valid_unique_arch_prop_dict = arch_metrics[2] # test_acc, flops, params, latency
# textfile = open(f'{this_sample_dir}/archs/{name}/valid_unique_arch_step-{current_step}.txt', "w")
if self.search_space != 'ofa':
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
for i in range(len(valid_unique_arch)):
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
textfile.write(f"Test Acc: {valid_unique_arch_prop_dict['test_acc_list'][i]} \n")
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n ")
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n \n")
textfile.writelines(valid_unique_arch)
textfile.close()
return arch_metrics

View File

@@ -0,0 +1,547 @@
import os
import torch
import imageio
import networkx as nx
import numpy as np
# import rdkit.Chem
import wandb
import matplotlib.pyplot as plt
# import igraph
# import pygraphviz as pgv
import datasets_nas
from configs.ckpt import DATAROOT_NB201
class ArchVisualization:
def __init__(self, config, remove_none=False, exp_name=None):
self.config = config
self.remove_none = remove_none
self.exp_name = exp_name
self.num_graphs_to_visualize = config.log.num_graphs_to_visualize
self.nasbench201 = torch.load(DATAROOT_NB201)
self.labels = {
0: 'input',
1: 'output',
2: 'conv3',
3: 'sep3',
4: 'conv5',
5: 'sep5',
6: 'avg3',
7: 'max3',
}
self.colors = ['skyblue', 'pink', 'yellow', 'orange', 'greenyellow', 'green', 'azure', 'beige']
def to_networkx_directed(self, node_list, adjacency_matrix):
"""
Convert graphs to neural architectures
node_list: the nodes of a batch of nodes (bs x n)
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
"""
graph = nx.DiGraph()
# add nodes to the graph
for i in range(len(node_list)):
if node_list[i] == -1:
continue
graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])
rows, cols = np.where(torch.triu(torch.tensor(adjacency_matrix), diagonal=1).numpy() >= 1)
edges = zip(rows.tolist(), cols.tolist())
for edge in edges:
edge_type = adjacency_matrix[edge[0]][edge[1]]
graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type)
return graph
def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=1200, largest_component=False):
if largest_component:
CGs = [graph.subgraph(c) for c in nx.connected_components(graph)]
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
graph = CGs[0]
# Plot the graph structure with colors
if pos is None:
pos = nx.nx_pydot.graphviz_layout(graph, prog="dot")
# pos = nx.multipartite_layout(graph, subset_key='number')
# pos = nx.spring_layout(graph, iterations=iterations)
# Set node colors based on the operations
plt.figure()
nx.draw(graph, pos=pos, labels=self.labels, arrows=True, node_shape="s",
node_size=node_size, node_color=self.colors, edge_color='grey', with_labels=True)
# nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1],
# cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey')
# import pdb; pdb.set_trace()
# plt.tight_layout()
plt.savefig(path)
plt.close("all")
def visualize(self, path: str, graphs: list, log='graph', adj=None):
# define path to save figures
os.makedirs(path, exist_ok=True)
# visualize the final molecules
for i in range(self.num_graphs_to_visualize):
file_path = os.path.join(path, 'graph_{}.png'.format(i))
graph = self.to_networkx_directed(graphs[i], adj[0].detach().cpu().numpy())
self.visualize_non_molecule(graph, pos=None, path=file_path)
im = plt.imread(file_path)
if wandb.run and log is not None:
wandb.log({log: [wandb.Image(im, caption=file_path)]})
def visualize_chain(self, path, sample_list, adjacency_matrix,
r_valid_chain, r_uniqueness_chain, r_novel_chain):
import pdb; pdb.set_trace()
# convert graphs to networkx
graphs = [self.to_networkx_directed(sample_list[i], adjacency_matrix[i]) for i in range(sample_list.shape[0])]
# find the coordinates of atoms in the final molecule
final_graph = graphs[-1]
final_pos = nx.nx_pydot.graphviz_layout(final_graph, prog="dot")
# final_pos = None
# draw gif
save_paths = []
num_frams = sample_list
for frame in range(num_frams):
file_name = os.path.join(path, 'frame_{}.png'.format(frame))
self.visualize_non_molecule(graphs[frame], pos=final_pos, path=file_name)
save_paths.append(file_name)
imgs = [imageio.imread(fn) for fn in save_paths]
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
print(f'==> Save gif at {gif_path}')
imgs.extend([imgs[-1]] * 10)
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
if wandb.run:
wandb.log({'chain': [wandb.Video(gif_path, caption=gif_path, format="gif")]})
def visualize_chain_vun(self, path, r_valid_chain, r_unique_chain, r_novel_chain, sde, sampling_eps, number_chain_steps=None):
os.makedirs(path, exist_ok=True)
# timesteps = torch.linspace(sampling_eps, sde.T, sde.N)
timesteps = torch.linspace(sde.T, sampling_eps, sde.N)
if number_chain_steps is not None:
timesteps_ = []
n = int(sde.N / number_chain_steps)
for i, t in enumerate(timesteps):
if i % n == n - 1:
timesteps_.append(t.item())
# timesteps_ = [t for i, t in enumerate(timesteps) if i % n == n-1]
assert len(timesteps_) == number_chain_steps
timesteps_ = timesteps_[::-1]
else:
timesteps_ = list(timesteps.numpy())[::-1]
# validity
plt.clf()
fig, ax = plt.subplots()
ax.plot(timesteps_, r_valid_chain, color='red')
ax.set_title(f'Validity')
ax.set_xlabel('time')
ax.set_ylabel('Validity')
plt.show()
file_path = os.path.join(path, 'validity.png')
plt.savefig(file_path)
plt.close("all")
print(f'==> Save scatter plot at {file_path}')
im = plt.imread(file_path)
if wandb.run:
wandb.log({'r_valid_chains': [wandb.Image(im, caption=file_path)]})
# Uniqueness
plt.clf()
fig, ax = plt.subplots()
ax.plot(timesteps_, r_unique_chain, color='green')
ax.set_title(f'Uniqueness')
ax.set_xlabel('time')
ax.set_ylabel('Uniqueness')
plt.show()
file_path = os.path.join(path, 'uniquness.png')
plt.savefig(file_path)
plt.close("all")
print(f'==> Save scatter plot at {file_path}')
im = plt.imread(file_path)
if wandb.run:
wandb.log({'r_uniqueness_chains': [wandb.Image(im, caption=file_path)]})
# Novelty
plt.clf()
fig, ax = plt.subplots()
ax.plot(timesteps_, r_novel_chain, color='blue')
ax.set_title(f'Novelty')
ax.set_xlabel('time')
ax.set_ylabel('Novelty')
file_path = os.path.join(path, 'novelty.png')
plt.savefig(file_path)
plt.close("all")
print(f'==> Save scatter plot at {file_path}')
im = plt.imread(file_path)
if wandb.run:
wandb.log({'r_novelty_chains': [wandb.Image(im, caption=file_path)]})
def visualize_grad_norm(self, path, score_grad_norm_p, classifier_grad_norm_p,
score_grad_norm_c, classifier_grad_norm_c, sde, sampling_eps,
number_chain_steps=None):
os.makedirs(path, exist_ok=True)
# timesteps = torch.linspace(sampling_eps, sde.T, sde.N)
timesteps = torch.linspace(sde.T, sampling_eps, sde.N)
timesteps_ = list(timesteps.numpy())[::-1]
if len(score_grad_norm_c) == 0:
score_grad_norm_c = [-1] * len(score_grad_norm_p)
if len(classifier_grad_norm_c) == 0:
classifier_grad_norm_c = [-1] * len(classifier_grad_norm_p)
plt.clf()
fig, ax1 = plt.subplots()
color_1 = 'red'
ax1.set_title(f'grad_norm (predictor)')
ax1.set_xlabel('time')
ax1.set_ylabel('score_grad_norm (predictor)', color=color_1)
ax1.plot(timesteps_, score_grad_norm_p, color=color_1)
ax1.tick_params(axis='y', labelcolor=color_1)
ax2 = ax1.twinx()
color_2 = 'blue'
ax2.set_ylabel('classifier_grad_norm (predictor)', color=color_2)
ax2.plot(timesteps_, classifier_grad_norm_p, color=color_2)
ax2.tick_params(axis='y', labelcolor=color_2)
fig.tight_layout()
plt.show()
file_path = os.path.join(path, 'grad_norm_p.png')
plt.savefig(file_path)
plt.close("all")
print(f'==> Save scatter plot at {file_path}')
im = plt.imread(file_path)
if wandb.run:
wandb.log({'grad_norm_p': [wandb.Image(im, caption=file_path)]})
plt.clf()
fig, ax1 = plt.subplots()
color_1 = 'green'
ax1.set_title(f'grad_norm (corrector)')
ax1.set_xlabel('time')
ax1.set_ylabel('score_grad_norm (corrector)', color=color_1)
ax1.plot(timesteps_, score_grad_norm_c, color=color_1)
ax1.tick_params(axis='y', labelcolor=color_1)
ax2 = ax1.twinx()
color_2 = 'yellow'
ax2.set_ylabel('classifier_grad_norm (corrector)', color=color_2)
ax2.plot(timesteps_, classifier_grad_norm_c, color=color_2)
ax2.tick_params(axis='y', labelcolor=color_2)
fig.tight_layout()
plt.show()
file_path = os.path.join(path, 'grad_norm_c.png')
plt.savefig(file_path)
plt.close("all")
print(f'==> Save scatter plot at {file_path}')
im = plt.imread(file_path)
if wandb.run:
wandb.log({'grad_norm_c': [wandb.Image(im, caption=file_path)]})
def visualize_scatter(self, path,
score_config, classifier_config,
sampled_arch_metric, plot_textstr=True,
x_axis='latency', y_axis='test-acc', x_label='Latency (ms)', y_label='Accuracy (%)',
log='scatter', check_dataname='cifar10-valid',
selected_arch_idx_list_topN=None, selected_arch_idx_list=None,
train_idx_list=None, return_file_path=False):
os.makedirs(path, exist_ok=True)
tg_dataset = classifier_config.data.tg_dataset
train_ds_s, eval_ds_s, test_ds_s = datasets_nas.get_dataset(score_config)
if selected_arch_idx_list is None:
train_ds_c, eval_ds_c, test_ds_c = datasets_nas.get_dataset(classifier_config)
else:
train_ds_c, eval_ds_c, test_ds_c = datasets_nas.get_dataset_iter(classifier_config)
plt.clf()
fig, ax = plt.subplots()
# entire architectures
entire_ds_x = train_ds_s.get_unnoramlized_entire_data(x_axis, tg_dataset)
entire_ds_y = train_ds_s.get_unnoramlized_entire_data(y_axis, tg_dataset)
ax.scatter(entire_ds_x, entire_ds_y, color = 'lightgray', alpha = 0.5, label='Entire', marker=',')
# architectures trained by the score_model
# train_ds_s_x = train_ds_s.get_unnoramlized_data(x_axis, tg_dataset)
# train_ds_s_y = train_ds_s.get_unnoramlized_data(y_axis, tg_dataset)
# ax.scatter(train_ds_s_x, train_ds_s_y, color = 'gray', alpha = 0.8, label='Trained by Score Model')
# architectures trained by the classifier
train_ds_c_x = train_ds_c.get_unnoramlized_data(x_axis, tg_dataset)
train_ds_c_y = train_ds_c.get_unnoramlized_data(y_axis, tg_dataset)
ax.scatter(train_ds_c_x, train_ds_c_y, color = 'black', alpha = 0.8, label='Trained by Predictor Model')
# oracle
oracle_idx = torch.argmax(torch.tensor(entire_ds_y)).item()
# oracle_idx = torch.argmax(torch.tensor(train_ds_s.get_unnoramlized_entire_data('val-acc', tg_dataset))).item()
oracle_item_x = entire_ds_x[oracle_idx]
oracle_item_y = entire_ds_y[oracle_idx]
ax.scatter(oracle_item_x, oracle_item_y, color = 'red', alpha = 1.0, label='Oracle', marker='*', s=150)
# architectures sampled by the score_model & classifier
AXIS_TO_PROP = {
'val-acc': 'val_acc_list',
'test-acc': 'test_acc_list',
'latency': 'latency_list',
'flops': 'flops_list',
'params': 'params_list',
}
sampled_ds_c_x = sampled_arch_metric[2][AXIS_TO_PROP[x_axis]]
sampled_ds_c_y = sampled_arch_metric[2][AXIS_TO_PROP[y_axis]]
ax.scatter(sampled_ds_c_x, sampled_ds_c_y, color = 'limegreen', alpha = 0.8, label='Sampled', marker='x')
ax.set_title(f'{tg_dataset.upper()} Dataset')
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
if selected_arch_idx_list_topN is not None:
selected_arch_topN_info_dict = get_arch_acc_info_dict(
self.nasbench201, dataname=check_dataname, arch_index_list=selected_arch_idx_list_topN)
selected_topN_ds_x = selected_arch_topN_info_dict[AXIS_TO_PROP[x_axis]]
selected_topN_ds_y = selected_arch_topN_info_dict[AXIS_TO_PROP[y_axis]]
ax.scatter(selected_topN_ds_x, selected_topN_ds_y, color = 'pink', alpha = 0.8, label='Selected_topN', marker='x')
# architectures selected by the prdictor
selected_ds_x, selected_ds_y = None, None
if selected_arch_idx_list is not None:
selected_arch_info_dict = get_arch_acc_info_dict(
self.nasbench201, dataname=check_dataname, arch_index_list=selected_arch_idx_list)
selected_ds_x = selected_arch_info_dict[AXIS_TO_PROP[x_axis]]
selected_ds_y = selected_arch_info_dict[AXIS_TO_PROP[y_axis]]
ax.scatter(selected_ds_x, selected_ds_y, color = 'blue', alpha = 0.8, label='Selected', marker='x')
if plot_textstr:
textstr = self.get_textstr(sampled_arch_metric=sampled_arch_metric,
sampled_ds_c_x=sampled_ds_c_x, sampled_ds_c_y=sampled_ds_c_y,
x_axis=x_axis, y_axis=y_axis,
classifier_config=classifier_config,
selected_ds_x=selected_ds_x, selected_ds_y=selected_ds_y,
selected_topN_ds_x=selected_topN_ds_x, selected_topN_ds_y=selected_topN_ds_y,
oracle_idx=oracle_idx, train_idx_list=train_idx_list
)
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
ax.text(0.6, 0.4, textstr, transform=ax.transAxes, verticalalignment='bottom', bbox=props, fontsize='x-small')
# ax.text(textstr, transform=ax.transAxes, verticalalignment='bottom', bbox=props)
ax.legend(loc="lower right")
plt.subplots_adjust(left=0, bottom=0, right=1, top=1)
plt.show()
plt.tight_layout()
file_path = os.path.join(path, 'scatter.png')
plt.savefig(file_path)
plt.close("all")
print(f'==> Save scatter plot at {path}')
if return_file_path:
return file_path
im = plt.imread(file_path)
if wandb.run and log is not None:
wandb.log({log: [wandb.Image(im, caption=file_path)]})
# if return_selected_arch_info_dict:
# return selected_arch_info_dict, selected_arch_topN_info_dict
def visualize_scatter_chain(self, path, score_config, classifier_config, sampled_arch_metric_chain, plot_textstr=True,
x_axis='latency', y_axis='test-acc', x_label='Latency (ms)', y_label='Accuracy (%)',
log='scatter_chain'):
# draw gif
os.makedirs(path, exist_ok=True)
save_paths = []
num_frames = len(sampled_arch_metric_chain)
tg_dataset = classifier_config.data.tg_dataset
train_ds_s, eval_ds_s, test_ds_s = datasets_nas.get_dataset(score_config)
train_ds_c, eval_ds_c, test_ds_c = datasets_nas.get_dataset(classifier_config)
# entire architectures
entire_ds_x = train_ds_s.get_unnoramlized_entire_data(x_axis, tg_dataset)
entire_ds_y = train_ds_s.get_unnoramlized_entire_data(y_axis, tg_dataset)
# architectures trained by the score_model
train_ds_s_x = train_ds_s.get_unnoramlized_data(x_axis, tg_dataset)
train_ds_s_y = train_ds_s.get_unnoramlized_data(y_axis, tg_dataset)
# architectures trained by the classifier
train_ds_c_x = train_ds_c.get_unnoramlized_data(x_axis, tg_dataset)
train_ds_c_y = train_ds_c.get_unnoramlized_data(y_axis, tg_dataset)
# oracle
# oracle_idx = torch.argmax(torch.tensor(entire_ds_y)).item()
oracle_idx = torch.argmax(torch.tensor(train_ds_s.get_unnoramlized_entire_data('val-acc', tg_dataset))).item()
oracle_item_x = entire_ds_x[oracle_idx]
oracle_item_y = entire_ds_y[oracle_idx]
for frame in range(num_frames):
sampled_arch_metric = sampled_arch_metric_chain[frame]
plt.clf()
fig, ax = plt.subplots()
# entire architectures
ax.scatter(entire_ds_x, entire_ds_y, color = 'lightgray', alpha = 0.5, label='Entire', marker=',')
# architectures trained by the score_model
ax.scatter(train_ds_s_x, train_ds_s_y, color = 'gray', alpha = 0.8, label='Trained by Score Model')
# architectures trained by the classifier
ax.scatter(train_ds_c_x, train_ds_c_y, color = 'black', alpha = 0.8, label='Trained by Predictor Model')
# oracle
ax.scatter(oracle_item_x, oracle_item_y, color = 'red', alpha = 1.0, label='Oracle', marker='*', s=150)
# architectures sampled by the score_model & classifier
AXIS_TO_PROP = {
'test-acc': 'test_acc_list',
'latency': 'latency_list',
'flops': 'flops_list',
'params': 'params_list',
}
sampled_ds_c_x = sampled_arch_metric[2][AXIS_TO_PROP[x_axis]]
sampled_ds_c_y = sampled_arch_metric[2][AXIS_TO_PROP[y_axis]]
ax.scatter(sampled_ds_c_x, sampled_ds_c_y, color = 'limegreen', alpha = 0.8, label='Sampled', marker='x')
ax.set_title(f'{tg_dataset.upper()} Dataset')
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
if plot_textstr:
textstr = self.get_textstr(sampled_arch_metric, sampled_ds_c_x, sampled_ds_c_y,
x_axis, y_axis, classifier_config)
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
ax.text(0.6, 0.3, textstr, transform=ax.transAxes, verticalalignment='bottom', bbox=props)
# ax.text(textstr, transform=ax.transAxes, verticalalignment='bottom', bbox=props)
ax.legend(loc="lower right")
plt.subplots_adjust(left=0, bottom=0, right=1, top=1)
plt.show()
# plt.tight_layout()
file_path = os.path.join(path, f'frame_{frame}.png')
plt.savefig(file_path)
plt.close("all")
print(f'==> Save scatter plot at {file_path}')
save_paths.append(file_path)
im = plt.imread(file_path)
if wandb.run and log is not None:
wandb.log({log: [wandb.Image(im, caption=file_path)]})
# draw gif
imgs = [imageio.imread(fn) for fn in save_paths[::-1]]
# gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
gif_path = os.path.join(path, f'scatter.gif')
print(f'==> Save gif at {gif_path}')
imgs.extend([imgs[-1]] * 10)
# imgs.extend([imgs[0]] * 10)
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
if wandb.run:
wandb.log({'chain_gif': [wandb.Video(gif_path, caption=gif_path, format="gif")]})
def get_textstr(self,
sampled_arch_metric,
sampled_ds_c_x, sampled_ds_c_y,
x_axis='latency', y_axis='test-acc',
classifier_config=None,
selected_ds_x=None, selected_ds_y=None,
selected_topN_ds_x=None, selected_topN_ds_y=None,
oracle_idx=None, train_idx_list=None):
mean_v_x = round(np.mean(np.array(sampled_ds_c_x)), 4)
std_v_x = round(np.std(np.array(sampled_ds_c_x)), 4)
max_v_x = round(np.max(np.array(sampled_ds_c_x)), 4)
min_v_x = round(np.min(np.array(sampled_ds_c_x)), 4)
mean_v_y = round(np.mean(np.array(sampled_ds_c_y)), 4)
std_v_y = round(np.std(np.array(sampled_ds_c_y)), 4)
max_v_y = round(np.max(np.array(sampled_ds_c_y)), 4)
min_v_y = round(np.min(np.array(sampled_ds_c_y)), 4)
if selected_ds_x is not None:
mean_v_x_s = round(np.mean(np.array(selected_ds_x)), 4)
std_v_x_s = round(np.std(np.array(selected_ds_x)), 4)
max_v_x_s = round(np.max(np.array(selected_ds_x)), 4)
min_v_x_s = round(np.min(np.array(selected_ds_x)), 4)
if selected_ds_y is not None:
mean_v_y_s = round(np.mean(np.array(selected_ds_y)), 4)
std_v_y_s = round(np.std(np.array(selected_ds_y)), 4)
max_v_y_s = round(np.max(np.array(selected_ds_y)), 4)
min_v_y_s = round(np.min(np.array(selected_ds_y)), 4)
textstr = ''
r_valid, r_unique, r_novel = round(sampled_arch_metric[0][0], 4), round(sampled_arch_metric[0][1], 4), round(sampled_arch_metric[0][2], 4)
textstr += f'V-{r_valid} | U-{r_unique} | N-{r_novel} \n'
textstr += f'Predictor (Noise-aware-{str(classifier_config.training.noised)[0]}, k={self.config.sampling.classifier_scale}) \n'
textstr += f'=> Sampled {x_axis} \n'
textstr += f'Mean-{mean_v_x} | Std-{std_v_x} \n'
textstr += f'Max-{max_v_x} | Min-{min_v_x} \n'
textstr += f'=> Sampled {y_axis} \n'
textstr += f'Mean-{mean_v_y} | Std-{std_v_y} \n'
textstr += f'Max-{max_v_y} | Min-{min_v_y} \n'
if selected_ds_x is not None:
textstr += f'==> Selected {x_axis} \n'
textstr += f'Mean-{mean_v_x_s} | Std-{std_v_x_s} \n'
textstr += f'Max-{max_v_x_s} | Min-{min_v_x_s} \n'
if selected_ds_y is not None:
textstr += f'==> Selected {y_axis} \n'
textstr += f'Mean-{mean_v_y_s} | Std-{std_v_y_s} \n'
textstr += f'Max-{max_v_y_s} | Min-{min_v_y_s} \n'
if selected_topN_ds_y is not None:
textstr += f'==> Predicted TopN (10) -{str(round(max(selected_topN_ds_y[:10]), 4))} \n'
if train_idx_list is not None and oracle_idx in train_idx_list:
textstr += f'==> Hit Oracle ({oracle_idx}) !'
return textstr
def get_arch_acc_info_dict(nasbench201, dataname='cifar10-valid', arch_index_list=None):
val_acc_list = []
test_acc_list = []
flops_list = []
params_list = []
latency_list = []
for arch_index in arch_index_list:
val_acc = nasbench201['val-acc'][dataname][arch_index]
val_acc_list.append(val_acc)
test_acc = nasbench201['test-acc'][dataname][arch_index]
test_acc_list.append(test_acc)
flops = nasbench201['flops'][dataname][arch_index]
flops_list.append(flops)
params = nasbench201['params'][dataname][arch_index]
params_list.append(params)
latency = nasbench201['latency'][dataname][arch_index]
latency_list.append(latency)
return {
'val_acc_list': val_acc_list,
'test_acc_list': test_acc_list,
'flops_list': flops_list,
'params_list': params_list,
'latency_list': latency_list
}

View File

@@ -0,0 +1,167 @@
import ml_collections
import torch
from all_path import SCORE_MODEL_CKPT_PATH, SCORE_MODEL_DATA_PATH
def get_config():
config = ml_collections.ConfigDict()
config.search_space = None
# genel
config.resume = False
config.folder_name = 'DiffusionNAG'
config.task = 'tr_meta_predictor'
config.exp_name = None
config.model_type = 'meta_predictor'
config.scorenet_ckpt_path = SCORE_MODEL_CKPT_PATH
config.is_meta = True
# training
config.training = training = ml_collections.ConfigDict()
training.sde = 'vesde'
training.continuous = True
training.reduce_mean = True
training.noised = True
training.batch_size = 128
training.eval_batch_size = 512
training.n_iters = 20000
training.snapshot_freq = 500
training.log_freq = 500
training.eval_freq = 500
## store additional checkpoints for preemption
training.snapshot_freq_for_preemption = 1000
## produce samples at each snapshot.
training.snapshot_sampling = True
training.likelihood_weighting = False
# training for perturbed data
training.t_spot = 1.
# training from pretrained score model
training.load_pretrained = False
training.pretrained_model_path = SCORE_MODEL_CKPT_PATH
# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.method = 'pc'
sampling.predictor = 'euler_maruyama'
sampling.corrector = 'none'
# sampling.corrector = 'langevin'
sampling.rtol = 1e-5
sampling.atol = 1e-5
sampling.ode_method = 'dopri5' # 'rk4'
sampling.ode_step = 0.01
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.16
sampling.vis_row = 4
sampling.vis_col = 4
# conditional
sampling.classifier_scale = 1.0
sampling.regress = True
sampling.labels = 'max'
sampling.weight_ratio = False
sampling.weight_scheduling = False
sampling.t_spot = 1.
sampling.t_spot_end = 0.
sampling.number_chain_steps = 50
sampling.check_dataname = 'imagenet1k'
# evaluation
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.begin_ckpt = 5
evaluate.end_ckpt = 20
# evaluate.batch_size = 512
evaluate.batch_size = 128
evaluate.enable_sampling = True
evaluate.num_samples = 1024
evaluate.mmd_distance = 'RBF'
evaluate.max_subgraph = False
evaluate.save_graph = False
# data
config.data = data = ml_collections.ConfigDict()
data.centered = True
data.dequantization = False
data.root = SCORE_MODEL_DATA_PATH
data.name = 'ofa'
data.split_ratio = 0.8
data.dataset_idx = 'random'
data.max_node = 20
data.n_vocab = 9
data.START_TYPE = 0
data.END_TYPE = 1
data.num_graphs = 100000
data.num_channels = 1
data.except_inout = False # ignore
data.triu_adj = True
data.connect_prev = False
data.tg_dataset = None
data.label_list = ['meta-acc']
# aug_mask
data.aug_mask_algo = 'none' # 'long_range' | 'floyd'
# num_train
data.num_train = 150
# model
config.model = model = ml_collections.ConfigDict()
model.name = 'MetaPredictorCATE'
model.ema_rate = 0.9999
model.normalization = 'GroupNorm'
model.nonlinearity = 'swish'
model.nf = 128
model.num_gnn_layers = 4
model.size_cond = False
model.embedding_type = 'positional'
model.rw_depth = 16
model.graph_layer = 'PosTransLayer'
model.edge_th = -1.
model.heads = 8
model.attn_clamp = False
#############################################################################
# meta
model.input_type = 'DA'
model.hs = 512
model.nz = 56
model.num_sample = 20
model.num_scales = 1000
model.beta_min = 0.1
model.beta_max = 5.0
model.sigma_min = 0.1
model.sigma_max = 5.0
model.dropout = 0.1
# graph encoder
config.model.graph_encoder = graph_encoder = ml_collections.ConfigDict()
graph_encoder.n_layers = 2
graph_encoder.d_model = 64
graph_encoder.n_head = 2
graph_encoder.d_ff = 32
graph_encoder.dropout = 0.1
graph_encoder.n_vocab = 9
# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 0.001
optim.beta1 = 0.9
optim.eps = 1e-8
optim.warmup = 1000
optim.grad_clip = 1.
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# log
config.log = log = ml_collections.ConfigDict()
log.use_wandb = True
log.wandb_project_name = 'DiffusionNAG'
log.log_valid_sample_prop = False
log.num_graphs_to_visualize = 20
return config

View File

@@ -0,0 +1,141 @@
"""Training PGSN on Community Small Dataset with GraphGDP"""
import ml_collections
import torch
def get_config():
config = ml_collections.ConfigDict()
# general
config.resume = False
config.resume_ckpt_path = './exp'
config.folder_name = 'tr_scorenet'
config.task = 'tr_scorenet'
config.exp_name = None
config.model_type = 'sde'
# training
config.training = training = ml_collections.ConfigDict()
training.sde = 'vesde'
training.continuous = True
training.reduce_mean = True
training.batch_size = 256
training.eval_batch_size = 1000
training.n_iters = 1000000
training.snapshot_freq = 10000
training.log_freq = 200
training.eval_freq = 10000
## store additional checkpoints for preemption
training.snapshot_freq_for_preemption = 5000
## produce samples at each snapshot.
training.snapshot_sampling = True
training.likelihood_weighting = False
# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.method = 'pc'
sampling.predictor = 'euler_maruyama'
sampling.corrector = 'none'
sampling.rtol = 1e-5
sampling.atol = 1e-5
sampling.ode_method = 'dopri5' # 'rk4'
sampling.ode_step = 0.01
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.16
sampling.vis_row = 4
sampling.vis_col = 4
sampling.alpha = 0.5
sampling.qtype = 'threshold'
# evaluation
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.begin_ckpt = 5
evaluate.end_ckpt = 20
evaluate.batch_size = 1024
evaluate.enable_sampling = True
evaluate.num_samples = 1024
evaluate.mmd_distance = 'RBF'
evaluate.max_subgraph = False
evaluate.save_graph = False
# data
config.data = data = ml_collections.ConfigDict()
data.centered = True
data.dequantization = False
data.root = './data/ofa/data_score_model/ofa_database_500000.pt'
data.name = 'ofa'
data.split_ratio = 0.9
data.dataset_idx = 'random'
data.max_node = 20
data.n_vocab = 9 # 10 #
data.START_TYPE = 0
data.END_TYPE = 1
data.num_graphs = 100000
data.num_channels = 1
data.except_inout = False
data.triu_adj = True
data.connect_prev = False
data.label_list = None
data.tg_dataset = None
data.node_rule_type = 2
# aug_mask
data.aug_mask_algo = 'none'
# model
config.model = model = ml_collections.ConfigDict()
model.name = 'CATE'
model.ema_rate = 0.9999
model.normalization = 'GroupNorm'
model.nonlinearity = 'swish'
model.nf = 128
model.num_gnn_layers = 4
model.size_cond = False
model.embedding_type = 'positional'
model.rw_depth = 16
model.graph_layer = 'PosTransLayer'
model.edge_th = -1.
model.heads = 8
model.attn_clamp = False
model.num_scales = 1000
model.sigma_min = 0.1
model.sigma_max = 1.0
model.dropout = 0.1
model.pos_enc_type = 2
# graph encoder
config.model.graph_encoder = graph_encoder = ml_collections.ConfigDict()
graph_encoder.n_layers = 12
graph_encoder.d_model = 64
graph_encoder.n_head = 8
graph_encoder.d_ff = 128
graph_encoder.dropout = 0.1
graph_encoder.n_vocab = 9 #10 # 30
# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 2e-5
optim.beta1 = 0.9
optim.eps = 1e-8
optim.warmup = 1000
optim.grad_clip = 1.
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# log
config.log = log = ml_collections.ConfigDict()
log.use_wandb = True
log.wandb_project_name = 'DiffusionNAG'
log.log_valid_sample_prop = False
log.num_graphs_to_visualize = 20
return config

493
MobileNetV3/datasets_nas.py Normal file
View File

@@ -0,0 +1,493 @@
from __future__ import print_function
import torch
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch_geometric.utils import to_networkx
from analysis.arch_functions import get_x_adj_from_opsdict_ofa, get_string_from_onehot_x
from all_path import PROCESSED_DATA_PATH, SCORE_MODEL_DATA_IDX_PATH
from analysis.arch_functions import OPS
def get_data_scaler(config):
"""Data normalizer. Assume data are always in [0, 1]."""
if config.data.centered:
# Rescale to [-1, 1]
return lambda x: x * 2. - 1.
else:
return lambda x: x
def get_data_inverse_scaler(config):
"""Inverse data normalizer."""
if config.data.centered:
# Rescale [-1, 1] to [0, 1]
return lambda x: (x + 1.) / 2.
else:
return lambda x: x
def networkx_graphs(dataset):
return [to_networkx(dataset[i], to_undirected=False, remove_self_loops=True) for i in range(len(dataset))]
def get_dataloader(config, train_dataset, eval_dataset, test_dataset):
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.training.batch_size,
shuffle=True,
collate_fn=collate_fn_ofa if config.model_type == 'meta_predictor' else None)
eval_loader = DataLoader(dataset=eval_dataset,
batch_size=config.training.batch_size,
shuffle=False,
collate_fn=collate_fn_ofa if config.model_type == 'meta_predictor' else None)
test_loader = DataLoader(dataset=test_dataset,
batch_size=config.training.batch_size,
shuffle=False,
collate_fn=collate_fn_ofa if config.model_type == 'meta_predictor' else None)
return train_loader, eval_loader, test_loader
def get_dataloader_iter(config, train_dataset, eval_dataset, test_dataset):
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.training.batch_size if len(train_dataset) > config.training.batch_size else len(train_dataset),
# batch_size=8,
shuffle=True,)
eval_loader = DataLoader(dataset=eval_dataset,
batch_size=config.training.batch_size if len(eval_dataset) > config.training.batch_size else len(eval_dataset),
# batch_size=8,
shuffle=False,)
test_loader = DataLoader(dataset=test_dataset,
batch_size=config.training.batch_size if len(test_dataset) > config.training.batch_size else len(test_dataset),
# batch_size=8,
shuffle=False,)
return train_loader, eval_loader, test_loader
def is_triu(mat):
is_triu_ = np.allclose(mat, np.triu(mat))
return is_triu_
def collate_fn_ofa(batch):
# x, adj, label_dict, task
x = torch.stack([item[0] for item in batch])
adj = torch.stack([item[1] for item in batch])
label_dict = {}
for item in batch:
for k, v in item[2].items():
if not k in label_dict.keys():
label_dict[k] = []
label_dict[k].append(v)
for k, v in label_dict.items():
label_dict[k] = torch.tensor(v)
task = [item[3] for item in batch]
return x, adj, label_dict, task
def get_dataset(config):
"""Create data loaders for training and evaluation.
Args:
config: A ml_collection.ConfigDict parsed from config files.
Returns:
train_ds, eval_ds, test_ds
"""
num_train = config.data.num_train if 'num_train' in config.data else None
NASDataset = OFADataset
train_dataset = NASDataset(
config.data.root,
config.data.split_ratio,
config.data.except_inout,
config.data.triu_adj,
config.data.connect_prev,
'train',
config.data.label_list,
config.data.tg_dataset,
config.data.dataset_idx,
num_train,
node_rule_type=config.data.node_rule_type)
eval_dataset = NASDataset(
config.data.root,
config.data.split_ratio,
config.data.except_inout,
config.data.triu_adj,
config.data.connect_prev,
'eval',
config.data.label_list,
config.data.tg_dataset,
config.data.dataset_idx,
num_train,
node_rule_type=config.data.node_rule_type)
test_dataset = NASDataset(
config.data.root,
config.data.split_ratio,
config.data.except_inout,
config.data.triu_adj,
config.data.connect_prev,
'test',
config.data.label_list,
config.data.tg_dataset,
config.data.dataset_idx,
num_train,
node_rule_type=config.data.node_rule_type)
return train_dataset, eval_dataset, test_dataset
def get_meta_dataset(config):
database = MetaTrainDatabaseOFA
data_path = PROCESSED_DATA_PATH
train_dataset = database(
data_path,
config.model.num_sample,
config.data.label_list,
True,
config.data.except_inout,
config.data.triu_adj,
config.data.connect_prev,
'train')
eval_dataset = database(
data_path,
config.model.num_sample,
config.data.label_list,
True,
config.data.except_inout,
config.data.triu_adj,
config.data.connect_prev,
'val')
# test_dataset = MetaTestDataset()
test_dataset = None
return train_dataset, eval_dataset, test_dataset
def get_meta_dataloader(config ,train_dataset, eval_dataset, test_dataset):
if config.data.name == 'ofa':
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.training.batch_size,
shuffle=True,)
# collate_fn=collate_fn_ofa)
eval_loader = DataLoader(dataset=eval_dataset,
batch_size=config.training.batch_size,)
# collate_fn=collate_fn_ofa)
else:
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.training.batch_size,
shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset,
batch_size=config.training.batch_size,
shuffle=False)
# test_loader = DataLoader(dataset=test_dataset,
# batch_size=config.training.batch_size,
# shuffle=False)
test_loader = None
return train_loader, eval_loader, test_loader
class MetaTestDataset(Dataset):
def __init__(self, data_path, data_name, num_sample, num_class=None):
self.num_sample = num_sample
self.data_name = data_name
num_class_dict = {
'cifar100': 100,
'cifar10': 10,
'mnist': 10,
'svhn': 10,
'aircraft': 30,
'pets': 37
}
if num_class is not None:
self.num_class = num_class
else:
self.num_class = num_class_dict[data_name]
self.x = torch.load(os.path.join(data_path, f'aircraft100bylabel.pt' if 'ofa' in data_path and data_name == 'aircraft' else f'{data_name}bylabel.pt' ))
def __len__(self):
return 1000000
def __getitem__(self, index):
data = []
classes = list(range(self.num_class))
for cls in classes:
cx = self.x[cls][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
return x
class MetaTrainDatabaseOFA(Dataset):
# def __init__(self, data_path, num_sample, is_pred=False):
def __init__(
self,
data_path,
num_sample,
label_list,
is_pred=True,
except_inout=False,
triu_adj=True,
connect_prev=False,
mode='train'):
self.ops_decoder = list(OPS.keys())
self.mode = mode
self.acc_norm = True
self.num_sample = num_sample
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
if is_pred:
self.dpath = f'{data_path}/predictor/processed/'
else:
raise NotImplementedError
self.dname = 'database_219152_14.0K'
data = torch.load(self.dpath + f'{self.dname}_{self.mode}.pt')
self.net = data['net']
self.x_list = []
self.adj_list = []
self.arch_str_list = []
for net in self.net:
x, adj = get_x_adj_from_opsdict_ofa(net)
# ---------- matrix ---------- #
self.x_list.append(x)
self.adj_list.append(torch.tensor(adj))
# ---------- arch_str ---------- #
self.arch_str_list.append(get_string_from_onehot_x(x))
# ---------- labels ---------- #
self.label_list = label_list
if self.label_list is not None:
self.flops_list = data['flops']
self.params_list = None
self.latency_list = None
self.acc_list = data['acc']
self.mean = data['mean']
self.std = data['std']
self.task_lst = data['class']
def __len__(self):
return len(self.acc_list)
def __getitem__(self, index):
data = []
classes = self.task_lst[index]
acc = self.acc_list[index]
graph = self.net[index]
# ---------- x -----------
x = self.x_list[index]
# ---------- adj ----------
adj = self.adj_list[index]
acc = self.acc_list[index]
for i, cls in enumerate(classes):
cx = self.x[cls.item()][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
task = torch.cat(data)
if self.acc_norm:
acc = ((acc - self.mean) / self.std) / 100.0
else:
acc = acc / 100.0
label_dict = {}
if self.label_list is not None:
assert type(self.label_list) == list
for label in self.label_list:
if label == 'meta-acc':
label_dict[f"{label}"] = acc
else:
raise ValueError
return x, adj, label_dict, task
class OFADataset(Dataset):
def __init__(
self,
data_path,
split_ratio=0.8,
except_inout=False,
triu_adj=True,
connect_prev=False,
mode='train',
label_list=None,
tg_dataset=None,
dataset_idx='random',
num_train=None,
node_rule_type=None):
# ---------- entire dataset ---------- #
self.data = torch.load(data_path)
self.except_inout = except_inout
self.triu_adj = triu_adj
self.connect_prev = connect_prev
self.node_rule_type = node_rule_type
# ---------- x ---------- #
self.x_list = self.data['x_none2zero']
# ---------- adj ---------- #
assert self.connect_prev == False
self.n_adj = len(self.data['node_type'][0])
const_adj = self.get_not_connect_prev_adj()
self.adj_list = [const_adj] * len(self.x_list)
# ---------- arch_str ---------- #
self.arch_str_list = self.data['net_setting']
# ---------- labels ---------- #
self.label_list = label_list
if self.label_list is not None:
raise NotImplementedError
# ----------- split dataset ---------- #
self.ds_idx = list(torch.load(SCORE_MODEL_DATA_IDX_PATH))
self.split_ratio = split_ratio
if num_train is None:
num_train = int(len(self.x_list) * self.split_ratio)
num_test = len(self.x_list) - num_train
else:
num_train = num_train
num_test = len(self.x_list) - num_train
# ----------- compute mean and std w/ training dataset ---------- #
if self.label_list is not None:
self.train_idx_list = self.ds_idx[:num_train]
print('Computing mean and std of the training set...')
from collections import defaultdict
LABEL_TO_MEAN_STD = defaultdict(dict)
assert type(self.label_list) == list
for label in self.label_list:
if label == 'test-acc':
self.test_acc_list_tr = [self.test_acc_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.test_acc_list_tr))
elif label == 'flops':
self.flops_list_tr = [self.flops_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.flops_list_tr))
elif label == 'params':
self.params_list_tr = [self.params_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.params_list_tr))
elif label == 'latency':
self.latency_list_tr = [self.latency_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.latency_list_tr))
else:
raise ValueError
self.mode = mode
if self.mode in ['train']:
self.idx_list = self.ds_idx[:num_train]
elif self.mode in ['eval']:
self.idx_list = self.ds_idx[:num_test]
elif self.mode in ['test']:
self.idx_list = self.ds_idx[num_train:]
self.x_list_ = [self.x_list[i] for i in self.idx_list]
self.adj_list_ = [self.adj_list[i] for i in self.idx_list]
self.arch_str_list_ = [self.arch_str_list[i] for i in self.idx_list]
if self.label_list is not None:
assert type(self.label_list) == list
for label in self.label_list:
if label == 'test-acc':
self.test_acc_list_ = [self.test_acc_list[i] for i in self.idx_list]
self.test_acc_list_ = self.normalize(self.test_acc_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'flops':
self.flops_list_ = [self.flops_list[i] for i in self.idx_list]
self.flops_list_ = self.normalize(self.flops_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'params':
self.params_list_ = [self.params_list[i] for i in self.idx_list]
self.params_list_ = self.normalize(self.params_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'latency':
self.latency_list_ = [self.latency_list[i] for i in self.idx_list]
self.latency_list_ = self.normalize(self.latency_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
else:
raise ValueError
def normalize(self, original, mean, std):
return [(i-mean)/std for i in original]
def get_not_connect_prev_adj(self):
_adj = torch.zeros(self.n_adj, self.n_adj)
for i in range(self.n_adj-1):
_adj[i, i+1] = 1
_adj = _adj.to(torch.float32).to('cpu') # torch.tensor(_adj, dtype=torch.float32, device=torch.device('cpu'))
# if self.except_inout:
# _adj = _adj[1:-1, 1:-1]
return _adj
@property
def adj(self):
return self.adj_list_[0]
# @property
def mask(self, algo='floyd', data='ofa'):
from utils import aug_mask
return aug_mask(self.adj, algo=algo, data=data)[0]
def get_unnoramlized_entire_data(self, label, tg_dataset):
entire_test_acc_list = self.data['test-acc'][tg_dataset]
entire_flops_list = self.data['flops'][tg_dataset]
entire_params_list = self.data['params'][tg_dataset]
entire_latency_list = self.data['latency'][tg_dataset]
if label == 'test-acc':
return entire_test_acc_list
elif label == 'flops':
return entire_flops_list
elif label == 'params':
return entire_params_list
elif label == 'latency':
return entire_latency_list
else:
raise ValueError
def get_unnoramlized_data(self, label, tg_dataset):
entire_test_acc_list = self.data['test-acc'][tg_dataset]
entire_flops_list = self.data['flops'][tg_dataset]
entire_params_list = self.data['params'][tg_dataset]
entire_latency_list = self.data['latency'][tg_dataset]
if label == 'test-acc':
return [entire_test_acc_list[i] for i in self.idx_list]
elif label == 'flops':
return [entire_flops_list[i] for i in self.idx_list]
elif label == 'params':
return [entire_params_list[i] for i in self.idx_list]
elif label == 'latency':
return [entire_latency_list[i] for i in self.idx_list]
else:
raise ValueError
def __len__(self):
return len(self.x_list_)
def __getitem__(self, index):
label_dict = {}
if self.label_list is not None:
assert type(self.label_list) == list
for label in self.label_list:
if label == 'test-acc':
label_dict[f"{label}"] = self.test_acc_list_[index]
elif label == 'flops':
label_dict[f"{label}"] = self.flops_list_[index]
elif label == 'params':
label_dict[f"{label}"] = self.params_list_[index]
elif label == 'latency':
label_dict[f"{label}"] = self.latency_list_[index]
else:
raise ValueError
return self.x_list_[index], self.adj_list_[index], label_dict

View File

@@ -0,0 +1 @@
from .evaluator import get_stats_eval, get_nn_eval

View File

@@ -0,0 +1,58 @@
import networkx as nx
from .structure_evaluator import mmd_eval
from .gin_evaluator import nn_based_eval
from torch_geometric.utils import to_networkx
import torch
import torch.nn.functional as F
import dgl
def get_stats_eval(config):
if config.eval.mmd_distance.lower() == 'rbf':
method = [('degree', 1., 'argmax'), ('cluster', 0.1, 'argmax'),
('spectral', 1., 'argmax')]
else:
raise ValueError
def eval_stats_fn(test_dataset, pred_graph_list):
pred_G = [nx.from_numpy_matrix(pred_adj) for pred_adj in pred_graph_list]
sub_pred_G = []
if config.eval.max_subgraph:
for G in pred_G:
CGs = [G.subgraph(c) for c in nx.connected_components(G)]
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
sub_pred_G += [CGs[0]]
pred_G = sub_pred_G
test_G = [to_networkx(test_dataset[i], to_undirected=True, remove_self_loops=True)
for i in range(len(test_dataset))]
results = mmd_eval(test_G, pred_G, method)
return results
return eval_stats_fn
def get_nn_eval(config):
if hasattr(config.eval, "N_gin"):
N_gin = config.eval.N_gin
else:
N_gin = 10
def nn_eval_fn(test_dataset, pred_graph_list):
pred_G = [nx.from_numpy_matrix(pred_adj) for pred_adj in pred_graph_list]
sub_pred_G = []
if config.eval.max_subgraph:
for G in pred_G:
CGs = [G.subgraph(c) for c in nx.connected_components(G)]
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
sub_pred_G += [CGs[0]]
pred_G = sub_pred_G
test_G = [to_networkx(test_dataset[i], to_undirected=True, remove_self_loops=True)
for i in range(len(test_dataset))]
results = nn_based_eval(test_G, pred_G, N_gin)
return results
return nn_eval_fn

View File

@@ -0,0 +1,311 @@
"""Modified from https://github.com/uoguelph-mlrg/GGM-metrics"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.utils import expand_as_pair
from dgl.nn import SumPooling, AvgPooling, MaxPooling
class GINConv(nn.Module):
def __init__(self,
apply_func,
aggregator_type,
init_eps=0,
learn_eps=False):
super(GINConv, self).__init__()
self.apply_func = apply_func
self._aggregator_type = aggregator_type
if aggregator_type == 'sum':
self._reducer = fn.sum
elif aggregator_type == 'max':
self._reducer = fn.max
elif aggregator_type == 'mean':
self._reducer = fn.mean
else:
raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type))
# to specify whether eps is trainable or not.
if learn_eps:
self.eps = torch.nn.Parameter(torch.FloatTensor([init_eps]))
else:
self.register_buffer('eps', torch.FloatTensor([init_eps]))
def forward(self, graph, feat, edge_weight=None):
r"""
Description
-----------
Compute Graph Isomorphism Network layer.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.
If ``apply_func`` is not None, :math:`D_{in}` should
fit the input dimensionality requirement of ``apply_func``.
edge_weight : torch.Tensor, optional
Optional tensor on the edge. If given, the convolution will weight
with regard to the message.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, D_{out})` where
:math:`D_{out}` is the output dimensionality of ``apply_func``.
If ``apply_func`` is None, :math:`D_{out}` should be the same
as input dimensionality.
"""
with graph.local_scope():
aggregate_fn = self.concat_edge_msg
# aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, self._reducer('m', 'neigh'))
diff = torch.tensor(graph.dstdata['neigh'].shape[1: ]) - torch.tensor(feat_dst.shape[1: ])
zeros = torch.zeros(feat_dst.shape[0], *diff).to(feat_dst.device)
feat_dst = torch.cat([feat_dst, zeros], dim=1)
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None:
rst = self.apply_func(rst)
return rst
def concat_edge_msg(self, edges):
if self.edge_feat_loc not in edges.data:
return {'m': edges.src['h']}
else:
m = torch.cat([edges.src['h'], edges.data[self.edge_feat_loc]], dim=1)
return {'m': m}
class ApplyNodeFunc(nn.Module):
"""Update the node feature hv with MLP, BN and ReLU."""
def __init__(self, mlp):
super(ApplyNodeFunc, self).__init__()
self.mlp = mlp
self.bn = nn.BatchNorm1d(self.mlp.output_dim)
def forward(self, h):
h = self.mlp(h)
h = self.bn(h)
h = F.relu(h)
return h
class MLP(nn.Module):
"""MLP with linear output"""
def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
"""MLP layers construction
Paramters
---------
num_layers: int
The number of linear layers
input_dim: int
The dimensionality of input features
hidden_dim: int
The dimensionality of hidden units at ALL layers
output_dim: int
The number of classes for prediction
"""
super(MLP, self).__init__()
self.linear_or_not = True # default is linear model
self.num_layers = num_layers
self.output_dim = output_dim
if num_layers < 1:
raise ValueError("number of layers should be positive!")
elif num_layers == 1:
# Linear model
self.linear = nn.Linear(input_dim, output_dim)
else:
# Multi-layer model
self.linear_or_not = False
self.linears = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
self.linears.append(nn.Linear(input_dim, hidden_dim))
for layer in range(num_layers - 2):
self.linears.append(nn.Linear(hidden_dim, hidden_dim))
self.linears.append(nn.Linear(hidden_dim, output_dim))
for layer in range(num_layers - 1):
self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))
def forward(self, x):
if self.linear_or_not:
# If linear model
return self.linear(x)
else:
# If MLP
h = x
for i in range(self.num_layers - 1):
h = F.relu(self.batch_norms[i](self.linears[i](h)))
return self.linears[-1](h)
class GIN(nn.Module):
"""GIN model"""
def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim,
graph_pooling_type, neighbor_pooling_type, edge_feat_dim=0,
final_dropout=0.0, learn_eps=False, output_dim=1, **kwargs):
"""model parameters setting
Paramters
---------
num_layers: int
The number of linear layers in the neural network
num_mlp_layers: int
The number of linear layers in mlps
input_dim: int
The dimensionality of input features
hidden_dim: int
The dimensionality of hidden units at ALL layers
output_dim: int
The number of classes for prediction
final_dropout: float
dropout ratio on the final linear layer
learn_eps: boolean
If True, learn epsilon to distinguish center nodes from neighbors
If False, aggregate neighbors and center nodes altogether.
neighbor_pooling_type: str
how to aggregate neighbors (sum, mean, or max)
graph_pooling_type: str
how to aggregate entire nodes in a graph (sum, mean or max)
"""
super().__init__()
def init_weights_orthogonal(m):
if isinstance(m, nn.Linear):
torch.nn.init.orthogonal_(m.weight)
elif isinstance(m, MLP):
if hasattr(m, 'linears'):
m.linears.apply(init_weights_orthogonal)
else:
m.linear.apply(init_weights_orthogonal)
elif isinstance(m, nn.ModuleList):
pass
else:
raise Exception()
self.num_layers = num_layers
self.learn_eps = learn_eps
# List of MLPs
self.ginlayers = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
# self.preprocess_nodes = PreprocessNodeAttrs(
# node_attrs=node_preprocess, output_dim=node_preprocess_output_dim)
# print(input_dim)
for layer in range(self.num_layers - 1):
if layer == 0:
mlp = MLP(num_mlp_layers, input_dim + edge_feat_dim, hidden_dim, hidden_dim)
else:
mlp = MLP(num_mlp_layers, hidden_dim + edge_feat_dim, hidden_dim, hidden_dim)
if kwargs['init'] == 'orthogonal':
init_weights_orthogonal(mlp)
self.ginlayers.append(
GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps))
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
# Linear function for graph poolings of output of each layer
# which maps the output of different layers into a prediction score
self.linears_prediction = torch.nn.ModuleList()
for layer in range(num_layers):
if layer == 0:
self.linears_prediction.append(
nn.Linear(input_dim, output_dim))
else:
self.linears_prediction.append(
nn.Linear(hidden_dim, output_dim))
if kwargs['init'] == 'orthogonal':
# print('orthogonal')
self.linears_prediction.apply(init_weights_orthogonal)
self.drop = nn.Dropout(final_dropout)
if graph_pooling_type == 'sum':
self.pool = SumPooling()
elif graph_pooling_type == 'mean':
self.pool = AvgPooling()
elif graph_pooling_type == 'max':
self.pool = MaxPooling()
else:
raise NotImplementedError
def forward(self, g, h):
# list of hidden representation at each layer (including input)
hidden_rep = [h]
# h = self.preprocess_nodes(h)
for i in range(self.num_layers - 1):
h = self.ginlayers[i](g, h)
h = self.batch_norms[i](h)
h = F.relu(h)
hidden_rep.append(h)
score_over_layer = 0
# perform pooling over all nodes in each graph in every layer
for i, h in enumerate(hidden_rep):
pooled_h = self.pool(g, h)
score_over_layer += self.drop(self.linears_prediction[i](pooled_h))
return score_over_layer
def get_graph_embed(self, g, h):
self.eval()
with torch.no_grad():
# return self.forward(g, h).detach().numpy()
hidden_rep = []
# h = self.preprocess_nodes(h)
for i in range(self.num_layers - 1):
h = self.ginlayers[i](g, h)
h = self.batch_norms[i](h)
h = F.relu(h)
hidden_rep.append(h)
# perform pooling over all nodes in each graph in every layer
graph_embed = torch.Tensor([]).to(self.device)
for i, h in enumerate(hidden_rep):
pooled_h = self.pool(g, h)
graph_embed = torch.cat([graph_embed, pooled_h], dim = 1)
return graph_embed
def get_graph_embed_no_cat(self, g, h):
self.eval()
with torch.no_grad():
hidden_rep = []
# h = self.preprocess_nodes(h)
for i in range(self.num_layers - 1):
h = self.ginlayers[i](g, h)
h = self.batch_norms[i](h)
h = F.relu(h)
hidden_rep.append(h)
return self.pool(g, hidden_rep[-1]).to(self.device)
@property
def edge_feat_loc(self):
return self.ginlayers[0].edge_feat_loc
@edge_feat_loc.setter
def edge_feat_loc(self, loc):
for layer in self.ginlayers:
layer.edge_feat_loc = loc

View File

@@ -0,0 +1,292 @@
"""Evaluation on random GIN features. Modified from https://github.com/uoguelph-mlrg/GGM-metrics"""
import torch
import numpy as np
import sklearn
import sklearn.metrics
from sklearn.preprocessing import StandardScaler
import time
import dgl
from .gin import GIN
def load_feature_extractor(
device, num_layers=3, hidden_dim=35, neighbor_pooling_type='sum',
graph_pooling_type='sum', input_dim=1, edge_feat_dim=0,
dont_concat=False, num_mlp_layers=2, output_dim=1,
node_feat_loc='attr', edge_feat_loc='attr', init='orthogonal',
**kwargs):
model = GIN(num_layers=num_layers, hidden_dim=hidden_dim, neighbor_pooling_type=neighbor_pooling_type,
graph_pooling_type=graph_pooling_type, input_dim=input_dim, edge_feat_dim=edge_feat_dim,
num_mlp_layers=num_mlp_layers, output_dim=output_dim, init=init)
model.node_feat_loc = node_feat_loc
model.edge_feat_loc = edge_feat_loc
model.eval()
if dont_concat:
model.forward = model.get_graph_embed_no_cat
else:
model.forward = model.get_graph_embed
model.device = device
return model.to(device)
def time_function(func):
def wrapper(*args, **kwargs):
start = time.time()
results = func(*args, **kwargs)
end = time.time()
return results, end - start
return wrapper
class GINMetric():
def __init__(self, model):
self.feat_extractor = model
self.get_activations = self.get_activations_gin
@time_function
def get_activations_gin(self, generated_dataset, reference_dataset):
return self._get_activations(generated_dataset, reference_dataset)
def _get_activations(self, generated_dataset, reference_dataset):
gen_activations = self.__get_activations_single_dataset(generated_dataset)
ref_activations = self.__get_activations_single_dataset(reference_dataset)
scaler = StandardScaler()
scaler.fit(ref_activations)
ref_activations = scaler.transform(ref_activations)
gen_activations = scaler.transform(gen_activations)
return gen_activations, ref_activations
def __get_activations_single_dataset(self, dataset):
node_feat_loc = self.feat_extractor.node_feat_loc
edge_feat_loc = self.feat_extractor.edge_feat_loc
ndata = [node_feat_loc] if node_feat_loc in dataset[0].ndata else '__ALL__'
edata = [edge_feat_loc] if edge_feat_loc in dataset[0].edata else '__ALL__'
graphs = dgl.batch(dataset, ndata=ndata, edata=edata).to(self.feat_extractor.device)
if node_feat_loc not in graphs.ndata: # Use degree as features
feats = graphs.in_degrees() + graphs.out_degrees()
feats = feats.unsqueeze(1).type(torch.float32)
else:
feats = graphs.ndata[node_feat_loc]
graph_embeds = self.feat_extractor(graphs, feats)
return graph_embeds.cpu().detach().numpy()
def evaluate(self, *args, **kwargs):
raise Exception('Must be implemented by child class')
class MMDEvaluation(GINMetric):
def __init__(self, model, kernel='rbf', sigma='range', multiplier='mean'):
super().__init__(model)
if multiplier == 'mean':
self.__get_sigma_mult_factor = self.__mean_pairwise_distance
elif multiplier == 'median':
self.__get_sigma_mult_factor = self.__median_pairwise_distance
elif multiplier is None:
self.__get_sigma_mult_factor = lambda *args, **kwargs: 1
else:
raise Exception(multiplier)
if 'rbf' in kernel:
if sigma == 'range':
self.base_sigmas = np.array([0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0])
if multiplier == 'mean':
self.name = 'mmd_rbf'
elif multiplier == 'median':
self.name = 'mmd_rbf_adaptive_median'
else:
self.name = 'mmd_rbf_adaptive'
elif sigma == 'one':
self.base_sigmas = np.array([1])
if multiplier == 'mean':
self.name = 'mmd_rbf_single_mean'
elif multiplier == 'median':
self.name = 'mmd_rbf_single_median'
else:
self.name = 'mmd_rbf_single'
else:
raise Exception(sigma)
self.evaluate = self.calculate_MMD_rbf_quadratic
elif 'linear' in kernel:
self.evaluate = self.calculate_MMD_linear_kernel
else:
raise Exception()
def __get_pairwise_distances(self, generated_dataset, reference_dataset):
return sklearn.metrics.pairwise_distances(reference_dataset, generated_dataset, metric='euclidean', n_jobs=8)**2
def __mean_pairwise_distance(self, dists_GR):
return np.sqrt(dists_GR.mean())
def __median_pairwise_distance(self, dists_GR):
return np.sqrt(np.median(dists_GR))
def get_sigmas(self, dists_GR):
mult_factor = self.__get_sigma_mult_factor(dists_GR)
return self.base_sigmas * mult_factor
@time_function
def calculate_MMD_rbf_quadratic(self, generated_dataset=None, reference_dataset=None):
# https://github.com/djsutherland/opt-mmd/blob/master/two_sample/mmd.py
if not isinstance(generated_dataset, torch.Tensor) and not isinstance(generated_dataset, np.ndarray):
(generated_dataset, reference_dataset), _ = self.get_activations(generated_dataset, reference_dataset)
GG = self.__get_pairwise_distances(generated_dataset, generated_dataset)
GR = self.__get_pairwise_distances(generated_dataset, reference_dataset)
RR = self.__get_pairwise_distances(reference_dataset, reference_dataset)
max_mmd = 0
sigmas = self.get_sigmas(GR)
for sigma in sigmas:
gamma = 1 / (2 * sigma**2)
K_GR = np.exp(-gamma * GR)
K_GG = np.exp(-gamma * GG)
K_RR = np.exp(-gamma * RR)
mmd = K_GG.mean() + K_RR.mean() - 2 * K_GR.mean()
max_mmd = mmd if mmd > max_mmd else max_mmd
return {self.name: max_mmd}
@time_function
def calculate_MMD_linear_kernel(self, generated_dataset=None, reference_dataset=None):
# https://github.com/djsutherland/opt-mmd/blob/master/two_sample/mmd.py
if not isinstance(generated_dataset, torch.Tensor) and not isinstance(generated_dataset, np.ndarray):
(generated_dataset, reference_dataset), _ = self.get_activations(generated_dataset, reference_dataset)
G_bar = generated_dataset.mean(axis=0)
R_bar = reference_dataset.mean(axis=0)
Z_bar = G_bar - R_bar
mmd = Z_bar.dot(Z_bar)
mmd = mmd if mmd >= 0 else 0
return {'mmd_linear': mmd}
class prdcEvaluation(GINMetric):
# From PRDC github: https://github.com/clovaai/generative-evaluation-prdc/blob/master/prdc/prdc.py#L54
def __init__(self, *args, use_pr=False, **kwargs):
super().__init__(*args, **kwargs)
self.use_pr = use_pr
@time_function
def evaluate(self, generated_dataset=None, reference_dataset=None, nearest_k=5):
""" Computes precision, recall, density, and coverage given two manifolds. """
if not isinstance(generated_dataset, torch.Tensor) and not isinstance(generated_dataset, np.ndarray):
(generated_dataset, reference_dataset), _ = self.get_activations(generated_dataset, reference_dataset)
real_nearest_neighbour_distances = self.__compute_nearest_neighbour_distances(reference_dataset, nearest_k)
distance_real_fake = self.__compute_pairwise_distance(reference_dataset, generated_dataset)
if self.use_pr:
fake_nearest_neighbour_distances = self.__compute_nearest_neighbour_distances(generated_dataset, nearest_k)
precision = (
distance_real_fake <= np.expand_dims(real_nearest_neighbour_distances, axis=1)
).any(axis=0).mean()
recall = (
distance_real_fake <= np.expand_dims(fake_nearest_neighbour_distances, axis=0)
).any(axis=1).mean()
f1_pr = 2 / ((1 / (precision + 1e-8)) + (1 / (recall + 1e-8)))
result = dict(precision=precision, recall=recall, f1_pr=f1_pr)
else:
density = (1. / float(nearest_k)) * (
distance_real_fake <= np.expand_dims(real_nearest_neighbour_distances, axis=1)).sum(axis=0).mean()
coverage = (distance_real_fake.min(axis=1) <= real_nearest_neighbour_distances).mean()
f1_dc = 2 / ((1 / (density + 1e-8)) + (1 / (coverage + 1e-8)))
result = dict(density=density, coverage=coverage, f1_dc=f1_dc)
return result
def __compute_pairwise_distance(self, data_x, data_y=None):
"""
Args:
data_x: numpy.ndarray([N, feature_dim], dtype=np.float32)
data_y: numpy.ndarray([N, feature_dim], dtype=np.float32)
Return:
numpy.ndarray([N, N], dtype=np.float32) of pairwise distances.
"""
if data_y is None:
data_y = data_x
dists = sklearn.metrics.pairwise_distances(data_x, data_y, metric='euclidean', n_jobs=8)
return dists
def __get_kth_value(self, unsorted, k, axis=-1):
"""
Args:
unsorted: numpy.ndarray of any dimensionality.
k: int
Return:
kth values along the designated axis.
"""
indices = np.argpartition(unsorted, k, axis=axis)[..., :k]
k_smallest = np.take_along_axis(unsorted, indices, axis=axis)
kth_values = k_smallest.max(axis=axis)
return kth_values
def __compute_nearest_neighbour_distances(self, input_features, nearest_k):
"""
Args:
input_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
nearest_k: int
Return:
Distances to kth nearest neighbours.
"""
distances = self.__compute_pairwise_distance(input_features)
radii = self.__get_kth_value(distances, k=nearest_k + 1, axis=-1)
return radii
def nn_based_eval(graph_ref_list, graph_pred_list, N_gin=10):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
evaluators = []
for _ in range(N_gin):
gin = load_feature_extractor(device)
evaluators.append(MMDEvaluation(model=gin, kernel='rbf', sigma='range', multiplier='mean'))
evaluators.append(prdcEvaluation(model=gin, use_pr=True))
evaluators.append(prdcEvaluation(model=gin, use_pr=False))
ref_graphs = [dgl.from_networkx(g).to(device) for g in graph_ref_list]
gen_graphs = [dgl.from_networkx(g).to(device) for g in graph_pred_list]
metrics = {
'mmd_rbf': [],
'f1_pr': [],
'f1_dc': []
}
for evaluator in evaluators:
res, time = evaluator.evaluate(generated_dataset=gen_graphs, reference_dataset=ref_graphs)
for key in list(res.keys()):
if key in metrics:
metrics[key].append(res[key])
results = {
'MMD_RBF': (np.mean(metrics['mmd_rbf']), np.std(metrics['mmd_rbf'])),
'F1_PR': (np.mean(metrics['f1_pr']), np.std(metrics['f1_pr'])),
'F1_DC': (np.mean(metrics['f1_dc']), np.std(metrics['f1_dc']))
}
return results

View File

@@ -0,0 +1,209 @@
"""MMD Evaluation on graph structure statistics. Modified from https://github.com/uoguelph-mlrg/GGM-metrics"""
import numpy as np
import networkx as nx
import numpy as np
# from scipy.linalg import toeplitz
# import pyemd
import concurrent.futures
from scipy.linalg import eigvalsh
from functools import partial
class Descriptor():
def __init__(self, is_parallel=False, bins=100, kernel='rbf', sigma_type='single', **kwargs):
self.is_parallel = is_parallel
self.bins = bins
self.max_workers = kwargs.get('max_workers')
if kernel == 'rbf':
self.distance = self.l2
self.name += '_rbf'
else:
ValueError
if sigma_type == 'argmax':
log_sigmas = np.linspace(-5., 5., 50)
# the first 30 sigma values is usually enough
log_sigmas = log_sigmas[:30]
self.sigmas = [np.exp(log_sigma) for log_sigma in log_sigmas]
elif sigma_type == 'single':
self.sigmas = kwargs['sigma']
else:
raise ValueError
def evaluate(self, graph_ref_list, graph_pred_list):
"""Compute the distance between the distributions of two unordered sets of graphs.
Args:
graph_ref_list, graph_pred_list: two lists of networkx graphs to be evaluated.
"""
graph_pred_list = [G for G in graph_pred_list if not G.number_of_nodes() == 0]
sample_pred = self.extract_features(graph_pred_list)
sample_ref = self.extract_features(graph_ref_list)
GG = self.disc(sample_pred, sample_pred, distance_scaling=self.distance_scaling)
GR = self.disc(sample_pred, sample_ref, distance_scaling=self.distance_scaling)
RR = self.disc(sample_ref, sample_ref, distance_scaling=self.distance_scaling)
sigmas = self.sigmas
max_mmd = 0
mmd_dict = []
for sigma in sigmas:
gamma = 1 / (2 * sigma ** 2)
K_GR = np.exp(-gamma * GR)
K_GG = np.exp(-gamma * GG)
K_RR = np.exp(-gamma * RR)
mmd = K_GG.mean() + K_RR.mean() - (2 * K_GR.mean())
mmd_dict.append((sigma, mmd))
max_mmd = mmd if mmd > max_mmd else max_mmd
# print(self.name, mmd_dict)
return max_mmd
def pad_histogram(self, x, y):
# convert histogram values x and y to float, and pad them for equal length
support_size = max(len(x), len(y))
x = x.astype(np.float)
y = y.astype(np.float)
if len(x) < len(y):
x = np.hstack((x, [0.] * (support_size - len(x))))
elif len(y) < len(x):
y = np.hstack((y, [0.] * (support_size - len(y))))
return x, y
# def emd(self, x, y, distance_scaling=1.0):
# support_size = max(len(x), len(y))
# x, y = self.pad_histogram(x, y)
#
# d_mat = toeplitz(range(support_size)).astype(np.float)
# distance_mat = d_mat / distance_scaling
#
# dist = pyemd.emd(x, y, distance_mat)
# return dist ** 2
def l2(self, x, y, **kwargs):
# gaussian rbf
x, y = self.pad_histogram(x, y)
dist = np.linalg.norm(x - y, 2)
return dist ** 2
def kernel_parallel_unpacked(self, x, samples2, kernel):
dist = []
for s2 in samples2:
dist += [kernel(x, s2)]
return dist
def kernel_parallel_worker(self, t):
return self.kernel_parallel_unpacked(*t)
def disc(self, samples1, samples2, **kwargs):
# Discrepancy between 2 samples
tot_dist = []
if not self.is_parallel:
for s1 in samples1:
for s2 in samples2:
tot_dist += [self.distance(s1, s2)]
else:
with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
for dist in executor.map(self.kernel_parallel_worker,
[(s1, samples2, partial(self.distance, **kwargs)) for s1 in samples1]):
tot_dist += [dist]
return np.array(tot_dist)
class degree(Descriptor):
def __init__(self, *args, **kwargs):
self.name = 'degree'
self.sigmas = [kwargs.get('sigma', 1.0)]
self.distance_scaling = 1.0
super().__init__(*args, **kwargs)
def extract_features(self, dataset):
res = []
if self.is_parallel:
with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
for deg_hist in executor.map(self.degree_worker, dataset):
res.append(deg_hist)
else:
for g in dataset:
degree_hist = self.degree_worker(g)
res.append(degree_hist)
res = [s1 / np.sum(s1) for s1 in res]
return res
def degree_worker(self, G):
return np.array(nx.degree_histogram(G))
class cluster(Descriptor):
def __init__(self, *args, **kwargs):
self.name = 'cluster'
self.sigmas = [kwargs.get('sigma', [1.0 / 10])]
super().__init__(*args, **kwargs)
self.distance_scaling = self.bins
def extract_features(self, dataset):
res = []
if self.is_parallel:
with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
for clustering_hist in executor.map(self.clustering_worker, [(G, self.bins) for G in dataset]):
res.append(clustering_hist)
else:
for g in dataset:
clustering_hist = self.clustering_worker((g, self.bins))
res.append(clustering_hist)
res = [s1 / np.sum(s1) for s1 in res]
return res
def clustering_worker(self, param):
G, bins = param
clustering_coeffs_list = list(nx.clustering(G).values())
hist, _ = np.histogram(
clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False)
return hist
class spectral(Descriptor):
def __init__(self, *args, **kwargs):
self.name = 'spectral'
self.sigmas = [kwargs.get('sigma', 1.0)]
self.distance_scaling = 1
super().__init__(*args, **kwargs)
def extract_features(self, dataset):
res = []
if self.is_parallel:
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
for spectral_density in executor.map(self.spectral_worker, dataset):
res.append(spectral_density)
else:
for g in dataset:
spectral_temp = self.spectral_worker(g)
res.append(spectral_temp)
return res
def spectral_worker(self, G):
eigs = eigvalsh(nx.normalized_laplacian_matrix(G).todense())
spectral_pmf, _ = np.histogram(eigs, bins=200, range=(-1e-5, 2), density=False)
spectral_pmf = spectral_pmf / spectral_pmf.sum()
return spectral_pmf
def mmd_eval(graph_ref_list, graph_pred_list, methods):
evaluators = []
for (method, sigma, sigma_type) in methods:
evaluators.append(eval(method)(sigma=sigma, sigma_type=sigma_type))
results = {}
for evaluator in evaluators:
results[evaluator.name] = evaluator.evaluate(graph_ref_list, graph_pred_list)
return results

180
MobileNetV3/logger.py Normal file
View File

@@ -0,0 +1,180 @@
import os
import wandb
import torch
import numpy as np
class Logger:
def __init__(
self,
exp_name,
log_dir=None,
exp_suffix="",
write_textfile=True,
use_wandb=False,
wandb_project_name=None,
entity='hysh',
config=None
):
self.log_dir = log_dir
self.write_textfile = write_textfile
self.use_wandb = use_wandb
self.logs_for_save = {}
self.logs = {}
if self.write_textfile:
self.f = open(os.path.join(log_dir, 'logs.txt'), 'w')
if self.use_wandb:
exp_suffix = "_".join(exp_suffix.split("/")[:-1])
wandb.init(
config=config if config is not None else wandb.config,
entity=entity,
project=wandb_project_name,
name=exp_name + "_" + exp_suffix,
group=exp_name,
reinit=True)
def write_str(self, log_str):
self.f.write(log_str+'\n')
self.f.flush()
def update_config(self, v, is_args=False):
if is_args:
self.logs_for_save.update({'args': v})
else:
self.logs_for_save.update(v)
if self.use_wandb:
wandb.config.update(v, allow_val_change=True)
def write_log_nohead(self, element, step):
log_str = f"{step} | "
log_dict = {}
for key, val in element.items():
if not key in self.logs_for_save:
self.logs_for_save[key] = []
self.logs_for_save[key].append(val)
log_str += f'{key} {val} | '
log_dict[f'{key}'] = val
if self.write_textfile:
self.f.write(log_str+'\n')
self.f.flush()
if self.use_wandb:
wandb.log(log_dict, step=step)
def write_log(self, element, step, return_log_dict=False):
log_str = f"{step} | "
log_dict = {}
for head, keys in element.items():
for k in keys:
if k in self.logs:
v = self.logs[k].avg
if not k in self.logs_for_save:
self.logs_for_save[k] = []
self.logs_for_save[k].append(v)
log_str += f'{k} {v}| '
log_dict[f'{head}/{k}'] = v
if self.write_textfile:
self.f.write(log_str+'\n')
self.f.flush()
if return_log_dict:
return log_dict
if self.use_wandb:
wandb.log(log_dict, step=step)
def log_sample(self, sample_x):
wandb.log({"sampled_x": [wandb.Image(x.unsqueeze(-1).cpu().numpy()) for x in sample_x]})
def log_valid_sample_prop(self, arch_metric, x_axis, y_axis):
assert x_axis in ['test_acc', 'flops', 'params', 'latency']
assert y_axis in ['test_acc', 'flops', 'params', 'latency']
data = [[x, y] for (x, y) in zip(arch_metric[2][f'{x_axis}_list'], arch_metric[2][f'{y_axis}_list'])]
table = wandb.Table(data=data, columns = [x_axis, y_axis])
wandb.log({f"valid_sample ({x_axis}-{y_axis})" : wandb.plot.scatter(table, x_axis, y_axis)})
def save_log(self, name=None):
name = 'logs.pt' if name is None else name
torch.save(self.logs_for_save, os.path.join(self.log_dir, name))
def update(self, key, v, n=1):
if not key in self.logs:
self.logs[key] = AverageMeter()
self.logs[key].update(v, n)
def reset(self, keys=None, except_keys=[]):
if keys is not None:
if isinstance(keys, list):
for key in keys:
self.logs[key] = AverageMeter()
else:
self.logs[keys] = AverageMeter()
else:
for key in self.logs.keys():
if not key in except_keys:
self.logs[key] = AverageMeter()
def avg(self, keys=None, except_keys=[]):
if keys is not None:
if isinstance(keys, list):
return {key: self.logs[key].avg for key in keys if key in self.logs.keys()}
else:
return self.logs[keys].avg
else:
avg_dict = {}
for key in self.logs.keys():
if not key in except_keys:
avg_dict[key] = self.logs[key].avg
return avg_dict
class AverageMeter(object):
"""
Computes and stores the average and current value
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_metrics(g_embeds, x_embeds, logit_scale, prefix='train'):
metrics = {}
logits_per_g = (logit_scale * g_embeds @ x_embeds.t()).detach().cpu()
logits_per_x = logits_per_g.t().detach().cpu()
logits = {"g_to_x": logits_per_g, "x_to_g": logits_per_x}
ground_truth = torch.arange(len(x_embeds)).view(-1, 1)
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1]
preds = preds.detach().cpu().numpy()
metrics[f"{prefix}_{name}_mean_rank"] = preds.mean() + 1
metrics[f"{prefix}_{name}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{prefix}_{name}_R@{k}"] = np.mean(preds < k)
return metrics

584
MobileNetV3/losses.py Normal file
View File

@@ -0,0 +1,584 @@
"""All functions related to loss computation and optimization."""
import torch
import torch.optim as optim
import numpy as np
from models import utils as mutils
from sde_lib import VPSDE, VESDE
def get_optimizer(config, params):
"""Return a flax optimizer object based on `config`."""
if config.optim.optimizer == 'Adam':
optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
weight_decay=config.optim.weight_decay)
else:
raise NotImplementedError(
f'Optimizer {config.optim.optimizer} not supported yet!'
)
return optimizer
def optimization_manager(config):
"""Return an optimize_fn based on `config`."""
def optimize_fn(optimizer, params, step, lr=config.optim.lr,
warmup=config.optim.warmup,
grad_clip=config.optim.grad_clip):
"""Optimize with warmup and gradient clipping (disabled if negative)."""
if warmup > 0:
for g in optimizer.param_groups:
g['lr'] = lr * np.minimum(step / warmup, 1.0)
if grad_clip >= 0:
torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
optimizer.step()
return optimize_fn
def get_sde_loss_fn_nas(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5):
"""Create a loss function for training with arbitrary SDEs.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
train: `True` for training loss and `False` for evaluation loss.
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
eps: A `float` number. The smallest time step to sample from.
Returns:
A loss function.
"""
# reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
def loss_fn(model, batch):
"""Compute the loss function.
Args:
model: A score model.
batch: A mini-batch of training data, including adjacency matrices and mask.
Returns:
loss: A scalar that represents the average loss value across the mini-batch.
"""
x, adj, mask = batch
# adj, mask: [32, 1, 20, 20]
score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
t = torch.rand(x.shape[0], device=adj.device) * (sde.T - eps) + eps
z = torch.randn_like(x) # [B, C, N, N]
# z = torch.tril(z, -1)
# z = z + z.transpose(2, 3)
mean, std = sde.marginal_prob(x, t)
# mean = torch.tril(mean, -1)
# mean = mean + mean.transpose(2, 3)
perturbed_data = mean + std[:, None, None] * z
score = score_fn(perturbed_data, t, mask)
# mask = torch.tril(mask, -1)
# mask = mask + mask.transpose(2, 3)
# mask = mask.reshape(mask.shape[0], -1) # low triangular part of adj matrices
if not likelihood_weighting:
losses = torch.square(score * std[:, None, None] + z)
losses = losses.reshape(losses.shape[0], -1)
if reduce_mean:
# losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
losses = torch.mean(losses, dim=-1)
else:
losses = 0.5 * torch.sum(losses, dim=-1)
loss = losses.mean()
else:
g2 = sde.sde(torch.zeros_like(x), t)[1] ** 2
losses = torch.square(score + z / std[:, None, None])
losses = losses.reshape(losses.shape[0], -1)
if reduce_mean:
# losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
losses = torch.mean(losses, dim=-1)
else:
losses = 0.5 * torch.sum(losses, dim=-1)
loss = (losses * g2).mean()
return loss
return loss_fn
def get_predictor_loss_fn_nas_binary(sde, train, reduce_mean=True, continuous=True,
likelihood_weighting=True, eps=1e-5, label_list=None,
noised=True, t_spot=None):
"""Create a loss function for training with arbitrary SDEs.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
train: `True` for training loss and `False` for evaluation loss.
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
eps: A `float` number. The smallest time step to sample from.
Returns:
A loss function.
"""
# reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
def loss_fn(model, batch):
"""Compute the loss function.
Args:
model: A score model.
batch: A mini-batch of training data, including adjacency matrices and mask.
Returns:
loss: A scalar that represents the average loss value across the mini-batch.
"""
x, adj, mask, extra = batch
# adj, mask: [32, 1, 20, 20]
# score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
predictor_fn = mutils.get_predictor_fn(sde, model, train=train, continuous=continuous)
if noised:
if t_spot < 1:
t = torch.rand(x.shape[0], device=adj.device) * (t_spot - eps) + eps # torch.rand: [0, 1)
else:
t = torch.rand(x.shape[0], device=adj.device) * (sde.T - eps) + eps
z = torch.randn_like(x) # [B, C, N, N]
# z = torch.tril(z, -1)
# z = z + z.transpose(2, 3)
mean, std = sde.marginal_prob(x, t)
# mean = torch.tril(mean, -1)
# mean = mean + mean.transpose(2, 3)
perturbed_data = mean + std[:, None, None] * z
# score = score_fn(perturbed_data, t, mask)
pred = predictor_fn(perturbed_data, t, mask)
else:
t = eps * torch.ones(x.shape[0], device=adj.device)
pred = predictor_fn(x, t, mask)
labels = extra[f"{label_list}"][1]
labels = labels.to(pred.device).unsqueeze(1).type(pred.dtype)
# mask = torch.tril(mask, -1)
# mask = mask + mask.transpose(2, 3)
# mask = mask.reshape(mask.shape[0], -1) # low triangular part of adj matrices
# loss = torch.nn.MSELoss()(pred, labels)
loss = torch.nn.BCEWithLogitsLoss()(pred, labels)
# if not likelihood_weighting:
# losses = torch.square(score * std[:, None, None] + z)
# losses = losses.reshape(losses.shape[0], -1)
# if reduce_mean:
# # losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
# losses = torch.mean(losses, dim=-1)
# else:
# losses = 0.5 * torch.sum(losses, dim=-1)
# loss = losses.mean()
# else:
# g2 = sde.sde(torch.zeros_like(x), t)[1] ** 2
# losses = torch.square(score + z / std[:, None, None])
# losses = losses.reshape(losses.shape[0], -1)
# if reduce_mean:
# # losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
# losses = torch.mean(losses, dim=-1)
# else:
# losses = 0.5 * torch.sum(losses, dim=-1)
# loss = (losses * g2).mean()
return loss, pred, labels
return loss_fn
def get_predictor_loss_fn_nas(sde, train, reduce_mean=True, continuous=True,
likelihood_weighting=True, eps=1e-5, label_list=None,
noised=True, t_spot=None):
"""Create a loss function for training with arbitrary SDEs.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
train: `True` for training loss and `False` for evaluation loss.
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
eps: A `float` number. The smallest time step to sample from.
Returns:
A loss function.
"""
# reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
def loss_fn(model, batch):
"""Compute the loss function.
Args:
model: A score model.
batch: A mini-batch of training data, including adjacency matrices and mask.
Returns:
loss: A scalar that represents the average loss value across the mini-batch.
"""
x, adj, mask, extra = batch
# adj, mask: [32, 1, 20, 20]
# score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
predictor_fn = mutils.get_predictor_fn(sde, model, train=train, continuous=continuous)
if noised:
if t_spot < 1:
t = torch.rand(x.shape[0], device=adj.device) * (t_spot - eps) + eps # torch.rand: [0, 1)
else:
t = torch.rand(x.shape[0], device=adj.device) * (sde.T - eps) + eps
z = torch.randn_like(x) # [B, C, N, N]
# z = torch.tril(z, -1)
# z = z + z.transpose(2, 3)
mean, std = sde.marginal_prob(x, t)
# mean = torch.tril(mean, -1)
# mean = mean + mean.transpose(2, 3)
perturbed_data = mean + std[:, None, None] * z
# score = score_fn(perturbed_data, t, mask)
pred = predictor_fn(perturbed_data, t, mask)
else:
t = eps * torch.ones(x.shape[0], device=adj.device)
pred = predictor_fn(x, t, mask)
labels = extra[f"{label_list[-1]}"]
labels = labels.to(pred.device).unsqueeze(1).type(pred.dtype)
# mask = torch.tril(mask, -1)
# mask = mask + mask.transpose(2, 3)
# mask = mask.reshape(mask.shape[0], -1) # low triangular part of adj matrices
loss = torch.nn.MSELoss()(pred, labels)
# if not likelihood_weighting:
# losses = torch.square(score * std[:, None, None] + z)
# losses = losses.reshape(losses.shape[0], -1)
# if reduce_mean:
# # losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
# losses = torch.mean(losses, dim=-1)
# else:
# losses = 0.5 * torch.sum(losses, dim=-1)
# loss = losses.mean()
# else:
# g2 = sde.sde(torch.zeros_like(x), t)[1] ** 2
# losses = torch.square(score + z / std[:, None, None])
# losses = losses.reshape(losses.shape[0], -1)
# if reduce_mean:
# # losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
# losses = torch.mean(losses, dim=-1)
# else:
# losses = 0.5 * torch.sum(losses, dim=-1)
# loss = (losses * g2).mean()
return loss, pred, labels
return loss_fn
def get_meta_predictor_loss_fn_nas(sde, train, reduce_mean=True, continuous=True,
likelihood_weighting=True, eps=1e-5, label_list=None,
noised=True, t_spot=None):
"""Create a loss function for training with arbitrary SDEs.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
train: `True` for training loss and `False` for evaluation loss.
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
eps: A `float` number. The smallest time step to sample from.
Returns:
A loss function.
"""
# reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
def loss_fn(model, batch):
"""Compute the loss function.
Args:
model: A score model.
batch: A mini-batch of training data, including adjacency matrices and mask.
Returns:
loss: A scalar that represents the average loss value across the mini-batch.
"""
x, adj, mask, extra, task = batch
predictor_fn = mutils.get_predictor_fn(sde, model, train=train, continuous=continuous)
if noised:
if t_spot < 1:
t = torch.rand(x.shape[0], device=adj.device) * (t_spot - eps) + eps # torch.rand: [0, 1)
else:
t = torch.rand(x.shape[0], device=adj.device) * (sde.T - eps) + eps
z = torch.randn_like(x) # [B, C, N, N]
mean, std = sde.marginal_prob(x, t)
perturbed_data = mean + std[:, None, None] * z
# score = score_fn(perturbed_data, t, mask)
pred = predictor_fn(perturbed_data, t, mask, task)
else:
t = eps * torch.ones(x.shape[0], device=adj.device)
pred = predictor_fn(x, t, mask, task)
labels = extra[f"{label_list[-1]}"]
labels = labels.to(pred.device).unsqueeze(1).type(pred.dtype)
loss = torch.nn.MSELoss()(pred, labels)
return loss, pred, labels
return loss_fn
def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5):
"""Create a loss function for training with arbitrary SDEs.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
train: `True` for training loss and `False` for evaluation loss.
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
eps: A `float` number. The smallest time step to sample from.
Returns:
A loss function.
"""
# reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
def loss_fn(model, batch):
"""Compute the loss function.
Args:
model: A score model.
batch: A mini-batch of training data, including adjacency matrices and mask.
Returns:
loss: A scalar that represents the average loss value across the mini-batch.
"""
adj, mask = batch
# adj, mask: [32, 1, 20, 20]
score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
t = torch.rand(adj.shape[0], device=adj.device) * (sde.T - eps) + eps
z = torch.randn_like(adj) # [B, C, N, N]
z = torch.tril(z, -1)
z = z + z.transpose(2, 3)
mean, std = sde.marginal_prob(adj, t)
mean = torch.tril(mean, -1)
mean = mean + mean.transpose(2, 3)
perturbed_data = mean + std[:, None, None, None] * z
score = score_fn(perturbed_data, t, mask=mask)
mask = torch.tril(mask, -1)
mask = mask + mask.transpose(2, 3)
mask = mask.reshape(mask.shape[0], -1) # low triangular part of adj matrices
if not likelihood_weighting:
losses = torch.square(score * std[:, None, None, None] + z)
losses = losses.reshape(losses.shape[0], -1)
if reduce_mean:
losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
else:
losses = 0.5 * torch.sum(losses * mask, dim=-1)
loss = losses.mean()
else:
g2 = sde.sde(torch.zeros_like(adj), t)[1] ** 2
losses = torch.square(score + z / std[:, None, None, None])
losses = losses.reshape(losses.shape[0], -1)
if reduce_mean:
losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
else:
losses = 0.5 * torch.sum(losses * mask, dim=-1)
loss = (losses * g2).mean()
return loss
return loss_fn
def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True,
likelihood_weighting=False, data='NASBench201'):
"""Create a one-step training/evaluation function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
Tuple (`sde_lib.SDE`, `sde_lib.SDE`) that represents the forward node SDE and edge SDE.
optimize_fn: An optimization function.
reduce_mean: If `True`, average the loss across data dimensions.
Otherwise, sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according to
https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended by score-sde.
Returns:
A one-step function for training or evaluation.
"""
if continuous:
if isinstance(sde, tuple):
loss_fn = get_multi_sde_loss_fn(sde[0], sde[1], train, reduce_mean=reduce_mean, continuous=True,
likelihood_weighting=likelihood_weighting)
else:
if data in ['NASBench201', 'ofa']:
loss_fn = get_sde_loss_fn_nas(sde, train, reduce_mean=reduce_mean,
continuous=True, likelihood_weighting=likelihood_weighting)
else:
loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean,
continuous=True, likelihood_weighting=likelihood_weighting)
else:
assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training."
if isinstance(sde, VESDE):
loss_fn = get_smld_loss_fn(sde, train, reduce_mean=reduce_mean)
elif isinstance(sde, VPSDE):
loss_fn = get_ddpm_loss_fn(sde, train, reduce_mean=reduce_mean)
elif isinstance(sde, tuple):
raise ValueError("Discrete training for multi sde is not recommended.")
else:
raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.")
def step_fn(state, batch):
"""Running one step of training or evaluation.
For jax version: This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and
jit-compiled together for faster execution.
Args:
state: A dictionary of training information, containing the score model, optimizer,
EMA status, and number of optimization steps.
batch: A mini-batch of training/evaluation data, including min-batch adjacency matrices and mask.
Returns:
loss: The average loss value of this state.
"""
model = state['model']
if train:
optimizer = state['optimizer']
optimizer.zero_grad()
loss = loss_fn(model, batch)
loss.backward()
optimize_fn(optimizer, model.parameters(), step=state['step'])
state['step'] += 1
state['ema'].update(model.parameters())
else:
with torch.no_grad():
ema = state['ema']
ema.store(model.parameters())
ema.copy_to(model.parameters())
loss = loss_fn(model, batch)
ema.restore(model.parameters())
return loss
return step_fn
def get_step_fn_predictor(sde, train, optimize_fn=None, reduce_mean=False, continuous=True,
likelihood_weighting=False, data='NASBench201', label_list=None, noised=True,
t_spot=None, is_meta=False, is_binary=False):
"""Create a one-step training/evaluation function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
Tuple (`sde_lib.SDE`, `sde_lib.SDE`) that represents the forward node SDE and edge SDE.
optimize_fn: An optimization function.
reduce_mean: If `True`, average the loss across data dimensions.
Otherwise, sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according to
https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended by score-sde.
Returns:
A one-step function for training or evaluation.
"""
if continuous:
if isinstance(sde, tuple):
loss_fn = get_multi_sde_loss_fn(sde[0], sde[1], train, reduce_mean=reduce_mean, continuous=True,
likelihood_weighting=likelihood_weighting)
else:
if data in ['NASBench201', 'ofa']:
if is_meta:
loss_fn = get_meta_predictor_loss_fn_nas(sde, train, reduce_mean=reduce_mean,
continuous=True, likelihood_weighting=likelihood_weighting,
label_list=label_list, noised=noised, t_spot=t_spot)
elif is_binary:
loss_fn = get_predictor_loss_fn_nas_binary(sde, train, reduce_mean=reduce_mean,
continuous=True, likelihood_weighting=likelihood_weighting,
label_list=label_list, noised=noised, t_spot=t_spot)
else:
loss_fn = get_predictor_loss_fn_nas(sde, train, reduce_mean=reduce_mean,
continuous=True, likelihood_weighting=likelihood_weighting,
label_list=label_list, noised=noised, t_spot=t_spot)
else:
loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean,
continuous=True, likelihood_weighting=likelihood_weighting)
else:
assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training."
if isinstance(sde, VESDE):
loss_fn = get_smld_loss_fn(sde, train, reduce_mean=reduce_mean)
elif isinstance(sde, VPSDE):
loss_fn = get_ddpm_loss_fn(sde, train, reduce_mean=reduce_mean)
elif isinstance(sde, tuple):
raise ValueError("Discrete training for multi sde is not recommended.")
else:
raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.")
def step_fn(state, batch):
"""Running one step of training or evaluation.
For jax version: This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and
jit-compiled together for faster execution.
Args:
state: A dictionary of training information, containing the score model, optimizer,
EMA status, and number of optimization steps.
batch: A mini-batch of training/evaluation data, including min-batch adjacency matrices and mask.
Returns:
loss: The average loss value of this state.
"""
model = state['model']
if train:
model.train()
optimizer = state['optimizer']
optimizer.zero_grad()
loss, pred, labels = loss_fn(model, batch)
loss.backward()
optimize_fn(optimizer, model.parameters(), step=state['step'])
state['step'] += 1
# state['ema'].update(model.parameters())
else:
model.eval()
with torch.no_grad():
# ema = state['ema']
# ema.store(model.parameters())
# ema.copy_to(model.parameters())
loss, pred, labels = loss_fn(model, batch)
# ema.restore(model.parameters())
return loss, pred, labels
return step_fn

40
MobileNetV3/main.py Normal file
View File

@@ -0,0 +1,40 @@
"""Training and evaluation"""
import run_lib
from absl import app, flags
from ml_collections.config_flags import config_flags
import logging
import os
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
'config', None, 'Training configuration.', lock_config=True
)
config_flags.DEFINE_config_file(
'classifier_config_nf', None, 'Training configuration.', lock_config=True
)
flags.DEFINE_string('workdir', None, 'Work directory.')
flags.DEFINE_enum('mode', None, ['train', 'eval'],
'Running mode: train or eval')
flags.DEFINE_string('eval_folder', 'eval', 'The folder name for storing evaluation results')
flags.mark_flags_as_required(['config', 'mode'])
def main(argv):
# Set random seed
run_lib.set_random_seed(FLAGS.config)
if FLAGS.mode == 'train':
logger = logging.getLogger()
logger.setLevel('INFO')
# Run the training pipeline
run_lib.train(FLAGS.config)
elif FLAGS.mode == 'eval':
run_lib.evaluate(FLAGS.config)
else:
raise ValueError(f"Mode {FLAGS.mode} not recognized.")
if __name__ == '__main__':
app.run(main)

View File

@@ -0,0 +1,329 @@
import torch
import numpy as np
import sys
from scipy.stats import pearsonr, spearmanr
from torch.utils.data import DataLoader
sys.path.append('.')
import sampling
import datasets_nas
from models import pgsn
from models import digcn
from models import cate
from models import dagformer
from models import digcn
from models import digcn_meta
from models import regressor
from models.GDSS import scorenetx
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import sde_lib
from utils import *
import losses
from analysis.arch_functions import BasicArchMetricsOFA
import losses
from analysis.arch_functions import NUM_STAGE, MAX_LAYER_PER_STAGE
from all_path import *
def get_sampling_fn(config, p=1, prod_w=False, weight_ratio_abs=False):
# Setup SDEs
if config.training.sde.lower() == 'vpsde':
sde = sde_lib.VPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales)
sampling_eps = 1e-3
elif config.training.sde.lower() == 'subvpsde':
sde = sde_lib.subVPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales)
sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
sde = sde_lib.VESDE(
sigma_min=config.model.sigma_min,
sigma_max=config.model.sigma_max,
N=config.model.num_scales)
sampling_eps = 1e-5
else:
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
# create data normalizer and its inverse
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
sampling_shape = (
config.eval.batch_size, config.data.max_node, config.data.n_vocab) # ofa: 1024, 20, 28
sampling_fn = sampling.get_sampling_fn(
config, sde, sampling_shape, inverse_scaler,
sampling_eps, config.data.name, conditional=True,
p=p, prod_w=prod_w, weight_ratio_abs=weight_ratio_abs)
return sampling_fn, sde
def get_sampling_fn_meta(config, p=1, prod_w=False, weight_ratio_abs=False, init=False, n_init=5):
# Setup SDEs
if config.training.sde.lower() == 'vpsde':
sde = sde_lib.VPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales)
sampling_eps = 1e-3
elif config.training.sde.lower() == 'subvpsde':
sde = sde_lib.subVPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales)
sampling_eps = 1e-3
elif config.training.sde.lower() == 'vesde':
sde = sde_lib.VESDE(
sigma_min=config.model.sigma_min,
sigma_max=config.model.sigma_max,
N=config.model.num_scales)
sampling_eps = 1e-5
else:
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
# create data normalizer and its inverse
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
if init:
sampling_shape = (
n_init, config.data.max_node, config.data.n_vocab)
else:
sampling_shape = (
config.eval.batch_size, config.data.max_node, config.data.n_vocab) # ofa: 1024, 20, 28
sampling_fn = sampling.get_sampling_fn(
config, sde, sampling_shape, inverse_scaler,
sampling_eps, config.data.name, conditional=True,
is_meta=True, data_name=config.sampling.check_dataname,
num_sample=config.model.num_sample)
return sampling_fn, sde
def get_score_model(config, pos_enc_type=2):
# Build sampling functions and Load pre-trained score network & predictor network
score_config = torch.load(config.scorenet_ckpt_path)['config']
ckpt_path = config.scorenet_ckpt_path
score_config.sampling.corrector = 'langevin'
score_config.model.pos_enc_type = pos_enc_type
score_model = mutils.create_model(score_config)
score_ema = ExponentialMovingAverage(
score_model.parameters(), decay=score_config.model.ema_rate)
score_state = dict(
model=score_model, ema=score_ema, step=0, config=score_config)
score_state = restore_checkpoint(
ckpt_path, score_state,
device=config.device, resume=True)
score_ema.copy_to(score_model.parameters())
return score_model, score_ema, score_config
def get_predictor(config):
classifier_model = mutils.create_model(config)
return classifier_model
def get_adj(data_name, except_inout):
if data_name == 'NASBench201':
_adj = np.asarray(
[[0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0]]
)
_adj = torch.tensor(_adj, dtype=torch.float32, device=torch.device('cpu'))
if except_inout:
_adj = _adj[1:-1, 1:-1]
elif data_name == 'ofa':
assert except_inout
num_nodes = NUM_STAGE * MAX_LAYER_PER_STAGE
_adj = torch.zeros(num_nodes, num_nodes)
for i in range(num_nodes-1):
_adj[i, i+1] = 1
return _adj
return _adj
def generate_archs(
config, sampling_fn, score_model, score_ema, classifier_model,
num_samples, patient_factor, batch_size=512, classifier_scale=None,
task=None):
metrics = BasicArchMetricsOFA()
# algo = 'none'
adj_s = get_adj(config.data.name, config.data.except_inout)
mask_s = aug_mask(adj_s, algo=config.data.aug_mask_algo)[0]
adj_c = get_adj(config.data.name, config.data.except_inout)
mask_c = aug_mask(adj_c, algo=config.data.aug_mask_algo)[0]
assert (adj_s == adj_c).all() and (mask_s == mask_c).all()
adj_s, mask_s, adj_c, mask_c = \
adj_s.to(config.device), mask_s.to(config.device), adj_c.to(config.device), mask_c.to(config.device)
# Generate and save samples
score_ema.copy_to(score_model.parameters())
if num_samples > batch_size:
num_sampling_rounds = int(np.ceil(num_samples / batch_size) * patient_factor)
else:
num_sampling_rounds = int(patient_factor)
print(f'==> Sampling for {num_sampling_rounds} rounds...')
r = 0
all_samples = []
classifier_scales = list(range(100000, 0, -int(classifier_scale)))
while True and r < num_sampling_rounds:
classifier_scale = classifier_scales[r]
print(f'==> round {r} classifier_scale {classifier_scale}')
sample, _, sample_chain, (score_grad_norm_p, classifier_grad_norm_p, score_grad_norm_c, classifier_grad_norm_c) \
= sampling_fn(score_model, mask_s, classifier_model,
eval_chain=True,
number_chain_steps=config.sampling.number_chain_steps,
classifier_scale=classifier_scale,
task=task, sample_bs=num_samples)
try:
sample_list = quantize(sample, adj_s) # quantization
_, validity, valid_arch_str, _, _ = metrics.compute_validity(sample_list, adj_s, mask_s)
except:
import pdb; pdb.set_trace()
validity = 0.
valid_arch_str = []
print(f' ==> [Validity]: {round(validity, 4)}')
if len(valid_arch_str) > 0:
all_samples += valid_arch_str
print(f' ==> [# Unique Arch]: {len(set(all_samples))}')
if (len(set(all_samples)) >= num_samples):
break
r += 1
return list(set(all_samples))[:num_samples]
def noise_aware_meta_predictor_fit(config,
predictor_model=None,
xtrain=None,
seed=None,
sde=None,
batch_size=5,
epochs=50,
save_best_p_corr=False,
save_path=None,):
assert save_best_p_corr
reset_seed(seed)
data_loader = DataLoader(xtrain,
batch_size=batch_size,
shuffle=True,
drop_last=True)
# create data normalizer and its inverse
scaler = datasets_nas.get_data_scaler(config)
# Initialize model.
optimizer = losses.get_optimizer(config, predictor_model.parameters())
state = dict(optimizer=optimizer,
model=predictor_model,
step=0,
config=config)
# Build one-step training and evaluation functions
optimize_fn = losses.optimization_manager(config)
continuous = config.training.continuous
reduce_mean = config.training.reduce_mean
likelihood_weighting = config.training.likelihood_weighting
train_step_fn = losses.get_step_fn_predictor(sde, train=True, optimize_fn=optimize_fn,
reduce_mean=reduce_mean, continuous=continuous,
likelihood_weighting=likelihood_weighting,
data=config.data.name, label_list=config.data.label_list,
noised=config.training.noised,
t_spot=config.training.t_spot,
is_meta=True)
# temp
# epochs = len(xtrain) * 100
is_best = False
best_p_corr = -1
ckpt_dir = os.path.join(save_path, 'loop')
print(f'==> Training for {epochs} epochs')
for epoch in range(epochs):
pred_list, labels_list = list(), list()
for step, batch in enumerate(data_loader):
x = batch['x'].to(config.device) # (5, 5, 20, 9)???
adj = get_adj(config.data.name, config.data.except_inout)
task = batch['task']
extra = batch
mask = aug_mask(adj,
algo=config.data.aug_mask_algo,
data=config.data.name)
x = scaler(x.to(config.device))
adj = adj.to(config.device)
mask = mask.to(config.device)
task = task.to(config.device)
batch = (x, adj, mask, extra, task)
# Execute one training step
loss, pred, labels = train_step_fn(state, batch)
pred_list += [v.detach().item() for v in pred.squeeze()]
labels_list += [v.detach().item() for v in labels.squeeze()]
p_corr = pearsonr(np.array(pred_list), np.array(labels_list))[0]
s_corr = spearmanr(np.array(pred_list), np.array(labels_list))[0]
if epoch % 50 == 0: print(f'==> [Epoch-{epoch}] P corr: {round(p_corr, 4)} | S corr: {round(s_corr, 4)}')
if save_best_p_corr:
if p_corr > best_p_corr:
is_best = True
best_p_corr = p_corr
os.makedirs(ckpt_dir, exist_ok=True)
save_checkpoint(ckpt_dir, state, epoch, is_best)
if save_best_p_corr:
loaded_state = torch.load(os.path.join(ckpt_dir, 'model_best.pth.tar'), map_location=config.device)
predictor_model.load_state_dict(loaded_state['model'])
def save_checkpoint(ckpt_dir, state, epoch, is_best):
saved_state = {}
for k in state:
if k in ['optimizer', 'model', 'ema']:
saved_state.update({k: state[k].state_dict()})
else:
saved_state.update({k: state[k]})
os.makedirs(ckpt_dir, exist_ok=True)
torch.save(saved_state, os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'))
if is_best:
shutil.copy(os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'), os.path.join(ckpt_dir, 'model_best.pth.tar'))
# remove the ckpt except is_best state
for ckpt_file in sorted(os.listdir(ckpt_dir)):
if not ckpt_file.startswith('checkpoint'):
continue
if os.path.join(ckpt_dir, ckpt_file) != os.path.join(ckpt_dir, 'model_best.pth.tar'):
os.remove(os.path.join(ckpt_dir, ckpt_file))
def restore_checkpoint(ckpt_dir, state, device, resume=False):
if not resume:
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
return state
elif not os.path.exists(ckpt_dir):
if not os.path.exists(os.path.dirname(ckpt_dir)):
os.makedirs(os.path.dirname(ckpt_dir))
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
return state
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
for k in state:
if k in ['optimizer', 'model', 'ema']:
state[k].load_state_dict(loaded_state[k])
else:
state[k] = loaded_state[k]
return state

View File

@@ -0,0 +1,63 @@
"""
@author: Hayeon Lee
2020/02/19
Script for downloading, and reorganizing aircraft
for few shot classification
Run this file as follows:
python get_data.py
"""
import pickle
import os
import numpy as np
from tqdm import tqdm
import requests
import tarfile
from PIL import Image
import glob
import shutil
import pickle
import collections
import sys
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
from all_path import RAW_DATA_PATH
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
dir_path = RAW_DATA_PATH
if not os.path.exists(dir_path):
os.makedirs(dir_path)
file_name = os.path.join(dir_path, 'fgvc-aircraft-2013b.tar.gz')
if not os.path.exists(file_name):
print(f"Downloading {file_name}\n")
download_file(
'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz',
file_name)
print("\nDownloading done.\n")
else:
print("fgvc-aircraft-2013b.tar.gz has already been downloaded. Did not download twice.\n")
untar_file_name = os.path.join(dir_path, 'aircraft')
if not os.path.exists(untar_file_name):
tarname = file_name
print("Untarring: {}".format(tarname))
tar = tarfile.open(tarname)
tar.extractall(untar_file_name)
tar.close()
else:
print(f"{untar_file_name} folder already exists. Did not untarring twice\n")
os.remove(file_name)

View File

@@ -0,0 +1,50 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
import sys
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
from all_path import RAW_DATA_PATH
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm(unit="B", total=int(r.headers['Content-Length']))
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update(len(chunk))
f.write(chunk)
return filename
dir_path = os.path.join(RAW_DATA_PATH, 'pets')
if not os.path.exists(dir_path):
os.makedirs(dir_path)
full_name = os.path.join(dir_path, 'test15.pth')
if not os.path.exists(full_name):
print(f"Downloading {full_name}\n")
download_file(
'https://www.dropbox.com/s/kzmrwyyk5iaugv0/test15.pth?dl=1', full_name)
print("Downloading done.\n")
else:
print(f"{full_name} has already been downloaded. Did not download twice.\n")
full_name = os.path.join(dir_path, 'train85.pth')
if not os.path.exists(full_name):
print(f"Downloading {full_name}\n")
download_file(
'https://www.dropbox.com/s/w7mikpztkamnw9s/train85.pth?dl=1', full_name)
print("Downloading done.\n")
else:
print(f"{full_name} has already been downloaded. Did not download twice.\n")

View File

@@ -0,0 +1,46 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
from all_path import PROCESSED_DATA_PATH
dir_path = PROCESSED_DATA_PATH
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
def get_preprocessed_data(file_name, url):
print(f"Downloading {file_name} datasets\n")
full_name = os.path.join(dir_path, file_name)
download_file(url, full_name)
print("Downloading done.\n")
for file_name, url in [
('aircraftbylabel.pt', 'https://www.dropbox.com/s/nn6mlrk1jijg108/aircraft100bylabel.pt?dl=1'),
('cifar100bylabel.pt', 'https://www.dropbox.com/s/nn6mlrk1jijg108/aircraft100bylabel.pt?dl=1'),
('cifar10bylabel.pt', 'https://www.dropbox.com/s/wt1pcwi991xyhwr/cifar10bylabel.pt?dl=1'),
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
('petsbylabel.pt', 'https://www.dropbox.com/s/mxh6qz3grhy7wcn/petsbylabel.pt?dl=1'),
]:
get_preprocessed_data(file_name, url)

View File

@@ -0,0 +1,44 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
DATA_PATH = "./data/ofa/data_score_model"
dir_path = DATA_PATH
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
def get_preprocessed_data(file_name, url):
print(f"Downloading {file_name} datasets\n")
full_name = os.path.join(dir_path, file_name)
download_file(url, full_name)
print("Downloading done.\n")
for file_name, url in [
('ofa_database_500000.pt', 'https://www.dropbox.com/scl/fi/0asz5qnvakf6ggucuynkk/ofa_database_500000.pt?rlkey=lqa1y4d6mikgzznevtanl2ybx&dl=1'),
('ridx-500000.pt', 'https://www.dropbox.com/scl/fi/ambrm9n5efdkyydmsli0h/ridx-500000.pt?rlkey=b6iliyuiaxya4ropms8chsa7c&dl=1'),
]:
get_preprocessed_data(file_name, url)

390
MobileNetV3/main_exp/nag.py Normal file
View File

@@ -0,0 +1,390 @@
from __future__ import print_function
import torch
import os
import gc
import sys
from tqdm import tqdm
import numpy as np
import time
import os
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from scipy.stats import pearsonr
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import load_graph_config, decode_ofa_mbv3_str_to_igraph
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import get_log
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import save_model, mean_confidence_interval
from transfer_nag_lib.MetaD2A_mobilenetV3.loader import get_meta_train_loader, MetaTestDataset
from transfer_nag_lib.encoder_FSBO_ofa import EncoderFSBO as PredictorModel
from transfer_nag_lib.MetaD2A_mobilenetV3.predictor import Predictor as MetaD2APredictor
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.train import train_single_model
from diffusion.run_lib import generate_archs
from diffusion.run_lib import get_sampling_fn_meta
from diffusion.run_lib import get_score_model
from diffusion.run_lib import get_predictor
sys.path.append(os.path.join(os.getcwd()))
from all_path import *
from utils import restore_checkpoint
class NAG:
def __init__(self, args, dgp_arch=[99, 50, 179, 194], bohb=False):
self.args = args
self.batch_size = args.batch_size
self.num_sample = args.num_sample
self.max_epoch = args.max_epoch
self.save_epoch = args.save_epoch
self.save_path = args.save_path
self.search_space = args.search_space
self.model_name = 'predictor'
self.test = args.test
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
self.max_corr_dict = {'corr': -1, 'epoch': -1}
self.train_arch = args.train_arch
self.use_metad2a_predictor_selec = args.use_metad2a_predictor_selec
self.raw_data_path = RAW_DATA_PATH
self.model_path = UNNOISE_META_PREDICTOR_CKPT_PATH
self.data_path = PROCESSED_DATA_PATH
self.classifier_ckpt_path = NOISE_META_PREDICTOR_CKPT_PATH
self.load_diffusion_model(self.args.n_training_samples, args.pos_enc_type)
graph_config = load_graph_config(
args.graph_data_name, args.nvt, self.data_path)
self.model = PredictorModel(args, graph_config, dgp_arch=dgp_arch)
self.metad2a_model = MetaD2APredictor(args).model
if self.test:
self.data_name = args.data_name
self.num_class = args.num_class
self.load_epoch = args.load_epoch
self.n_training_samples = self.args.n_training_samples
self.n_gen_samples = args.n_gen_samples
self.folder_name = args.folder_name
self.unique = args.unique
model_state_dict = self.model.state_dict()
load_max_pt = 'ckpt_max_corr.pt'
ckpt_path = os.path.join(self.model_path, load_max_pt)
ckpt = torch.load(ckpt_path)
for k, v in ckpt.items():
if k in model_state_dict.keys():
model_state_dict[k] = v
self.model.cpu()
self.model.load_state_dict(model_state_dict)
self.model.to(self.device)
self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)
self.scheduler = ReduceLROnPlateau(self.optimizer, 'min',
factor=0.1, patience=1000, verbose=True)
self.mtrloader = get_meta_train_loader(
self.batch_size, self.data_path, self.num_sample, is_pred=True)
self.acc_mean = self.mtrloader.dataset.mean
self.acc_std = self.mtrloader.dataset.std
def forward(self, x, arch, labels=None, train=False, matrix=False, metad2a=False):
if metad2a:
D_mu = self.metad2a_model.set_encode(x.to(self.device))
G_mu = self.metad2a_model.graph_encode(arch)
y_pred = self.metad2a_model.predict(D_mu, G_mu)
return y_pred
else:
D_mu = self.model.set_encode(x.to(self.device))
G_mu = self.model.graph_encode(arch, matrix=matrix)
y_pred, y_dist = self.model.predict(D_mu, G_mu, labels=labels, train=train)
return y_pred, y_dist
def meta_train(self):
sttime = time.time()
for epoch in range(1, self.max_epoch + 1):
self.mtrlog.ep_sttime = time.time()
loss, corr = self.meta_train_epoch(epoch)
self.scheduler.step(loss)
self.mtrlog.print_pred_log(loss, corr, 'train', epoch)
valoss, vacorr = self.meta_validation(epoch)
if self.max_corr_dict['corr'] < vacorr or epoch==1:
self.max_corr_dict['corr'] = vacorr
self.max_corr_dict['epoch'] = epoch
self.max_corr_dict['loss'] = valoss
save_model(epoch, self.model, self.model_path, max_corr=True)
self.mtrlog.print_pred_log(
valoss, vacorr, 'valid', max_corr_dict=self.max_corr_dict)
if epoch % self.save_epoch == 0:
save_model(epoch, self.model, self.model_path)
self.mtrlog.save_time_log()
self.mtrlog.max_corr_log(self.max_corr_dict)
def meta_train_epoch(self, epoch):
self.model.to(self.device)
self.model.train()
self.mtrloader.dataset.set_mode('train')
dlen = len(self.mtrloader.dataset)
trloss = 0
y_all, y_pred_all = [], []
pbar = tqdm(self.mtrloader)
for x, g, acc in pbar:
self.optimizer.zero_grad()
y_pred, y_dist = self.forward(x, g, labels=acc, train=True, matrix=False)
y = acc.to(self.device).double()
print(y.double())
print(y_dist)
loss = -self.model.mll(y_dist, y)
loss.backward()
self.optimizer.step()
y = y.tolist()
y_pred = y_pred.squeeze().tolist()
y_all += y
y_pred_all += y_pred
pbar.set_description(get_log(
epoch, loss, y_pred, y, self.acc_std, self.acc_mean))
trloss += float(loss)
return trloss / dlen, pearsonr(np.array(y_all),
np.array(y_pred_all))[0]
def meta_validation(self, epoch):
self.model.to(self.device)
self.model.eval()
valoss = 0
self.mtrloader.dataset.set_mode('valid')
dlen = len(self.mtrloader.dataset)
y_all, y_pred_all = [], []
pbar = tqdm(self.mtrloader)
with torch.no_grad():
for x, g, acc in pbar:
y_pred, y_dist = self.forward(x, g, labels=acc, train=False, matrix=False)
y = acc.to(self.device)
loss = -self.model.mll(y_dist, y)
y = y.tolist()
y_pred = y_pred.squeeze().tolist()
y_all += y
y_pred_all += y_pred
pbar.set_description(get_log(
epoch, loss, y_pred, y, self.acc_std, self.acc_mean, tag='val'))
valoss += float(loss)
try:
pearson_corr = pearsonr(np.array(y_all), np.array(y_pred_all))[0]
except Exception as e:
pearson_corr = 0
return valoss / dlen, pearson_corr
def meta_test(self):
if self.data_name == 'all':
for data_name in ['cifar10', 'cifar100', 'aircraft', 'pets']:
acc = self.meta_test_per_dataset(data_name)
else:
acc = self.meta_test_per_dataset(self.data_name)
return acc
def meta_test_per_dataset(self, data_name):
self.test_dataset = MetaTestDataset(
self.data_path, data_name, self.num_sample, self.num_class)
meta_test_path = self.args.exp_name
os.makedirs(meta_test_path, exist_ok=True)
f_arch_str = open(os.path.join(meta_test_path, 'architecture.txt'), 'w')
f = open(os.path.join(meta_test_path, 'accuracy.txt'), 'w')
elasped_time = []
print(f'==> select top architectures for {data_name} by meta-predictor...')
gen_arch_str = self.get_gen_arch_str()
gen_arch_igraph = [decode_ofa_mbv3_str_to_igraph(_) for _ in gen_arch_str]
y_pred_all = []
self.metad2a_model.eval()
self.metad2a_model.to(self.device)
# MetaD2A ver. prediction
sttime = time.time()
with torch.no_grad():
for i, arch_igraph in enumerate(gen_arch_igraph):
x, g = self.collect_data(arch_igraph)
y_pred = self.forward(x, g, metad2a=True)
y_pred = torch.mean(y_pred)
y_pred_all.append(y_pred.cpu().detach().item())
if self.use_metad2a_predictor_selec:
top_arch_lst = self.select_top_arch(
data_name, torch.tensor(y_pred_all), gen_arch_str, self.n_training_samples)
else:
top_arch_lst = gen_arch_str[:self.n_training_samples]
elasped = time.time() - sttime
elasped_time.append(elasped)
for _, arch_str in enumerate(top_arch_lst):
f_arch_str.write(f'{arch_str}\n'); print(f'neural architecture config: {arch_str}')
support = top_arch_lst
x_support = []
y_support = []
seeds = [777, 888, 999]
y_support_per_seed = {
_: [] for _ in seeds
}
net_info = {
'params': [],
'flops': [],
}
best_acc = 0.0
best_sampe_num = 0
print("Data name: %s" % data_name)
for i, arch_str in enumerate(support):
save_path = os.path.join(meta_test_path, arch_str)
os.makedirs(save_path, exist_ok=True)
acc_runs = []
for seed in seeds:
print(f'==> train for {data_name} {arch_str} ({seed})')
valid_acc, max_valid_acc, params, flops = train_single_model(save_path=save_path,
workers=8,
datasets=data_name,
xpaths=f'{self.raw_data_path}/{data_name}',
splits=[0],
use_less=False,
seed=seed,
model_str=arch_str,
device='cuda',
lr=0.01,
momentum=0.9,
weight_decay=4e-5,
report_freq=50,
epochs=20,
grad_clip=5,
cutout=True,
cutout_length=16,
autoaugment=True,
drop=0.2,
drop_path=0.2,
img_size=224)
acc_runs.append(valid_acc)
y_support_per_seed[seed].append(valid_acc)
for r, acc in enumerate(acc_runs):
msg = f'run {r + 1} {acc:.2f} (%)'
f.write(msg + '\n')
f.flush()
print(msg)
m, h = mean_confidence_interval(acc_runs)
if m > best_acc:
best_acc = m
best_sampe_num = i
msg = f'Avg {m:.3f}+-{h.item():.2f} (%) (best acc {best_acc:.3f} - #{i})'
f.write(msg + '\n')
print(msg)
y_support.append(np.mean(acc_runs))
x_support.append(arch_str)
net_info['params'].append(params)
net_info['flops'].append(flops)
torch.save({'y_support': y_support, 'x_support': x_support,
'y_support_per_seed': y_support_per_seed,
'net_info': net_info,
'best_acc': best_acc,
'best_sample_num': best_sampe_num},
meta_test_path+'/result.pt')
return None
def train_single_arch(self, data_name, arch_str, meta_test_path):
save_path = os.path.join(meta_test_path, arch_str)
seeds = (777, 888, 999)
train_single_model(save_path=save_path,
workers=24,
datasets=[data_name],
xpaths=[f'{self.raw_data_path}/{data_name}'],
splits=[0],
use_less=False,
seeds=seeds,
model_str=arch_str,
arch_config={'channel': 16, 'num_cells': 5})
# Changed training time from 49/199
epoch = 49 if data_name == 'mnist' else 199
test_acc_lst = []
for seed in seeds:
result = torch.load(os.path.join(save_path, f'seed-0{seed}.pth'))
test_acc_lst.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
return test_acc_lst
def select_top_arch(
self, data_name, y_pred_all, gen_arch_str, N):
_, sorted_idx = torch.sort(y_pred_all, descending=True)
sotred_gen_arch_str = [gen_arch_str[_] for _ in sorted_idx]
final_str = sotred_gen_arch_str[:N]
return final_str
def collect_data_only(self):
x_batch = []
x_batch.append(self.test_dataset[0])
return torch.stack(x_batch).to(self.device)
def collect_data(self, arch_igraph):
x_batch, g_batch = [], []
for _ in range(10):
x_batch.append(self.test_dataset[0])
g_batch.append(arch_igraph)
return torch.stack(x_batch).to(self.device), g_batch
def load_diffusion_model(self, n_training_samples, pos_enc_type):
self.config = torch.load(CONFIG_PATH)
self.config.data.root = SCORE_MODEL_DATA_PATH
self.config.scorenet_ckpt_path = SCORE_MODEL_CKPT_PATH
torch.save(self.config, CONFIG_PATH)
self.sampling_fn, self.sde = get_sampling_fn_meta(self.config)
self.sampling_fn_training_samples, _ = get_sampling_fn_meta(self.config, init=True, n_init=n_training_samples)
self.score_model, self.score_ema, self.score_config \
= get_score_model(self.config, pos_enc_type=pos_enc_type)
def get_gen_arch_str(self):
classifier_config = torch.load(self.classifier_ckpt_path)['config']
# Load meta-predictor
classifier_model = get_predictor(classifier_config)
classifier_state = dict(model=classifier_model, step=0, config=classifier_config)
classifier_state = restore_checkpoint(self.classifier_ckpt_path,
classifier_state, device=self.config.device, resume=True)
print(f'==> load checkpoint for our predictor: {self.classifier_ckpt_path}...')
with torch.no_grad():
x = self.collect_data_only()
generated_arch_str = generate_archs(
self.config,
self.sampling_fn,
self.score_model,
self.score_ema,
classifier_model,
num_samples=self.n_gen_samples,
patient_factor=self.args.patient_factor,
batch_size=self.args.eval_batch_size,
classifier_scale=self.args.classifier_scale,
task=x if self.args.fix_task else None)
gc.collect()
return generated_arch_str

View File

@@ -0,0 +1,154 @@
import os
import sys
import random
import numpy as np
import argparse
import torch
import os
from nag import NAG
# sys.path.append(os.getcwd())
# from utils import str2bool
def str2bool(v):
return v.lower() in ['t', 'true', True]
# save_path = "results"
# data_path = os.path.join('MetaD2A_nas_bench_201', 'data')
# model_load_path = '/home/data/GTAD/baselines/transferNAS'
def get_parser():
parser = argparse.ArgumentParser()
# general settings
parser.add_argument('--seed', type=int, default=444)
parser.add_argument('--gpu', type=str, default='0',
help='set visible gpus')
parser.add_argument('--search_space', type=str, default='ofa')
parser.add_argument('--save-path', type=str,
default=None, help='the path of save directory')
parser.add_argument('--data-path', type=str,
default=None, help='the path of save directory')
parser.add_argument('--model-load-path', type=str,
default=None, help='')
parser.add_argument('--save-epoch', type=int, default=20,
help='how many epochs to wait each time to save model states')
parser.add_argument('--max-epoch', type=int, default=50,
help='number of epochs to train')
parser.add_argument('--batch_size', type=int,
default=1024, help='batch size for generator')
parser.add_argument('--graph-data-name',
default='ofa', help='graph dataset name')
parser.add_argument('--nvt', type=int, default=27,
help='number of different node types')
# set encoder
parser.add_argument('--num-sample', type=int, default=20,
help='the number of images as input for set encoder')
# graph encoder
parser.add_argument('--hs', type=int, default=512,
help='hidden size of GRUs')
parser.add_argument('--nz', type=int, default=56,
help='the number of dimensions of latent vectors z')
# test
parser.add_argument('--test', action='store_true',
default=True, help='turn on test mode')
parser.add_argument('--load-epoch', type=int, default=100,
help='checkpoint epoch loaded for meta-test')
parser.add_argument('--data-name', type=str,
default='pets', help='meta-test dataset name')
parser.add_argument('--trials', type=int, default=5)
parser.add_argument('--num-class', type=int, default=None,
help='the number of class of dataset')
parser.add_argument('--num-gen-arch', type=int, default=500,
help='the number of candidate architectures generated by the generator')
parser.add_argument('--train-arch', type=str2bool, default=True,
help='whether to train the searched architecture')
parser.add_argument('--n_training_samples', type=int, default=5)
parser.add_argument('--N', type=int, default=10)
parser.add_argument('--use_gp', type=str2bool, default=False)
parser.add_argument('--sorting', type=str2bool, default=True)
parser.add_argument('--use_metad2a_predictor_selec', type=str2bool, default=True)
parser.add_argument('--use_ensemble_selec', type=str2bool, default=False)
# ---------- For diffusion NAG ------------ #
parser.add_argument('--folder_name', type=str, default='DiffusionNAG')
parser.add_argument('--task', type=str, default='mtst')
parser.add_argument('--exp_name', type=str, default='')
parser.add_argument('--wandb_exp_name', type=str, default='')
parser.add_argument('--wandb_project_name', type=str, default='DiffusionNAG')
parser.add_argument('--use_wandb', type=str2bool, default=False)
parser.add_argument('--classifier_scale', type=int, default=10000.0, help='classifier scale')
parser.add_argument('--eval_batch_size', type=int, default=256)
parser.add_argument('--predictor', type=str, default='euler_maruyama',
choices=['euler_maruyama', 'reverse_diffusion', 'none'])
parser.add_argument('--corrector', type=str, default='langevin',
choices=['none', 'langevin'])
parser.add_argument('--weight_ratio', type=str2bool, default=False)
parser.add_argument('--weight_scheduling', type=str2bool, default=False)
parser.add_argument('--weight_ratio_abs', type=str2bool, default=False)
parser.add_argument('--p', type=int, default=1)
parser.add_argument('--prod_w', type=str2bool, default=False)
parser.add_argument('--t_spot', type=float, default=1.0)
parser.add_argument('--t_spot_end', type=float, default=0.0)
# Train
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--save_best_p_corr', type=str2bool, default=True)
parser.add_argument('--unique', type=str2bool, default=True)
parser.add_argument('--patient_factor', type=int, default=20)
parser.add_argument('--n_gen_samples', type=int, default=50)
################ OFA ####################
parser.add_argument('--ofa_path', type=str, default='/home/hayeon/imagenet1k', help='')
parser.add_argument('--ofa_batch_size', type=int, default=256, help='')
parser.add_argument('--ofa_workers', type=int, default=4, help='')
################ Diffusion ##############
parser.add_argument('--diffusion_lr', type=float, default=1e-3, help='')
parser.add_argument('--noise_aware_acc_norm', type=int, default=-1)
parser.add_argument('--fix_task', type=str2bool, default=True)
################ BO ####################
parser.add_argument('--bo_loop_max_epoch', type=int, default=30)
parser.add_argument('--bo_loop_acc_norm', type=int, default=1)
parser.add_argument('--gp_model_acc_norm', type=int, default=1)
parser.add_argument('--num_ensemble', type=int, default=3)
parser.add_argument('--explore_type', type=str, default='ei')
################ BO ####################
# parser.add_argument('--multi_proc', type=str2bool, default=False)
parser.add_argument('--eps', type=float, default=0.)
parser.add_argument('--beta', type=float, default=0.5)
parser.add_argument('--pos_enc_type', type=int, default=4)
args = parser.parse_args()
return args
def set_exp_name(args):
exp_name = f'./exp/{args.task}/{args.folder_name}/data-{args.data_name}'
wandb_exp_name = f'./exp/{args.task}/{args.folder_name}/{args.data_name}'
os.makedirs(exp_name, exist_ok=True)
args.exp_name = exp_name
args.wandb_exp_name = wandb_exp_name
def main():
args = get_parser()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
set_exp_name(args)
p = NAG(args)
if args.test:
p.meta_test()
else:
p.meta_train()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,100 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jul 6 14:02:53 2021
@author: hsjomaa
"""
import numpy as np
from scipy.stats import norm
import pandas as pd
from torch import autograd as ag
import torch
from sklearn.preprocessing import PowerTransformer
def regret(output,response):
incumbent = output[0]
best_output = []
for _ in output:
incumbent = _ if _ > incumbent else incumbent
best_output.append(incumbent)
opt = max(response)
orde = list(np.sort(np.unique(response))[::-1])
tmp = pd.DataFrame(best_output,columns=['regret_validation'])
tmp['rank_valid'] = tmp['regret_validation'].map(lambda x : orde.index(x))
tmp['regret_validation'] = opt - tmp['regret_validation']
return tmp
def EI(incumbent, model_fn,support,queries,return_variance, return_score=False):
mu, stddev = model_fn(queries)
mu = mu.reshape(-1,)
stddev = stddev.reshape(-1,)
if return_variance:
stddev = np.sqrt(stddev)
with np.errstate(divide='warn'):
imp = mu - incumbent
Z = imp / stddev
score = imp * norm.cdf(Z) + stddev * norm.pdf(Z)
if not return_score:
score[support] = 0
return np.argmax(score)
else:
return score
class Metric(object):
def __init__(self,prefix='train: '):
self.reset()
self.message=prefix + "loss: {loss:.2f} - noise: {log_var:.2f} - mse: {mse:.2f}"
def update(self,loss,noise,mse):
self.loss.append(np.asscalar(loss))
self.noise.append(np.asscalar(noise))
self.mse.append(np.asscalar(mse))
def reset(self,):
self.loss = []
self.noise = []
self.mse = []
def report(self):
return self.message.format(loss=np.mean(self.loss),
log_var=np.mean(self.noise),
mse=np.mean(self.mse))
def get(self):
return {"loss":np.mean(self.loss),
"noise":np.mean(self.noise),
"mse":np.mean(self.mse)}
def totorch(x,device):
if type(x) is tuple:
return tuple([ag.Variable(torch.Tensor(e)).to(device) for e in x])
return torch.Tensor(x).to(device)
def prepare_data(indexes, support, Lambda, response, metafeatures=None, output_transform=False):
# Generate indexes of the batch
X,E,Z,y,r = [],[],[],[],[]
#### get support data
for dim in indexes:
if metafeatures is not None:
Z.append(metafeatures)
E.append(Lambda[support])
X.append(Lambda[dim])
r_ = response[support,np.newaxis]
y_ = response[dim]
if output_transform:
power = PowerTransformer(method="yeo-johnson")
r_ = power.fit_transform(r_)
y_ = power.transform(y_.reshape(-1,1)).reshape(-1,)
r.append(r_)
y.append(y_)
X = np.array(X)
E = np.array(E)
Z = np.array(Z)
y = np.array(y)
r = np.array(r)
return (np.expand_dims(E, axis=-1), r, np.expand_dims(X, axis=-1), Z), y

View File

@@ -0,0 +1,581 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jul 6 14:03:42 2021
@author: hsjomaa
"""
## Original packages
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler
import copy
import numpy as np
import os
# from torch.utils.tensorboard import SummaryWriter
import json
import time
## Our packages
import gpytorch
import logging
from transfer_nag_lib.DeepKernelGPHelpers import totorch,prepare_data, Metric, EI
from transfer_nag_lib.MetaD2A_nas_bench_201.generator import Generator
from transfer_nag_lib.MetaD2A_nas_bench_201.main import get_parser
np.random.seed(1203)
RandomQueryGenerator= np.random.RandomState(413)
RandomSupportGenerator= np.random.RandomState(413)
RandomTaskGenerator = np.random.RandomState(413)
class DeepKernelGP(nn.Module):
def __init__(self,X,Y,Z,kernel,backbone_fn, config, support,log_dir,seed):
super(DeepKernelGP, self).__init__()
torch.manual_seed(seed)
## GP parameters
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.X,self.Y,self.Z = X,Y,Z
self.feature_extractor = backbone_fn().to(self.device)
self.config=config
self.get_model_likelihood_mll(len(support),kernel,backbone_fn)
logging.basicConfig(filename=log_dir, level=logging.DEBUG)
def get_model_likelihood_mll(self, train_size,kernel,backbone_fn):
train_x=torch.ones(train_size, self.feature_extractor.out_features).to(self.device)
train_y=torch.ones(train_size).to(self.device)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPLayer(train_x=train_x, train_y=train_y, likelihood=likelihood, config=self.config,
dims=self.feature_extractor.out_features)
self.model = model.to(self.device)
self.likelihood = likelihood.to(self.device)
self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(self.device)
def set_forward(self, x, is_feature=False):
pass
def set_forward_loss(self, x):
pass
def train(self, support, load_model,optimizer, checkpoint=None,epochs=1000, verbose = False):
if load_model:
assert(checkpoint is not None)
print("KEYS MATCHED")
self.load_checkpoint(os.path.join(checkpoint,"weights"))
inputs,labels = prepare_data(support,support,self.X,self.Y,self.Z)
inputs,labels = totorch(inputs,device=self.device), totorch(labels.reshape(-1,),device=self.device)
losses = [np.inf]
best_loss = np.inf
starttime = time.time()
initial_weights = copy.deepcopy(self.state_dict())
patience=0
max_patience = self.config["patience"]
for _ in range(epochs):
optimizer.zero_grad()
z = self.feature_extractor(inputs)
self.model.set_train_data(inputs=z, targets=labels)
predictions = self.model(z)
try:
loss = -self.mll(predictions, self.model.train_targets)
loss.backward()
optimizer.step()
except Exception as ada:
logging.info(f"Exception {ada}")
break
if verbose:
print("Iter {iter}/{epochs} - Loss: {loss:.5f} noise: {noise:.5f}".format(
iter=_+1,epochs=epochs,loss=loss.item(),noise=self.likelihood.noise.item()))
losses.append(loss.detach().to("cpu").item())
if best_loss>losses[-1]:
best_loss = losses[-1]
weights = copy.deepcopy(self.state_dict())
if np.allclose(losses[-1],losses[-2],atol=self.config["loss_tol"]):
patience+=1
else:
patience=0
if patience>max_patience:
break
self.load_state_dict(weights)
logging.info(f"Current Iteration: {len(support)} | Incumbent {max(self.Y[support])} | Duration {np.round(time.time()-starttime)} | Epochs {_} | Noise {self.likelihood.noise.item()}")
return losses,weights,initial_weights
def load_checkpoint(self, checkpoint):
ckpt = torch.load(checkpoint,map_location=torch.device(self.device))
self.model.load_state_dict(ckpt['gp'],strict=False)
self.likelihood.load_state_dict(ckpt['likelihood'],strict=False)
self.feature_extractor.load_state_dict(ckpt['net'],strict=False)
def predict(self,support, query_range=None, noise_fn=None):
card = len(self.Y)
if noise_fn:
self.Y = noise_fn(self.Y)
x_support,y_support = prepare_data(support,support,
self.X,self.Y,self.Z)
if query_range is None:
x_query,_ = prepare_data(np.arange(card),support,
self.X,self.Y,self.Z)
else:
x_query,_ = prepare_data(query_range,support,
self.X,self.Y,self.Z)
self.model.eval()
self.feature_extractor.eval()
self.likelihood.eval()
z_support = self.feature_extractor(totorch(x_support,self.device)).detach()
self.model.set_train_data(inputs=z_support, targets=totorch(y_support.reshape(-1,),self.device), strict=False)
with torch.no_grad():
z_query = self.feature_extractor(totorch(x_query,self.device)).detach()
pred = self.likelihood(self.model(z_query))
mu = pred.mean.detach().to("cpu").numpy().reshape(-1,)
stddev = pred.stddev.detach().to("cpu").numpy().reshape(-1,)
return mu,stddev
class DKT(nn.Module):
def __init__(self, train_data,valid_data, kernel,backbone_fn, config):
super(DKT, self).__init__()
## GP parameters
self.train_data = train_data
self.valid_data = valid_data
self.fixed_context_size = config["fixed_context_size"]
self.minibatch_size = config["minibatch_size"]
self.n_inner_steps = config["n_inner_steps"]
self.checkpoint_path = config["checkpoint_path"]
os.makedirs(self.checkpoint_path,exist_ok=False)
json.dump(config, open(os.path.join(self.checkpoint_path,"configuration.json"),"w"))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(filename=os.path.join(self.checkpoint_path,"log.txt"), level=logging.DEBUG)
self.feature_extractor = backbone_fn().to(self.device)
self.config=config
self.get_model_likelihood_mll(self.fixed_context_size,kernel,backbone_fn)
self.mse = nn.MSELoss()
self.curr_valid_loss = np.inf
self.get_tasks()
self.setup_writers()
self.train_metrics = Metric()
self.valid_metrics = Metric(prefix="valid: ")
print(self)
def setup_writers(self,):
train_log_dir = os.path.join(self.checkpoint_path,"train")
os.makedirs(train_log_dir,exist_ok=True)
self.train_summary_writer = SummaryWriter(train_log_dir)
valid_log_dir = os.path.join(self.checkpoint_path,"valid")
os.makedirs(valid_log_dir,exist_ok=True)
self.valid_summary_writer = SummaryWriter(valid_log_dir)
def get_tasks(self,):
pairs = []
for space in self.train_data.keys():
for task in self.train_data[space].keys():
pairs.append([space,task])
self.tasks = pairs
##########
pairs = []
for space in self.valid_data.keys():
for task in self.valid_data[space].keys():
pairs.append([space,task])
self.valid_tasks = pairs
def get_model_likelihood_mll(self, train_size,kernel,backbone_fn):
train_x=torch.ones(train_size, self.feature_extractor.out_features).to(self.device)
train_y=torch.ones(train_size).to(self.device)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPLayer(train_x=train_x, train_y=train_y, likelihood=likelihood, config=self.config,dims = self.feature_extractor.out_features)
self.model = model.to(self.device)
self.likelihood = likelihood.to(self.device)
self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(self.device)
def set_forward(self, x, is_feature=False):
pass
def set_forward_loss(self, x):
pass
def epoch_end(self):
RandomTaskGenerator.shuffle(self.tasks)
def train_loop(self, epoch, optimizer, scheduler_fn=None):
if scheduler_fn:
scheduler = scheduler_fn(optimizer,len(self.tasks))
self.epoch_end()
assert(self.training)
for task in self.tasks:
inputs, labels = self.get_batch(task)
for _ in range(self.n_inner_steps):
optimizer.zero_grad()
z = self.feature_extractor(inputs)
self.model.set_train_data(inputs=z, targets=labels, strict=False)
predictions = self.model(z)
loss = -self.mll(predictions, self.model.train_targets)
loss.backward()
optimizer.step()
mse = self.mse(predictions.mean, labels)
self.train_metrics.update(loss,self.model.likelihood.noise,mse)
if scheduler_fn:
scheduler.step()
training_results = self.train_metrics.get()
for k,v in training_results.items():
self.train_summary_writer.add_scalar(k, v, epoch)
for task in self.valid_tasks:
mse,loss = self.test_loop(task,train=False)
self.valid_metrics.update(loss,np.array(0),mse,)
logging.info(self.train_metrics.report() + " " + self.valid_metrics.report())
validation_results = self.valid_metrics.get()
for k,v in validation_results.items():
self.valid_summary_writer.add_scalar(k, v, epoch)
self.feature_extractor.train()
self.likelihood.train()
self.model.train()
if validation_results["loss"] < self.curr_valid_loss:
self.save_checkpoint(os.path.join(self.checkpoint_path,"weights"))
self.curr_valid_loss = validation_results["loss"]
self.valid_metrics.reset()
self.train_metrics.reset()
def test_loop(self, task, train, optimizer=None): # no optimizer needed for GP
(x_support, y_support),(x_query,y_query) = self.get_support_and_queries(task,train)
z_support = self.feature_extractor(x_support).detach()
self.model.set_train_data(inputs=z_support, targets=y_support, strict=False)
self.model.eval()
self.feature_extractor.eval()
self.likelihood.eval()
with torch.no_grad():
z_query = self.feature_extractor(x_query).detach()
pred = self.likelihood(self.model(z_query))
loss = -self.mll(pred, y_query)
lower, upper = pred.confidence_region() #2 standard deviations above and below the mean
mse = self.mse(pred.mean, y_query)
return mse,loss
def get_batch(self,task):
# we want to fit the gp given context info to new observations
# task is an algorithm/dataset pair
space,task = task
Lambda,response = np.array(self.train_data[space][task]["X"]), MinMaxScaler().fit_transform(np.array(self.train_data[space][task]["y"])).reshape(-1,)
card, dim = Lambda.shape
support = RandomSupportGenerator.choice(np.arange(card),
replace=False,size=self.fixed_context_size)
remaining = np.setdiff1d(np.arange(card),support)
indexes = RandomQueryGenerator.choice(
remaining,replace=False,size=self.minibatch_size if len(remaining)>self.minibatch_size else len(remaining))
inputs,labels = prepare_data(support,indexes,Lambda,response,np.zeros(32))
inputs,labels = totorch(inputs,device=self.device), totorch(labels.reshape(-1,),device=self.device)
return inputs, labels
def get_support_and_queries(self,task, train=False):
# task is an algorithm/dataset pair
space,task = task
hpo_data = self.valid_data if not train else self.train_data
Lambda,response = np.array(hpo_data[space][task]["X"]), MinMaxScaler().fit_transform(np.array(hpo_data[space][task]["y"])).reshape(-1,)
card, dim = Lambda.shape
support = RandomSupportGenerator.choice(np.arange(card),
replace=False,size=self.fixed_context_size)
indexes = RandomQueryGenerator.choice(
np.setdiff1d(np.arange(card),support),replace=False,size=self.minibatch_size)
support_x,support_y = prepare_data(support,support,Lambda,response,np.zeros(32))
query_x,query_y = prepare_data(support,indexes,Lambda,response,np.zeros(32))
return (totorch(support_x,self.device),totorch(support_y.reshape(-1,),self.device)),\
(totorch(query_x,self.device),totorch(query_y.reshape(-1,),self.device))
def save_checkpoint(self, checkpoint):
# save state
gp_state_dict = self.model.state_dict()
likelihood_state_dict = self.likelihood.state_dict()
nn_state_dict = self.feature_extractor.state_dict()
torch.save({'gp': gp_state_dict, 'likelihood': likelihood_state_dict, 'net':nn_state_dict}, checkpoint)
def load_checkpoint(self, checkpoint):
ckpt = torch.load(checkpoint)
self.model.load_state_dict(ckpt['gp'])
self.likelihood.load_state_dict(ckpt['likelihood'])
self.feature_extractor.load_state_dict(ckpt['net'])
class ExactGPLayer(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood,config,dims ):
super(ExactGPLayer, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
## RBF kernel
if(config["kernel"]=='rbf' or config["kernel"]=='RBF'):
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=dims if config["ard"] else None))
elif(config["kernel"]=='52'):
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=config["nu"],ard_num_dims=dims if config["ard"] else None))
## Spectral kernel
else:
raise ValueError("[ERROR] the kernel '" + str(config["kernel"]) + "' is not supported for regression, use 'rbf' or 'spectral'.")
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
class batch_mlp(nn.Module):
def __init__(self, d_in, output_sizes, nonlinearity="relu",dropout=0.0):
super(batch_mlp, self).__init__()
assert(nonlinearity=="relu")
self.nonlinearity = nn.ReLU()
self.fc = nn.ModuleList([nn.Linear(in_features=d_in, out_features=output_sizes[0])])
for d_out in output_sizes[1:]:
self.fc.append(nn.Linear(in_features=self.fc[-1].out_features, out_features=d_out))
self.out_features = output_sizes[-1]
self.dropout = nn.Dropout(dropout)
def forward(self,x):
for fc in self.fc[:-1]:
x = fc(x)
x = self.dropout(x)
x = self.nonlinearity(x)
x = self.fc[-1](x)
x = self.dropout(x)
return x
class StandardDeepGP(nn.Module):
def __init__(self, configuration):
super(StandardDeepGP, self).__init__()
self.A = batch_mlp(configuration["dim"], configuration["output_size_A"],dropout=configuration["dropout"])
self.out_features = configuration["output_size_A"][-1]
def forward(self, x):
# e,r,x,z = x
hidden = self.A(x.squeeze(dim=-1)) ### NxA
return hidden
class DKTNAS(nn.Module):
def __init__(self, kernel, backbone_fn, config, pretrained_encoder=True, GP_only=False):
super(DKTNAS, self).__init__()
## GP parameters
self.fixed_context_size = config["fixed_context_size"]
self.minibatch_size = config["minibatch_size"]
self.n_inner_steps = config["n_inner_steps"]
self.set_encoder_args = get_parser()
if not os.path.exists(self.set_encoder_args.save_path):
os.makedirs(self.set_encoder_args.save_path)
self.set_encoder_args.model_path = os.path.join(self.set_encoder_args.save_path,
self.set_encoder_args.model_name, 'model')
if not os.path.exists(self.set_encoder_args.model_path):
os.makedirs(self.set_encoder_args.model_path)
self.set_encoder = Generator(self.set_encoder_args)
if pretrained_encoder:
self.dataset_enc, self.arch, self.acc = self.set_encoder.train_dgp(encode=False)
self.dataset_enc_val, self.acc_val = self.set_encoder.test_dgp(data_name='cifar100', encode=False)
else: # In case we want to train the set-encoder from scratch
self.dataset_enc = np.load("train_data_path.npy")
self.acc = np.load("train_acc.npy")
self.dataset_enc_val = np.load("cifar100_data_path.npy")
self.acc_val = np.load("cifar100_acc.npy")
self.valid_data = self.dataset_enc_val
self.checkpoint_path = config["checkpoint_path"]
os.makedirs(self.checkpoint_path, exist_ok=False)
json.dump(config, open(os.path.join(self.checkpoint_path, "configuration.json"), "w"))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(filename=os.path.join(self.checkpoint_path, "log.txt"), level=logging.DEBUG)
self.feature_extractor = backbone_fn().to(self.device)
self.config = config
self.GP_only = GP_only
self.get_model_likelihood_mll(self.fixed_context_size, kernel, backbone_fn)
self.mse = nn.MSELoss()
self.curr_valid_loss = np.inf
# self.get_tasks()
self.setup_writers()
self.train_metrics = Metric()
self.valid_metrics = Metric(prefix="valid: ")
self.tasks = len(self.dataset_enc)
print(self)
def setup_writers(self, ):
train_log_dir = os.path.join(self.checkpoint_path, "train")
os.makedirs(train_log_dir, exist_ok=True)
# self.train_summary_writer = SummaryWriter(train_log_dir)
valid_log_dir = os.path.join(self.checkpoint_path, "valid")
os.makedirs(valid_log_dir, exist_ok=True)
# self.valid_summary_writer = SummaryWriter(valid_log_dir)
def get_model_likelihood_mll(self, train_size, kernel, backbone_fn):
if not self.GP_only:
train_x = torch.ones(train_size, self.feature_extractor.out_features).to(self.device)
train_y = torch.ones(train_size).to(self.device)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPLayer(train_x=None, train_y=None, likelihood=likelihood, config=self.config,
dims=self.feature_extractor.out_features)
else:
train_x = torch.ones(train_size, self.fixed_context_size).to(self.device)
train_y = torch.ones(train_size).to(self.device)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPLayer(train_x=None, train_y=None, likelihood=likelihood, config=self.config,
dims=self.fixed_context_size)
self.model = model.to(self.device)
self.likelihood = likelihood.to(self.device)
self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(self.device)
def set_forward(self, x, is_feature=False):
pass
def set_forward_loss(self, x):
pass
def epoch_end(self):
RandomTaskGenerator.shuffle([1])
def train_loop(self, epoch, optimizer, scheduler_fn=None):
if scheduler_fn:
scheduler = scheduler_fn(optimizer, 1)
self.epoch_end()
assert (self.training)
for task in range(self.tasks):
inputs, labels = self.get_batch(task)
for _ in range(self.n_inner_steps):
optimizer.zero_grad()
z = self.feature_extractor(inputs)
self.model.set_train_data(inputs=z, targets=labels, strict=False)
predictions = self.model(z)
loss = -self.mll(predictions, self.model.train_targets)
loss.backward()
optimizer.step()
mse = self.mse(predictions.mean, labels)
self.train_metrics.update(loss, self.model.likelihood.noise, mse)
if scheduler_fn:
scheduler.step()
training_results = self.train_metrics.get()
for k, v in training_results.items():
self.train_summary_writer.add_scalar(k, v, epoch)
mse, loss = self.test_loop(train=False)
self.valid_metrics.update(loss, np.array(0), mse, )
logging.info(self.train_metrics.report() + " " + self.valid_metrics.report())
validation_results = self.valid_metrics.get()
for k, v in validation_results.items():
self.valid_summary_writer.add_scalar(k, v, epoch)
self.feature_extractor.train()
self.likelihood.train()
self.model.train()
if validation_results["loss"] < self.curr_valid_loss:
self.save_checkpoint(os.path.join(self.checkpoint_path, "weights"))
self.curr_valid_loss = validation_results["loss"]
self.valid_metrics.reset()
self.train_metrics.reset()
def test_loop(self, train=None, optimizer=None): # no optimizer needed for GP
(x_support, y_support), (x_query, y_query) = self.get_support_and_queries(train)
z_support = self.feature_extractor(x_support).detach()
self.model.set_train_data(inputs=z_support, targets=y_support, strict=False)
self.model.eval()
self.feature_extractor.eval()
self.likelihood.eval()
with torch.no_grad():
z_query = self.feature_extractor(x_query).detach()
pred = self.likelihood(self.model(z_query))
loss = -self.mll(pred, y_query)
lower, upper = pred.confidence_region() # 2 standard deviations above and below the mean
mse = self.mse(pred.mean, y_query)
return mse, loss
def get_batch(self, task, valid=False):
# we want to fit the gp given context info to new observations
#TODO: scale the response as in FSBO(needed for train)
Lambda, response = np.array(self.dataset_enc), np.array(self.acc)
inputs, labels = Lambda[task], response[task]
inputs, labels = totorch([inputs], device=self.device), totorch([labels], device=self.device)
return inputs, labels
def get_support_and_queries(self, task, train=False):
# TODO: scale the response as in FSBO(not necessary for test)
Lambda, response = np.array(self.dataset_enc_val), np.array(self.acc_val)
card, dim = Lambda.shape
support = RandomSupportGenerator.choice(np.arange(card),
replace=False, size=self.fixed_context_size)
indexes = RandomQueryGenerator.choice(
np.setdiff1d(np.arange(card), support), replace=False, size=self.minibatch_size)
support_x, support_y = Lambda[support], response[support]
query_x, query_y = Lambda[indexes], response[indexes]
return (totorch(support_x, self.device), totorch(support_y.reshape(-1, ), self.device)), \
(totorch(query_x, self.device), totorch(query_y.reshape(-1, ), self.device))
def save_checkpoint(self, checkpoint):
# save state
gp_state_dict = self.model.state_dict()
likelihood_state_dict = self.likelihood.state_dict()
nn_state_dict = self.feature_extractor.state_dict()
torch.save({'gp': gp_state_dict, 'likelihood': likelihood_state_dict, 'net': nn_state_dict}, checkpoint)
def load_checkpoint(self, checkpoint):
ckpt = torch.load(checkpoint)
self.model.load_state_dict(ckpt['gp'])
self.likelihood.load_state_dict(ckpt['likelihood'])
self.feature_extractor.load_state_dict(ckpt['net'])
def predict(self, x_support, y_support, x_query, y_query, GP_only=False):
if not GP_only:
z_support = self.feature_extractor(x_support).detach()
else:
z_support = x_support
self.model.set_train_data(inputs=z_support, targets=y_support, strict=False)
self.model.eval()
self.feature_extractor.eval()
self.likelihood.eval()
with torch.no_grad():
if not GP_only:
z_query = self.feature_extractor(x_query).detach()
else:
z_query = x_query
pred = self.likelihood(self.model(z_query))
mu = pred.mean.detach().to("cpu").numpy().reshape(-1, )
stddev = pred.stddev.detach().to("cpu").numpy().reshape(-1, )
return mu, stddev

View File

@@ -0,0 +1,168 @@
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets
This code is for MobileNetV3 Search Space experiments
## Prerequisites
- Python 3.6 (Anaconda)
- PyTorch 1.6.0
- CUDA 10.2
- python-igraph==0.8.2
- tqdm==4.50.2
- torchvision==0.7.0
- python-igraph==0.8.2
- scipy==1.5.2
- ofa==0.0.4-2007200808
## MobileNetV3 Search Space
Go to the folder for MobileNetV3 experiments (i.e. ```MetaD2A_mobilenetV3```)
The overall flow is summarized as follows:
- Building database for Predictor
- Meta-Training Predictor
- Building database for Generator with trained Predictor
- Meta-Training Generator
- Meta-Testing (Searching)
- Evaluating the Searched architecture
## Data Preparation
To download preprocessed data files, run ```get_files/get_preprocessed_data.py```:
```shell script
$ python get_files/get_preprocessed_data.py
```
It will take some time to download and preprocess each dataset.
## Meta Test and Evaluation
### Meta-Test
You can download trained checkpoint files for generator and predictor
```shell script
$ python get_files/get_generator_checkpoint.py
$ python get_files/get_predictor_checkpoint.py
```
If you want to meta-test with your own dataset, please first make your own preprocessed data,
by modifying ```process_dataset.py``` .
```shell script
$ process_dataset.py
```
This code automatically generates neural architecturess and then
selects high-performing architectures among the candidates.
By setting ```--data-name``` as the name of dataset (i.e. ```cifar10```, ```cifar100```, ```aircraft100```, ```pets```),
you can evaluate the specific dataset.
```shell script
# Meta-testing
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 --test --load-epoch 120 --num-gen-arch 200 --data-name {DATASET_NAME}
```
### Arhictecture Evaluation (MetaD2A vs NSGANetV2)
##### Dataset Preparation
You need to download Oxford-IIIT Pet dataset to evaluate on ```--data-name pets```
```shell script
$ python get_files/get_pets.py
```
Every others ```cifar10```, ```cifar100```, ```aircraft100``` will be downloaded automatically.
##### evaluation
You can run the searched architecture by running ```evaluation/main```. Codes are based on NSGANetV2.
Go to the evaluation folder (i.e. ```evaluation```)
```shell script
$ cd evaluation
```
This automatically run the top 1 predicted architecture derived by MetaD2A.
```shell script
python main.py --data-name cifar10 --num-gen-arch 200
```
You can also give flop constraint by using ```bound``` option.
```shell script
python main.py --data-name cifar10 --num-gen-arch 200 --bound 300
```
You can compare MetaD2A with NSGANetV2
but you need to download some files provided
by [NSGANetV2](https://github.com/human-analysis/nsganetv2)
```shell script
python main.py --data-name cifar10 --num-gen-arch 200 --model-config flops@232
```
## Meta-Training MetaD2A Model
To build database for Meta-training, you need to set ```IMGNET_PATH```, which is a directory of ILSVRC2021.
### Database Building for Predictor
We recommend you to run the multiple ```create_database.sh``` simultaneously to build fast.
You need to set ```IMGNET_PATH``` in the shell script.
```shell script
# Examples
bash create_database.sh 0,1,2,3 0 49 predictor
bash create_database.sh all 50 99 predictor
...
```
After enough dataset is gathered, run ```build_database.py``` to collect them as one file.
```shell script
python build_database.py --model_name predictor --collect
```
We also provide the database we use. To download database, run ```get_files/get_predictor_database.py```:
```shell script
$ python get_files/get_predictor_database.py
```
### Meta-Train Predictor
You can train the predictor as follows
```shell script
# Meta-training for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56
```
### Database Building for Generator
We recommend you to run the multiple ```create_database.sh``` simultaneously to build fast.
```shell script
# Examples
bash create_database.sh 4,5,6,7 0 49 generator
bash create_database.sh all 50 99 generator
...
```
After enough dataset is gathered, run ```build_database.py``` to collect them as one.
```shell script
python build_database.py --model_name generator --collect
```
We also provide the database we use. To download database, run ```get_files/get_generator_database.py```
```shell script
$ python get_files/get_generator_database.py
```
### Meta-Train Generator
You can train the generator as follows
```shell script
# Meta-training for generator
$ python main.py --gpu 0 --model generator --hs 56 --nz 56
```
## Citation
If you found the provided code useful, please cite our work.
```
@inproceedings{
lee2021rapid,
title={Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets},
author={Hayeon Lee and Eunyoung Hyung and Sung Ju Hwang},
booktitle={ICLR},
year={2021}
}
```
## Reference
- [Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks (ICML2019)](https://github.com/juho-lee/set_transformer)
- [D-VAE: A Variational Autoencoder for Directed Acyclic Graphs, Advances in Neural Information Processing Systems (NeurIPS2019)](https://github.com/muhanzhang/D-VAE)
- [Once for All: Train One Network and Specialize it for Efficient Deployment (ICLR2020)](https://github.com/mit-han-lab/once-for-all)
- [NSGANetV2: Evolutionary Multi-Objective Surrogate-Assisted Neural Architecture Search (ECCV2020)](https://github.com/human-analysis/nsganetv2)

View File

@@ -0,0 +1,49 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
import random
import numpy as np
import torch
from parser import get_parser
from predictor import PredictorModel
from database import DatabaseOFA
from utils import load_graph_config
def main():
args = get_parser()
if args.gpu == 'all':
device_list = range(torch.cuda.device_count())
args.gpu = ','.join(str(_) for _ in device_list)
else:
device_list = [int(_) for _ in args.gpu.split(',')]
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.device = torch.device("cuda:0")
args.batch_size = args.batch_size * max(len(device_list), 1)
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
args.model_path = os.path.join(args.save_path, args.model_name, 'model')
if args.model_name == 'generator':
graph_config = load_graph_config(
args.graph_data_name, args.nvt, args.data_path)
model = PredictorModel(args, graph_config)
d = DatabaseOFA(args, model)
else:
d = DatabaseOFA(args)
if args.collect:
d.collect_db()
else:
assert args.index is not None
assert args.imgnet is not None
d.make_db()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,15 @@
#bash create_database.sh all predictor 0 49
IMGNET_PATH='/w14/dataset/ILSVRC2012' # PUT YOUR ILSVRC2012 DIR
for ((ind=$2;ind<=$3;ind++))
do
python build_database.py --gpu $1 \
--model_name $4 \
--index $ind \
--imgnet $IMGNET_PATH \
--hs 512 \
--nz 56
done

View File

@@ -0,0 +1,5 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from .db_ofa import DatabaseOFA

View File

@@ -0,0 +1,57 @@
######################################################################################
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import numpy as np
import torch
__all__ = ['DataProvider']
class DataProvider:
SUB_SEED = 937162211 # random seed for sampling subset
VALID_SEED = 2147483647 # random seed for the validation set
@staticmethod
def name():
""" Return name of the dataset """
raise NotImplementedError
@property
def data_shape(self):
""" Return shape as python list of one data entry """
raise NotImplementedError
@property
def n_classes(self):
""" Return `int` of num classes """
raise NotImplementedError
@property
def save_path(self):
""" local path to save the data """
raise NotImplementedError
@property
def data_url(self):
""" link to download the data """
raise NotImplementedError
@staticmethod
def random_sample_valid_set(train_size, valid_size):
assert train_size > valid_size
g = torch.Generator()
g.manual_seed(DataProvider.VALID_SEED) # set random seed before sampling validation set
rand_indexes = torch.randperm(train_size, generator=g).tolist()
valid_indexes = rand_indexes[:valid_size]
train_indexes = rand_indexes[valid_size:]
return train_indexes, valid_indexes
@staticmethod
def labels_to_one_hot(n_classes, labels):
new_labels = np.zeros((labels.shape[0], n_classes), dtype=np.float32)
new_labels[range(labels.shape[0]), labels] = np.ones(labels.shape)
return new_labels

View File

@@ -0,0 +1,107 @@
import os
import torch
import time
import copy
import glob
from .imagenet import ImagenetDataProvider
from .imagenet_loader import ImagenetRunConfig
from .run_manager import RunManager
from ofa.model_zoo import ofa_net
class DatabaseOFA:
def __init__(self, args, predictor=None):
self.path = f'{args.data_path}/{args.model_name}'
self.model_name = args.model_name
self.index = args.index
self.args = args
self.predictor = predictor
ImagenetDataProvider.DEFAULT_PATH = args.imgnet
if not os.path.exists(self.path):
os.makedirs(self.path)
def make_db(self):
self.ofa_network = ofa_net('ofa_mbv3_d234_e346_k357_w1.0', pretrained=True)
self.run_config = ImagenetRunConfig(test_batch_size=self.args.batch_size,
n_worker=20)
database = []
st_time = time.time()
f = open(f'{self.path}/txt_{self.index}.txt', 'w')
for dn in range(10000):
best_pp = -1
best_info = None
dls = None
with torch.no_grad():
if self.model_name == 'generator':
for i in range(10):
net_setting = self.ofa_network.sample_active_subnet()
subnet = self.ofa_network.get_active_subnet(preserve_weight=True)
if i == 0:
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet,
self.run_config, init=False, pp=self.predictor)
self.run_config.data_provider.assign_active_img_size(224)
dls = {j: copy.deepcopy(run_manager.data_loader) for j in range(1, 10)}
else:
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet,
self.run_config,
init=False, data_loader=dls[i], pp=self.predictor)
run_manager.reset_running_statistics(net=subnet)
loss, (top1, top5), pred_acc \
= run_manager.validate(net=subnet, net_setting=net_setting)
if best_pp < pred_acc:
best_pp = pred_acc
print('[%d] class=%d,\t loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (
dn, len(run_manager.cls_lst), loss, top1, top5))
info_dict = {'loss': loss,
'top1': top1,
'top5': top5,
'net': net_setting,
'class': run_manager.cls_lst,
'params': run_manager.net_info['params'],
'flops': run_manager.net_info['flops'],
'test_transform': run_manager.test_transform
}
best_info = info_dict
elif self.model_name == 'predictor':
net_setting = self.ofa_network.sample_active_subnet()
subnet = self.ofa_network.get_active_subnet(preserve_weight=True)
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet, self.run_config, init=False)
self.run_config.data_provider.assign_active_img_size(224)
run_manager.reset_running_statistics(net=subnet)
loss, (top1, top5), _ = run_manager.validate(net=subnet)
print('[%d] class=%d,\t loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (
dn, len(run_manager.cls_lst), loss, top1, top5))
best_info = {'loss': loss,
'top1': top1,
'top5': top5,
'net': net_setting,
'class': run_manager.cls_lst,
'params': run_manager.net_info['params'],
'flops': run_manager.net_info['flops'],
'test_transform': run_manager.test_transform
}
database.append(best_info)
if (len(database)) % 10 == 0:
msg = f'{(time.time() - st_time) / 60.0:0.2f}(min) save {len(database)} database, {self.index} id'
print(msg)
f.write(msg + '\n')
f.flush()
torch.save(database, f'{self.path}/database_{self.index}.pt')
def collect_db(self):
if not os.path.exists(self.path + f'/processed'):
os.makedirs(self.path + f'/processed')
database = []
dlst = glob.glob(self.path + '/*.pt')
for filepath in dlst:
database += torch.load(filepath)
assert len(database) != 0
print(f'The number of database: {len(database)}')
torch.save(database, self.path + f'/processed/collected_database.pt')

View File

@@ -0,0 +1,240 @@
######################################################################################
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import warnings
import os
import torch
import math
import numpy as np
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa_local.imagenet_classification.data_providers.base_provider import DataProvider
from ofa_local.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
from .metaloader import MetaImageNetDataset, EpisodeSampler, MetaDataLoader
__all__ = ['ImagenetDataProvider']
class ImagenetDataProvider(DataProvider):
DEFAULT_PATH = '/dataset/imagenet'
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = 'None' if distort_color is None else distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
from ofa.utils.my_dataloader import MyDataLoader
assert isinstance(self.image_size, list)
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size) # active resolution for test
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
########################## modification ########################
train_dataset = self.train_dataset(self.build_train_transform())
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, True, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, True, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
self.valid = None
# test_dataset = self.test_dataset(valid_transforms)
test_dataset = self.meta_test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
# self.test = torch.utils.data.DataLoader(
# test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
# )
sampler = EpisodeSampler(
max_way=1000, query=10, ylst=test_dataset.ylst)
self.test = MetaDataLoader(dataset=test_dataset,
sampler=sampler,
batch_size=test_batch_size,
shuffle=False,
num_workers=4)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'imagenet'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 1000
@property
def save_path(self):
if self._save_path is None:
self._save_path = self.DEFAULT_PATH
if not os.path.exists(self._save_path):
self._save_path = os.path.expanduser('~/dataset/imagenet')
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
return datasets.ImageFolder(self.train_path, _transforms)
def test_dataset(self, _transforms):
return datasets.ImageFolder(self.valid_path, _transforms)
def meta_test_dataset(self, _transforms):
return MetaImageNetDataset('val', max_way=1000, query=10,
dpath='/w14/dataset/ILSVRC2012', transform=_transforms)
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self.save_path, 'val')
@property
def normalize(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
# random_resize_crop -> random_horizontal_flip
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
# color augmentation (optional)
color_transform = None
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting BN running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, True, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,40 @@
from .imagenet import ImagenetDataProvider
from ofa_local.imagenet_classification.run_manager import RunConfig
__all__ = ['ImagenetRunConfig']
class ImagenetRunConfig(RunConfig):
def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='imagenet', train_batch_size=256, test_batch_size=500, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None,
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224, **kwargs):
super(ImagenetRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == ImagenetDataProvider.name():
DataProviderClass = ImagenetDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']

View File

@@ -0,0 +1,210 @@
from torch.utils.data.sampler import Sampler
import os
import random
from PIL import Image
from collections import defaultdict
import torch
from torch.utils.data import Dataset, DataLoader
import glob
class RandCycleIter:
'''
Return data_list per class
Shuffle the returning order after one epoch
'''
def __init__ (self, data, shuffle=True):
self.data_list = list(data)
self.length = len(self.data_list)
self.i = self.length - 1
self.shuffle = shuffle
def __iter__ (self):
return self
def __next__ (self):
self.i += 1
if self.i == self.length:
self.i = 0
if self.shuffle:
random.shuffle(self.data_list)
return self.data_list[self.i]
class EpisodeSampler(Sampler):
def __init__(self, max_way, query, ylst):
self.max_way = max_way
self.query = query
self.ylst = ylst
# self.n_epi = n_epi
clswise_xidx = defaultdict(list)
for i, y in enumerate(ylst):
clswise_xidx[y].append(i)
self.cws_xidx_iter = [RandCycleIter(cxidx, shuffle=True)
for cxidx in clswise_xidx.values()]
self.n_cls = len(clswise_xidx)
self.create_episode()
def __iter__ (self):
return self.get_index()
def __len__ (self):
return self.get_len()
def create_episode(self):
self.way = torch.randperm(int(self.max_way/10.0)-1)[0] * 10 + 10
cls_lst = torch.sort(torch.randperm(self.max_way)[:self.way])[0]
self.cls_itr = iter(cls_lst)
self.cls_lst = cls_lst
def get_len(self):
return self.way * self.query
def get_index(self):
x_itr = self.cws_xidx_iter
i, j = 0, 0
while i < self.query * self.way:
if j >= self.query:
j = 0
if j == 0:
cls_idx = next(self.cls_itr).item()
bb = [x_itr[cls_idx]] * self.query
didx = next(zip(*bb))
yield didx[j]
# yield (didx[j], self.way)
i += 1; j += 1
class MetaImageNetDataset(Dataset):
def __init__(self, mode='val',
max_way=1000, query=10,
dpath='/w14/dataset/ILSVRC2012', transform=None):
self.dpath = dpath
self.transform = transform
self.mode = mode
self.max_way = max_way
self.query = query
classes, class_to_idx = self._find_classes(dpath+'/'+mode)
self.classes, self.class_to_idx = classes, class_to_idx
# self.class_folder_lst = \
# glob.glob(dpath+'/'+mode+'/*')
# ## sorting alphabetically
# self.class_folder_lst = sorted(self.class_folder_lst)
self.file_path_lst, self.ylst = [], []
for cls in classes:
xlst = glob.glob(dpath+'/'+mode+'/'+cls+'/*')
self.file_path_lst += xlst[:self.query]
y = class_to_idx[cls]
self.ylst += [y] * len(xlst[:self.query])
# for y, cls in enumerate(self.class_folder_lst):
# xlst = glob.glob(cls+'/*')
# self.file_path_lst += xlst[:self.query]
# self.ylst += [y] * len(xlst[:self.query])
# # self.file_path_lst += [xlst[_] for _ in
# # torch.randperm(len(xlst))[:self.query]]
# # self.ylst += [cls.split('/')[-1]] * len(xlst)
self.way_idx = 0
self.x_idx = 0
self.way = 2
self.cls_lst = None
def __len__(self):
return self.way * self.query
def __getitem__(self, index):
# if self.way != index[1]:
# self.way = index[1]
# index = index[0]
x = Image.open(
self.file_path_lst[index]).convert('RGB')
if self.transform is not None:
x = self.transform(x)
cls_name = self.ylst[index]
y = self.cls_lst.index(cls_name)
# y = self.way_idx
# self.x_idx += 1
# if self.x_idx == self.query:
# self.way_idx += 1
# self.x_idx = 0
# if self.way_idx == self.way:
# self.way_idx = 0
# self.x_idx = 0
return x, y #, cls_name # y # cls_name #y
def _find_classes(self, dir: str):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
class MetaDataLoader(DataLoader):
def __init__(self,
dataset, sampler, batch_size, shuffle, num_workers):
super(MetaDataLoader, self).__init__(
dataset=dataset,
sampler=sampler,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers)
def create_episode(self):
self.sampler.create_episode()
self.dataset.way = self.sampler.way
self.dataset.cls_lst = self.sampler.cls_lst.tolist()
def get_cls_idx(self):
return self.sampler.cls_lst
def get_loader(mode='val', way=10, query=10,
n_epi=100, dpath='/w14/dataset/ILSVRC2012',
transform=None):
trans = get_transforms(mode)
dataset = MetaImageNetDataset(mode, way, query, dpath, trans)
sampler = EpisodeSampler(
way, query, n_epi, dataset.ylst)
dataset.way = sampler.way
dataset.cls_lst = sampler.cls_lst
loader = MetaDataLoader(dataset=dataset,
sampler=sampler,
batch_size=10,
shuffle=False,
num_workers=4)
return loader
# trloader = get_loader()
# trloader.create_episode()
# print(len(trloader))
# print(trloader.dataset.way)
# print(trloader.sampler.way)
# for i, episode in enumerate(trloader, start=1):
# print(episode[2])

View File

@@ -0,0 +1,302 @@
######################################################################################
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import os
import json
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from tqdm import tqdm
from utils import decode_ofa_mbv3_to_igraph
from ofa_local.utils import get_net_info, cross_entropy_loss_with_soft_target, cross_entropy_with_label_smoothing
from ofa_local.utils import AverageMeter, accuracy, write_log, mix_images, mix_labels, init_models
__all__ = ['RunManager']
import torchvision.models as models
class RunManager:
def __init__(self, path, args, net, run_config, init=True, measure_latency=None,
no_gpu=False, data_loader=None, pp=None):
self.path = path
self.mode = args.model_name
self.net = net
self.run_config = run_config
self.best_acc = 0
self.start_epoch = 0
os.makedirs(self.path, exist_ok=True)
# dataloader
if data_loader is not None:
self.data_loader = data_loader
cls_lst = self.data_loader.get_cls_idx()
self.cls_lst = cls_lst
else:
self.data_loader = self.run_config.valid_loader
self.data_loader.create_episode()
cls_lst = self.data_loader.get_cls_idx()
self.cls_lst = cls_lst
state_dict = self.net.classifier.state_dict()
new_state_dict = {'weight': state_dict['linear.weight'][cls_lst],
'bias': state_dict['linear.bias'][cls_lst]}
self.net.classifier = nn.Linear(1280, len(cls_lst), bias=True)
self.net.classifier.load_state_dict(new_state_dict)
# move network to GPU if available
if torch.cuda.is_available() and (not no_gpu):
self.device = torch.device('cuda:0')
self.net = self.net.to(self.device)
cudnn.benchmark = True
else:
self.device = torch.device('cpu')
# net info
net_info = get_net_info(
self.net, self.run_config.data_provider.data_shape, measure_latency, False)
self.net_info = net_info
self.test_transform = self.run_config.data_provider.test.dataset.transform
# criterion
if isinstance(self.run_config.mixup_alpha, float):
self.train_criterion = cross_entropy_loss_with_soft_target
elif self.run_config.label_smoothing > 0:
self.train_criterion = \
lambda pred, target: cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
else:
self.train_criterion = nn.CrossEntropyLoss()
self.test_criterion = nn.CrossEntropyLoss()
# optimizer
if self.run_config.no_decay_keys:
keys = self.run_config.no_decay_keys.split('#')
net_params = [
self.network.get_parameters(keys, mode='exclude'), # parameters with weight decay
self.network.get_parameters(keys, mode='include'), # parameters without weight decay
]
else:
# noinspection PyBroadException
try:
net_params = self.network.weight_parameters()
except Exception:
net_params = []
for param in self.network.parameters():
if param.requires_grad:
net_params.append(param)
self.optimizer = self.run_config.build_optimizer(net_params)
self.net = torch.nn.DataParallel(self.net)
if self.mode == 'generator':
# PP
save_dir = f'{args.save_path}/predictor/model/ckpt_max_corr.pt'
self.acc_predictor = pp.to('cuda')
self.acc_predictor.load_state_dict(torch.load(save_dir))
self.acc_predictor = torch.nn.DataParallel(self.acc_predictor)
model = models.resnet18(pretrained=True).eval()
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1]).to(self.device)
self.feature_extractor = torch.nn.DataParallel(feature_extractor)
""" save path and log path """
@property
def save_path(self):
if self.__dict__.get('_save_path', None) is None:
save_path = os.path.join(self.path, 'checkpoint')
os.makedirs(save_path, exist_ok=True)
self.__dict__['_save_path'] = save_path
return self.__dict__['_save_path']
@property
def logs_path(self):
if self.__dict__.get('_logs_path', None) is None:
logs_path = os.path.join(self.path, 'logs')
os.makedirs(logs_path, exist_ok=True)
self.__dict__['_logs_path'] = logs_path
return self.__dict__['_logs_path']
@property
def network(self):
return self.net.module if isinstance(self.net, nn.DataParallel) else self.net
def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
write_log(self.logs_path, log_str, prefix, should_print, mode)
""" save and load models """
def save_model(self, checkpoint=None, is_best=False, model_name=None):
if checkpoint is None:
checkpoint = {'state_dict': self.network.state_dict()}
if model_name is None:
model_name = 'checkpoint.pth.tar'
checkpoint['dataset'] = self.run_config.dataset # add `dataset` info to the checkpoint
latest_fname = os.path.join(self.save_path, 'latest.txt')
model_path = os.path.join(self.save_path, model_name)
with open(latest_fname, 'w') as fout:
fout.write(model_path + '\n')
torch.save(checkpoint, model_path)
if is_best:
best_path = os.path.join(self.save_path, 'model_best.pth.tar')
torch.save({'state_dict': checkpoint['state_dict']}, best_path)
def load_model(self, model_fname=None):
latest_fname = os.path.join(self.save_path, 'latest.txt')
if model_fname is None and os.path.exists(latest_fname):
with open(latest_fname, 'r') as fin:
model_fname = fin.readline()
if model_fname[-1] == '\n':
model_fname = model_fname[:-1]
# noinspection PyBroadException
try:
if model_fname is None or not os.path.exists(model_fname):
model_fname = '%s/checkpoint.pth.tar' % self.save_path
with open(latest_fname, 'w') as fout:
fout.write(model_fname + '\n')
print("=> loading checkpoint '{}'".format(model_fname))
checkpoint = torch.load(model_fname, map_location='cpu')
except Exception:
print('fail to load checkpoint from %s' % self.save_path)
return {}
self.network.load_state_dict(checkpoint['state_dict'])
if 'epoch' in checkpoint:
self.start_epoch = checkpoint['epoch'] + 1
if 'best_acc' in checkpoint:
self.best_acc = checkpoint['best_acc']
if 'optimizer' in checkpoint:
self.optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}'".format(model_fname))
return checkpoint
def save_config(self, extra_run_config=None, extra_net_config=None):
""" dump run_config and net_config to the model_folder """
run_save_path = os.path.join(self.path, 'run.config')
if not os.path.isfile(run_save_path):
run_config = self.run_config.config
if extra_run_config is not None:
run_config.update(extra_run_config)
json.dump(run_config, open(run_save_path, 'w'), indent=4)
print('Run configs dump to %s' % run_save_path)
try:
net_save_path = os.path.join(self.path, 'net.config')
net_config = self.network.config
if extra_net_config is not None:
net_config.update(extra_net_config)
json.dump(net_config, open(net_save_path, 'w'), indent=4)
print('Network configs dump to %s' % net_save_path)
except Exception:
print('%s do not support net config' % type(self.network))
""" metric related """
def get_metric_dict(self):
return {
'top1': AverageMeter(),
'top5': AverageMeter(),
}
def update_metric(self, metric_dict, output, labels):
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
metric_dict['top1'].update(acc1[0].item(), output.size(0))
metric_dict['top5'].update(acc5[0].item(), output.size(0))
def get_metric_vals(self, metric_dict, return_dict=False):
if return_dict:
return {
key: metric_dict[key].avg for key in metric_dict
}
else:
return [metric_dict[key].avg for key in metric_dict]
def get_metric_names(self):
return 'top1', 'top5'
""" train and test """
def validate(self, epoch=0, is_test=False, run_str='', net=None,
data_loader=None, no_logs=False, train_mode=False, net_setting=None):
if net is None:
net = self.net
if not isinstance(net, nn.DataParallel):
net = nn.DataParallel(net)
if data_loader is not None:
self.data_loader = data_loader
if train_mode:
net.train()
else:
net.eval()
losses = AverageMeter()
metric_dict = self.get_metric_dict()
features_stack = []
with torch.no_grad():
with tqdm(total=len(self.data_loader),
desc='Validate Epoch #{} {}'.format(epoch + 1, run_str), disable=no_logs) as t:
for i, (images, labels) in enumerate(self.data_loader):
images, labels = images.to(self.device), labels.to(self.device)
if self.mode == 'generator':
features = self.feature_extractor(images).squeeze()
features_stack.append(features)
# compute output
output = net(images)
loss = self.test_criterion(output, labels)
# measure accuracy and record loss
self.update_metric(metric_dict, output, labels)
losses.update(loss.item(), images.size(0))
t.set_postfix({
'loss': losses.avg,
**self.get_metric_vals(metric_dict, return_dict=True),
'img_size': images.size(2),
})
t.update(1)
if self.mode == 'generator':
features_stack = torch.cat(features_stack)
igraph_g = decode_ofa_mbv3_to_igraph(net_setting)[0]
D_mu = self.acc_predictor.module.set_encode(features_stack.unsqueeze(0).to('cuda'))
G_mu = self.acc_predictor.module.graph_encode(igraph_g)
pred_acc = self.acc_predictor.module.predict(D_mu.unsqueeze(0), G_mu).item()
return losses.avg, self.get_metric_vals(metric_dict), \
pred_acc if self.mode == 'generator' else None
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
if net is None:
net = self.network
if isinstance(self.run_config.data_provider.image_size, list):
img_size_list, loss_list, top1_list, top5_list = [], [], [], []
for img_size in self.run_config.data_provider.image_size:
img_size_list.append(img_size)
self.run_config.data_provider.assign_active_img_size(img_size)
self.reset_running_statistics(net=net)
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
loss_list.append(loss)
top1_list.append(top1)
top5_list.append(top5)
return img_size_list, loss_list, top1_list, top5_list
else:
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
from ofa_local.imagenet_classification.elastic_nn.utils import set_running_statistics
if net is None:
net = self.network
if data_loader is None:
data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
set_running_statistics(net, data_loader)

View File

@@ -0,0 +1,4 @@
######################################################################################
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################

View File

@@ -0,0 +1,401 @@
from __future__ import print_function
import os
import math
import warnings
import numpy as np
# from timm.data.transforms import _pil_interp
from timm.data.auto_augment import rand_augment_transform
import torch.utils.data
import torchvision.transforms as transforms
from torchvision.datasets.folder import default_loader
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
def make_dataset(dir, image_ids, targets):
assert(len(image_ids) == len(targets))
images = []
dir = os.path.expanduser(dir)
for i in range(len(image_ids)):
item = (os.path.join(dir, 'data', 'images',
'%s.jpg' % image_ids[i]), targets[i])
images.append(item)
return images
def find_classes(classes_file):
# read classes file, separating out image IDs and class names
image_ids = []
targets = []
f = open(classes_file, 'r')
for line in f:
split_line = line.split(' ')
image_ids.append(split_line[0])
targets.append(' '.join(split_line[1:]))
f.close()
# index class names
classes = np.unique(targets)
class_to_idx = {classes[i]: i for i in range(len(classes))}
targets = [class_to_idx[c] for c in targets]
return (image_ids, targets, classes, class_to_idx)
class FGVCAircraft(torch.utils.data.Dataset):
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset.
Args:
root (string): Root directory path to dataset.
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g. ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in the root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
class_types = ('variant', 'family', 'manufacturer')
splits = ('train', 'val', 'trainval', 'test')
def __init__(self, root, class_type='variant', split='train', transform=None,
target_transform=None, loader=default_loader, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
if class_type not in self.class_types:
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
class_type, ', '.join(self.class_types),
))
self.root = os.path.expanduser(root)
self.class_type = class_type
self.split = split
self.classes_file = os.path.join(self.root, 'data',
'images_%s_%s.txt' % (self.class_type, self.split))
if download:
self.download()
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
samples = make_dataset(self.root, image_ids, targets)
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.samples = samples
self.classes = classes
self.class_to_idx = class_to_idx
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def _check_exists(self):
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
os.path.exists(self.classes_file)
def download(self):
"""Download the FGVC-Aircraft data if it doesn't exist already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
print('Downloading %s ... (may take a few minutes)' % self.url)
parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
tar_name = self.url.rpartition('/')[-1]
tar_path = os.path.join(parent_dir, tar_name)
data = urllib.request.urlopen(self.url)
# download .tar.gz file
with open(tar_path, 'wb') as f:
f.write(data.read())
# extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
data_folder = tar_path.strip('.tar.gz')
print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
tar = tarfile.open(tar_path)
tar.extractall(parent_dir)
# if necessary, rename data folder to self.root
if not os.path.samefile(data_folder, self.root):
print('Renaming %s to %s ...' % (data_folder, self.root))
os.rename(data_folder, self.root)
# delete .tar.gz file
print('Deleting %s ...' % tar_path)
os.remove(tar_path)
print('Done!')
class FGVCAircraftDataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.samples) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'aircraft'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 100
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/Aircraft' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/Aircraft' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.train_path, _transforms)
dataset = FGVCAircraft(
root=self.train_path, split='trainval', download=True, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
dataset = FGVCAircraft(
root=self.valid_path, split='test', download=True, transform=_transforms)
return dataset
@property
def train_path(self):
return self.save_path
@property
def valid_path(self):
return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.48933587508932375, 0.5183537408957618, 0.5387914411673883],
std=[0.22388883112804625, 0.21641635409388751, 0.24615605842636115])
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
if image_size is None:
image_size = self.image_size
# if print_log:
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
# (self.distort_color, self.resize_scale, image_size))
# if self.distort_color == 'torch':
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
# elif self.distort_color == 'tf':
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
# else:
# color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
img_size_min = min(image_size)
else:
resize_transform_class = transforms.RandomResizedCrop
img_size_min = image_size
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in [0.48933587508932375, 0.5183537408957618,
0.5387914411673883]]),
)
aa_params['interpolation'] = transforms.Resize(image_size) # _pil_interp('bicubic')
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
# if color_transform is not None:
# train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]
if __name__ == '__main__':
data = FGVCAircraft(root='/mnt/datastore/Aircraft',
split='trainval', download=True)
print(len(data.classes))
print(len(data.samples))

View File

@@ -0,0 +1,238 @@
"""
Taken from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
"""
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
class ImageNetPolicy(object):
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
Example:
>>> policy = ImageNetPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment ImageNet Policy"
class CIFAR10Policy(object):
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
Example:
>>> policy = CIFAR10Policy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> CIFAR10Policy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment CIFAR10 Policy"
class SVHNPolicy(object):
""" Randomly choose one of the best 25 Sub-policies on SVHN.
Example:
>>> policy = SVHNPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> SVHNPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment SVHN Policy"
class SubPolicy(object):
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
ranges = {
"shearX": np.linspace(0, 0.3, 10),
"shearY": np.linspace(0, 0.3, 10),
"translateX": np.linspace(0, 150 / 331, 10),
"translateY": np.linspace(0, 150 / 331, 10),
"rotate": np.linspace(0, 30, 10),
"color": np.linspace(0.0, 0.9, 10),
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
"solarize": np.linspace(256, 0, 10),
"contrast": np.linspace(0.0, 0.9, 10),
"sharpness": np.linspace(0.0, 0.9, 10),
"brightness": np.linspace(0.0, 0.9, 10),
"autocontrast": [0] * 10,
"equalize": [0] * 10,
"invert": [0] * 10
}
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
func = {
"shearX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}
self.p1 = p1
self.operation1 = func[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
self.p2 = p2
self.operation2 = func[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
return img

View File

@@ -0,0 +1,657 @@
import os
import math
import numpy as np
import torchvision
import torch.utils.data
import torchvision.transforms as transforms
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class CIFAR10DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.data) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'cifar10'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 10
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/CIFAR' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/CIFAR' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.train_path, _transforms)
dataset = torchvision.datasets.CIFAR10(
root=self.valid_path, train=True, download=False, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
dataset = torchvision.datasets.CIFAR10(
root=self.valid_path, train=False, download=False, transform=_transforms)
return dataset
@property
def train_path(self):
# return os.path.join(self.save_path, 'train')
return self.save_path
@property
def valid_path(self):
# return os.path.join(self.save_path, 'val')
return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.data)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]
class CIFAR100DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.data) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'cifar100'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 100
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/CIFAR' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/CIFAR' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.train_path, _transforms)
dataset = torchvision.datasets.CIFAR100(
root=self.valid_path, train=True, download=False, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
dataset = torchvision.datasets.CIFAR100(
root=self.valid_path, train=False, download=False, transform=_transforms)
return dataset
@property
def train_path(self):
# return os.path.join(self.save_path, 'train')
return self.save_path
@property
def valid_path(self):
# return os.path.join(self.save_path, 'val')
return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.data)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]
class CINIC10DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.data) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'cinic10'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 10
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/CINIC10' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/CINIC10' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = torchvision.datasets.ImageFolder(self.train_path, transform=_transforms)
# dataset = torchvision.datasets.CIFAR10(
# root=self.valid_path, train=True, download=False, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
dataset = torchvision.datasets.ImageFolder(self.valid_path, transform=_transforms)
# dataset = torchvision.datasets.CIFAR10(
# root=self.valid_path, train=False, download=False, transform=_transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train_and_valid')
# return self.save_path
@property
def valid_path(self):
return os.path.join(self.save_path, 'test')
# return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.47889522, 0.47227842, 0.43047404], std=[0.24205776, 0.23828046, 0.25874835])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,237 @@
import os
import warnings
import numpy as np
from timm.data.transforms import _pil_interp
from timm.data.auto_augment import rand_augment_transform
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class DTDDataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.samples) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'dtd'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 47
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/dtd' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/dtd' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.train_path, _transforms)
return dataset
def test_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.valid_path, _transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self.save_path, 'valid')
@property
def normalize(self):
return transforms.Normalize(
mean=[0.5329876098715876, 0.474260843249454, 0.42627281899380676],
std=[0.26549755708788914, 0.25473554309855373, 0.2631728035662832])
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
if image_size is None:
image_size = self.image_size
# if print_log:
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
# (self.distort_color, self.resize_scale, image_size))
# if self.distort_color == 'torch':
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
# elif self.distort_color == 'tf':
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
# else:
# color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
img_size_min = min(image_size)
else:
resize_transform_class = transforms.RandomResizedCrop
img_size_min = image_size
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in [0.5329876098715876, 0.474260843249454,
0.42627281899380676]]),
)
aa_params['interpolation'] = _pil_interp('bicubic')
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
# if color_transform is not None:
# train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
# transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.Resize((image_size, image_size), interpolation=3),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,241 @@
import warnings
import os
import math
import numpy as np
import PIL
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class Flowers102DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=512, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
# warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
weights = self.make_weights_for_balanced_classes(
train_dataset.imgs, self.n_classes)
weights = torch.DoubleTensor(weights)
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
if valid_size is not None:
raise NotImplementedError("validation dataset not yet implemented")
# valid_dataset = self.valid_dataset(valid_transforms)
# self.train = train_loader_class(
# train_dataset, batch_size=train_batch_size, sampler=train_sampler,
# num_workers=n_worker, pin_memory=True)
# self.valid = torch.utils.data.DataLoader(
# valid_dataset, batch_size=test_batch_size,
# num_workers=n_worker, pin_memory=True)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'flowers102'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 102
@property
def save_path(self):
if self._save_path is None:
# self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
self._save_path = '/mnt/datastore/Flowers102' # home server
if not os.path.exists(self._save_path):
# self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
self._save_path = '/mnt/datastore/Flowers102' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.train_path, _transforms)
return dataset
# def valid_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
# return dataset
def test_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.test_path, _transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
# @property
# def valid_path(self):
# return os.path.join(self.save_path, 'train')
@property
def test_path(self):
return os.path.join(self.save_path, 'test')
@property
def normalize(self):
return transforms.Normalize(
mean=[0.5178361839861569, 0.4106749456881299, 0.32864167836880803],
std=[0.2972239085211309, 0.24976049135203868, 0.28533308036347665])
@staticmethod
def make_weights_for_balanced_classes(images, nclasses):
count = [0] * nclasses
# Counts per label
for item in images:
count[item[1]] += 1
weight_per_class = [0.] * nclasses
# Total number of images.
N = float(sum(count))
# super-sample the smaller classes.
for i in range(nclasses):
weight_per_class[i] = N / float(count[i])
weight = [0] * len(images)
# Calculate a weight per image.
for idx, val in enumerate(images):
weight[idx] = weight_per_class[val[1]]
return weight
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
transforms.RandomAffine(
45, translate=(0.4, 0.4), scale=(0.75, 1.5), shear=None, resample=PIL.Image.BILINEAR, fillcolor=0),
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
# transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,225 @@
import warnings
import os
import math
import numpy as np
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class ImagenetDataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.samples) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'imagenet'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 1000
@property
def save_path(self):
if self._save_path is None:
# self._save_path = '/dataset/imagenet'
# self._save_path = '/usr/local/soft/temp-datastore/ILSVRC2012' # servers
self._save_path = '/mnt/datastore/ILSVRC2012' # home server
if not os.path.exists(self._save_path):
# self._save_path = os.path.expanduser('~/dataset/imagenet')
# self._save_path = os.path.expanduser('/usr/local/soft/temp-datastore/ILSVRC2012')
self._save_path = '/mnt/datastore/ILSVRC2012' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.train_path, _transforms)
return dataset
def test_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.valid_path, _transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self.save_path, 'val')
@property
def normalize(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,237 @@
import os
import math
import warnings
import numpy as np
# from timm.data.transforms import _pil_interp
from timm.data.auto_augment import rand_augment_transform
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class OxfordIIITPetsDataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.samples) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'pets'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 37
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.train_path, _transforms)
return dataset
def test_dataset(self, _transforms):
dataset = datasets.ImageFolder(self.valid_path, _transforms)
return dataset
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self.save_path, 'valid')
@property
def normalize(self):
return transforms.Normalize(
mean=[0.4828895122298728, 0.4448394893850807, 0.39566558230789783],
std=[0.25925664613996574, 0.2532760018681693, 0.25981017205097917])
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
if image_size is None:
image_size = self.image_size
# if print_log:
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
# (self.distort_color, self.resize_scale, image_size))
# if self.distort_color == 'torch':
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
# elif self.distort_color == 'tf':
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
# else:
# color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
img_size_min = min(image_size)
else:
resize_transform_class = transforms.RandomResizedCrop
img_size_min = image_size
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in [0.4828895122298728, 0.4448394893850807,
0.39566558230789783]]),
)
aa_params['interpolation'] = transforms.Resize(image_size) # _pil_interp('bicubic')
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
# if color_transform is not None:
# train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.samples)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,69 @@
import torch
from glob import glob
from torch.utils.data.dataset import Dataset
import os
from PIL import Image
def load_image(filename):
img = Image.open(filename)
img = img.convert('RGB')
return img
class PetDataset(Dataset):
def __init__(self, root, train=True, num_cl=37, val_split=0.15, transforms=None):
pt_name = os.path.join(root, '{}{}.pth'.format('train' if train else 'test',
int(100 * (1 - val_split)) if train else int(
100 * val_split)))
if not os.path.exists(pt_name):
filenames = glob(os.path.join(root, 'images') + '/*.jpg')
classes = set()
data = []
labels = []
for image in filenames:
class_name = image.rsplit("/", 1)[1].rsplit('_', 1)[0]
classes.add(class_name)
img = load_image(image)
data.append(img)
labels.append(class_name)
# convert classnames to indices
class2idx = {cl: idx for idx, cl in enumerate(classes)}
labels = torch.Tensor(list(map(lambda x: class2idx[x], labels))).long()
data = list(zip(data, labels))
class_values = [[] for x in range(num_cl)]
# create arrays for each class type
for d in data:
class_values[d[1].item()].append(d)
train_data = []
val_data = []
for class_dp in class_values:
split_idx = int(len(class_dp) * (1 - val_split))
train_data += class_dp[:split_idx]
val_data += class_dp[split_idx:]
torch.save(train_data, os.path.join(root, 'train{}.pth'.format(int(100 * (1 - val_split)))))
torch.save(val_data, os.path.join(root, 'test{}.pth'.format(int(100 * val_split))))
self.data = torch.load(pt_name)
self.len = len(self.data)
self.transform = transforms
def __getitem__(self, index):
img, label = self.data[index]
if self.transform:
img = self.transform(img)
return img, label
def __len__(self):
return self.len

View File

@@ -0,0 +1,226 @@
import os
import math
import numpy as np
import torchvision
import torch.utils.data
import torchvision.transforms as transforms
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
class STL10DataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
assert isinstance(self.image_size, list)
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size)
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_transforms = self.build_train_transform()
train_dataset = self.train_dataset(train_transforms)
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset.data) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'stl10'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 10
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/mnt/datastore/STL10' # home server
if not os.path.exists(self._save_path):
self._save_path = '/mnt/datastore/STL10' # home server
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.train_path, _transforms)
dataset = torchvision.datasets.STL10(
root=self.valid_path, split='train', download=False, transform=_transforms)
return dataset
def test_dataset(self, _transforms):
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
dataset = torchvision.datasets.STL10(
root=self.valid_path, split='test', download=False, transform=_transforms)
return dataset
@property
def train_path(self):
# return os.path.join(self.save_path, 'train')
return self.save_path
@property
def valid_path(self):
# return os.path.join(self.save_path, 'val')
return self.save_path
@property
def normalize(self):
return transforms.Normalize(
mean=[0.44671097, 0.4398105, 0.4066468],
std=[0.2603405, 0.25657743, 0.27126738])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset.data)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,4 @@
from ofa.imagenet_codebase.networks.proxyless_nets import ProxylessNASNets, proxyless_base, MobileNetV2
from ofa.imagenet_codebase.networks.mobilenet_v3 import MobileNetV3, MobileNetV3Large
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks.nsganetv2 import NSGANetV2

View File

@@ -0,0 +1,126 @@
from timm.models.layers import drop_path
from ofa.imagenet_codebase.modules.layers import *
from ofa.imagenet_codebase.networks import MobileNetV3
class MobileInvertedResidualBlock(MyModule):
"""
Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
imagenet_codebase/networks/proxyless_nets.py to include drop path in training
"""
def __init__(self, mobile_inverted_conv, shortcut, drop_connect_rate=0.0):
super(MobileInvertedResidualBlock, self).__init__()
self.mobile_inverted_conv = mobile_inverted_conv
self.shortcut = shortcut
self.drop_connect_rate = drop_connect_rate
def forward(self, x):
if self.mobile_inverted_conv is None or isinstance(self.mobile_inverted_conv, ZeroLayer):
res = x
elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
res = self.mobile_inverted_conv(x)
else:
# res = self.mobile_inverted_conv(x) + self.shortcut(x)
res = self.mobile_inverted_conv(x)
if self.drop_connect_rate > 0.:
res = drop_path(res, drop_prob=self.drop_connect_rate, training=self.training)
res += self.shortcut(x)
return res
@property
def module_str(self):
return '(%s, %s)' % (
self.mobile_inverted_conv.module_str if self.mobile_inverted_conv is not None else None,
self.shortcut.module_str if self.shortcut is not None else None
)
@property
def config(self):
return {
'name': MobileInvertedResidualBlock.__name__,
'mobile_inverted_conv': self.mobile_inverted_conv.config if self.mobile_inverted_conv is not None else None,
'shortcut': self.shortcut.config if self.shortcut is not None else None,
}
@staticmethod
def build_from_config(config):
mobile_inverted_conv = set_layer_from_config(config['mobile_inverted_conv'])
shortcut = set_layer_from_config(config['shortcut'])
return MobileInvertedResidualBlock(
mobile_inverted_conv, shortcut, drop_connect_rate=config['drop_connect_rate'])
class NSGANetV2(MobileNetV3):
"""
Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
imagenet_codebase/networks/mobilenet_v3.py to include drop path in training
and option to reset classification layer
"""
@staticmethod
def build_from_config(config, drop_connect_rate=0.0):
first_conv = set_layer_from_config(config['first_conv'])
final_expand_layer = set_layer_from_config(config['final_expand_layer'])
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
classifier = set_layer_from_config(config['classifier'])
blocks = []
for block_idx, block_config in enumerate(config['blocks']):
block_config['drop_connect_rate'] = drop_connect_rate * block_idx / len(config['blocks'])
blocks.append(MobileInvertedResidualBlock.build_from_config(block_config))
net = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
if 'bn' in config:
net.set_bn_param(**config['bn'])
else:
net.set_bn_param(momentum=0.1, eps=1e-3)
return net
def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, MobileInvertedResidualBlock):
if isinstance(m.mobile_inverted_conv, MBInvertedConvLayer) and isinstance(m.shortcut, IdentityLayer):
m.mobile_inverted_conv.point_linear.bn.weight.data.zero_()
@staticmethod
def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
# first conv layer
first_conv = ConvLayer(
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='h_swish', ops_order='weight_bn_act'
)
# build mobile blocks
feature_dim = input_channel
blocks = []
for stage_id, block_config_list in cfg.items():
for k, mid_channel, out_channel, use_se, act_func, stride, expand_ratio in block_config_list:
mb_conv = MBInvertedConvLayer(
feature_dim, out_channel, k, stride, expand_ratio, mid_channel, act_func, use_se
)
if stride == 1 and out_channel == feature_dim:
shortcut = IdentityLayer(out_channel, out_channel)
else:
shortcut = None
blocks.append(MobileInvertedResidualBlock(mb_conv, shortcut))
feature_dim = out_channel
# final expand layer
final_expand_layer = ConvLayer(
feature_dim, feature_dim * 6, kernel_size=1, use_bn=True, act_func='h_swish', ops_order='weight_bn_act',
)
feature_dim = feature_dim * 6
# feature mix layer
feature_mix_layer = ConvLayer(
feature_dim, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
)
# classifier
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
@staticmethod
def reset_classifier(model, last_channel, n_classes, dropout_rate=0.0):
model.classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)

View File

@@ -0,0 +1,309 @@
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.imagenet import *
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.cifar import *
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.pets import *
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.aircraft import *
from ofa.imagenet_codebase.run_manager.run_manager import *
class ImagenetRunConfig(RunConfig):
def __init__(self, n_epochs=1, init_lr=1e-4, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='imagenet', train_batch_size=128, test_batch_size=512, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
data_path='/mnt/datastore/ILSVRC2012',
**kwargs):
super(ImagenetRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.imagenet_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == ImagenetDataProvider.name():
DataProviderClass = ImagenetDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.imagenet_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class CIFARRunConfig(RunConfig):
def __init__(self, n_epochs=5, init_lr=0.01, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='cifar10', train_batch_size=96, test_batch_size=256, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224,
data_path='/mnt/datastore/CIFAR',
**kwargs):
super(CIFARRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.cifar_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == CIFAR10DataProvider.name():
DataProviderClass = CIFAR10DataProvider
elif self.dataset == CIFAR100DataProvider.name():
DataProviderClass = CIFAR100DataProvider
elif self.dataset == CINIC10DataProvider.name():
DataProviderClass = CINIC10DataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.cifar_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class Flowers102RunConfig(RunConfig):
def __init__(self, n_epochs=3, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='flowers102', train_batch_size=32, test_batch_size=250, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
data_path='/mnt/datastore/Flowers102',
**kwargs):
super(Flowers102RunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.flowers102_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == Flowers102DataProvider.name():
DataProviderClass = Flowers102DataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.flowers102_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class STL10RunConfig(RunConfig):
def __init__(self, n_epochs=5, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='stl10', train_batch_size=96, test_batch_size=256, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
data_path='/mnt/datastore/STL10',
**kwargs):
super(STL10RunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.stl10_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == STL10DataProvider.name():
DataProviderClass = STL10DataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.stl10_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class DTDRunConfig(RunConfig):
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='dtd', train_batch_size=32, test_batch_size=250, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
data_path='/mnt/datastore/dtd',
**kwargs):
super(DTDRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == DTDDataProvider.name():
DataProviderClass = DTDDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class PetsRunConfig(RunConfig):
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='pets', train_batch_size=32, test_batch_size=250, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
data_path='/mnt/datastore/Oxford-IIITPets',
**kwargs):
super(PetsRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.imagenet_data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == OxfordIIITPetsDataProvider.name():
DataProviderClass = OxfordIIITPetsDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.imagenet_data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class AircraftRunConfig(RunConfig):
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='aircraft', train_batch_size=32, test_batch_size=250, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
mixup_alpha=None,
model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
data_path='/mnt/datastore/Aircraft',
**kwargs):
super(AircraftRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
self.data_path = data_path
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == FGVCAircraftDataProvider.name():
DataProviderClass = FGVCAircraftDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
save_path=self.data_path,
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
def get_run_config(**kwargs):
if kwargs['dataset'] == 'imagenet':
run_config = ImagenetRunConfig(**kwargs)
elif kwargs['dataset'].startswith('cifar') or kwargs['dataset'].startswith('cinic'):
run_config = CIFARRunConfig(**kwargs)
elif kwargs['dataset'] == 'flowers102':
run_config = Flowers102RunConfig(**kwargs)
elif kwargs['dataset'] == 'stl10':
run_config = STL10RunConfig(**kwargs)
elif kwargs['dataset'] == 'dtd':
run_config = DTDRunConfig(**kwargs)
elif kwargs['dataset'] == 'pets':
run_config = PetsRunConfig(**kwargs)
elif kwargs['dataset'] == 'aircraft':
run_config = AircraftRunConfig(**kwargs)
elif kwargs['dataset'] == 'aircraft100':
run_config = AircraftRunConfig(**kwargs)
else:
raise NotImplementedError
return run_config

View File

@@ -0,0 +1,122 @@
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
import torchvision.utils
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.aircraft import FGVCAircraft
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.pets2 import PetDataset
import torch.utils.data as Data
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.autoaugment import CIFAR10Policy
def get_dataset(data_name, batch_size, data_path, num_workers,
img_size, autoaugment, cutout, cutout_length):
num_class_dict = {
'cifar100': 100,
'cifar10': 10,
'mnist': 10,
'aircraft': 100,
'svhn': 10,
'pets': 37
}
# 'aircraft30': 30,
# 'aircraft100': 100,
train_transform, valid_transform = _data_transforms(
data_name, img_size, autoaugment, cutout, cutout_length)
if data_name == 'cifar100':
train_data = torchvision.datasets.CIFAR100(
root=data_path, train=True, download=True, transform=train_transform)
valid_data = torchvision.datasets.CIFAR100(
root=data_path, train=False, download=True, transform=valid_transform)
elif data_name == 'cifar10':
train_data = torchvision.datasets.CIFAR10(
root=data_path, train=True, download=True, transform=train_transform)
valid_data = torchvision.datasets.CIFAR10(
root=data_path, train=False, download=True, transform=valid_transform)
elif data_name.startswith('aircraft'):
print(data_path)
if 'aircraft100' in data_path:
data_path = data_path.replace('aircraft100', 'aircraft/fgvc-aircraft-2013b')
else:
data_path = data_path.replace('aircraft', 'aircraft/fgvc-aircraft-2013b')
train_data = FGVCAircraft(data_path, class_type='variant', split='trainval',
transform=train_transform, download=True)
valid_data = FGVCAircraft(data_path, class_type='variant', split='test',
transform=valid_transform, download=True)
elif data_name.startswith('pets'):
train_data = PetDataset(data_path, train=True, num_cl=37,
val_split=0.15, transforms=train_transform)
valid_data = PetDataset(data_path, train=False, num_cl=37,
val_split=0.15, transforms=valid_transform)
else:
raise KeyError
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=batch_size, shuffle=True, pin_memory=True,
num_workers=num_workers)
valid_queue = torch.utils.data.DataLoader(
valid_data, batch_size=200, shuffle=False, pin_memory=True,
num_workers=num_workers)
return train_queue, valid_queue, num_class_dict[data_name]
class Cutout(object):
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
def _data_transforms(data_name, img_size, autoaugment, cutout, cutout_length):
if 'cifar' in data_name:
norm_mean = [0.49139968, 0.48215827, 0.44653124]
norm_std = [0.24703233, 0.24348505, 0.26158768]
elif 'aircraft' in data_name:
norm_mean = [0.48933587508932375, 0.5183537408957618, 0.5387914411673883]
norm_std = [0.22388883112804625, 0.21641635409388751, 0.24615605842636115]
elif 'pets' in data_name:
norm_mean = [0.4828895122298728, 0.4448394893850807, 0.39566558230789783]
norm_std = [0.25925664613996574, 0.2532760018681693, 0.25981017205097917]
else:
raise KeyError
train_transform = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), # BICUBIC interpolation
transforms.RandomHorizontalFlip(),
])
if autoaugment:
train_transform.transforms.append(CIFAR10Policy())
train_transform.transforms.append(transforms.ToTensor())
if cutout:
train_transform.transforms.append(Cutout(cutout_length))
train_transform.transforms.append(transforms.Normalize(norm_mean, norm_std))
valid_transform = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), # BICUBIC interpolation
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
return train_transform, valid_transform

View File

@@ -0,0 +1,233 @@
import os
import torch
import numpy as np
import random
import sys
import transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.eval_utils
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks import NSGANetV2
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.run_manager import get_run_config
from ofa.elastic_nn.networks import OFAMobileNetV3
from ofa.imagenet_codebase.run_manager import RunManager
from ofa.elastic_nn.modules.dynamic_op import DynamicSeparableConv2d
from torchprofile import profile_macs
import copy
import json
import warnings
warnings.simplefilter("ignore")
DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = 1
class ArchManager:
def __init__(self):
self.num_blocks = 20
self.num_stages = 5
self.kernel_sizes = [3, 5, 7]
self.expand_ratios = [3, 4, 6]
self.depths = [2, 3, 4]
self.resolutions = [160, 176, 192, 208, 224]
def random_sample(self):
sample = {}
d = []
e = []
ks = []
for i in range(self.num_stages):
d.append(random.choice(self.depths))
for i in range(self.num_blocks):
e.append(random.choice(self.expand_ratios))
ks.append(random.choice(self.kernel_sizes))
sample = {
'wid': None,
'ks': ks,
'e': e,
'd': d,
'r': [random.choice(self.resolutions)]
}
return sample
def random_resample(self, sample, i):
assert i >= 0 and i < self.num_blocks
sample['ks'][i] = random.choice(self.kernel_sizes)
sample['e'][i] = random.choice(self.expand_ratios)
def random_resample_depth(self, sample, i):
assert i >= 0 and i < self.num_stages
sample['d'][i] = random.choice(self.depths)
def random_resample_resolution(self, sample):
sample['r'][0] = random.choice(self.resolutions)
def parse_string_list(string):
if isinstance(string, str):
# convert '[5 5 5 7 7 7 3 3 7 7 7 3 3]' to [5, 5, 5, 7, 7, 7, 3, 3, 7, 7, 7, 3, 3]
return list(map(int, string[1:-1].split()))
else:
return string
def pad_none(x, depth, max_depth):
new_x, counter = [], 0
for d in depth:
for _ in range(d):
new_x.append(x[counter])
counter += 1
if d < max_depth:
new_x += [None] * (max_depth - d)
return new_x
def get_net_info(net, data_shape, measure_latency=None, print_info=True, clean=False, lut=None):
net_info = eval_utils.get_net_info(
net, data_shape, measure_latency, print_info=print_info, clean=clean, lut=lut)
gpu_latency, cpu_latency = None, None
for k in net_info.keys():
if 'gpu' in k:
gpu_latency = np.round(net_info[k]['val'], 2)
if 'cpu' in k:
cpu_latency = np.round(net_info[k]['val'], 2)
return {
'params': np.round(net_info['params'] / 1e6, 2),
'flops': np.round(net_info['flops'] / 1e6, 2),
'gpu': gpu_latency, 'cpu': cpu_latency
}
def validate_config(config, max_depth=4):
kernel_size, exp_ratio, depth = config['ks'], config['e'], config['d']
if isinstance(kernel_size, str): kernel_size = parse_string_list(kernel_size)
if isinstance(exp_ratio, str): exp_ratio = parse_string_list(exp_ratio)
if isinstance(depth, str): depth = parse_string_list(depth)
assert (isinstance(kernel_size, list) or isinstance(kernel_size, int))
assert (isinstance(exp_ratio, list) or isinstance(exp_ratio, int))
assert isinstance(depth, list)
if len(kernel_size) < len(depth) * max_depth:
kernel_size = pad_none(kernel_size, depth, max_depth)
if len(exp_ratio) < len(depth) * max_depth:
exp_ratio = pad_none(exp_ratio, depth, max_depth)
# return {'ks': kernel_size, 'e': exp_ratio, 'd': depth, 'w': config['w']}
return {'ks': kernel_size, 'e': exp_ratio, 'd': depth}
def set_nas_test_dataset(path, test_data_name, max_img):
if not test_data_name in ['mnist', 'svhn', 'cifar10',
'cifar100', 'aircraft', 'pets']: raise ValueError(test_data_name)
dpath = path
num_cls = 10 # mnist, svhn, cifar10
if test_data_name in ['cifar100', 'aircraft']:
num_cls = 100
elif test_data_name == 'pets':
num_cls = 37
x = torch.load(dpath + f'/{test_data_name}bylabel')
img_per_cls = min(int(max_img / num_cls), 20)
return x, img_per_cls, num_cls
class OFAEvaluator:
""" based on OnceForAll supernet taken from https://github.com/mit-han-lab/once-for-all """
def __init__(self, num_gen_arch, img_size, drop_path,
n_classes=1000,
model_path=None,
kernel_size=None, exp_ratio=None, depth=None):
# default configurations
self.kernel_size = [3, 5, 7] if kernel_size is None else kernel_size # depth-wise conv kernel size
self.exp_ratio = [3, 4, 6] if exp_ratio is None else exp_ratio # expansion rate
self.depth = [2, 3, 4] if depth is None else depth # number of MB block repetition
if 'w1.0' in model_path:
self.width_mult = 1.0
elif 'w1.2' in model_path:
self.width_mult = 1.2
else:
raise ValueError
self.engine = OFAMobileNetV3(
n_classes=n_classes,
dropout_rate=0, width_mult_list=self.width_mult, ks_list=self.kernel_size,
expand_ratio_list=self.exp_ratio, depth_list=self.depth)
init = torch.load(model_path, map_location='cpu')['state_dict']
self.engine.load_weights_from_net(init)
print(f'load {model_path}...')
## metad2a
self.arch_manager = ArchManager()
self.num_gen_arch = num_gen_arch
def sample_random_architecture(self):
sampled_architecture = self.arch_manager.random_sample()
return sampled_architecture
def get_architecture(self, bound=None):
g_lst, pred_acc_lst, x_lst = [], [], []
searched_g, max_pred_acc = None, 0
with torch.no_grad():
for n in range(self.num_gen_arch):
file_acc = self.lines[n].split()[0]
g_dict = ' '.join(self.lines[n].split())
g = json.loads(g_dict.replace("'", "\""))
if bound is not None:
subnet, config = self.sample(config=g)
net = NSGANetV2.build_from_config(subnet.config,
drop_connect_rate=self.drop_path)
inputs = torch.randn(1, 3, self.img_size, self.img_size)
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
if flops <= bound:
searched_g = g
break
else:
searched_g = g
pred_acc_lst.append(file_acc)
break
if searched_g is None:
raise ValueError(searched_g)
return searched_g, pred_acc_lst
def sample(self, config=None):
""" randomly sample a sub-network """
if config is not None:
config = validate_config(config)
self.engine.set_active_subnet(ks=config['ks'], e=config['e'], d=config['d'])
else:
config = self.engine.sample_active_subnet()
subnet = self.engine.get_active_subnet(preserve_weight=True)
return subnet, config
@staticmethod
def save_net_config(path, net, config_name='net.config'):
""" dump run_config and net_config to the model_folder """
net_save_path = os.path.join(path, config_name)
json.dump(net.config, open(net_save_path, 'w'), indent=4)
print('Network configs dump to %s' % net_save_path)
@staticmethod
def save_net(path, net, model_name):
""" dump net weight as checkpoint """
if isinstance(net, torch.nn.DataParallel):
checkpoint = {'state_dict': net.module.state_dict()}
else:
checkpoint = {'state_dict': net.state_dict()}
model_path = os.path.join(path, model_name)
torch.save(checkpoint, model_path)
print('Network model dump to %s' % model_path)

View File

@@ -0,0 +1,169 @@
import os
import sys
import json
import logging
import numpy as np
import copy
import torch
import torch.nn as nn
import random
import torch.optim as optim
from evaluator import OFAEvaluator
from torchprofile import profile_macs
from codebase.networks import NSGANetV2
from parser import get_parse
from eval_utils import get_dataset
args = get_parse()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device_list = [int(_) for _ in args.gpu.split(',')]
args.n_gpus = len(device_list)
args.device = torch.device("cuda:0")
if args.seed is None or args.seed < 0: args.seed = random.randint(1, 100000)
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
evaluator = OFAEvaluator(args,
model_path='../.torch/ofa_nets/ofa_mbv3_d234_e346_k357_w1.0')
args.save_path = os.path.join(args.save_path, f'evaluation/{args.data_name}')
if args.model_config.startswith('flops@'):
args.save_path += f'-nsganetV2-{args.model_config}-{args.seed}'
else:
args.save_path += f'-metaD2A-{args.bound}-{args.seed}'
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
args.data_path = os.path.join(args.data_path, args.data_name)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save_path, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
if not torch.cuda.is_available():
logging.info('no gpu self.args.device available')
sys.exit(1)
logging.info("args = %s", args)
def set_architecture(n_cls):
if args.model_config.startswith('flops@'):
names = {'cifar10': 'CIFAR-10', 'cifar100': 'CIFAR-100',
'aircraft100': 'Aircraft', 'pets': 'Pets'}
p = os.path.join('./searched-architectures/{}/net-{}/net.subnet'.
format(names[args.data_name], args.model_config))
g = json.load(open(p))
else:
g, acc = evaluator.get_architecture(args)
subnet, config = evaluator.sample(g)
net = NSGANetV2.build_from_config(subnet.config, drop_connect_rate=args.drop_path)
net.load_state_dict(subnet.state_dict())
NSGANetV2.reset_classifier(
net, last_channel=net.classifier.in_features,
n_classes=n_cls, dropout_rate=args.drop)
# calculate #Paramaters and #FLOPS
inputs = torch.randn(1, 3, args.img_size, args.img_size)
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
net_name = "net_flops@{:.0f}".format(flops)
logging.info('#params {:.2f}M, #flops {:.0f}M'.format(params, flops))
OFAEvaluator.save_net_config(args.save_path, net, net_name + '.config')
if args.n_gpus > 1:
net = nn.DataParallel(net) # data parallel in case more than 1 gpu available
net = net.to(args.device)
return net, net_name
def train(train_queue, net, criterion, optimizer):
net.train()
train_loss, correct, total = 0, 0, 0
for step, (inputs, targets) in enumerate(train_queue):
# upsample by bicubic to match imagenet training size
inputs, targets = inputs.to(args.device), targets.to(args.device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), args.grad_clip)
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if step % args.report_freq == 0:
logging.info('train %03d %e %f', step, train_loss / total, 100. * correct / total)
logging.info('train acc %f', 100. * correct / total)
return train_loss / total, 100. * correct / total
def infer(valid_queue, net, criterion, early_stop=False):
net.eval()
test_loss, correct, total = 0, 0, 0
with torch.no_grad():
for step, (inputs, targets) in enumerate(valid_queue):
inputs, targets = inputs.to(args.device), targets.to(args.device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if step % args.report_freq == 0:
logging.info('valid %03d %e %f', step, test_loss / total, 100. * correct / total)
if early_stop and step == 10:
break
acc = 100. * correct / total
logging.info('valid acc %f', 100. * correct / total)
return test_loss / total, acc
def main():
best_acc, top_checkpoints = 0, []
train_queue, valid_queue, n_cls = get_dataset(args)
net, net_name = set_architecture(n_cls)
parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
criterion = nn.CrossEntropyLoss().to(args.device)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
for epoch in range(args.epochs):
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
train(train_queue, net, criterion, optimizer)
_, valid_acc = infer(valid_queue, net, criterion)
# checkpoint saving
if len(top_checkpoints) < args.topk:
OFAEvaluator.save_net(args.save_path, net, net_name + '.ckpt{}'.format(epoch))
top_checkpoints.append((os.path.join(args.save_path, net_name + '.ckpt{}'.format(epoch)), valid_acc))
else:
idx = np.argmin([x[1] for x in top_checkpoints])
if valid_acc > top_checkpoints[idx][1]:
OFAEvaluator.save_net(args.save_path, net, net_name + '.ckpt{}'.format(epoch))
top_checkpoints.append((os.path.join(args.save_path, net_name + '.ckpt{}'.format(epoch)), valid_acc))
# remove the idx
os.remove(top_checkpoints[idx][0])
top_checkpoints.pop(idx)
print(top_checkpoints)
if valid_acc > best_acc:
OFAEvaluator.save_net(args.save_path, net, net_name + '.best')
best_acc = valid_acc
scheduler.step()
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,43 @@
import argparse
def get_parse():
parser = argparse.ArgumentParser(description='MetaD2A vs NSGANETv2')
parser.add_argument('--save-path', type=str, default='../results', help='the path of save directory')
parser.add_argument('--data-path', type=str, default='../data', help='the path of save directory')
parser.add_argument('--data-name', type=str, default=None, help='meta-test dataset name')
parser.add_argument('--num-gen-arch', type=int, default=200,
help='the number of candidate architectures generated by the generator')
parser.add_argument('--bound', type=int, default=None)
# original setting
parser.add_argument('--seed', type=int, default=-1, help='random seed')
parser.add_argument('--batch-size', type=int, default=96, help='batch size')
parser.add_argument('--num_workers', type=int, default=2, help='number of workers for data loading')
parser.add_argument('--gpu', type=str, default='0', help='set visible gpus')
parser.add_argument('--lr', type=float, default=0.01, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=4e-5, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--epochs', type=int, default=150, help='num of training epochs')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--autoaugment', action='store_true', default=True, help='use auto augmentation')
parser.add_argument('--topk', type=int, default=10, help='top k checkpoints to save')
parser.add_argument('--evaluate', action='store_true', default=False, help='evaluate a pretrained model')
# model related
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--model-config', type=str, default='search',
help='location of a json file of specific model declaration')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--drop', type=float, default=0.2,
help='dropout rate')
parser.add_argument('--drop-path', type=float, default=0.2, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--img-size', type=int, default=224,
help='input resolution (192 -> 256)')
args = parser.parse_args()
return args

View File

@@ -0,0 +1,261 @@
import os
import sys
import json
import logging
import numpy as np
import copy
import torch
import torch.nn as nn
import random
import torch.optim as optim
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.evaluator import OFAEvaluator
from torchprofile import profile_macs
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks import NSGANetV2
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.parser import get_parse
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.eval_utils import get_dataset
from transfer_nag_lib.MetaD2A_nas_bench_201.metad2a_utils import reset_seed
from transfer_nag_lib.ofa_net import OFASubNet
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
# device_list = [int(_) for _ in args.gpu.split(',')]
# args.n_gpus = len(device_list)
# args.device = torch.device("cuda:0")
# if args.seed is None or args.seed < 0: args.seed = random.randint(1, 100000)
# torch.cuda.manual_seed(args.seed)
# torch.manual_seed(args.seed)
# np.random.seed(args.seed)
# random.seed(args.seed)
# args.save_path = os.path.join(args.save_path, f'evaluation/{args.data_name}')
# if args.model_config.startswith('flops@'):
# args.save_path += f'-nsganetV2-{args.model_config}-{args.seed}'
# else:
# args.save_path += f'-metaD2A-{args.bound}-{args.seed}'
# if not os.path.exists(args.save_path):
# os.makedirs(args.save_path)
# args.data_path = os.path.join(args.data_path, args.data_name)
# log_format = '%(asctime)s %(message)s'
# logging.basicConfig(stream=sys.stdout, level=print,
# format=log_format, datefmt='%m/%d %I:%M:%S %p')
# fh = logging.FileHandler(os.path.join(args.save_path, 'log.txt'))
# fh.setFormatter(logging.Formatter(log_format))
# logging.getLogger().addHandler(fh)
# if not torch.cuda.is_available():
# print('no gpu self.args.device available')
# sys.exit(1)
# print("args = %s", args)
def set_architecture(n_cls, evaluator, drop_path, drop, img_size, n_gpus, device, save_path, model_str):
# g, acc = evaluator.get_architecture(model_str)
g = OFASubNet(model_str).get_op_dict()
subnet, config = evaluator.sample(g)
net = NSGANetV2.build_from_config(subnet.config, drop_connect_rate=drop_path)
net.load_state_dict(subnet.state_dict())
NSGANetV2.reset_classifier(
net, last_channel=net.classifier.in_features,
n_classes=n_cls, dropout_rate=drop)
# calculate #Paramaters and #FLOPS
inputs = torch.randn(1, 3, img_size, img_size)
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
net_name = "net_flops@{:.0f}".format(flops)
print('#params {:.2f}M, #flops {:.0f}M'.format(params, flops))
# OFAEvaluator.save_net_config(save_path, net, net_name + '.config')
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
net = nn.DataParallel(net)
net = net.to(device)
return net, net_name, params, flops
def train(train_queue, net, criterion, optimizer, grad_clip, device, report_freq):
net.train()
train_loss, correct, total = 0, 0, 0
for step, (inputs, targets) in enumerate(train_queue):
# upsample by bicubic to match imagenet training size
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if step % report_freq == 0:
print(f'train step {step:03d} loss {train_loss / total:.4f} train acc {100. * correct / total:.4f}')
print(f'train acc {100. * correct / total:.4f}')
return train_loss / total, 100. * correct / total
def infer(valid_queue, net, criterion, device, report_freq, early_stop=False):
net.eval()
test_loss, correct, total = 0, 0, 0
with torch.no_grad():
for step, (inputs, targets) in enumerate(valid_queue):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if step % report_freq == 0:
print(f'valid {step:03d} {test_loss / total:.4f} {100. * correct / total:.4f}')
if early_stop and step == 10:
break
acc = 100. * correct / total
print('valid acc {:.4f}'.format(100. * correct / total))
return test_loss / total, acc
def train_single_model(save_path, workers, datasets, xpaths, splits, use_less,
seed, model_str, device,
lr=0.01,
momentum=0.9,
weight_decay=4e-5,
report_freq=50,
epochs=150,
grad_clip=5,
cutout=True,
cutout_length=16,
autoaugment=True,
drop=0.2,
drop_path=0.2,
img_size=224,
batch_size=96,
):
assert torch.cuda.is_available(), 'CUDA is not available.'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
reset_seed(seed)
# save_dir = Path(save_dir)
# logger = Logger(str(save_dir), 0, False)
os.makedirs(save_path, exist_ok=True)
to_save_name = save_path + '/seed-{:04d}.pth'.format(seed)
print(to_save_name)
# args = get_parse()
num_gen_arch = None
evaluator = OFAEvaluator(num_gen_arch, img_size, drop_path,
model_path='/home/data/GTAD/checkpoints/ofa/ofa_net/ofa_mbv3_d234_e346_k357_w1.0')
train_queue, valid_queue, n_cls = get_dataset(datasets, batch_size,
xpaths, workers, img_size, autoaugment, cutout, cutout_length)
net, net_name, params, flops = set_architecture(n_cls, evaluator,
drop_path, drop, img_size, n_gpus=1, device=device, save_path=save_path, model_str=model_str)
# net.to(device)
parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss().to(device)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
# assert epochs == 1
max_valid_acc = 0
max_epoch = 0
for epoch in range(epochs):
print('epoch {:d} lr {:.4f}'.format(epoch, scheduler.get_lr()[0]))
train(train_queue, net, criterion, optimizer, grad_clip, device, report_freq)
_, valid_acc = infer(valid_queue, net, criterion, device, report_freq)
torch.save(valid_acc, to_save_name)
print(f'seed {seed:04d} last acc {valid_acc:.4f} max acc {max_valid_acc:.4f}')
if max_valid_acc < valid_acc:
max_valid_acc = valid_acc
max_epoch = epoch
# parent_path = os.path.abspath(os.path.join(save_path, os.pardir))
# with open(parent_path + '/accuracy.txt', 'a+') as f:
# f.write(f'{model_str} seed {seed:04d} {valid_acc:.4f}\n')
return valid_acc, max_valid_acc, params, flops
################ NAS BENCH 201 #####################
# def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less,
# seeds, model_str, arch_config):
# assert torch.cuda.is_available(), 'CUDA is not available.'
# torch.backends.cudnn.enabled = True
# torch.backends.cudnn.deterministic = True
# torch.set_num_threads(workers)
# save_dir = Path(save_dir)
# logger = Logger(str(save_dir), 0, False)
# if model_str in CellArchitectures:
# arch = CellArchitectures[model_str]
# logger.log(
# 'The model string is found in pre-defined architecture dict : {:}'.format(model_str))
# else:
# try:
# arch = CellStructure.str2structure(model_str)
# except:
# raise ValueError(
# 'Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
# assert arch.check_valid_op(get_search_spaces(
# 'cell', 'nas-bench-201')), '{:} has the invalid op.'.format(arch)
# # assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
# logger.log('Start train-evaluate {:}'.format(arch.tostr()))
# logger.log('arch_config : {:}'.format(arch_config))
# start_time, seed_time = time.time(), AverageMeter()
# for _is, seed in enumerate(seeds):
# logger.log(
# '\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds),
# seed))
# to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
# if to_save_name.exists():
# logger.log(
# 'Find the existing file {:}, directly load!'.format(to_save_name))
# checkpoint = torch.load(to_save_name)
# else:
# logger.log(
# 'Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
# checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less,
# seed, arch_config, workers, logger)
# torch.save(checkpoint, to_save_name)
# # log information
# logger.log('{:}'.format(checkpoint['info']))
# all_dataset_keys = checkpoint['all_dataset_keys']
# for dataset_key in all_dataset_keys:
# logger.log('\n{:} dataset : {:} {:}'.format(
# '-' * 15, dataset_key, '-' * 15))
# dataset_info = checkpoint[dataset_key]
# # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
# logger.log('Flops = {:} MB, Params = {:} MB'.format(
# dataset_info['flop'], dataset_info['param']))
# logger.log('config : {:}'.format(dataset_info['config']))
# logger.log('Training State (finish) = {:}'.format(
# dataset_info['finish-train']))
# last_epoch = dataset_info['total_epoch'] - 1
# train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
# valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
# # measure elapsed time
# seed_time.update(time.time() - start_time)
# start_time = time.time()
# need_time = 'Time Left: {:}'.format(convert_secs2time(
# seed_time.avg * (len(seeds) - _is - 1), True))
# logger.log(
# '\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed,
# need_time))
# logger.close()
# ###################
if __name__ == '__main__':
train_single_model()

View File

@@ -0,0 +1,5 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from .generator import Generator

View File

@@ -0,0 +1,204 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import os
import random
from tqdm import tqdm
import numpy as np
import time
import torch
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import load_graph_config, decode_ofa_mbv3_to_igraph, decode_igraph_to_ofa_mbv3
from utils import Accumulator, Log
from utils import load_model, save_model
from loader import get_meta_train_loader, get_meta_test_loader
from .generator_model import GeneratorModel
class Generator:
def __init__(self, args):
self.args = args
self.batch_size = args.batch_size
self.data_path = args.data_path
self.num_sample = args.num_sample
self.max_epoch = args.max_epoch
self.save_epoch = args.save_epoch
self.model_path = args.model_path
self.save_path = args.save_path
self.model_name = args.model_name
self.test = args.test
self.device = args.device
graph_config = load_graph_config(
args.graph_data_name, args.nvt, args.data_path)
self.model = GeneratorModel(args, graph_config)
self.model.to(self.device)
if self.test:
self.data_name = args.data_name
self.num_class = args.num_class
self.load_epoch = args.load_epoch
self.num_gen_arch = args.num_gen_arch
load_model(self.model, self.model_path, self.load_epoch)
else:
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
self.scheduler = ReduceLROnPlateau(self.optimizer, 'min',
factor=0.1, patience=10, verbose=True)
self.mtrloader = get_meta_train_loader(
self.batch_size, self.data_path, self.num_sample)
self.mtrlog = Log(self.args, open(os.path.join(
self.save_path, self.model_name, 'meta_train_generator.log'), 'w'))
self.mtrlog.print_args()
self.mtrlogger = Accumulator('loss', 'recon_loss', 'kld')
self.mvallogger = Accumulator('loss', 'recon_loss', 'kld')
def meta_train(self):
sttime = time.time()
for epoch in range(1, self.max_epoch + 1):
self.mtrlog.ep_sttime = time.time()
loss = self.meta_train_epoch(epoch)
self.scheduler.step(loss)
self.mtrlog.print(self.mtrlogger, epoch, tag='train')
self.meta_validation()
self.mtrlog.print(self.mvallogger, epoch, tag='valid')
if epoch % self.save_epoch == 0:
save_model(epoch, self.model, self.model_path)
self.mtrlog.save_time_log()
def meta_train_epoch(self, epoch):
self.model.to(self.device)
self.model.train()
self.mtrloader.dataset.set_mode('train')
pbar = tqdm(self.mtrloader)
for batch in pbar:
for x, g, acc in batch:
self.optimizer.zero_grad()
g = decode_ofa_mbv3_to_igraph(g)[0]
x_ = x.unsqueeze(0).to(self.device)
mu, logvar = self.model.set_encode(x_)
loss, recon, kld = self.model.loss(mu.unsqueeze(0), logvar.unsqueeze(0), [g])
loss.backward()
self.optimizer.step()
cnt = len(x)
self.mtrlogger.accum([loss.item() / cnt,
recon.item() / cnt,
kld.item() / cnt])
return self.mtrlogger.get('loss')
def meta_validation(self):
self.model.to(self.device)
self.model.eval()
self.mtrloader.dataset.set_mode('valid')
pbar = tqdm(self.mtrloader)
for batch in pbar:
for x, g, acc in batch:
with torch.no_grad():
g = decode_ofa_mbv3_to_igraph(g)[0]
x_ = x.unsqueeze(0).to(self.device)
mu, logvar = self.model.set_encode(x_)
loss, recon, kld = self.model.loss(mu.unsqueeze(0), logvar.unsqueeze(0), [g])
cnt = len(x)
self.mvallogger.accum([loss.item() / cnt,
recon.item() / cnt,
kld.item() / cnt])
return self.mvallogger.get('loss')
def meta_test(self, predictor):
if self.data_name == 'all':
for data_name in ['cifar100', 'cifar10', 'mnist', 'svhn', 'aircraft30', 'aircraft100', 'pets']:
self.meta_test_per_dataset(data_name, predictor)
else:
self.meta_test_per_dataset(self.data_name, predictor)
def meta_test_per_dataset(self, data_name, predictor):
# meta_test_path = os.path.join(
# self.save_path, 'meta_test', data_name, 'generated_arch')
meta_test_path = os.path.join(
self.save_path, 'meta_test', data_name, f'{self.num_gen_arch}', 'generated_arch')
if not os.path.exists(meta_test_path):
os.makedirs(meta_test_path)
meta_test_loader = get_meta_test_loader(
self.data_path, data_name, self.num_sample, self.num_class)
print(f'==> generate architectures for {data_name}')
runs = 10 if data_name in ['cifar10', 'cifar100'] else 1
# num_gen_arch = 500 if data_name in ['cifar100'] else self.num_gen_arch
elasped_time = []
for run in range(1, runs + 1):
print(f'==> run {run}/{runs}')
elasped_time.append(self.generate_architectures(
meta_test_loader, data_name,
meta_test_path, run, self.num_gen_arch, predictor))
print(f'==> done\n')
# time_path = os.path.join(self.save_path, 'meta_test', data_name, 'time.txt')
time_path = os.path.join(self.save_path, 'meta_test', data_name, f'{self.num_gen_arch}', 'time.txt')
with open(time_path, 'w') as f_time:
msg = f'generator elasped time {np.mean(elasped_time):.2f}s'
print(f'==> save time in {time_path}')
f_time.write(msg + '\n');
print(msg)
def generate_architectures(self, meta_test_loader, data_name,
meta_test_path, run, num_gen_arch, predictor):
self.model.eval()
self.model.to(self.device)
architecture_string_lst, pred_acc_lst = [], []
total_cnt, valid_cnt = 0, 0
flag = False
start = time.time()
with torch.no_grad():
for x in meta_test_loader:
x_ = x.unsqueeze(0).to(self.device)
mu, logvar = self.model.set_encode(x_)
z = self.model.reparameterize(mu.unsqueeze(0), logvar.unsqueeze(0))
g_recon = self.model.graph_decode(z)
pred_acc = predictor.forward(x_, g_recon)
architecture_string = decode_igraph_to_ofa_mbv3(g_recon[0])
total_cnt += 1
if architecture_string is not None:
if not architecture_string in architecture_string_lst:
valid_cnt += 1
architecture_string_lst.append(architecture_string)
pred_acc_lst.append(pred_acc.item())
if valid_cnt == num_gen_arch:
flag = True
break
if flag:
break
elapsed = time.time() - start
pred_acc_lst, architecture_string_lst = zip(*sorted(zip(pred_acc_lst,
architecture_string_lst),
key=lambda x: x[0], reverse=True))
spath = os.path.join(meta_test_path, f"run_{run}.txt")
with open(spath, 'w') as f:
print(f'==> save generated architectures in {spath}')
msg = f'elapsed time: {elapsed:6.2f}s '
print(msg);
f.write(msg + '\n')
for i, architecture_string in enumerate(architecture_string_lst):
f.write(f"{architecture_string}\n")
return elapsed

View File

@@ -0,0 +1,396 @@
######################################################################################
# Copyright (c) muhanzhang, D-VAE, NeurIPS 2019 [GitHub D-VAE]
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
######################################################################################
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import igraph
from set_encoder.setenc_models import SetPool
class GeneratorModel(nn.Module):
def __init__(self, args, graph_config):
super(GeneratorModel, self).__init__()
self.max_n = graph_config['max_n'] # maximum number of vertices
self.nvt = graph_config['num_vertex_type'] # number of vertex types
self.START_TYPE = graph_config['START_TYPE']
self.END_TYPE = graph_config['END_TYPE']
self.hs = args.hs # hidden state size of each vertex
self.nz = args.nz # size of latent representation z
self.gs = args.hs # size of graph state
self.bidir = True # whether to use bidirectional encoding
self.vid = True
self.device = None
self.num_sample = args.num_sample
if self.vid:
self.vs = self.hs + self.max_n # vertex state size = hidden state + vid
else:
self.vs = self.hs
# 0. encoding-related
self.grue_forward = nn.GRUCell(self.nvt, self.hs) # encoder GRU
self.grue_backward = nn.GRUCell(self.nvt, self.hs) # backward encoder GRU
self.enc_g_mu = nn.Linear(self.gs, self.nz) # latent mean
self.enc_g_var = nn.Linear(self.gs, self.nz) # latent var
self.fc1 = nn.Linear(self.gs, self.nz) # latent mean
self.fc2 = nn.Linear(self.gs, self.nz) # latent logvar
# 1. decoding-related
self.grud = nn.GRUCell(self.nvt, self.hs) # decoder GRU
self.fc3 = nn.Linear(self.nz, self.hs) # from latent z to initial hidden state h0
self.add_vertex = nn.Sequential(
nn.Linear(self.hs, self.hs * 2),
nn.ReLU(),
nn.Linear(self.hs * 2, self.nvt)
) # which type of new vertex to add f(h0, hg)
self.add_edge = nn.Sequential(
nn.Linear(self.hs * 2, self.hs * 4),
nn.ReLU(),
nn.Linear(self.hs * 4, 1)
) # whether to add edge between v_i and v_new, f(hvi, hnew)
self.decoding_gate = nn.Sequential(
nn.Linear(self.vs, self.hs),
nn.Sigmoid()
)
self.decoding_mapper = nn.Sequential(
nn.Linear(self.vs, self.hs, bias=False),
) # disable bias to ensure padded zeros also mapped to zeros
# 2. gate-related
self.gate_forward = nn.Sequential(
nn.Linear(self.vs, self.hs),
nn.Sigmoid()
)
self.gate_backward = nn.Sequential(
nn.Linear(self.vs, self.hs),
nn.Sigmoid()
)
self.mapper_forward = nn.Sequential(
nn.Linear(self.vs, self.hs, bias=False),
) # disable bias to ensure padded zeros also mapped to zeros
self.mapper_backward = nn.Sequential(
nn.Linear(self.vs, self.hs, bias=False),
)
# 3. bidir-related, to unify sizes
if self.bidir:
self.hv_unify = nn.Sequential(
nn.Linear(self.hs * 2, self.hs),
)
self.hg_unify = nn.Sequential(
nn.Linear(self.gs * 2, self.gs),
)
# 4. other
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.logsoftmax1 = nn.LogSoftmax(1)
# 6. predictor
np = self.gs
self.intra_setpool = SetPool(dim_input=512,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.inter_setpool = SetPool(dim_input=self.nz,
num_outputs=1,
dim_output=self.nz,
dim_hidden=self.nz,
mode='sabPF')
self.set_fc = nn.Sequential(
nn.Linear(512, self.nz),
nn.ReLU())
def get_device(self):
if self.device is None:
self.device = next(self.parameters()).device
return self.device
def _get_zeros(self, n, length):
return torch.zeros(n, length).to(self.get_device()) # get a zero hidden state
def _get_zero_hidden(self, n=1):
return self._get_zeros(n, self.hs) # get a zero hidden state
def _one_hot(self, idx, length):
if type(idx) in [list, range]:
if idx == []:
return None
idx = torch.LongTensor(idx).unsqueeze(0).t()
x = torch.zeros((len(idx), length)
).scatter_(1, idx, 1).to(self.get_device())
else:
idx = torch.LongTensor([idx]).unsqueeze(0)
x = torch.zeros((1, length)
).scatter_(1, idx, 1).to(self.get_device())
return x
def _gated(self, h, gate, mapper):
return gate(h) * mapper(h)
def _collate_fn(self, G):
return [g.copy() for g in G]
def _propagate_to(self, G, v, propagator,
H=None, reverse=False, gate=None, mapper=None):
# propagate messages to vertex index v for all graphs in G
# return the new messages (states) at v
G = [g for g in G if g.vcount() > v]
if len(G) == 0:
return
if H is not None:
idx = [i for i, g in enumerate(G) if g.vcount() > v]
H = H[idx]
v_types = [g.vs[v]['type'] for g in G]
X = self._one_hot(v_types, self.nvt)
H_name = 'H_forward' # name of the hidden states attribute
H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
if self.vid:
vids = [self._one_hot(g.predecessors(v), self.max_n) for g in G]
if reverse:
H_name = 'H_backward' # name of the hidden states attribute
H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G]
if self.vid:
vids = [self._one_hot(g.successors(v), self.max_n) for g in G]
gate, mapper = self.gate_backward, self.mapper_backward
else:
H_name = 'H_forward' # name of the hidden states attribute
H_pred = [
[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
if self.vid:
vids = [
self._one_hot(g.predecessors(v), self.max_n) for g in G]
if gate is None:
gate, mapper = self.gate_forward, self.mapper_forward
if self.vid:
H_pred = [[torch.cat(
[x[i], y[i:i + 1]], 1) for i in range(len(x))
] for x, y in zip(H_pred, vids)]
# if h is not provided, use gated sum of v's predecessors' states as the input hidden state
if H is None:
max_n_pred = max([len(x) for x in H_pred]) # maximum number of predecessors
if max_n_pred == 0:
H = self._get_zero_hidden(len(G))
else:
H_pred = [torch.cat(h_pred +
[self._get_zeros(max_n_pred - len(h_pred),
self.vs)], 0).unsqueeze(0)
for h_pred in H_pred] # pad all to same length
H_pred = torch.cat(H_pred, 0) # batch * max_n_pred * vs
H = self._gated(H_pred, gate, mapper).sum(1) # batch * hs
Hv = propagator(X, H)
for i, g in enumerate(G):
g.vs[v][H_name] = Hv[i:i + 1]
return Hv
def _propagate_from(self, G, v, propagator, H0=None, reverse=False):
# perform a series of propagation_to steps starting from v following a topo order
# assume the original vertex indices are in a topological order
if reverse:
prop_order = range(v, -1, -1)
else:
prop_order = range(v, self.max_n)
Hv = self._propagate_to(G, v, propagator, H0, reverse=reverse) # the initial vertex
for v_ in prop_order[1:]:
self._propagate_to(G, v_, propagator, reverse=reverse)
return Hv
def _update_v(self, G, v, H0=None):
# perform a forward propagation step at v when decoding to update v's state
# self._propagate_to(G, v, self.grud, H0, reverse=False)
self._propagate_to(G, v, self.grud, H0,
reverse=False, gate=self.decoding_gate,
mapper=self.decoding_mapper)
return
def _get_vertex_state(self, G, v):
# get the vertex states at v
Hv = []
for g in G:
if v >= g.vcount():
hv = self._get_zero_hidden()
else:
hv = g.vs[v]['H_forward']
Hv.append(hv)
Hv = torch.cat(Hv, 0)
return Hv
def _get_graph_state(self, G, decode=False):
# get the graph states
# when decoding, use the last generated vertex's state as the graph state
# when encoding, use the ending vertex state or unify the starting and ending vertex states
Hg = []
for g in G:
hg = g.vs[g.vcount() - 1]['H_forward']
if self.bidir and not decode: # decoding never uses backward propagation
hg_b = g.vs[0]['H_backward']
hg = torch.cat([hg, hg_b], 1)
Hg.append(hg)
Hg = torch.cat(Hg, 0)
if self.bidir and not decode:
Hg = self.hg_unify(Hg)
return Hg
def graph_encode(self, G):
# encode graphs G into latent vectors
if type(G) != list:
G = [G]
self._propagate_from(G, 0, self.grue_forward,
H0=self._get_zero_hidden(len(G)), reverse=False)
if self.bidir:
self._propagate_from(G, self.max_n - 1, self.grue_backward,
H0=self._get_zero_hidden(len(G)), reverse=True)
Hg = self._get_graph_state(G)
mu, logvar = self.enc_g_mu(Hg), self.enc_g_var(Hg)
return mu, logvar
def set_encode(self, X):
proto_batch = []
for x in X: # X.shape: [32, 400, 512]
cls_protos = self.intra_setpool(
x.view(-1, self.num_sample, 512)).squeeze(1)
proto_batch.append(
self.inter_setpool(cls_protos.unsqueeze(0)))
v = torch.stack(proto_batch).squeeze()
mu, logvar = self.fc1(v), self.fc2(v)
return mu, logvar
def reparameterize(self, mu, logvar, eps_scale=0.01):
# return z ~ N(mu, std)
if self.training:
std = logvar.mul(0.5).exp_()
eps = torch.randn_like(std) * eps_scale
return eps.mul(std).add_(mu)
else:
return mu
def _get_edge_score(self, Hvi, H, H0):
# compute scores for edges from vi based on Hvi, H (current vertex) and H0
# in most cases, H0 need not be explicitly included since Hvi and H contain its information
return self.sigmoid(self.add_edge(torch.cat([Hvi, H], -1)))
def graph_decode(self, z, stochastic=True):
# decode latent vectors z back to graphs
# if stochastic=True, stochastically sample each action from the predicted distribution;
# otherwise, select argmax action deterministically.
H0 = self.tanh(self.fc3(z)) # or relu activation, similar performance
G = [igraph.Graph(directed=True) for _ in range(len(z))]
for g in G:
g.add_vertex(type=self.START_TYPE)
self._update_v(G, 0, H0)
finished = [False] * len(G)
for idx in range(1, self.max_n):
# decide the type of the next added vertex
if idx == self.max_n - 1: # force the last node to be end_type
new_types = [self.END_TYPE] * len(G)
else:
Hg = self._get_graph_state(G, decode=True)
type_scores = self.add_vertex(Hg)
if stochastic:
type_probs = F.softmax(type_scores, 1
).cpu().detach().numpy()
new_types = [np.random.choice(range(self.nvt),
p=type_probs[i]) for i in range(len(G))]
else:
new_types = torch.argmax(type_scores, 1)
new_types = new_types.flatten().tolist()
for i, g in enumerate(G):
if not finished[i]:
g.add_vertex(type=new_types[i])
self._update_v(G, idx)
# decide connections
edge_scores = []
for vi in range(idx - 1, -1, -1):
Hvi = self._get_vertex_state(G, vi)
H = self._get_vertex_state(G, idx)
ei_score = self._get_edge_score(Hvi, H, H0)
if stochastic:
random_score = torch.rand_like(ei_score)
decisions = random_score < ei_score
else:
decisions = ei_score > 0.5
for i, g in enumerate(G):
if finished[i]:
continue
if new_types[i] == self.END_TYPE:
# if new node is end_type, connect it to all loose-end vertices (out_degree==0)
end_vertices = set([
v.index for v in g.vs.select(_outdegree_eq=0)
if v.index != g.vcount() - 1])
for v in end_vertices:
g.add_edge(v, g.vcount() - 1)
finished[i] = True
continue
if decisions[i, 0]:
g.add_edge(vi, g.vcount() - 1)
self._update_v(G, idx)
for g in G:
del g.vs['H_forward'] # delete hidden states to save GPU memory
return G
def loss(self, mu, logvar, G_true, beta=0.005):
# compute the loss of decoding mu and logvar to true graphs using teacher forcing
# ensure when computing the loss of step i, steps 0 to i-1 are correct
z = self.reparameterize(mu, logvar)
H0 = self.tanh(self.fc3(z)) # or relu activation, similar performance
G = [igraph.Graph(directed=True) for _ in range(len(z))]
for g in G:
g.add_vertex(type=self.START_TYPE)
self._update_v(G, 0, H0)
res = 0 # log likelihood
for v_true in range(1, self.max_n):
# calculate the likelihood of adding true types of nodes
# use start type to denote padding vertices since start type only appears for vertex 0
# and will never be a true type for later vertices, thus it's free to use
true_types = [g_true.vs[v_true]['type']
if v_true < g_true.vcount()
else self.START_TYPE for g_true in G_true]
Hg = self._get_graph_state(G, decode=True)
type_scores = self.add_vertex(Hg)
# vertex log likelihood
vll = self.logsoftmax1(type_scores)[
np.arange(len(G)), true_types].sum()
res = res + vll
for i, g in enumerate(G):
if true_types[i] != self.START_TYPE:
g.add_vertex(type=true_types[i])
self._update_v(G, v_true)
# calculate the likelihood of adding true edges
true_edges = []
for i, g_true in enumerate(G_true):
true_edges.append(g_true.get_adjlist(igraph.IN)[v_true]
if v_true < g_true.vcount() else [])
edge_scores = []
for vi in range(v_true - 1, -1, -1):
Hvi = self._get_vertex_state(G, vi)
H = self._get_vertex_state(G, v_true)
ei_score = self._get_edge_score(Hvi, H, H0)
edge_scores.append(ei_score)
for i, g in enumerate(G):
if vi in true_edges[i]:
g.add_edge(vi, v_true)
self._update_v(G, v_true)
edge_scores = torch.cat(edge_scores[::-1], 1)
ground_truth = torch.zeros_like(edge_scores)
idx1 = [i for i, x in enumerate(true_edges)
for _ in range(len(x))]
idx2 = [xx for x in true_edges for xx in x]
ground_truth[idx1, idx2] = 1.0
# edges log-likelihood
ell = - F.binary_cross_entropy(
edge_scores, ground_truth, reduction='sum')
res = res + ell
res = -res # convert likelihood to loss
kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return res + beta * kld, res, kld

View File

@@ -0,0 +1,37 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
file_name = 'ckpt_120.pt'
dir_path = 'results/generator/model'
if not os.path.exists(dir_path):
os.makedirs(dir_path)
file_name = os.path.join(dir_path, file_name)
if not os.path.exists(file_name):
print(f"Downloading {file_name}\n")
download_file('https://www.dropbox.com/s/zss9yt034hen45h/ckpt_120.pt?dl=1', file_name)
print("Downloading done.\n")
else:
print(f"{file_name} has already been downloaded. Did not download twice.\n")

View File

@@ -0,0 +1,38 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
file_name = 'collected_database.pt'
dir_path = 'data/generator/processed'
if not os.path.exists(dir_path):
os.makedirs(dir_path)
file_name = os.path.join(dir_path, file_name)
if not os.path.exists(file_name):
print(f"Downloading generator {file_name}\n")
download_file('https://www.dropbox.com/s/zgip4aq0w2pkj49/generator_collected_database.pt?dl=1', file_name)
print("Downloading done.\n")
else:
print(f"{file_name} has already been downloaded. Did not download twice.\n")

View File

@@ -0,0 +1,43 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
dir_path = 'data/pets'
if not os.path.exists(dir_path):
os.makedirs(dir_path)
full_name = os.path.join(dir_path, 'test15.pth')
if not os.path.exists(full_name):
print(f"Downloading {full_name}\n")
download_file('https://www.dropbox.com/s/kzmrwyyk5iaugv0/test15.pth?dl=1', full_name)
print("Downloading done.\n")
else:
print(f"{full_name} has already been downloaded. Did not download twice.\n")
full_name = os.path.join(dir_path, 'train85.pth')
if not os.path.exists(full_name):
print(f"Downloading {full_name}\n")
download_file('https://www.dropbox.com/s/w7mikpztkamnw9s/train85.pth?dl=1', full_name)
print("Downloading done.\n")
else:
print(f"{full_name} has already been downloaded. Did not download twice.\n")

View File

@@ -0,0 +1,35 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
file_name = 'ckpt_max_corr.pt'
dir_path = 'results/predictor/model'
if not os.path.exists(dir_path):
os.makedirs(dir_path)
file_name = os.path.join(dir_path, file_name)
if not os.path.exists(file_name):
print(f"Downloading {file_name}\n")
download_file('https://www.dropbox.com/s/ycm4jaojgswp0zm/ckpt_max_corr.pt?dl=1', file_name)
print("Downloading done.\n")
else:
print(f"{file_name} has already been downloaded. Did not download twice.\n")

View File

@@ -0,0 +1,38 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
file_name = 'collected_database.pt'
dir_path = 'data/predictor/processed'
if not os.path.exists(dir_path):
os.makedirs(dir_path)
file_name = os.path.join(dir_path, file_name)
if not os.path.exists(file_name):
print(f"Downloading predictor {file_name}\n")
download_file('https://www.dropbox.com/s/ycm4jaojgswp0zm/ckpt_max_corr.pt?dl=1', file_name)
print("Downloading done.\n")
else:
print(f"{file_name} has already been downloaded. Did not download twice.\n")

View File

@@ -0,0 +1,47 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
from tqdm import tqdm
import requests
import zipfile
def download_file(url, filename):
"""
Helper method handling downloading large files from `url`
to `filename`. Returns a pointer to `filename`.
"""
chunkSize = 1024
r = requests.get(url, stream=True)
with open(filename, 'wb') as f:
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk: # filter out keep-alive new chunks
pbar.update (len(chunk))
f.write(chunk)
return filename
dir_path = 'data'
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def get_preprocessed_data(file_name, url):
print(f"Downloading {file_name} datasets\n")
full_name = os.path.join(dir_path, file_name)
download_file(url, full_name)
print("Downloading done.\n")
for file_name, url in [
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
('aircraft100bylabel.pt', 'https://www.dropbox.com/s/nn6mlrk1jijg108/aircraft100bylabel.pt?dl=1'),
('cifar100bylabel.pt', 'https://www.dropbox.com/s/y0xahxgzj29kffk/cifar100bylabel.pt?dl=1'),
('cifar10bylabel.pt', 'https://www.dropbox.com/s/wt1pcwi991xyhwr/cifar10bylabel.pt?dl=1'),
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
('petsbylabel.pt', 'https://www.dropbox.com/s/mxh6qz3grhy7wcn/petsbylabel.pt?dl=1'),
('mnistbylabel.pt', 'https://www.dropbox.com/s/86rbuic7a7y34e4/mnistbylabel.pt?dl=1'),
('svhnbylabel.pt', 'https://www.dropbox.com/s/yywaelhrsl6egvd/svhnbylabel.pt?dl=1')
]:
get_preprocessed_data(file_name, url)

View File

@@ -0,0 +1,149 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import os
import torch
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=False):
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
print(f'==> The number of tasks for meta-training: {len(dataset)}')
loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1,
collate_fn=collate_fn)
return loader
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
dataset = MetaTestDataset(data_path, data_name, num_class)
print(f'==> Meta-Test dataset {data_name}')
loader = DataLoader(dataset=dataset,
batch_size=100,
shuffle=False,
num_workers=1)
return loader
class MetaTrainDatabase(Dataset):
def __init__(self, data_path, num_sample, is_pred=False):
self.mode = 'train'
self.acc_norm = True
self.num_sample = num_sample
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
self.dpath = '{}/{}/processed/'.format(data_path, 'predictor' if is_pred else 'generator')
self.dname = f'database_219152_14.0K'
if not os.path.exists(self.dpath + f'{self.dname}_train.pt'):
raise ValueError('')
database = torch.load(self.dpath + f'{self.dname}.pt')
rand_idx = torch.randperm(len(database))
test_len = int(len(database) * 0.15)
idxlst = {'test': rand_idx[:test_len],
'valid': rand_idx[test_len:2 * test_len],
'train': rand_idx[2 * test_len:]}
for m in ['train', 'valid', 'test']:
acc, graph, cls, net, flops = [], [], [], [], []
for idx in tqdm(idxlst[m].tolist(), desc=f'data-{m}'):
acc.append(database[idx]['top1'])
net.append(database[idx]['net'])
cls.append(database[idx]['class'])
flops.append(database[idx]['flops'])
if m == 'train':
mean = torch.mean(torch.tensor(acc)).item()
std = torch.std(torch.tensor(acc)).item()
torch.save({'acc': acc,
'class': cls,
'net': net,
'flops': flops,
'mean': mean,
'std': std},
self.dpath + f'{self.dname}_{m}.pt')
self.set_mode(self.mode)
def set_mode(self, mode):
self.mode = mode
data = torch.load(self.dpath + f'{self.dname}_{self.mode}.pt')
self.acc = data['acc']
self.cls = data['class']
self.net = data['net']
self.flops = data['flops']
self.mean = data['mean']
self.std = data['std']
def __len__(self):
return len(self.acc)
def __getitem__(self, index):
data = []
classes = self.cls[index]
acc = self.acc[index]
graph = self.net[index]
for i, cls in enumerate(classes):
cx = self.x[cls.item()][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
if self.acc_norm:
acc = ((acc - self.mean) / self.std) / 100.0
else:
acc = acc / 100.0
return x, graph, torch.tensor(acc).view(1, 1)
class MetaTestDataset(Dataset):
def __init__(self, data_path, data_name, num_sample, num_class=None):
self.num_sample = num_sample
self.data_name = data_name
if data_name == 'aircraft':
data_name = 'aircraft100'
num_class_dict = {
'cifar100': 100,
'cifar10': 10,
'mnist': 10,
'aircraft100': 30,
'svhn': 10,
'pets': 37
}
# 'aircraft30': 30,
# 'aircraft100': 100,
if num_class is not None:
self.num_class = num_class
else:
self.num_class = num_class_dict[data_name]
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
def __len__(self):
return 1000000
def __getitem__(self, index):
data = []
classes = list(range(self.num_class))
for cls in classes:
cx = self.x[cls][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
return x
def collate_fn(batch):
# x = torch.stack([item[0] for item in batch])
# graph = [item[1] for item in batch]
# acc = torch.stack([item[2] for item in batch])
return batch

View File

@@ -0,0 +1,48 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
import os
import random
import numpy as np
import torch
from parser import get_parser
from generator import Generator
from predictor import Predictor
def main():
args = get_parser()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
args.device = torch.device("cuda:0")
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
args.model_path = os.path.join(args.save_path, args.model_name, 'model')
if not os.path.exists(args.model_path):
os.makedirs(args.model_path)
if args.model_name == 'generator':
g = Generator(args)
if args.test:
args.model_path = os.path.join(args.save_path, 'predictor', 'model')
hs = args.hs
args.hs = 512
p = Predictor(args)
args.model_path = os.path.join(args.save_path, args.model_name, 'model')
args.hs = hs
g.meta_test(p)
else:
g.meta_train()
elif args.model_name == 'predictor':
p = Predictor(args)
p.meta_train()
else:
raise ValueError('You should select generator|predictor|train_arch')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,344 @@
###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import os
import time
import igraph
import random
import numpy as np
import scipy.stats
import argparse
import torch
def load_graph_config(graph_data_name, nvt, data_path):
max_n=20
graph_config = {}
graph_config['num_vertex_type'] = nvt + 2 # original types + start/end types
graph_config['max_n'] = max_n + 2 # maximum number of nodes
graph_config['START_TYPE'] = 0 # predefined start vertex type
graph_config['END_TYPE'] = 1 # predefined end vertex type
return graph_config
type_dict = {'2-3-3': 0, '2-3-4': 1, '2-3-6': 2,
'2-5-3': 3, '2-5-4': 4, '2-5-6': 5,
'2-7-3': 6, '2-7-4': 7, '2-7-6': 8,
'3-3-3': 9, '3-3-4': 10, '3-3-6': 11,
'3-5-3': 12, '3-5-4': 13, '3-5-6': 14,
'3-7-3': 15, '3-7-4': 16, '3-7-6': 17,
'4-3-3': 18, '4-3-4': 19, '4-3-6': 20,
'4-5-3': 21, '4-5-4': 22, '4-5-6': 23,
'4-7-3': 24, '4-7-4': 25, '4-7-6': 26}
edge_dict = {2: (2, 3, 3), 3: (2, 3, 4), 4: (2, 3, 6),
5: (2, 5, 3), 6: (2, 5, 4), 7: (2, 5, 6),
8: (2, 7, 3), 9: (2, 7, 4), 10: (2, 7, 6),
11: (3, 3, 3), 12: (3, 3, 4), 13: (3, 3, 6),
14: (3, 5, 3), 15: (3, 5, 4), 16: (3, 5, 6),
17: (3, 7, 3), 18: (3, 7, 4), 19: (3, 7, 6),
20: (4, 3, 3), 21: (4, 3, 4), 22: (4, 3, 6),
23: (4, 5, 3), 24: (4, 5, 4), 25: (4, 5, 6),
26: (4, 7, 3), 27: (4, 7, 4), 28: (4, 7, 6)}
def decode_ofa_mbv3_to_igraph(matrix):
# 5 stages, 4 layers for each stage
# d: 2, 3, 4
# e: 3, 4, 6
# k: 3, 5, 7
# stage_depth to one hot
num_stage = 5
num_layer = 4
node_types = torch.zeros(num_stage * num_layer)
d = []
for i in range(num_stage):
for j in range(num_layer):
d.append(matrix['d'][i])
for i, (ks, e, d) in enumerate(zip(
matrix['ks'], matrix['e'], d)):
node_types[i] = type_dict[f'{d}-{ks}-{e}']
n = num_stage * num_layer
g = igraph.Graph(directed=True)
g.add_vertices(n + 2) # + in/out nodes
g.vs[0]['type'] = 0
for i, v in enumerate(node_types):
g.vs[i + 1]['type'] = v + 2 # in node: 0, out node: 1
g.add_edge(i, i + 1)
g.vs[n + 1]['type'] = 1
g.add_edge(n, n + 1)
return g, n + 2
def decode_ofa_mbv3_str_to_igraph(gen_str):
# 5 stages, 4 layers for each stage
# d: 2, 3, 4
# e: 3, 4, 6
# k: 3, 5, 7
# stage_depth to one hot
num_stage = 5
num_layer = 4
node_types = torch.zeros(num_stage * num_layer)
d = []
split_str = gen_str.split('_')
for i, s in enumerate(split_str):
if s == '0-0-0':
node_types[i] = random.randint(0, 26)
else:
node_types[i] = type_dict[s]
n = num_stage * num_layer
g = igraph.Graph(directed=True)
g.add_vertices(n + 2) # + in/out nodes
g.vs[0]['type'] = 0
for i, v in enumerate(node_types):
g.vs[i + 1]['type'] = v + 2 # in node: 0, out node: 1
g.add_edge(i, i + 1)
g.vs[n + 1]['type'] = 1
g.add_edge(n, n + 1)
return g
def is_valid_ofa_mbv3(g, START_TYPE=0, END_TYPE=1):
# first need to be a valid DAG computation graph
msg = ''
res = is_valid_DAG(g, START_TYPE, END_TYPE)
# in addition, node i must connect to node i+1
res = res and len(g.vs['type']) == 22
if not res:
return res
msg += '{} ({}) '.format(g.vs['type'][1:-1], len(g.vs['type']))
for i in range(5):
if ((g.vs['type'][1:-1][i * 4]) - 2) // 9 == 0:
for j in range(1, 4):
res = res and ((g.vs['type'][1:-1][i * 4 + j]) - 2) // 9 == 0
elif ((g.vs['type'][1:-1][i * 4]) - 2) // 9 == 1:
for j in range(1, 4):
res = res and ((g.vs['type'][1:-1][i * 4 + j]) - 2) // 9 == 1
elif ((g.vs['type'][1:-1][i * 4]) - 2) // 9 == 2:
for j in range(1, 4):
res = res and ((g.vs['type'][1:-1][i * 4 + j]) - 2) // 9 == 2
else:
raise ValueError
return res
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
res = g.is_dag()
n_start, n_end = 0, 0
for v in g.vs:
if v['type'] == START_TYPE:
n_start += 1
elif v['type'] == END_TYPE:
n_end += 1
if v.indegree() == 0 and v['type'] != START_TYPE:
return False
if v.outdegree() == 0 and v['type'] != END_TYPE:
return False
return res and n_start == 1 and n_end == 1
def decode_igraph_to_ofa_mbv3(g):
if not is_valid_ofa_mbv3(g, START_TYPE=0, END_TYPE=1):
return None
graph = {'ks': [], 'e': [], 'd': [4, 4, 4, 4, 4]}
for i, edge_type in enumerate(g.vs['type'][1:-1]):
edge_type = int(edge_type)
d, ks, e = edge_dict[edge_type]
graph['ks'].append(ks)
graph['e'].append(e)
graph['d'][i // 4] = d
return graph
class Accumulator():
def __init__(self, *args):
self.args = args
self.argdict = {}
for i, arg in enumerate(args):
self.argdict[arg] = i
self.sums = [0] * len(args)
self.cnt = 0
def accum(self, val):
val = [val] if type(val) is not list else val
val = [v for v in val if v is not None]
assert (len(val) == len(self.args))
for i in range(len(val)):
if torch.is_tensor(val[i]):
val[i] = val[i].item()
self.sums[i] += val[i]
self.cnt += 1
def clear(self):
self.sums = [0] * len(self.args)
self.cnt = 0
def get(self, arg, avg=True):
i = self.argdict.get(arg, -1)
assert (i is not -1)
if avg:
return self.sums[i] / (self.cnt + 1e-8)
else:
return self.sums[i]
def print_(self, header=None, time=None,
logfile=None, do_not_print=[], as_int=[],
avg=True):
msg = '' if header is None else header + ': '
if time is not None:
msg += ('(%.3f secs), ' % time)
args = [arg for arg in self.args if arg not in do_not_print]
arg = []
for arg in args:
val = self.sums[self.argdict[arg]]
if avg:
val /= (self.cnt + 1e-8)
if arg in as_int:
msg += ('%s %d, ' % (arg, int(val)))
else:
msg += ('%s %.4f, ' % (arg, val))
print(msg)
if logfile is not None:
logfile.write(msg + '\n')
logfile.flush()
def add_scalars(self, summary, header=None, tag_scalar=None,
step=None, avg=True, args=None):
for arg in self.args:
val = self.sums[self.argdict[arg]]
if avg:
val /= (self.cnt + 1e-8)
else:
val = val
tag = f'{header}/{arg}' if header is not None else arg
if tag_scalar is not None:
summary.add_scalars(main_tag=tag,
tag_scalar_dict={tag_scalar: val},
global_step=step)
else:
summary.add_scalar(tag=tag,
scalar_value=val,
global_step=step)
class Log:
def __init__(self, args, logf, summary=None):
self.args = args
self.logf = logf
self.summary = summary
self.stime = time.time()
self.ep_sttime = None
def print(self, logger, epoch, tag=None, avg=True):
if tag == 'train':
ct = time.time() - self.ep_sttime
tt = time.time() - self.stime
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
print(msg)
self.logf.write(msg + '\n')
logger.print_(header=tag, logfile=self.logf, avg=avg)
if self.summary is not None:
logger.add_scalars(
self.summary, header=tag, step=epoch, avg=avg)
logger.clear()
def print_args(self):
argdict = vars(self.args)
print(argdict)
for k, v in argdict.items():
self.logf.write(k + ': ' + str(v) + '\n')
self.logf.write('\n')
def set_time(self):
self.stime = time.time()
def save_time_log(self):
ct = time.time() - self.stime
msg = f'({ct:6.2f}s) meta-training phase done'
print(msg)
self.logf.write(msg + '\n')
def print_pred_log(self, loss, corr, tag, epoch=None, max_corr_dict=None):
if tag == 'train':
ct = time.time() - self.ep_sttime
tt = time.time() - self.stime
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
self.logf.write(msg + '\n');
print(msg);
self.logf.flush()
# msg = f'ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} '
# msg += f'time {time.time() - sttime:6.2f} '
if max_corr_dict is not None:
max_corr = max_corr_dict['corr']
max_loss = max_corr_dict['loss']
msg = f'{tag}: loss {loss:.6f} ({max_loss:.6f}) '
msg += f'corr {corr:.4f} ({max_corr:.4f})'
else:
msg = f'{tag}: loss {loss:.6f} corr {corr:.4f}'
self.logf.write(msg + '\n');
print(msg);
self.logf.flush()
def max_corr_log(self, max_corr_dict):
corr = max_corr_dict['corr']
loss = max_corr_dict['loss']
epoch = max_corr_dict['epoch']
msg = f'[epoch {epoch}] max correlation: {corr:.4f}, loss: {loss:.6f}'
self.logf.write(msg + '\n');
print(msg);
self.logf.flush()
def get_log(epoch, loss, y_pred, y, acc_std, acc_mean, tag='train'):
msg = f'[{tag}] Ep {epoch} loss {loss.item() / len(y):0.4f} '
msg += f'pacc {y_pred[0]:0.4f}'
msg += f'({y_pred[0] * 100.0 * acc_std + acc_mean:0.4f}) '
msg += f'acc {y[0]:0.4f}({y[0] * 100 * acc_std + acc_mean:0.4f})'
return msg
def load_model(model, model_path, load_epoch=None, load_max_pt=None):
if load_max_pt is not None:
ckpt_path = os.path.join(model_path, load_max_pt)
else:
ckpt_path = os.path.join(model_path, f'ckpt_{load_epoch}.pt')
print(f"==> load checkpoint for MetaD2A predictor: {ckpt_path} ...")
model.cpu()
model.load_state_dict(torch.load(ckpt_path))
def save_model(epoch, model, model_path, max_corr=None):
print("==> save current model...")
if max_corr is not None:
torch.save(model.cpu().state_dict(),
os.path.join(model_path, 'ckpt_max_corr.pt'))
else:
torch.save(model.cpu().state_dict(),
os.path.join(model_path, f'ckpt_{epoch}.pt'))
def mean_confidence_interval(data, confidence=0.95):
a = 1.0 * np.array(data)
n = len(a)
m, se = np.mean(a), scipy.stats.sem(a)
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
return m, h

View File

@@ -0,0 +1,5 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .imagenet import *

View File

@@ -0,0 +1,56 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import numpy as np
import torch
__all__ = ['DataProvider']
class DataProvider:
SUB_SEED = 937162211 # random seed for sampling subset
VALID_SEED = 2147483647 # random seed for the validation set
@staticmethod
def name():
""" Return name of the dataset """
raise NotImplementedError
@property
def data_shape(self):
""" Return shape as python list of one data entry """
raise NotImplementedError
@property
def n_classes(self):
""" Return `int` of num classes """
raise NotImplementedError
@property
def save_path(self):
""" local path to save the data """
raise NotImplementedError
@property
def data_url(self):
""" link to download the data """
raise NotImplementedError
@staticmethod
def random_sample_valid_set(train_size, valid_size):
assert train_size > valid_size
g = torch.Generator()
g.manual_seed(DataProvider.VALID_SEED) # set random seed before sampling validation set
rand_indexes = torch.randperm(train_size, generator=g).tolist()
valid_indexes = rand_indexes[:valid_size]
train_indexes = rand_indexes[valid_size:]
return train_indexes, valid_indexes
@staticmethod
def labels_to_one_hot(n_classes, labels):
new_labels = np.zeros((labels.shape[0], n_classes), dtype=np.float32)
new_labels[range(labels.shape[0]), labels] = np.ones(labels.shape)
return new_labels

View File

@@ -0,0 +1,225 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import warnings
import os
import math
import numpy as np
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from .base_provider import DataProvider
from ofa_local.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
__all__ = ['ImagenetDataProvider']
class ImagenetDataProvider(DataProvider):
DEFAULT_PATH = '/dataset/imagenet'
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
resize_scale=0.08, distort_color=None, image_size=224,
num_replicas=None, rank=None):
warnings.filterwarnings('ignore')
self._save_path = save_path
self.image_size = image_size # int or list of int
self.distort_color = 'None' if distort_color is None else distort_color
self.resize_scale = resize_scale
self._valid_transform_dict = {}
if not isinstance(self.image_size, int):
from ofa.utils.my_dataloader import MyDataLoader
assert isinstance(self.image_size, list)
self.image_size.sort() # e.g., 160 -> 224
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
for img_size in self.image_size:
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
self.active_img_size = max(self.image_size) # active resolution for test
valid_transforms = self._valid_transform_dict[self.active_img_size]
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
else:
self.active_img_size = self.image_size
valid_transforms = self.build_valid_transform()
train_loader_class = torch.utils.data.DataLoader
train_dataset = self.train_dataset(self.build_train_transform())
if valid_size is not None:
if not isinstance(valid_size, int):
assert isinstance(valid_size, float) and 0 < valid_size < 1
valid_size = int(len(train_dataset) * valid_size)
valid_dataset = self.train_dataset(valid_transforms)
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset), valid_size)
if num_replicas is not None:
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, True, np.array(train_indexes))
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, True, np.array(valid_indexes))
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
if num_replicas is not None:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True
)
else:
self.train = train_loader_class(
train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
self.valid = None
test_dataset = self.test_dataset(valid_transforms)
if num_replicas is not None:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
)
else:
self.test = torch.utils.data.DataLoader(
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'imagenet'
@property
def data_shape(self):
return 3, self.active_img_size, self.active_img_size # C, H, W
@property
def n_classes(self):
return 1000
@property
def save_path(self):
if self._save_path is None:
self._save_path = self.DEFAULT_PATH
if not os.path.exists(self._save_path):
self._save_path = os.path.expanduser('~/dataset/imagenet')
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download %s' % self.name())
def train_dataset(self, _transforms):
return datasets.ImageFolder(self.train_path, _transforms)
def test_dataset(self, _transforms):
return datasets.ImageFolder(self.valid_path, _transforms)
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self.save_path, 'val')
@property
def normalize(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def build_train_transform(self, image_size=None, print_log=True):
if image_size is None:
image_size = self.image_size
if print_log:
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
(self.distort_color, self.resize_scale, image_size))
if isinstance(image_size, list):
resize_transform_class = MyRandomResizedCrop
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
else:
resize_transform_class = transforms.RandomResizedCrop
# random_resize_crop -> random_horizontal_flip
train_transforms = [
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
]
# color augmentation (optional)
color_transform = None
if self.distort_color == 'torch':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif self.distort_color == 'tf':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
if color_transform is not None:
train_transforms.append(color_transform)
train_transforms += [
transforms.ToTensor(),
self.normalize,
]
train_transforms = transforms.Compose(train_transforms)
return train_transforms
def build_valid_transform(self, image_size=None):
if image_size is None:
image_size = self.active_img_size
return transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
])
def assign_active_img_size(self, new_img_size):
self.active_img_size = new_img_size
if self.active_img_size not in self._valid_transform_dict:
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
# change the transform of the valid and test set
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
# used for resetting BN running statistics
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
if num_worker is None:
num_worker = self.train.num_workers
n_samples = len(self.train.dataset)
g = torch.Generator()
g.manual_seed(DataProvider.SUB_SEED)
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
new_train_dataset = self.train_dataset(
self.build_train_transform(image_size=self.active_img_size, print_log=False))
chosen_indexes = rand_indexes[:n_images]
if num_replicas is not None:
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, True, np.array(chosen_indexes))
else:
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
sub_data_loader = torch.utils.data.DataLoader(
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
num_workers=num_worker, pin_memory=True,
)
self.__dict__['sub_train_%d' % self.active_img_size] = []
for images, labels in sub_data_loader:
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
return self.__dict__['sub_train_%d' % self.active_img_size]

View File

@@ -0,0 +1,6 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .dynamic_layers import *
from .dynamic_op import *

View File

@@ -0,0 +1,632 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import copy
import torch
import torch.nn as nn
from collections import OrderedDict
from ofa_local.utils.layers import MBConvLayer, ConvLayer, IdentityLayer, set_layer_from_config
from ofa_local.utils.layers import ResNetBottleneckBlock, LinearLayer
from ofa_local.utils import MyModule, val2list, get_net_device, build_activation, make_divisible, SEModule, MyNetwork
from .dynamic_op import DynamicSeparableConv2d, DynamicConv2d, DynamicBatchNorm2d, DynamicSE, DynamicGroupNorm
from .dynamic_op import DynamicLinear
__all__ = [
'adjust_bn_according_to_idx', 'copy_bn',
'DynamicMBConvLayer', 'DynamicConvLayer', 'DynamicLinearLayer', 'DynamicResNetBottleneckBlock'
]
def adjust_bn_according_to_idx(bn, idx):
bn.weight.data = torch.index_select(bn.weight.data, 0, idx)
bn.bias.data = torch.index_select(bn.bias.data, 0, idx)
if type(bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
bn.running_mean.data = torch.index_select(bn.running_mean.data, 0, idx)
bn.running_var.data = torch.index_select(bn.running_var.data, 0, idx)
def copy_bn(target_bn, src_bn):
feature_dim = target_bn.num_channels if isinstance(target_bn, nn.GroupNorm) else target_bn.num_features
target_bn.weight.data.copy_(src_bn.weight.data[:feature_dim])
target_bn.bias.data.copy_(src_bn.bias.data[:feature_dim])
if type(src_bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
target_bn.running_mean.data.copy_(src_bn.running_mean.data[:feature_dim])
target_bn.running_var.data.copy_(src_bn.running_var.data[:feature_dim])
class DynamicLinearLayer(MyModule):
def __init__(self, in_features_list, out_features, bias=True, dropout_rate=0):
super(DynamicLinearLayer, self).__init__()
self.in_features_list = in_features_list
self.out_features = out_features
self.bias = bias
self.dropout_rate = dropout_rate
if self.dropout_rate > 0:
self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
else:
self.dropout = None
self.linear = DynamicLinear(
max_in_features=max(self.in_features_list), max_out_features=self.out_features, bias=self.bias
)
def forward(self, x):
if self.dropout is not None:
x = self.dropout(x)
return self.linear(x)
@property
def module_str(self):
return 'DyLinear(%d, %d)' % (max(self.in_features_list), self.out_features)
@property
def config(self):
return {
'name': DynamicLinear.__name__,
'in_features_list': self.in_features_list,
'out_features': self.out_features,
'bias': self.bias,
'dropout_rate': self.dropout_rate,
}
@staticmethod
def build_from_config(config):
return DynamicLinearLayer(**config)
def get_active_subnet(self, in_features, preserve_weight=True):
sub_layer = LinearLayer(in_features, self.out_features, self.bias, dropout_rate=self.dropout_rate)
sub_layer = sub_layer.to(get_net_device(self))
if not preserve_weight:
return sub_layer
sub_layer.linear.weight.data.copy_(
self.linear.get_active_weight(self.out_features, in_features).data
)
if self.bias:
sub_layer.linear.bias.data.copy_(
self.linear.get_active_bias(self.out_features).data
)
return sub_layer
def get_active_subnet_config(self, in_features):
return {
'name': LinearLayer.__name__,
'in_features': in_features,
'out_features': self.out_features,
'bias': self.bias,
'dropout_rate': self.dropout_rate,
}
class DynamicMBConvLayer(MyModule):
def __init__(self, in_channel_list, out_channel_list,
kernel_size_list=3, expand_ratio_list=6, stride=1, act_func='relu6', use_se=False):
super(DynamicMBConvLayer, self).__init__()
self.in_channel_list = in_channel_list
self.out_channel_list = out_channel_list
self.kernel_size_list = val2list(kernel_size_list)
self.expand_ratio_list = val2list(expand_ratio_list)
self.stride = stride
self.act_func = act_func
self.use_se = use_se
# build modules
max_middle_channel = make_divisible(
round(max(self.in_channel_list) * max(self.expand_ratio_list)), MyNetwork.CHANNEL_DIVISIBLE)
if max(self.expand_ratio_list) == 1:
self.inverted_bottleneck = None
else:
self.inverted_bottleneck = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max(self.in_channel_list), max_middle_channel)),
('bn', DynamicBatchNorm2d(max_middle_channel)),
('act', build_activation(self.act_func)),
]))
self.depth_conv = nn.Sequential(OrderedDict([
('conv', DynamicSeparableConv2d(max_middle_channel, self.kernel_size_list, self.stride)),
('bn', DynamicBatchNorm2d(max_middle_channel)),
('act', build_activation(self.act_func))
]))
if self.use_se:
self.depth_conv.add_module('se', DynamicSE(max_middle_channel))
self.point_linear = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max_middle_channel, max(self.out_channel_list))),
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
]))
self.active_kernel_size = max(self.kernel_size_list)
self.active_expand_ratio = max(self.expand_ratio_list)
self.active_out_channel = max(self.out_channel_list)
def forward(self, x):
in_channel = x.size(1)
if self.inverted_bottleneck is not None:
self.inverted_bottleneck.conv.active_out_channel = \
make_divisible(round(in_channel * self.active_expand_ratio), MyNetwork.CHANNEL_DIVISIBLE)
self.depth_conv.conv.active_kernel_size = self.active_kernel_size
self.point_linear.conv.active_out_channel = self.active_out_channel
if self.inverted_bottleneck is not None:
x = self.inverted_bottleneck(x)
x = self.depth_conv(x)
x = self.point_linear(x)
return x
@property
def module_str(self):
if self.use_se:
return 'SE(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)
else:
return '(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)
@property
def config(self):
return {
'name': DynamicMBConvLayer.__name__,
'in_channel_list': self.in_channel_list,
'out_channel_list': self.out_channel_list,
'kernel_size_list': self.kernel_size_list,
'expand_ratio_list': self.expand_ratio_list,
'stride': self.stride,
'act_func': self.act_func,
'use_se': self.use_se,
}
@staticmethod
def build_from_config(config):
return DynamicMBConvLayer(**config)
############################################################################################
@property
def in_channels(self):
return max(self.in_channel_list)
@property
def out_channels(self):
return max(self.out_channel_list)
def active_middle_channel(self, in_channel):
return make_divisible(round(in_channel * self.active_expand_ratio), MyNetwork.CHANNEL_DIVISIBLE)
############################################################################################
def get_active_subnet(self, in_channel, preserve_weight=True):
# build the new layer
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
sub_layer = sub_layer.to(get_net_device(self))
if not preserve_weight:
return sub_layer
middle_channel = self.active_middle_channel(in_channel)
# copy weight from current layer
if sub_layer.inverted_bottleneck is not None:
sub_layer.inverted_bottleneck.conv.weight.data.copy_(
self.inverted_bottleneck.conv.get_active_filter(middle_channel, in_channel).data,
)
copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn)
sub_layer.depth_conv.conv.weight.data.copy_(
self.depth_conv.conv.get_active_filter(middle_channel, self.active_kernel_size).data
)
copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn)
if self.use_se:
se_mid = make_divisible(middle_channel // SEModule.REDUCTION, divisor=MyNetwork.CHANNEL_DIVISIBLE)
sub_layer.depth_conv.se.fc.reduce.weight.data.copy_(
self.depth_conv.se.get_active_reduce_weight(se_mid, middle_channel).data
)
sub_layer.depth_conv.se.fc.reduce.bias.data.copy_(
self.depth_conv.se.get_active_reduce_bias(se_mid).data
)
sub_layer.depth_conv.se.fc.expand.weight.data.copy_(
self.depth_conv.se.get_active_expand_weight(se_mid, middle_channel).data
)
sub_layer.depth_conv.se.fc.expand.bias.data.copy_(
self.depth_conv.se.get_active_expand_bias(middle_channel).data
)
sub_layer.point_linear.conv.weight.data.copy_(
self.point_linear.conv.get_active_filter(self.active_out_channel, middle_channel).data
)
copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn)
return sub_layer
def get_active_subnet_config(self, in_channel):
return {
'name': MBConvLayer.__name__,
'in_channels': in_channel,
'out_channels': self.active_out_channel,
'kernel_size': self.active_kernel_size,
'stride': self.stride,
'expand_ratio': self.active_expand_ratio,
'mid_channels': self.active_middle_channel(in_channel),
'act_func': self.act_func,
'use_se': self.use_se,
}
def re_organize_middle_weights(self, expand_ratio_stage=0):
importance = torch.sum(torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3))
if isinstance(self.depth_conv.bn, DynamicGroupNorm):
channel_per_group = self.depth_conv.bn.channel_per_group
importance_chunks = torch.split(importance, channel_per_group)
for chunk in importance_chunks:
chunk.data.fill_(torch.mean(chunk))
importance = torch.cat(importance_chunks, dim=0)
if expand_ratio_stage > 0:
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
sorted_expand_list.sort(reverse=True)
target_width_list = [
make_divisible(round(max(self.in_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE)
for expand in sorted_expand_list
]
right = len(importance)
base = - len(target_width_list) * 1e5
for i in range(expand_ratio_stage + 1):
left = target_width_list[i]
importance[left:right] += base
base += 1e5
right = left
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
self.point_linear.conv.conv.weight.data = torch.index_select(
self.point_linear.conv.conv.weight.data, 1, sorted_idx
)
adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx)
self.depth_conv.conv.conv.weight.data = torch.index_select(
self.depth_conv.conv.conv.weight.data, 0, sorted_idx
)
if self.use_se:
# se expand: output dim 0 reorganize
se_expand = self.depth_conv.se.fc.expand
se_expand.weight.data = torch.index_select(se_expand.weight.data, 0, sorted_idx)
se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx)
# se reduce: input dim 1 reorganize
se_reduce = self.depth_conv.se.fc.reduce
se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 1, sorted_idx)
# middle weight reorganize
se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3))
se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True)
se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx)
se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx)
se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx)
if self.inverted_bottleneck is not None:
adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx)
self.inverted_bottleneck.conv.conv.weight.data = torch.index_select(
self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx
)
return None
else:
return sorted_idx
class DynamicConvLayer(MyModule):
def __init__(self, in_channel_list, out_channel_list, kernel_size=3, stride=1, dilation=1,
use_bn=True, act_func='relu6'):
super(DynamicConvLayer, self).__init__()
self.in_channel_list = in_channel_list
self.out_channel_list = out_channel_list
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.use_bn = use_bn
self.act_func = act_func
self.conv = DynamicConv2d(
max_in_channels=max(self.in_channel_list), max_out_channels=max(self.out_channel_list),
kernel_size=self.kernel_size, stride=self.stride, dilation=self.dilation,
)
if self.use_bn:
self.bn = DynamicBatchNorm2d(max(self.out_channel_list))
self.act = build_activation(self.act_func)
self.active_out_channel = max(self.out_channel_list)
def forward(self, x):
self.conv.active_out_channel = self.active_out_channel
x = self.conv(x)
if self.use_bn:
x = self.bn(x)
x = self.act(x)
return x
@property
def module_str(self):
return 'DyConv(O%d, K%d, S%d)' % (self.active_out_channel, self.kernel_size, self.stride)
@property
def config(self):
return {
'name': DynamicConvLayer.__name__,
'in_channel_list': self.in_channel_list,
'out_channel_list': self.out_channel_list,
'kernel_size': self.kernel_size,
'stride': self.stride,
'dilation': self.dilation,
'use_bn': self.use_bn,
'act_func': self.act_func,
}
@staticmethod
def build_from_config(config):
return DynamicConvLayer(**config)
############################################################################################
@property
def in_channels(self):
return max(self.in_channel_list)
@property
def out_channels(self):
return max(self.out_channel_list)
############################################################################################
def get_active_subnet(self, in_channel, preserve_weight=True):
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
sub_layer = sub_layer.to(get_net_device(self))
if not preserve_weight:
return sub_layer
sub_layer.conv.weight.data.copy_(self.conv.get_active_filter(self.active_out_channel, in_channel).data)
if self.use_bn:
copy_bn(sub_layer.bn, self.bn.bn)
return sub_layer
def get_active_subnet_config(self, in_channel):
return {
'name': ConvLayer.__name__,
'in_channels': in_channel,
'out_channels': self.active_out_channel,
'kernel_size': self.kernel_size,
'stride': self.stride,
'dilation': self.dilation,
'use_bn': self.use_bn,
'act_func': self.act_func,
}
class DynamicResNetBottleneckBlock(MyModule):
def __init__(self, in_channel_list, out_channel_list, expand_ratio_list=0.25,
kernel_size=3, stride=1, act_func='relu', downsample_mode='avgpool_conv'):
super(DynamicResNetBottleneckBlock, self).__init__()
self.in_channel_list = in_channel_list
self.out_channel_list = out_channel_list
self.expand_ratio_list = val2list(expand_ratio_list)
self.kernel_size = kernel_size
self.stride = stride
self.act_func = act_func
self.downsample_mode = downsample_mode
# build modules
max_middle_channel = make_divisible(
round(max(self.out_channel_list) * max(self.expand_ratio_list)), MyNetwork.CHANNEL_DIVISIBLE)
self.conv1 = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max(self.in_channel_list), max_middle_channel)),
('bn', DynamicBatchNorm2d(max_middle_channel)),
('act', build_activation(self.act_func, inplace=True)),
]))
self.conv2 = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max_middle_channel, max_middle_channel, kernel_size, stride)),
('bn', DynamicBatchNorm2d(max_middle_channel)),
('act', build_activation(self.act_func, inplace=True))
]))
self.conv3 = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max_middle_channel, max(self.out_channel_list))),
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
]))
if self.stride == 1 and self.in_channel_list == self.out_channel_list:
self.downsample = IdentityLayer(max(self.in_channel_list), max(self.out_channel_list))
elif self.downsample_mode == 'conv':
self.downsample = nn.Sequential(OrderedDict([
('conv', DynamicConv2d(max(self.in_channel_list), max(self.out_channel_list), stride=stride)),
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
]))
elif self.downsample_mode == 'avgpool_conv':
self.downsample = nn.Sequential(OrderedDict([
('avg_pool', nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0, ceil_mode=True)),
('conv', DynamicConv2d(max(self.in_channel_list), max(self.out_channel_list))),
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
]))
else:
raise NotImplementedError
self.final_act = build_activation(self.act_func, inplace=True)
self.active_expand_ratio = max(self.expand_ratio_list)
self.active_out_channel = max(self.out_channel_list)
def forward(self, x):
feature_dim = self.active_middle_channels
self.conv1.conv.active_out_channel = feature_dim
self.conv2.conv.active_out_channel = feature_dim
self.conv3.conv.active_out_channel = self.active_out_channel
if not isinstance(self.downsample, IdentityLayer):
self.downsample.conv.active_out_channel = self.active_out_channel
residual = self.downsample(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x + residual
x = self.final_act(x)
return x
@property
def module_str(self):
return '(%s, %s)' % (
'%dx%d_BottleneckConv_in->%d->%d_S%d' % (
self.kernel_size, self.kernel_size, self.active_middle_channels, self.active_out_channel, self.stride
),
'Identity' if isinstance(self.downsample, IdentityLayer) else self.downsample_mode,
)
@property
def config(self):
return {
'name': DynamicResNetBottleneckBlock.__name__,
'in_channel_list': self.in_channel_list,
'out_channel_list': self.out_channel_list,
'expand_ratio_list': self.expand_ratio_list,
'kernel_size': self.kernel_size,
'stride': self.stride,
'act_func': self.act_func,
'downsample_mode': self.downsample_mode,
}
@staticmethod
def build_from_config(config):
return DynamicResNetBottleneckBlock(**config)
############################################################################################
@property
def in_channels(self):
return max(self.in_channel_list)
@property
def out_channels(self):
return max(self.out_channel_list)
@property
def active_middle_channels(self):
feature_dim = round(self.active_out_channel * self.active_expand_ratio)
feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE)
return feature_dim
############################################################################################
def get_active_subnet(self, in_channel, preserve_weight=True):
# build the new layer
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
sub_layer = sub_layer.to(get_net_device(self))
if not preserve_weight:
return sub_layer
# copy weight from current layer
sub_layer.conv1.conv.weight.data.copy_(
self.conv1.conv.get_active_filter(self.active_middle_channels, in_channel).data)
copy_bn(sub_layer.conv1.bn, self.conv1.bn.bn)
sub_layer.conv2.conv.weight.data.copy_(
self.conv2.conv.get_active_filter(self.active_middle_channels, self.active_middle_channels).data)
copy_bn(sub_layer.conv2.bn, self.conv2.bn.bn)
sub_layer.conv3.conv.weight.data.copy_(
self.conv3.conv.get_active_filter(self.active_out_channel, self.active_middle_channels).data)
copy_bn(sub_layer.conv3.bn, self.conv3.bn.bn)
if not isinstance(self.downsample, IdentityLayer):
sub_layer.downsample.conv.weight.data.copy_(
self.downsample.conv.get_active_filter(self.active_out_channel, in_channel).data)
copy_bn(sub_layer.downsample.bn, self.downsample.bn.bn)
return sub_layer
def get_active_subnet_config(self, in_channel):
return {
'name': ResNetBottleneckBlock.__name__,
'in_channels': in_channel,
'out_channels': self.active_out_channel,
'kernel_size': self.kernel_size,
'stride': self.stride,
'expand_ratio': self.active_expand_ratio,
'mid_channels': self.active_middle_channels,
'act_func': self.act_func,
'groups': 1,
'downsample_mode': self.downsample_mode,
}
def re_organize_middle_weights(self, expand_ratio_stage=0):
# conv3 -> conv2
importance = torch.sum(torch.abs(self.conv3.conv.conv.weight.data), dim=(0, 2, 3))
if isinstance(self.conv2.bn, DynamicGroupNorm):
channel_per_group = self.conv2.bn.channel_per_group
importance_chunks = torch.split(importance, channel_per_group)
for chunk in importance_chunks:
chunk.data.fill_(torch.mean(chunk))
importance = torch.cat(importance_chunks, dim=0)
if expand_ratio_stage > 0:
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
sorted_expand_list.sort(reverse=True)
target_width_list = [
make_divisible(round(max(self.out_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE)
for expand in sorted_expand_list
]
right = len(importance)
base = - len(target_width_list) * 1e5
for i in range(expand_ratio_stage + 1):
left = target_width_list[i]
importance[left:right] += base
base += 1e5
right = left
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
self.conv3.conv.conv.weight.data = torch.index_select(self.conv3.conv.conv.weight.data, 1, sorted_idx)
adjust_bn_according_to_idx(self.conv2.bn.bn, sorted_idx)
self.conv2.conv.conv.weight.data = torch.index_select(self.conv2.conv.conv.weight.data, 0, sorted_idx)
# conv2 -> conv1
importance = torch.sum(torch.abs(self.conv2.conv.conv.weight.data), dim=(0, 2, 3))
if isinstance(self.conv1.bn, DynamicGroupNorm):
channel_per_group = self.conv1.bn.channel_per_group
importance_chunks = torch.split(importance, channel_per_group)
for chunk in importance_chunks:
chunk.data.fill_(torch.mean(chunk))
importance = torch.cat(importance_chunks, dim=0)
if expand_ratio_stage > 0:
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
sorted_expand_list.sort(reverse=True)
target_width_list = [
make_divisible(round(max(self.out_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE)
for expand in sorted_expand_list
]
right = len(importance)
base = - len(target_width_list) * 1e5
for i in range(expand_ratio_stage + 1):
left = target_width_list[i]
importance[left:right] += base
base += 1e5
right = left
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
self.conv2.conv.conv.weight.data = torch.index_select(self.conv2.conv.conv.weight.data, 1, sorted_idx)
adjust_bn_according_to_idx(self.conv1.bn.bn, sorted_idx)
self.conv1.conv.conv.weight.data = torch.index_select(self.conv1.conv.conv.weight.data, 0, sorted_idx)
return None

View File

@@ -0,0 +1,314 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import torch.nn.functional as F
import torch.nn as nn
import torch
from torch.nn.parameter import Parameter
from ofa_local.utils import get_same_padding, sub_filter_start_end, make_divisible, SEModule, MyNetwork, MyConv2d
__all__ = ['DynamicSeparableConv2d', 'DynamicConv2d', 'DynamicGroupConv2d',
'DynamicBatchNorm2d', 'DynamicGroupNorm', 'DynamicSE', 'DynamicLinear']
class DynamicSeparableConv2d(nn.Module):
KERNEL_TRANSFORM_MODE = 1 # None or 1
def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1):
super(DynamicSeparableConv2d, self).__init__()
self.max_in_channels = max_in_channels
self.kernel_size_list = kernel_size_list
self.stride = stride
self.dilation = dilation
self.conv = nn.Conv2d(
self.max_in_channels, self.max_in_channels, max(self.kernel_size_list), self.stride,
groups=self.max_in_channels, bias=False,
)
self._ks_set = list(set(self.kernel_size_list))
self._ks_set.sort() # e.g., [3, 5, 7]
if self.KERNEL_TRANSFORM_MODE is not None:
# register scaling parameters
# 7to5_matrix, 5to3_matrix
scale_params = {}
for i in range(len(self._ks_set) - 1):
ks_small = self._ks_set[i]
ks_larger = self._ks_set[i + 1]
param_name = '%dto%d' % (ks_larger, ks_small)
# noinspection PyArgumentList
scale_params['%s_matrix' % param_name] = Parameter(torch.eye(ks_small ** 2))
for name, param in scale_params.items():
self.register_parameter(name, param)
self.active_kernel_size = max(self.kernel_size_list)
def get_active_filter(self, in_channel, kernel_size):
out_channel = in_channel
max_kernel_size = max(self.kernel_size_list)
start, end = sub_filter_start_end(max_kernel_size, kernel_size)
filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end]
if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size:
start_filter = self.conv.weight[:out_channel, :in_channel, :, :] # start with max kernel
for i in range(len(self._ks_set) - 1, 0, -1):
src_ks = self._ks_set[i]
if src_ks <= kernel_size:
break
target_ks = self._ks_set[i - 1]
start, end = sub_filter_start_end(src_ks, target_ks)
_input_filter = start_filter[:, :, start:end, start:end]
_input_filter = _input_filter.contiguous()
_input_filter = _input_filter.view(_input_filter.size(0), _input_filter.size(1), -1)
_input_filter = _input_filter.view(-1, _input_filter.size(2))
_input_filter = F.linear(
_input_filter, self.__getattr__('%dto%d_matrix' % (src_ks, target_ks)),
)
_input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks ** 2)
_input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks, target_ks)
start_filter = _input_filter
filters = start_filter
return filters
def forward(self, x, kernel_size=None):
if kernel_size is None:
kernel_size = self.active_kernel_size
in_channel = x.size(1)
filters = self.get_active_filter(in_channel, kernel_size).contiguous()
padding = get_same_padding(kernel_size)
filters = self.conv.weight_standardization(filters) if isinstance(self.conv, MyConv2d) else filters
y = F.conv2d(
x, filters, None, self.stride, padding, self.dilation, in_channel
)
return y
class DynamicConv2d(nn.Module):
def __init__(self, max_in_channels, max_out_channels, kernel_size=1, stride=1, dilation=1):
super(DynamicConv2d, self).__init__()
self.max_in_channels = max_in_channels
self.max_out_channels = max_out_channels
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.conv = nn.Conv2d(
self.max_in_channels, self.max_out_channels, self.kernel_size, stride=self.stride, bias=False,
)
self.active_out_channel = self.max_out_channels
def get_active_filter(self, out_channel, in_channel):
return self.conv.weight[:out_channel, :in_channel, :, :]
def forward(self, x, out_channel=None):
if out_channel is None:
out_channel = self.active_out_channel
in_channel = x.size(1)
filters = self.get_active_filter(out_channel, in_channel).contiguous()
padding = get_same_padding(self.kernel_size)
filters = self.conv.weight_standardization(filters) if isinstance(self.conv, MyConv2d) else filters
y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, 1)
return y
class DynamicGroupConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size_list, groups_list, stride=1, dilation=1):
super(DynamicGroupConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size_list = kernel_size_list
self.groups_list = groups_list
self.stride = stride
self.dilation = dilation
self.conv = nn.Conv2d(
self.in_channels, self.out_channels, max(self.kernel_size_list), self.stride,
groups=min(self.groups_list), bias=False,
)
self.active_kernel_size = max(self.kernel_size_list)
self.active_groups = min(self.groups_list)
def get_active_filter(self, kernel_size, groups):
start, end = sub_filter_start_end(max(self.kernel_size_list), kernel_size)
filters = self.conv.weight[:, :, start:end, start:end]
sub_filters = torch.chunk(filters, groups, dim=0)
sub_in_channels = self.in_channels // groups
sub_ratio = filters.size(1) // sub_in_channels
filter_crops = []
for i, sub_filter in enumerate(sub_filters):
part_id = i % sub_ratio
start = part_id * sub_in_channels
filter_crops.append(sub_filter[:, start:start + sub_in_channels, :, :])
filters = torch.cat(filter_crops, dim=0)
return filters
def forward(self, x, kernel_size=None, groups=None):
if kernel_size is None:
kernel_size = self.active_kernel_size
if groups is None:
groups = self.active_groups
filters = self.get_active_filter(kernel_size, groups).contiguous()
padding = get_same_padding(kernel_size)
filters = self.conv.weight_standardization(filters) if isinstance(self.conv, MyConv2d) else filters
y = F.conv2d(
x, filters, None, self.stride, padding, self.dilation, groups,
)
return y
class DynamicBatchNorm2d(nn.Module):
SET_RUNNING_STATISTICS = False
def __init__(self, max_feature_dim):
super(DynamicBatchNorm2d, self).__init__()
self.max_feature_dim = max_feature_dim
self.bn = nn.BatchNorm2d(self.max_feature_dim)
@staticmethod
def bn_forward(x, bn: nn.BatchNorm2d, feature_dim):
if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS:
return bn(x)
else:
exponential_average_factor = 0.0
if bn.training and bn.track_running_stats:
if bn.num_batches_tracked is not None:
bn.num_batches_tracked += 1
if bn.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(bn.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = bn.momentum
return F.batch_norm(
x, bn.running_mean[:feature_dim], bn.running_var[:feature_dim], bn.weight[:feature_dim],
bn.bias[:feature_dim], bn.training or not bn.track_running_stats,
exponential_average_factor, bn.eps,
)
def forward(self, x):
feature_dim = x.size(1)
y = self.bn_forward(x, self.bn, feature_dim)
return y
class DynamicGroupNorm(nn.GroupNorm):
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, channel_per_group=None):
super(DynamicGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
self.channel_per_group = channel_per_group
def forward(self, x):
n_channels = x.size(1)
n_groups = n_channels // self.channel_per_group
return F.group_norm(x, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps)
@property
def bn(self):
return self
class DynamicSE(SEModule):
def __init__(self, max_channel):
super(DynamicSE, self).__init__(max_channel)
def get_active_reduce_weight(self, num_mid, in_channel, groups=None):
if groups is None or groups == 1:
return self.fc.reduce.weight[:num_mid, :in_channel, :, :]
else:
assert in_channel % groups == 0
sub_in_channels = in_channel // groups
sub_filters = torch.chunk(self.fc.reduce.weight[:num_mid, :, :, :], groups, dim=1)
return torch.cat([
sub_filter[:, :sub_in_channels, :, :] for sub_filter in sub_filters
], dim=1)
def get_active_reduce_bias(self, num_mid):
return self.fc.reduce.bias[:num_mid] if self.fc.reduce.bias is not None else None
def get_active_expand_weight(self, num_mid, in_channel, groups=None):
if groups is None or groups == 1:
return self.fc.expand.weight[:in_channel, :num_mid, :, :]
else:
assert in_channel % groups == 0
sub_in_channels = in_channel // groups
sub_filters = torch.chunk(self.fc.expand.weight[:, :num_mid, :, :], groups, dim=0)
return torch.cat([
sub_filter[:sub_in_channels, :, :, :] for sub_filter in sub_filters
], dim=0)
def get_active_expand_bias(self, in_channel, groups=None):
if groups is None or groups == 1:
return self.fc.expand.bias[:in_channel] if self.fc.expand.bias is not None else None
else:
assert in_channel % groups == 0
sub_in_channels = in_channel // groups
sub_bias_list = torch.chunk(self.fc.expand.bias, groups, dim=0)
return torch.cat([
sub_bias[:sub_in_channels] for sub_bias in sub_bias_list
], dim=0)
def forward(self, x, groups=None):
in_channel = x.size(1)
num_mid = make_divisible(in_channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE)
y = x.mean(3, keepdim=True).mean(2, keepdim=True)
# reduce
reduce_filter = self.get_active_reduce_weight(num_mid, in_channel, groups=groups).contiguous()
reduce_bias = self.get_active_reduce_bias(num_mid)
y = F.conv2d(y, reduce_filter, reduce_bias, 1, 0, 1, 1)
# relu
y = self.fc.relu(y)
# expand
expand_filter = self.get_active_expand_weight(num_mid, in_channel, groups=groups).contiguous()
expand_bias = self.get_active_expand_bias(in_channel, groups=groups)
y = F.conv2d(y, expand_filter, expand_bias, 1, 0, 1, 1)
# hard sigmoid
y = self.fc.h_sigmoid(y)
return x * y
class DynamicLinear(nn.Module):
def __init__(self, max_in_features, max_out_features, bias=True):
super(DynamicLinear, self).__init__()
self.max_in_features = max_in_features
self.max_out_features = max_out_features
self.bias = bias
self.linear = nn.Linear(self.max_in_features, self.max_out_features, self.bias)
self.active_out_features = self.max_out_features
def get_active_weight(self, out_features, in_features):
return self.linear.weight[:out_features, :in_features]
def get_active_bias(self, out_features):
return self.linear.bias[:out_features] if self.bias else None
def forward(self, x, out_features=None):
if out_features is None:
out_features = self.active_out_features
in_features = x.size(1)
weight = self.get_active_weight(out_features, in_features).contiguous()
bias = self.get_active_bias(out_features)
y = F.linear(x, weight, bias)
return y

View File

@@ -0,0 +1,7 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .ofa_proxyless import OFAProxylessNASNets
from .ofa_mbv3 import OFAMobileNetV3
from .ofa_resnets import OFAResNets

View File

@@ -0,0 +1,336 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import copy
import random
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_layers import DynamicMBConvLayer
from ofa_local.utils.layers import ConvLayer, IdentityLayer, LinearLayer, MBConvLayer, ResidualBlock
from ofa_local.imagenet_classification.networks import MobileNetV3
from ofa_local.utils import make_divisible, val2list, MyNetwork
from ofa_local.utils.layers import set_layer_from_config
import gin
__all__ = ['OFAMobileNetV3']
@gin.configurable
class OFAMobileNetV3(MobileNetV3):
def __init__(self, n_classes=1000, bn_param=(0.1, 1e-5), dropout_rate=0.1, base_stage_width=None, width_mult=1.0,
ks_list=3, expand_ratio_list=6, depth_list=4, dropblock=False, block_size=0):
self.width_mult = width_mult
self.ks_list = val2list(ks_list, 1)
self.expand_ratio_list = val2list(expand_ratio_list, 1)
self.depth_list = val2list(depth_list, 1)
self.ks_list.sort()
self.expand_ratio_list.sort()
self.depth_list.sort()
base_stage_width = [16, 16, 24, 40, 80, 112, 160, 960, 1280]
final_expand_width = make_divisible(base_stage_width[-2] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
last_channel = make_divisible(base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
stride_stages = [1, 2, 2, 2, 1, 2]
act_stages = ['relu', 'relu', 'relu', 'h_swish', 'h_swish', 'h_swish']
se_stages = [False, False, True, False, True, True]
n_block_list = [1] + [max(self.depth_list)] * 5
width_list = []
for base_width in base_stage_width[:-2]:
width = make_divisible(base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
width_list.append(width)
input_channel, first_block_dim = width_list[0], width_list[1]
# first conv layer
first_conv = ConvLayer(3, input_channel, kernel_size=3, stride=2, act_func='h_swish')
first_block_conv = MBConvLayer(
in_channels=input_channel, out_channels=first_block_dim, kernel_size=3, stride=stride_stages[0],
expand_ratio=1, act_func=act_stages[0], use_se=se_stages[0],
)
first_block = ResidualBlock(
first_block_conv,
IdentityLayer(first_block_dim, first_block_dim) if input_channel == first_block_dim else None,
dropout_rate, dropblock, block_size
)
# inverted residual blocks
self.block_group_info = []
blocks = [first_block]
_block_index = 1
feature_dim = first_block_dim
for width, n_block, s, act_func, use_se in zip(width_list[2:], n_block_list[1:],
stride_stages[1:], act_stages[1:], se_stages[1:]):
self.block_group_info.append([_block_index + i for i in range(n_block)])
_block_index += n_block
output_channel = width
for i in range(n_block):
if i == 0:
stride = s
else:
stride = 1
mobile_inverted_conv = DynamicMBConvLayer(
in_channel_list=val2list(feature_dim), out_channel_list=val2list(output_channel),
kernel_size_list=ks_list, expand_ratio_list=expand_ratio_list,
stride=stride, act_func=act_func, use_se=use_se,
)
if stride == 1 and feature_dim == output_channel:
shortcut = IdentityLayer(feature_dim, feature_dim)
else:
shortcut = None
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut,
dropout_rate, dropblock, block_size))
feature_dim = output_channel
# final expand layer, feature mix layer & classifier
final_expand_layer = ConvLayer(feature_dim, final_expand_width, kernel_size=1, act_func='h_swish')
feature_mix_layer = ConvLayer(
final_expand_width, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
)
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
super(OFAMobileNetV3, self).__init__(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
# set bn param
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
# runtime_depth
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
""" MyNetwork required methods """
@staticmethod
def name():
return 'OFAMobileNetV3'
def forward(self, x):
# first conv
x = self.first_conv(x)
# first block
x = self.blocks[0](x)
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
x = self.blocks[idx](x)
x = self.final_expand_layer(x)
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
x = self.feature_mix_layer(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
_str += self.blocks[0].module_str + '\n'
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
_str += self.blocks[idx].module_str + '\n'
_str += self.final_expand_layer.module_str + '\n'
_str += self.feature_mix_layer.module_str + '\n'
_str += self.classifier.module_str + '\n'
return _str
@property
def config(self):
return {
'name': OFAMobileNetV3.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
'final_expand_layer': self.final_expand_layer.config,
'feature_mix_layer': self.feature_mix_layer.config,
'classifier': self.classifier.config,
}
@staticmethod
def build_from_config(config):
raise ValueError('do not support this function')
@property
def grouped_block_index(self):
return self.block_group_info
def load_state_dict(self, state_dict, **kwargs):
model_dict = self.state_dict()
for key in state_dict:
if '.mobile_inverted_conv.' in key:
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
else:
new_key = key
if new_key in model_dict:
pass
elif '.bn.bn.' in new_key:
new_key = new_key.replace('.bn.bn.', '.bn.')
elif '.conv.conv.weight' in new_key:
new_key = new_key.replace('.conv.conv.weight', '.conv.weight')
elif '.linear.linear.' in new_key:
new_key = new_key.replace('.linear.linear.', '.linear.')
##############################################################################
elif '.linear.' in new_key:
new_key = new_key.replace('.linear.', '.linear.linear.')
elif 'bn.' in new_key:
new_key = new_key.replace('bn.', 'bn.bn.')
elif 'conv.weight' in new_key:
new_key = new_key.replace('conv.weight', 'conv.conv.weight')
else:
raise ValueError(new_key)
assert new_key in model_dict, '%s' % new_key
model_dict[new_key] = state_dict[key]
super(OFAMobileNetV3, self).load_state_dict(model_dict)
""" set, sample and get active sub-networks """
def set_max_net(self):
self.set_active_subnet(ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list))
def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
ks = val2list(ks, len(self.blocks) - 1)
expand_ratio = val2list(e, len(self.blocks) - 1)
depth = val2list(d, len(self.block_group_info))
for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
if k is not None:
block.conv.active_kernel_size = k
if e is not None:
block.conv.active_expand_ratio = e
for i, d in enumerate(depth):
if d is not None:
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
def set_constraint(self, include_list, constraint_type='depth'):
if constraint_type == 'depth':
self.__dict__['_depth_include_list'] = include_list.copy()
elif constraint_type == 'expand_ratio':
self.__dict__['_expand_include_list'] = include_list.copy()
elif constraint_type == 'kernel_size':
self.__dict__['_ks_include_list'] = include_list.copy()
else:
raise NotImplementedError
def clear_constraint(self):
self.__dict__['_depth_include_list'] = None
self.__dict__['_expand_include_list'] = None
self.__dict__['_ks_include_list'] = None
def sample_active_subnet(self):
ks_candidates = self.ks_list if self.__dict__.get('_ks_include_list', None) is None \
else self.__dict__['_ks_include_list']
expand_candidates = self.expand_ratio_list if self.__dict__.get('_expand_include_list', None) is None \
else self.__dict__['_expand_include_list']
depth_candidates = self.depth_list if self.__dict__.get('_depth_include_list', None) is None else \
self.__dict__['_depth_include_list']
# sample kernel size
ks_setting = []
if not isinstance(ks_candidates[0], list):
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
for k_set in ks_candidates:
k = random.choice(k_set)
ks_setting.append(k)
# sample expand ratio
expand_setting = []
if not isinstance(expand_candidates[0], list):
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
for e_set in expand_candidates:
e = random.choice(e_set)
expand_setting.append(e)
# sample depth
depth_setting = []
if not isinstance(depth_candidates[0], list):
depth_candidates = [depth_candidates for _ in range(len(self.block_group_info))]
for d_set in depth_candidates:
d = random.choice(d_set)
depth_setting.append(d)
import pdb; pdb.set_trace()
self.set_active_subnet(ks_setting, expand_setting, depth_setting)
return {
'ks': ks_setting,
'e': expand_setting,
'd': depth_setting,
}
def get_active_subnet(self, preserve_weight=True):
first_conv = copy.deepcopy(self.first_conv)
blocks = [copy.deepcopy(self.blocks[0])]
final_expand_layer = copy.deepcopy(self.final_expand_layer)
feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
classifier = copy.deepcopy(self.classifier)
input_channel = blocks[0].conv.out_channels
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append(ResidualBlock(
self.blocks[idx].conv.get_active_subnet(input_channel, preserve_weight),
copy.deepcopy(self.blocks[idx].shortcut),
copy.deepcopy(self.blocks[idx].dropout_rate),
copy.deepcopy(self.blocks[idx].dropblock),
copy.deepcopy(self.blocks[idx].block_size),
))
input_channel = stage_blocks[-1].conv.out_channels
blocks += stage_blocks
_subnet = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
_subnet.set_bn_param(**self.get_bn_param())
return _subnet
def get_active_net_config(self):
# first conv
first_conv_config = self.first_conv.config
first_block_config = self.blocks[0].config
final_expand_config = self.final_expand_layer.config
feature_mix_layer_config = self.feature_mix_layer.config
classifier_config = self.classifier.config
block_config_list = [first_block_config]
input_channel = first_block_config['conv']['out_channels']
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append({
'name': ResidualBlock.__name__,
'conv': self.blocks[idx].conv.get_active_subnet_config(input_channel),
'shortcut': self.blocks[idx].shortcut.config if self.blocks[idx].shortcut is not None else None,
})
input_channel = self.blocks[idx].conv.active_out_channel
block_config_list += stage_blocks
return {
'name': MobileNetV3.__name__,
'bn': self.get_bn_param(),
'first_conv': first_conv_config,
'blocks': block_config_list,
'final_expand_layer': final_expand_config,
'feature_mix_layer': feature_mix_layer_config,
'classifier': classifier_config,
}
""" Width Related Methods """
def re_organize_middle_weights(self, expand_ratio_stage=0):
for block in self.blocks[1:]:
block.conv.re_organize_middle_weights(expand_ratio_stage)

View File

@@ -0,0 +1,331 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import copy
import random
from ofa_local.utils import make_divisible, val2list, MyNetwork
from ofa_local.imagenet_classification.elastic_nn.modules import DynamicMBConvLayer
from ofa_local.utils.layers import ConvLayer, IdentityLayer, LinearLayer, MBConvLayer, ResidualBlock
from ofa_local.imagenet_classification.networks.proxyless_nets import ProxylessNASNets
__all__ = ['OFAProxylessNASNets']
class OFAProxylessNASNets(ProxylessNASNets):
def __init__(self, n_classes=1000, bn_param=(0.1, 1e-3), dropout_rate=0.1, base_stage_width=None, width_mult=1.0,
ks_list=3, expand_ratio_list=6, depth_list=4):
self.width_mult = width_mult
self.ks_list = val2list(ks_list, 1)
self.expand_ratio_list = val2list(expand_ratio_list, 1)
self.depth_list = val2list(depth_list, 1)
self.ks_list.sort()
self.expand_ratio_list.sort()
self.depth_list.sort()
if base_stage_width == 'google':
# MobileNetV2 Stage Width
base_stage_width = [32, 16, 24, 32, 64, 96, 160, 320, 1280]
else:
# ProxylessNAS Stage Width
base_stage_width = [32, 16, 24, 40, 80, 96, 192, 320, 1280]
input_channel = make_divisible(base_stage_width[0] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
first_block_width = make_divisible(base_stage_width[1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
last_channel = make_divisible(base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
# first conv layer
first_conv = ConvLayer(
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='relu6', ops_order='weight_bn_act'
)
# first block
first_block_conv = MBConvLayer(
in_channels=input_channel, out_channels=first_block_width, kernel_size=3, stride=1,
expand_ratio=1, act_func='relu6',
)
first_block = ResidualBlock(first_block_conv, None)
input_channel = first_block_width
# inverted residual blocks
self.block_group_info = []
blocks = [first_block]
_block_index = 1
stride_stages = [2, 2, 2, 1, 2, 1]
n_block_list = [max(self.depth_list)] * 5 + [1]
width_list = []
for base_width in base_stage_width[2:-1]:
width = make_divisible(base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
width_list.append(width)
for width, n_block, s in zip(width_list, n_block_list, stride_stages):
self.block_group_info.append([_block_index + i for i in range(n_block)])
_block_index += n_block
output_channel = width
for i in range(n_block):
if i == 0:
stride = s
else:
stride = 1
mobile_inverted_conv = DynamicMBConvLayer(
in_channel_list=val2list(input_channel, 1), out_channel_list=val2list(output_channel, 1),
kernel_size_list=ks_list, expand_ratio_list=expand_ratio_list, stride=stride, act_func='relu6',
)
if stride == 1 and input_channel == output_channel:
shortcut = IdentityLayer(input_channel, input_channel)
else:
shortcut = None
mb_inverted_block = ResidualBlock(mobile_inverted_conv, shortcut)
blocks.append(mb_inverted_block)
input_channel = output_channel
# 1x1_conv before global average pooling
feature_mix_layer = ConvLayer(
input_channel, last_channel, kernel_size=1, use_bn=True, act_func='relu6',
)
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
super(OFAProxylessNASNets, self).__init__(first_conv, blocks, feature_mix_layer, classifier)
# set bn param
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
# runtime_depth
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
""" MyNetwork required methods """
@staticmethod
def name():
return 'OFAProxylessNASNets'
def forward(self, x):
# first conv
x = self.first_conv(x)
# first block
x = self.blocks[0](x)
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
x = self.blocks[idx](x)
# feature_mix_layer
x = self.feature_mix_layer(x)
x = x.mean(3).mean(2)
x = self.classifier(x)
return x
@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
_str += self.blocks[0].module_str + '\n'
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
for idx in active_idx:
_str += self.blocks[idx].module_str + '\n'
_str += self.feature_mix_layer.module_str + '\n'
_str += self.classifier.module_str + '\n'
return _str
@property
def config(self):
return {
'name': OFAProxylessNASNets.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
'feature_mix_layer': None if self.feature_mix_layer is None else self.feature_mix_layer.config,
'classifier': self.classifier.config,
}
@staticmethod
def build_from_config(config):
raise ValueError('do not support this function')
@property
def grouped_block_index(self):
return self.block_group_info
def load_state_dict(self, state_dict, **kwargs):
model_dict = self.state_dict()
for key in state_dict:
if '.mobile_inverted_conv.' in key:
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
else:
new_key = key
if new_key in model_dict:
pass
elif '.bn.bn.' in new_key:
new_key = new_key.replace('.bn.bn.', '.bn.')
elif '.conv.conv.weight' in new_key:
new_key = new_key.replace('.conv.conv.weight', '.conv.weight')
elif '.linear.linear.' in new_key:
new_key = new_key.replace('.linear.linear.', '.linear.')
##############################################################################
elif '.linear.' in new_key:
new_key = new_key.replace('.linear.', '.linear.linear.')
elif 'bn.' in new_key:
new_key = new_key.replace('bn.', 'bn.bn.')
elif 'conv.weight' in new_key:
new_key = new_key.replace('conv.weight', 'conv.conv.weight')
else:
raise ValueError(new_key)
assert new_key in model_dict, '%s' % new_key
model_dict[new_key] = state_dict[key]
super(OFAProxylessNASNets, self).load_state_dict(model_dict)
""" set, sample and get active sub-networks """
def set_max_net(self):
self.set_active_subnet(ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list))
def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
ks = val2list(ks, len(self.blocks) - 1)
expand_ratio = val2list(e, len(self.blocks) - 1)
depth = val2list(d, len(self.block_group_info))
for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
if k is not None:
block.conv.active_kernel_size = k
if e is not None:
block.conv.active_expand_ratio = e
for i, d in enumerate(depth):
if d is not None:
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
def set_constraint(self, include_list, constraint_type='depth'):
if constraint_type == 'depth':
self.__dict__['_depth_include_list'] = include_list.copy()
elif constraint_type == 'expand_ratio':
self.__dict__['_expand_include_list'] = include_list.copy()
elif constraint_type == 'kernel_size':
self.__dict__['_ks_include_list'] = include_list.copy()
else:
raise NotImplementedError
def clear_constraint(self):
self.__dict__['_depth_include_list'] = None
self.__dict__['_expand_include_list'] = None
self.__dict__['_ks_include_list'] = None
def sample_active_subnet(self):
ks_candidates = self.ks_list if self.__dict__.get('_ks_include_list', None) is None \
else self.__dict__['_ks_include_list']
expand_candidates = self.expand_ratio_list if self.__dict__.get('_expand_include_list', None) is None \
else self.__dict__['_expand_include_list']
depth_candidates = self.depth_list if self.__dict__.get('_depth_include_list', None) is None else \
self.__dict__['_depth_include_list']
# sample kernel size
ks_setting = []
if not isinstance(ks_candidates[0], list):
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
for k_set in ks_candidates:
k = random.choice(k_set)
ks_setting.append(k)
# sample expand ratio
expand_setting = []
if not isinstance(expand_candidates[0], list):
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
for e_set in expand_candidates:
e = random.choice(e_set)
expand_setting.append(e)
# sample depth
depth_setting = []
if not isinstance(depth_candidates[0], list):
depth_candidates = [depth_candidates for _ in range(len(self.block_group_info))]
for d_set in depth_candidates:
d = random.choice(d_set)
depth_setting.append(d)
depth_setting[-1] = 1
self.set_active_subnet(ks_setting, expand_setting, depth_setting)
return {
'ks': ks_setting,
'e': expand_setting,
'd': depth_setting,
}
def get_active_subnet(self, preserve_weight=True):
first_conv = copy.deepcopy(self.first_conv)
blocks = [copy.deepcopy(self.blocks[0])]
feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
classifier = copy.deepcopy(self.classifier)
input_channel = blocks[0].conv.out_channels
# blocks
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append(ResidualBlock(
self.blocks[idx].conv.get_active_subnet(input_channel, preserve_weight),
copy.deepcopy(self.blocks[idx].shortcut)
))
input_channel = stage_blocks[-1].conv.out_channels
blocks += stage_blocks
_subnet = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier)
_subnet.set_bn_param(**self.get_bn_param())
return _subnet
def get_active_net_config(self):
first_conv_config = self.first_conv.config
first_block_config = self.blocks[0].config
feature_mix_layer_config = self.feature_mix_layer.config
classifier_config = self.classifier.config
block_config_list = [first_block_config]
input_channel = first_block_config['conv']['out_channels']
for stage_id, block_idx in enumerate(self.block_group_info):
depth = self.runtime_depth[stage_id]
active_idx = block_idx[:depth]
stage_blocks = []
for idx in active_idx:
stage_blocks.append({
'name': ResidualBlock.__name__,
'conv': self.blocks[idx].conv.get_active_subnet_config(input_channel),
'shortcut': self.blocks[idx].shortcut.config if self.blocks[idx].shortcut is not None else None,
})
try:
input_channel = self.blocks[idx].conv.active_out_channel
except Exception:
input_channel = self.blocks[idx].conv.out_channels
block_config_list += stage_blocks
return {
'name': ProxylessNASNets.__name__,
'bn': self.get_bn_param(),
'first_conv': first_conv_config,
'blocks': block_config_list,
'feature_mix_layer': feature_mix_layer_config,
'classifier': classifier_config,
}
""" Width Related Methods """
def re_organize_middle_weights(self, expand_ratio_stage=0):
for block in self.blocks[1:]:
block.conv.re_organize_middle_weights(expand_ratio_stage)

View File

@@ -0,0 +1,267 @@
import random
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_layers import DynamicConvLayer, DynamicLinearLayer
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_layers import DynamicResNetBottleneckBlock
from ofa_local.utils.layers import IdentityLayer, ResidualBlock
from ofa_local.imagenet_classification.networks import ResNets
from ofa_local.utils import make_divisible, val2list, MyNetwork
__all__ = ['OFAResNets']
class OFAResNets(ResNets):
def __init__(self, n_classes=1000, bn_param=(0.1, 1e-5), dropout_rate=0,
depth_list=2, expand_ratio_list=0.25, width_mult_list=1.0):
self.depth_list = val2list(depth_list)
self.expand_ratio_list = val2list(expand_ratio_list)
self.width_mult_list = val2list(width_mult_list)
# sort
self.depth_list.sort()
self.expand_ratio_list.sort()
self.width_mult_list.sort()
input_channel = [
make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) for width_mult in self.width_mult_list
]
mid_input_channel = [
make_divisible(channel // 2, MyNetwork.CHANNEL_DIVISIBLE) for channel in input_channel
]
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
for i, width in enumerate(stage_width_list):
stage_width_list[i] = [
make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE) for width_mult in self.width_mult_list
]
n_block_list = [base_depth + max(self.depth_list) for base_depth in ResNets.BASE_DEPTH_LIST]
stride_list = [1, 2, 2, 2]
# build input stem
input_stem = [
DynamicConvLayer(val2list(3), mid_input_channel, 3, stride=2, use_bn=True, act_func='relu'),
ResidualBlock(
DynamicConvLayer(mid_input_channel, mid_input_channel, 3, stride=1, use_bn=True, act_func='relu'),
IdentityLayer(mid_input_channel, mid_input_channel)
),
DynamicConvLayer(mid_input_channel, input_channel, 3, stride=1, use_bn=True, act_func='relu')
]
# blocks
blocks = []
for d, width, s in zip(n_block_list, stage_width_list, stride_list):
for i in range(d):
stride = s if i == 0 else 1
bottleneck_block = DynamicResNetBottleneckBlock(
input_channel, width, expand_ratio_list=self.expand_ratio_list,
kernel_size=3, stride=stride, act_func='relu', downsample_mode='avgpool_conv',
)
blocks.append(bottleneck_block)
input_channel = width
# classifier
classifier = DynamicLinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
super(OFAResNets, self).__init__(input_stem, blocks, classifier)
# set bn param
self.set_bn_param(*bn_param)
# runtime_depth
self.input_stem_skipping = 0
self.runtime_depth = [0] * len(n_block_list)
@property
def ks_list(self):
return [3]
@staticmethod
def name():
return 'OFAResNets'
def forward(self, x):
for layer in self.input_stem:
if self.input_stem_skipping > 0 \
and isinstance(layer, ResidualBlock) and isinstance(layer.shortcut, IdentityLayer):
pass
else:
x = layer(x)
x = self.max_pooling(x)
for stage_id, block_idx in enumerate(self.grouped_block_index):
depth_param = self.runtime_depth[stage_id]
active_idx = block_idx[:len(block_idx) - depth_param]
for idx in active_idx:
x = self.blocks[idx](x)
x = self.global_avg_pool(x)
x = self.classifier(x)
return x
@property
def module_str(self):
_str = ''
for layer in self.input_stem:
if self.input_stem_skipping > 0 \
and isinstance(layer, ResidualBlock) and isinstance(layer.shortcut, IdentityLayer):
pass
else:
_str += layer.module_str + '\n'
_str += 'max_pooling(ks=3, stride=2)\n'
for stage_id, block_idx in enumerate(self.grouped_block_index):
depth_param = self.runtime_depth[stage_id]
active_idx = block_idx[:len(block_idx) - depth_param]
for idx in active_idx:
_str += self.blocks[idx].module_str + '\n'
_str += self.global_avg_pool.__repr__() + '\n'
_str += self.classifier.module_str
return _str
@property
def config(self):
return {
'name': OFAResNets.__name__,
'bn': self.get_bn_param(),
'input_stem': [
layer.config for layer in self.input_stem
],
'blocks': [
block.config for block in self.blocks
],
'classifier': self.classifier.config,
}
@staticmethod
def build_from_config(config):
raise ValueError('do not support this function')
def load_state_dict(self, state_dict, **kwargs):
model_dict = self.state_dict()
for key in state_dict:
new_key = key
if new_key in model_dict:
pass
elif '.linear.' in new_key:
new_key = new_key.replace('.linear.', '.linear.linear.')
elif 'bn.' in new_key:
new_key = new_key.replace('bn.', 'bn.bn.')
elif 'conv.weight' in new_key:
new_key = new_key.replace('conv.weight', 'conv.conv.weight')
else:
raise ValueError(new_key)
assert new_key in model_dict, '%s' % new_key
model_dict[new_key] = state_dict[key]
super(OFAResNets, self).load_state_dict(model_dict)
""" set, sample and get active sub-networks """
def set_max_net(self):
self.set_active_subnet(d=max(self.depth_list), e=max(self.expand_ratio_list), w=len(self.width_mult_list) - 1)
def set_active_subnet(self, d=None, e=None, w=None, **kwargs):
depth = val2list(d, len(ResNets.BASE_DEPTH_LIST) + 1)
expand_ratio = val2list(e, len(self.blocks))
width_mult = val2list(w, len(ResNets.BASE_DEPTH_LIST) + 2)
for block, e in zip(self.blocks, expand_ratio):
if e is not None:
block.active_expand_ratio = e
if width_mult[0] is not None:
self.input_stem[1].conv.active_out_channel = self.input_stem[0].active_out_channel = \
self.input_stem[0].out_channel_list[width_mult[0]]
if width_mult[1] is not None:
self.input_stem[2].active_out_channel = self.input_stem[2].out_channel_list[width_mult[1]]
if depth[0] is not None:
self.input_stem_skipping = (depth[0] != max(self.depth_list))
for stage_id, (block_idx, d, w) in enumerate(zip(self.grouped_block_index, depth[1:], width_mult[2:])):
if d is not None:
self.runtime_depth[stage_id] = max(self.depth_list) - d
if w is not None:
for idx in block_idx:
self.blocks[idx].active_out_channel = self.blocks[idx].out_channel_list[w]
def sample_active_subnet(self):
# sample expand ratio
expand_setting = []
for block in self.blocks:
expand_setting.append(random.choice(block.expand_ratio_list))
# sample depth
depth_setting = [random.choice([max(self.depth_list), min(self.depth_list)])]
for stage_id in range(len(ResNets.BASE_DEPTH_LIST)):
depth_setting.append(random.choice(self.depth_list))
# sample width_mult
width_mult_setting = [
random.choice(list(range(len(self.input_stem[0].out_channel_list)))),
random.choice(list(range(len(self.input_stem[2].out_channel_list)))),
]
for stage_id, block_idx in enumerate(self.grouped_block_index):
stage_first_block = self.blocks[block_idx[0]]
width_mult_setting.append(
random.choice(list(range(len(stage_first_block.out_channel_list))))
)
arch_config = {
'd': depth_setting,
'e': expand_setting,
'w': width_mult_setting
}
self.set_active_subnet(**arch_config)
return arch_config
def get_active_subnet(self, preserve_weight=True):
input_stem = [self.input_stem[0].get_active_subnet(3, preserve_weight)]
if self.input_stem_skipping <= 0:
input_stem.append(ResidualBlock(
self.input_stem[1].conv.get_active_subnet(self.input_stem[0].active_out_channel, preserve_weight),
IdentityLayer(self.input_stem[0].active_out_channel, self.input_stem[0].active_out_channel)
))
input_stem.append(self.input_stem[2].get_active_subnet(self.input_stem[0].active_out_channel, preserve_weight))
input_channel = self.input_stem[2].active_out_channel
blocks = []
for stage_id, block_idx in enumerate(self.grouped_block_index):
depth_param = self.runtime_depth[stage_id]
active_idx = block_idx[:len(block_idx) - depth_param]
for idx in active_idx:
blocks.append(self.blocks[idx].get_active_subnet(input_channel, preserve_weight))
input_channel = self.blocks[idx].active_out_channel
classifier = self.classifier.get_active_subnet(input_channel, preserve_weight)
subnet = ResNets(input_stem, blocks, classifier)
subnet.set_bn_param(**self.get_bn_param())
return subnet
def get_active_net_config(self):
input_stem_config = [self.input_stem[0].get_active_subnet_config(3)]
if self.input_stem_skipping <= 0:
input_stem_config.append({
'name': ResidualBlock.__name__,
'conv': self.input_stem[1].conv.get_active_subnet_config(self.input_stem[0].active_out_channel),
'shortcut': IdentityLayer(self.input_stem[0].active_out_channel, self.input_stem[0].active_out_channel),
})
input_stem_config.append(self.input_stem[2].get_active_subnet_config(self.input_stem[0].active_out_channel))
input_channel = self.input_stem[2].active_out_channel
blocks_config = []
for stage_id, block_idx in enumerate(self.grouped_block_index):
depth_param = self.runtime_depth[stage_id]
active_idx = block_idx[:len(block_idx) - depth_param]
for idx in active_idx:
blocks_config.append(self.blocks[idx].get_active_subnet_config(input_channel))
input_channel = self.blocks[idx].active_out_channel
classifier_config = self.classifier.get_active_subnet_config(input_channel)
return {
'name': ResNets.__name__,
'bn': self.get_bn_param(),
'input_stem': input_stem_config,
'blocks': blocks_config,
'classifier': classifier_config,
}
""" Width Related Methods """
def re_organize_middle_weights(self, expand_ratio_stage=0):
for block in self.blocks:
block.re_organize_middle_weights(expand_ratio_stage)

View File

@@ -0,0 +1,5 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .progressive_shrinking import *

View File

@@ -0,0 +1,320 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import torch.nn as nn
import random
import time
import torch
import torch.nn.functional as F
from tqdm import tqdm
from ofa.utils import AverageMeter, cross_entropy_loss_with_soft_target
from ofa.utils import DistributedMetric, list_mean, subset_mean, val2list, MyRandomResizedCrop
from ofa.imagenet_classification.run_manager import DistributedRunManager
__all__ = [
'validate', 'train_one_epoch', 'train', 'load_models',
'train_elastic_depth', 'train_elastic_expand', 'train_elastic_width_mult',
]
def validate(run_manager, epoch=0, is_test=False, image_size_list=None,
ks_list=None, expand_ratio_list=None, depth_list=None, width_mult_list=None, additional_setting=None):
dynamic_net = run_manager.net
if isinstance(dynamic_net, nn.DataParallel):
dynamic_net = dynamic_net.module
dynamic_net.eval()
if image_size_list is None:
image_size_list = val2list(run_manager.run_config.data_provider.image_size, 1)
if ks_list is None:
ks_list = dynamic_net.ks_list
if expand_ratio_list is None:
expand_ratio_list = dynamic_net.expand_ratio_list
if depth_list is None:
depth_list = dynamic_net.depth_list
if width_mult_list is None:
if 'width_mult_list' in dynamic_net.__dict__:
width_mult_list = list(range(len(dynamic_net.width_mult_list)))
else:
width_mult_list = [0]
subnet_settings = []
for d in depth_list:
for e in expand_ratio_list:
for k in ks_list:
for w in width_mult_list:
for img_size in image_size_list:
subnet_settings.append([{
'image_size': img_size,
'd': d,
'e': e,
'ks': k,
'w': w,
}, 'R%s-D%s-E%s-K%s-W%s' % (img_size, d, e, k, w)])
if additional_setting is not None:
subnet_settings += additional_setting
losses_of_subnets, top1_of_subnets, top5_of_subnets = [], [], []
valid_log = ''
for setting, name in subnet_settings:
run_manager.write_log('-' * 30 + ' Validate %s ' % name + '-' * 30, 'train', should_print=False)
run_manager.run_config.data_provider.assign_active_img_size(setting.pop('image_size'))
dynamic_net.set_active_subnet(**setting)
run_manager.write_log(dynamic_net.module_str, 'train', should_print=False)
run_manager.reset_running_statistics(dynamic_net)
loss, (top1, top5) = run_manager.validate(epoch=epoch, is_test=is_test, run_str=name, net=dynamic_net)
losses_of_subnets.append(loss)
top1_of_subnets.append(top1)
top5_of_subnets.append(top5)
valid_log += '%s (%.3f), ' % (name, top1)
return list_mean(losses_of_subnets), list_mean(top1_of_subnets), list_mean(top5_of_subnets), valid_log
def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0):
dynamic_net = run_manager.network
distributed = isinstance(run_manager, DistributedRunManager)
# switch to train mode
dynamic_net.train()
if distributed:
run_manager.run_config.train_loader.sampler.set_epoch(epoch)
MyRandomResizedCrop.EPOCH = epoch
nBatch = len(run_manager.run_config.train_loader)
data_time = AverageMeter()
losses = DistributedMetric('train_loss') if distributed else AverageMeter()
metric_dict = run_manager.get_metric_dict()
with tqdm(total=nBatch,
desc='Train Epoch #{}'.format(epoch + 1),
disable=distributed and not run_manager.is_root) as t:
end = time.time()
for i, (images, labels) in enumerate(run_manager.run_config.train_loader):
MyRandomResizedCrop.BATCH = i
data_time.update(time.time() - end)
if epoch < warmup_epochs:
new_lr = run_manager.run_config.warmup_adjust_learning_rate(
run_manager.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
)
else:
new_lr = run_manager.run_config.adjust_learning_rate(
run_manager.optimizer, epoch - warmup_epochs, i, nBatch
)
images, labels = images.cuda(), labels.cuda()
target = labels
# soft target
if args.kd_ratio > 0:
args.teacher_model.train()
with torch.no_grad():
soft_logits = args.teacher_model(images).detach()
soft_label = F.softmax(soft_logits, dim=1)
# clean gradients
dynamic_net.zero_grad()
loss_of_subnets = []
# compute output
subnet_str = ''
for _ in range(args.dynamic_batch_size):
# set random seed before sampling
subnet_seed = int('%d%.3d%.3d' % (epoch * nBatch + i, _, 0))
random.seed(subnet_seed)
subnet_settings = dynamic_net.sample_active_subnet()
subnet_str += '%d: ' % _ + ','.join(['%s_%s' % (
key, '%.1f' % subset_mean(val, 0) if isinstance(val, list) else val
) for key, val in subnet_settings.items()]) + ' || '
output = run_manager.net(images)
if args.kd_ratio == 0:
loss = run_manager.train_criterion(output, labels)
loss_type = 'ce'
else:
if args.kd_type == 'ce':
kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
else:
kd_loss = F.mse_loss(output, soft_logits)
loss = args.kd_ratio * kd_loss + run_manager.train_criterion(output, labels)
loss_type = '%.1fkd-%s & ce' % (args.kd_ratio, args.kd_type)
# measure accuracy and record loss
loss_of_subnets.append(loss)
run_manager.update_metric(metric_dict, output, target)
loss.backward()
run_manager.optimizer.step()
losses.update(list_mean(loss_of_subnets), images.size(0))
t.set_postfix({
'loss': losses.avg.item(),
**run_manager.get_metric_vals(metric_dict, return_dict=True),
'R': images.size(2),
'lr': new_lr,
'loss_type': loss_type,
'seed': str(subnet_seed),
'str': subnet_str,
'data_time': data_time.avg,
})
t.update(1)
end = time.time()
return losses.avg.item(), run_manager.get_metric_vals(metric_dict)
def train(run_manager, args, validate_func=None):
distributed = isinstance(run_manager, DistributedRunManager)
if validate_func is None:
validate_func = validate
for epoch in range(run_manager.start_epoch, run_manager.run_config.n_epochs + args.warmup_epochs):
train_loss, (train_top1, train_top5) = train_one_epoch(
run_manager, args, epoch, args.warmup_epochs, args.warmup_lr)
if (epoch + 1) % args.validation_frequency == 0:
val_loss, val_acc, val_acc5, _val_log = validate_func(run_manager, epoch=epoch, is_test=False)
# best_acc
is_best = val_acc > run_manager.best_acc
run_manager.best_acc = max(run_manager.best_acc, val_acc)
if not distributed or run_manager.is_root:
val_log = 'Valid [{0}/{1}] loss={2:.3f}, top-1={3:.3f} ({4:.3f})'. \
format(epoch + 1 - args.warmup_epochs, run_manager.run_config.n_epochs, val_loss, val_acc,
run_manager.best_acc)
val_log += ', Train top-1 {top1:.3f}, Train loss {loss:.3f}\t'.format(top1=train_top1, loss=train_loss)
val_log += _val_log
run_manager.write_log(val_log, 'valid', should_print=False)
run_manager.save_model({
'epoch': epoch,
'best_acc': run_manager.best_acc,
'optimizer': run_manager.optimizer.state_dict(),
'state_dict': run_manager.network.state_dict(),
}, is_best=is_best)
def load_models(run_manager, dynamic_net, model_path=None):
# specify init path
init = torch.load(model_path, map_location='cpu')['state_dict']
dynamic_net.load_state_dict(init)
run_manager.write_log('Loaded init from %s' % model_path, 'valid')
def train_elastic_depth(train_func, run_manager, args, validate_func_dict):
dynamic_net = run_manager.net
if isinstance(dynamic_net, nn.DataParallel):
dynamic_net = dynamic_net.module
depth_stage_list = dynamic_net.depth_list.copy()
depth_stage_list.sort(reverse=True)
n_stages = len(depth_stage_list) - 1
current_stage = n_stages - 1
# load pretrained models
if run_manager.start_epoch == 0 and not args.resume:
validate_func_dict['depth_list'] = sorted(dynamic_net.depth_list)
load_models(run_manager, dynamic_net, model_path=args.ofa_checkpoint_path)
# validate after loading weights
run_manager.write_log('%.3f\t%.3f\t%.3f\t%s' %
validate(run_manager, is_test=True, **validate_func_dict), 'valid')
else:
assert args.resume
run_manager.write_log(
'-' * 30 + 'Supporting Elastic Depth: %s -> %s' %
(depth_stage_list[:current_stage + 1], depth_stage_list[:current_stage + 2]) + '-' * 30, 'valid'
)
# add depth list constraints
if len(set(dynamic_net.ks_list)) == 1 and len(set(dynamic_net.expand_ratio_list)) == 1:
validate_func_dict['depth_list'] = depth_stage_list
else:
validate_func_dict['depth_list'] = sorted({min(depth_stage_list), max(depth_stage_list)})
# train
train_func(
run_manager, args,
lambda _run_manager, epoch, is_test: validate(_run_manager, epoch, is_test, **validate_func_dict)
)
def train_elastic_expand(train_func, run_manager, args, validate_func_dict):
dynamic_net = run_manager.net
if isinstance(dynamic_net, nn.DataParallel):
dynamic_net = dynamic_net.module
expand_stage_list = dynamic_net.expand_ratio_list.copy()
expand_stage_list.sort(reverse=True)
n_stages = len(expand_stage_list) - 1
current_stage = n_stages - 1
# load pretrained models
if run_manager.start_epoch == 0 and not args.resume:
validate_func_dict['expand_ratio_list'] = sorted(dynamic_net.expand_ratio_list)
load_models(run_manager, dynamic_net, model_path=args.ofa_checkpoint_path)
dynamic_net.re_organize_middle_weights(expand_ratio_stage=current_stage)
run_manager.write_log('%.3f\t%.3f\t%.3f\t%s' %
validate(run_manager, is_test=True, **validate_func_dict), 'valid')
else:
assert args.resume
run_manager.write_log(
'-' * 30 + 'Supporting Elastic Expand Ratio: %s -> %s' %
(expand_stage_list[:current_stage + 1], expand_stage_list[:current_stage + 2]) + '-' * 30, 'valid'
)
if len(set(dynamic_net.ks_list)) == 1 and len(set(dynamic_net.depth_list)) == 1:
validate_func_dict['expand_ratio_list'] = expand_stage_list
else:
validate_func_dict['expand_ratio_list'] = sorted({min(expand_stage_list), max(expand_stage_list)})
# train
train_func(
run_manager, args,
lambda _run_manager, epoch, is_test: validate(_run_manager, epoch, is_test, **validate_func_dict)
)
def train_elastic_width_mult(train_func, run_manager, args, validate_func_dict):
dynamic_net = run_manager.net
if isinstance(dynamic_net, nn.DataParallel):
dynamic_net = dynamic_net.module
width_stage_list = dynamic_net.width_mult_list.copy()
width_stage_list.sort(reverse=True)
n_stages = len(width_stage_list) - 1
current_stage = n_stages - 1
if run_manager.start_epoch == 0 and not args.resume:
load_models(run_manager, dynamic_net, model_path=args.ofa_checkpoint_path)
if current_stage == 0:
dynamic_net.re_organize_middle_weights(expand_ratio_stage=len(dynamic_net.expand_ratio_list) - 1)
run_manager.write_log('reorganize_middle_weights (expand_ratio_stage=%d)'
% (len(dynamic_net.expand_ratio_list) - 1), 'valid')
try:
dynamic_net.re_organize_outer_weights()
run_manager.write_log('reorganize_outer_weights', 'valid')
except Exception:
pass
run_manager.write_log('%.3f\t%.3f\t%.3f\t%s' %
validate(run_manager, is_test=True, **validate_func_dict), 'valid')
else:
assert args.resume
run_manager.write_log(
'-' * 30 + 'Supporting Elastic Width Mult: %s -> %s' %
(width_stage_list[:current_stage + 1], width_stage_list[:current_stage + 2]) + '-' * 30, 'valid'
)
validate_func_dict['width_mult_list'] = sorted({0, len(width_stage_list) - 1})
# train
train_func(
run_manager, args,
lambda _run_manager, epoch, is_test: validate(_run_manager, epoch, is_test, **validate_func_dict)
)

View File

@@ -0,0 +1,70 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import copy
import torch.nn.functional as F
import torch.nn as nn
import torch
from ofa_local.utils import AverageMeter, get_net_device, DistributedTensor
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_op import DynamicBatchNorm2d
__all__ = ['set_running_statistics']
def set_running_statistics(model, data_loader, distributed=False):
bn_mean = {}
bn_var = {}
forward_model = copy.deepcopy(model)
for name, m in forward_model.named_modules():
if isinstance(m, nn.BatchNorm2d):
if distributed:
bn_mean[name] = DistributedTensor(name + '#mean')
bn_var[name] = DistributedTensor(name + '#var')
else:
bn_mean[name] = AverageMeter()
bn_var[name] = AverageMeter()
def new_forward(bn, mean_est, var_est):
def lambda_forward(x):
batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1
batch_var = (x - batch_mean) * (x - batch_mean)
batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
batch_mean = torch.squeeze(batch_mean)
batch_var = torch.squeeze(batch_var)
mean_est.update(batch_mean.data, x.size(0))
var_est.update(batch_var.data, x.size(0))
# bn forward using calculated mean & var
_feature_dim = batch_mean.size(0)
return F.batch_norm(
x, batch_mean, batch_var, bn.weight[:_feature_dim],
bn.bias[:_feature_dim], False,
0.0, bn.eps,
)
return lambda_forward
m.forward = new_forward(m, bn_mean[name], bn_var[name])
if len(bn_mean) == 0:
# skip if there is no batch normalization layers in the network
return
with torch.no_grad():
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = True
for images, labels in data_loader:
images = images.to(get_net_device(forward_model))
forward_model(images)
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = False
for name, m in model.named_modules():
if name in bn_mean and bn_mean[name].count > 0:
feature_dim = bn_mean[name].avg.size(0)
assert isinstance(m, nn.BatchNorm2d)
m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
m.running_var.data[:feature_dim].copy_(bn_var[name].avg)

View File

@@ -0,0 +1,18 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .proxyless_nets import *
from .mobilenet_v3 import *
from .resnets import *
def get_net_by_name(name):
if name == ProxylessNASNets.__name__:
return ProxylessNASNets
elif name == MobileNetV3.__name__:
return MobileNetV3
elif name == ResNets.__name__:
return ResNets
else:
raise ValueError('unrecognized type of network: %s' % name)

View File

@@ -0,0 +1,218 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import copy
import torch.nn as nn
from ofa_local.utils.layers import set_layer_from_config, MBConvLayer, ConvLayer, IdentityLayer, LinearLayer, ResidualBlock
from ofa_local.utils import MyNetwork, make_divisible, MyGlobalAvgPool2d
__all__ = ['MobileNetV3', 'MobileNetV3Large']
class MobileNetV3(MyNetwork):
def __init__(self, first_conv, blocks, final_expand_layer, feature_mix_layer, classifier):
super(MobileNetV3, self).__init__()
self.first_conv = first_conv
self.blocks = nn.ModuleList(blocks)
self.final_expand_layer = final_expand_layer
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=True)
self.feature_mix_layer = feature_mix_layer
self.classifier = classifier
def forward(self, x):
x = self.first_conv(x)
for block in self.blocks:
x = block(x)
x = self.final_expand_layer(x)
x = self.global_avg_pool(x) # global average pooling
x = self.feature_mix_layer(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
for block in self.blocks:
_str += block.module_str + '\n'
_str += self.final_expand_layer.module_str + '\n'
_str += self.global_avg_pool.__repr__() + '\n'
_str += self.feature_mix_layer.module_str + '\n'
_str += self.classifier.module_str
return _str
@property
def config(self):
return {
'name': MobileNetV3.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
'final_expand_layer': self.final_expand_layer.config,
'feature_mix_layer': self.feature_mix_layer.config,
'classifier': self.classifier.config,
}
@staticmethod
def build_from_config(config):
first_conv = set_layer_from_config(config['first_conv'])
final_expand_layer = set_layer_from_config(config['final_expand_layer'])
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
classifier = set_layer_from_config(config['classifier'])
blocks = []
for block_config in config['blocks']:
blocks.append(ResidualBlock.build_from_config(block_config))
net = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
if 'bn' in config:
net.set_bn_param(**config['bn'])
else:
net.set_bn_param(momentum=0.1, eps=1e-5)
return net
def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, ResidualBlock):
if isinstance(m.conv, MBConvLayer) and isinstance(m.shortcut, IdentityLayer):
m.conv.point_linear.bn.weight.data.zero_()
@property
def grouped_block_index(self):
info_list = []
block_index_list = []
for i, block in enumerate(self.blocks[1:], 1):
if block.shortcut is None and len(block_index_list) > 0:
info_list.append(block_index_list)
block_index_list = []
block_index_list.append(i)
if len(block_index_list) > 0:
info_list.append(block_index_list)
return info_list
@staticmethod
def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
# first conv layer
first_conv = ConvLayer(
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='h_swish', ops_order='weight_bn_act'
)
# build mobile blocks
feature_dim = input_channel
blocks = []
for stage_id, block_config_list in cfg.items():
for k, mid_channel, out_channel, use_se, act_func, stride, expand_ratio in block_config_list:
mb_conv = MBConvLayer(
feature_dim, out_channel, k, stride, expand_ratio, mid_channel, act_func, use_se
)
if stride == 1 and out_channel == feature_dim:
shortcut = IdentityLayer(out_channel, out_channel)
else:
shortcut = None
blocks.append(ResidualBlock(mb_conv, shortcut))
feature_dim = out_channel
# final expand layer
final_expand_layer = ConvLayer(
feature_dim, feature_dim * 6, kernel_size=1, use_bn=True, act_func='h_swish', ops_order='weight_bn_act',
)
# feature mix layer
feature_mix_layer = ConvLayer(
feature_dim * 6, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
)
# classifier
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
@staticmethod
def adjust_cfg(cfg, ks=None, expand_ratio=None, depth_param=None, stage_width_list=None):
for i, (stage_id, block_config_list) in enumerate(cfg.items()):
for block_config in block_config_list:
if ks is not None and stage_id != '0':
block_config[0] = ks
if expand_ratio is not None and stage_id != '0':
block_config[-1] = expand_ratio
block_config[1] = None
if stage_width_list is not None:
block_config[2] = stage_width_list[i]
if depth_param is not None and stage_id != '0':
new_block_config_list = [block_config_list[0]]
new_block_config_list += [copy.deepcopy(block_config_list[-1]) for _ in range(depth_param - 1)]
cfg[stage_id] = new_block_config_list
return cfg
def load_state_dict(self, state_dict, **kwargs):
current_state_dict = self.state_dict()
for key in state_dict:
if key not in current_state_dict:
assert '.mobile_inverted_conv.' in key
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
else:
new_key = key
current_state_dict[new_key] = state_dict[key]
super(MobileNetV3, self).load_state_dict(current_state_dict)
class MobileNetV3Large(MobileNetV3):
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-5), dropout_rate=0.2,
ks=None, expand_ratio=None, depth_param=None, stage_width_list=None):
input_channel = 16
last_channel = 1280
input_channel = make_divisible(input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
last_channel = make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE) \
if width_mult > 1.0 else last_channel
cfg = {
# k, exp, c, se, nl, s, e,
'0': [
[3, 16, 16, False, 'relu', 1, 1],
],
'1': [
[3, 64, 24, False, 'relu', 2, None], # 4
[3, 72, 24, False, 'relu', 1, None], # 3
],
'2': [
[5, 72, 40, True, 'relu', 2, None], # 3
[5, 120, 40, True, 'relu', 1, None], # 3
[5, 120, 40, True, 'relu', 1, None], # 3
],
'3': [
[3, 240, 80, False, 'h_swish', 2, None], # 6
[3, 200, 80, False, 'h_swish', 1, None], # 2.5
[3, 184, 80, False, 'h_swish', 1, None], # 2.3
[3, 184, 80, False, 'h_swish', 1, None], # 2.3
],
'4': [
[3, 480, 112, True, 'h_swish', 1, None], # 6
[3, 672, 112, True, 'h_swish', 1, None], # 6
],
'5': [
[5, 672, 160, True, 'h_swish', 2, None], # 6
[5, 960, 160, True, 'h_swish', 1, None], # 6
[5, 960, 160, True, 'h_swish', 1, None], # 6
]
}
cfg = self.adjust_cfg(cfg, ks, expand_ratio, depth_param, stage_width_list)
# width multiplier on mobile setting, change `exp: 1` and `c: 2`
for stage_id, block_config_list in cfg.items():
for block_config in block_config_list:
if block_config[1] is not None:
block_config[1] = make_divisible(block_config[1] * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
block_config[2] = make_divisible(block_config[2] * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier = self.build_net_via_cfg(
cfg, input_channel, last_channel, n_classes, dropout_rate
)
super(MobileNetV3Large, self).__init__(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
# set bn param
self.set_bn_param(*bn_param)

View File

@@ -0,0 +1,210 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import json
import torch.nn as nn
from ofa_local.utils.layers import set_layer_from_config, MBConvLayer, ConvLayer, IdentityLayer, LinearLayer, ResidualBlock
from ofa_local.utils import download_url, make_divisible, val2list, MyNetwork, MyGlobalAvgPool2d
__all__ = ['proxyless_base', 'ProxylessNASNets', 'MobileNetV2']
def proxyless_base(net_config=None, n_classes=None, bn_param=None, dropout_rate=None,
local_path='~/.torch/proxylessnas/'):
assert net_config is not None, 'Please input a network config'
if 'http' in net_config:
net_config_path = download_url(net_config, local_path)
else:
net_config_path = net_config
net_config_json = json.load(open(net_config_path, 'r'))
if n_classes is not None:
net_config_json['classifier']['out_features'] = n_classes
if dropout_rate is not None:
net_config_json['classifier']['dropout_rate'] = dropout_rate
net = ProxylessNASNets.build_from_config(net_config_json)
if bn_param is not None:
net.set_bn_param(*bn_param)
return net
class ProxylessNASNets(MyNetwork):
def __init__(self, first_conv, blocks, feature_mix_layer, classifier):
super(ProxylessNASNets, self).__init__()
self.first_conv = first_conv
self.blocks = nn.ModuleList(blocks)
self.feature_mix_layer = feature_mix_layer
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
self.classifier = classifier
def forward(self, x):
x = self.first_conv(x)
for block in self.blocks:
x = block(x)
if self.feature_mix_layer is not None:
x = self.feature_mix_layer(x)
x = self.global_avg_pool(x)
x = self.classifier(x)
return x
@property
def module_str(self):
_str = self.first_conv.module_str + '\n'
for block in self.blocks:
_str += block.module_str + '\n'
_str += self.feature_mix_layer.module_str + '\n'
_str += self.global_avg_pool.__repr__() + '\n'
_str += self.classifier.module_str
return _str
@property
def config(self):
return {
'name': ProxylessNASNets.__name__,
'bn': self.get_bn_param(),
'first_conv': self.first_conv.config,
'blocks': [
block.config for block in self.blocks
],
'feature_mix_layer': None if self.feature_mix_layer is None else self.feature_mix_layer.config,
'classifier': self.classifier.config,
}
@staticmethod
def build_from_config(config):
first_conv = set_layer_from_config(config['first_conv'])
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
classifier = set_layer_from_config(config['classifier'])
blocks = []
for block_config in config['blocks']:
blocks.append(ResidualBlock.build_from_config(block_config))
net = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier)
if 'bn' in config:
net.set_bn_param(**config['bn'])
else:
net.set_bn_param(momentum=0.1, eps=1e-3)
return net
def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, ResidualBlock):
if isinstance(m.conv, MBConvLayer) and isinstance(m.shortcut, IdentityLayer):
m.conv.point_linear.bn.weight.data.zero_()
@property
def grouped_block_index(self):
info_list = []
block_index_list = []
for i, block in enumerate(self.blocks[1:], 1):
if block.shortcut is None and len(block_index_list) > 0:
info_list.append(block_index_list)
block_index_list = []
block_index_list.append(i)
if len(block_index_list) > 0:
info_list.append(block_index_list)
return info_list
def load_state_dict(self, state_dict, **kwargs):
current_state_dict = self.state_dict()
for key in state_dict:
if key not in current_state_dict:
assert '.mobile_inverted_conv.' in key
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
else:
new_key = key
current_state_dict[new_key] = state_dict[key]
super(ProxylessNASNets, self).load_state_dict(current_state_dict)
class MobileNetV2(ProxylessNASNets):
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-3), dropout_rate=0.2,
ks=None, expand_ratio=None, depth_param=None, stage_width_list=None):
ks = 3 if ks is None else ks
expand_ratio = 6 if expand_ratio is None else expand_ratio
input_channel = 32
last_channel = 1280
input_channel = make_divisible(input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
last_channel = make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE) \
if width_mult > 1.0 else last_channel
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[expand_ratio, 24, 2, 2],
[expand_ratio, 32, 3, 2],
[expand_ratio, 64, 4, 2],
[expand_ratio, 96, 3, 1],
[expand_ratio, 160, 3, 2],
[expand_ratio, 320, 1, 1],
]
if depth_param is not None:
assert isinstance(depth_param, int)
for i in range(1, len(inverted_residual_setting) - 1):
inverted_residual_setting[i][2] = depth_param
if stage_width_list is not None:
for i in range(len(inverted_residual_setting)):
inverted_residual_setting[i][1] = stage_width_list[i]
ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1)
_pt = 0
# first conv layer
first_conv = ConvLayer(
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='relu6', ops_order='weight_bn_act'
)
# inverted residual blocks
blocks = []
for t, c, n, s in inverted_residual_setting:
output_channel = make_divisible(c * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
for i in range(n):
if i == 0:
stride = s
else:
stride = 1
if t == 1:
kernel_size = 3
else:
kernel_size = ks[_pt]
_pt += 1
mobile_inverted_conv = MBConvLayer(
in_channels=input_channel, out_channels=output_channel, kernel_size=kernel_size, stride=stride,
expand_ratio=t,
)
if stride == 1:
if input_channel == output_channel:
shortcut = IdentityLayer(input_channel, input_channel)
else:
shortcut = None
else:
shortcut = None
blocks.append(
ResidualBlock(mobile_inverted_conv, shortcut)
)
input_channel = output_channel
# 1x1_conv before global average pooling
feature_mix_layer = ConvLayer(
input_channel, last_channel, kernel_size=1, use_bn=True, act_func='relu6', ops_order='weight_bn_act',
)
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
super(MobileNetV2, self).__init__(first_conv, blocks, feature_mix_layer, classifier)
# set bn param
self.set_bn_param(*bn_param)

View File

@@ -0,0 +1,192 @@
import torch.nn as nn
from ofa_local.utils.layers import set_layer_from_config, ConvLayer, IdentityLayer, LinearLayer
from ofa_local.utils.layers import ResNetBottleneckBlock, ResidualBlock
from ofa_local.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d
__all__ = ['ResNets', 'ResNet50', 'ResNet50D']
class ResNets(MyNetwork):
BASE_DEPTH_LIST = [2, 2, 4, 2]
STAGE_WIDTH_LIST = [256, 512, 1024, 2048]
def __init__(self, input_stem, blocks, classifier):
super(ResNets, self).__init__()
self.input_stem = nn.ModuleList(input_stem)
self.max_pooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
self.blocks = nn.ModuleList(blocks)
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
self.classifier = classifier
def forward(self, x):
for layer in self.input_stem:
x = layer(x)
x = self.max_pooling(x)
for block in self.blocks:
x = block(x)
x = self.global_avg_pool(x)
x = self.classifier(x)
return x
@property
def module_str(self):
_str = ''
for layer in self.input_stem:
_str += layer.module_str + '\n'
_str += 'max_pooling(ks=3, stride=2)\n'
for block in self.blocks:
_str += block.module_str + '\n'
_str += self.global_avg_pool.__repr__() + '\n'
_str += self.classifier.module_str
return _str
@property
def config(self):
return {
'name': ResNets.__name__,
'bn': self.get_bn_param(),
'input_stem': [
layer.config for layer in self.input_stem
],
'blocks': [
block.config for block in self.blocks
],
'classifier': self.classifier.config,
}
@staticmethod
def build_from_config(config):
classifier = set_layer_from_config(config['classifier'])
input_stem = []
for layer_config in config['input_stem']:
input_stem.append(set_layer_from_config(layer_config))
blocks = []
for block_config in config['blocks']:
blocks.append(set_layer_from_config(block_config))
net = ResNets(input_stem, blocks, classifier)
if 'bn' in config:
net.set_bn_param(**config['bn'])
else:
net.set_bn_param(momentum=0.1, eps=1e-5)
return net
def zero_last_gamma(self):
for m in self.modules():
if isinstance(m, ResNetBottleneckBlock) and isinstance(m.downsample, IdentityLayer):
m.conv3.bn.weight.data.zero_()
@property
def grouped_block_index(self):
info_list = []
block_index_list = []
for i, block in enumerate(self.blocks):
if not isinstance(block.downsample, IdentityLayer) and len(block_index_list) > 0:
info_list.append(block_index_list)
block_index_list = []
block_index_list.append(i)
if len(block_index_list) > 0:
info_list.append(block_index_list)
return info_list
def load_state_dict(self, state_dict, **kwargs):
super(ResNets, self).load_state_dict(state_dict)
class ResNet50(ResNets):
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-5), dropout_rate=0,
expand_ratio=None, depth_param=None):
expand_ratio = 0.25 if expand_ratio is None else expand_ratio
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
for i, width in enumerate(stage_width_list):
stage_width_list[i] = make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
depth_list = [3, 4, 6, 3]
if depth_param is not None:
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
depth_list[i] = depth + depth_param
stride_list = [1, 2, 2, 2]
# build input stem
input_stem = [ConvLayer(
3, input_channel, kernel_size=7, stride=2, use_bn=True, act_func='relu', ops_order='weight_bn_act',
)]
# blocks
blocks = []
for d, width, s in zip(depth_list, stage_width_list, stride_list):
for i in range(d):
stride = s if i == 0 else 1
bottleneck_block = ResNetBottleneckBlock(
input_channel, width, kernel_size=3, stride=stride, expand_ratio=expand_ratio,
act_func='relu', downsample_mode='conv',
)
blocks.append(bottleneck_block)
input_channel = width
# classifier
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
super(ResNet50, self).__init__(input_stem, blocks, classifier)
# set bn param
self.set_bn_param(*bn_param)
class ResNet50D(ResNets):
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-5), dropout_rate=0,
expand_ratio=None, depth_param=None):
expand_ratio = 0.25 if expand_ratio is None else expand_ratio
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
mid_input_channel = make_divisible(input_channel // 2, MyNetwork.CHANNEL_DIVISIBLE)
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
for i, width in enumerate(stage_width_list):
stage_width_list[i] = make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
depth_list = [3, 4, 6, 3]
if depth_param is not None:
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
depth_list[i] = depth + depth_param
stride_list = [1, 2, 2, 2]
# build input stem
input_stem = [
ConvLayer(3, mid_input_channel, 3, stride=2, use_bn=True, act_func='relu'),
ResidualBlock(
ConvLayer(mid_input_channel, mid_input_channel, 3, stride=1, use_bn=True, act_func='relu'),
IdentityLayer(mid_input_channel, mid_input_channel)
),
ConvLayer(mid_input_channel, input_channel, 3, stride=1, use_bn=True, act_func='relu')
]
# blocks
blocks = []
for d, width, s in zip(depth_list, stage_width_list, stride_list):
for i in range(d):
stride = s if i == 0 else 1
bottleneck_block = ResNetBottleneckBlock(
input_channel, width, kernel_size=3, stride=stride, expand_ratio=expand_ratio,
act_func='relu', downsample_mode='avgpool_conv',
)
blocks.append(bottleneck_block)
input_channel = width
# classifier
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
super(ResNet50D, self).__init__(input_stem, blocks, classifier)
# set bn param
self.set_bn_param(*bn_param)

View File

@@ -0,0 +1,7 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .run_config import *
from .run_manager import *
from .distributed_run_manager import *

View File

@@ -0,0 +1,381 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import json
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from tqdm import tqdm
from ofa_local.utils import cross_entropy_with_label_smoothing, cross_entropy_loss_with_soft_target, write_log, init_models
from ofa_local.utils import DistributedMetric, list_mean, get_net_info, accuracy, AverageMeter, mix_labels, mix_images
from ofa_local.utils import MyRandomResizedCrop
__all__ = ['DistributedRunManager']
class DistributedRunManager:
def __init__(self, path, net, run_config, hvd_compression, backward_steps=1, is_root=False, init=True):
import horovod.torch as hvd
self.path = path
self.net = net
self.run_config = run_config
self.is_root = is_root
self.best_acc = 0.0
self.start_epoch = 0
os.makedirs(self.path, exist_ok=True)
self.net.cuda()
cudnn.benchmark = True
if init and self.is_root:
init_models(self.net, self.run_config.model_init)
if self.is_root:
# print net info
net_info = get_net_info(self.net, self.run_config.data_provider.data_shape)
with open('%s/net_info.txt' % self.path, 'w') as fout:
fout.write(json.dumps(net_info, indent=4) + '\n')
try:
fout.write(self.net.module_str + '\n')
except Exception:
fout.write('%s do not support `module_str`' % type(self.net))
fout.write('%s\n' % self.run_config.data_provider.train.dataset.transform)
fout.write('%s\n' % self.run_config.data_provider.test.dataset.transform)
fout.write('%s\n' % self.net)
# criterion
if isinstance(self.run_config.mixup_alpha, float):
self.train_criterion = cross_entropy_loss_with_soft_target
elif self.run_config.label_smoothing > 0:
self.train_criterion = lambda pred, target: \
cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
else:
self.train_criterion = nn.CrossEntropyLoss()
self.test_criterion = nn.CrossEntropyLoss()
# optimizer
if self.run_config.no_decay_keys:
keys = self.run_config.no_decay_keys.split('#')
net_params = [
self.net.get_parameters(keys, mode='exclude'), # parameters with weight decay
self.net.get_parameters(keys, mode='include'), # parameters without weight decay
]
else:
# noinspection PyBroadException
try:
net_params = self.network.weight_parameters()
except Exception:
net_params = []
for param in self.network.parameters():
if param.requires_grad:
net_params.append(param)
self.optimizer = self.run_config.build_optimizer(net_params)
self.optimizer = hvd.DistributedOptimizer(
self.optimizer, named_parameters=self.net.named_parameters(), compression=hvd_compression,
backward_passes_per_step=backward_steps,
)
""" save path and log path """
@property
def save_path(self):
if self.__dict__.get('_save_path', None) is None:
save_path = os.path.join(self.path, 'checkpoint')
os.makedirs(save_path, exist_ok=True)
self.__dict__['_save_path'] = save_path
return self.__dict__['_save_path']
@property
def logs_path(self):
if self.__dict__.get('_logs_path', None) is None:
logs_path = os.path.join(self.path, 'logs')
os.makedirs(logs_path, exist_ok=True)
self.__dict__['_logs_path'] = logs_path
return self.__dict__['_logs_path']
@property
def network(self):
return self.net
@network.setter
def network(self, new_val):
self.net = new_val
def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
if self.is_root:
write_log(self.logs_path, log_str, prefix, should_print, mode)
""" save & load model & save_config & broadcast """
def save_config(self, extra_run_config=None, extra_net_config=None):
if self.is_root:
run_save_path = os.path.join(self.path, 'run.config')
if not os.path.isfile(run_save_path):
run_config = self.run_config.config
if extra_run_config is not None:
run_config.update(extra_run_config)
json.dump(run_config, open(run_save_path, 'w'), indent=4)
print('Run configs dump to %s' % run_save_path)
try:
net_save_path = os.path.join(self.path, 'net.config')
net_config = self.net.config
if extra_net_config is not None:
net_config.update(extra_net_config)
json.dump(net_config, open(net_save_path, 'w'), indent=4)
print('Network configs dump to %s' % net_save_path)
except Exception:
print('%s do not support net config' % type(self.net))
def save_model(self, checkpoint=None, is_best=False, model_name=None):
if self.is_root:
if checkpoint is None:
checkpoint = {'state_dict': self.net.state_dict()}
if model_name is None:
model_name = 'checkpoint.pth.tar'
latest_fname = os.path.join(self.save_path, 'latest.txt')
model_path = os.path.join(self.save_path, model_name)
with open(latest_fname, 'w') as _fout:
_fout.write(model_path + '\n')
torch.save(checkpoint, model_path)
if is_best:
best_path = os.path.join(self.save_path, 'model_best.pth.tar')
torch.save({'state_dict': checkpoint['state_dict']}, best_path)
def load_model(self, model_fname=None):
if self.is_root:
latest_fname = os.path.join(self.save_path, 'latest.txt')
if model_fname is None and os.path.exists(latest_fname):
with open(latest_fname, 'r') as fin:
model_fname = fin.readline()
if model_fname[-1] == '\n':
model_fname = model_fname[:-1]
# noinspection PyBroadException
try:
if model_fname is None or not os.path.exists(model_fname):
model_fname = '%s/checkpoint.pth.tar' % self.save_path
with open(latest_fname, 'w') as fout:
fout.write(model_fname + '\n')
print("=> loading checkpoint '{}'".format(model_fname))
checkpoint = torch.load(model_fname, map_location='cpu')
except Exception:
self.write_log('fail to load checkpoint from %s' % self.save_path, 'valid')
return
self.net.load_state_dict(checkpoint['state_dict'])
if 'epoch' in checkpoint:
self.start_epoch = checkpoint['epoch'] + 1
if 'best_acc' in checkpoint:
self.best_acc = checkpoint['best_acc']
if 'optimizer' in checkpoint:
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.write_log("=> loaded checkpoint '{}'".format(model_fname), 'valid')
# noinspection PyArgumentList
def broadcast(self):
import horovod.torch as hvd
self.start_epoch = hvd.broadcast(torch.LongTensor(1).fill_(self.start_epoch)[0], 0, name='start_epoch').item()
self.best_acc = hvd.broadcast(torch.Tensor(1).fill_(self.best_acc)[0], 0, name='best_acc').item()
hvd.broadcast_parameters(self.net.state_dict(), 0)
hvd.broadcast_optimizer_state(self.optimizer, 0)
""" metric related """
def get_metric_dict(self):
return {
'top1': DistributedMetric('top1'),
'top5': DistributedMetric('top5'),
}
def update_metric(self, metric_dict, output, labels):
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
metric_dict['top1'].update(acc1[0], output.size(0))
metric_dict['top5'].update(acc5[0], output.size(0))
def get_metric_vals(self, metric_dict, return_dict=False):
if return_dict:
return {
key: metric_dict[key].avg.item() for key in metric_dict
}
else:
return [metric_dict[key].avg.item() for key in metric_dict]
def get_metric_names(self):
return 'top1', 'top5'
""" train & validate """
def validate(self, epoch=0, is_test=False, run_str='', net=None, data_loader=None, no_logs=False):
if net is None:
net = self.net
if data_loader is None:
if is_test:
data_loader = self.run_config.test_loader
else:
data_loader = self.run_config.valid_loader
net.eval()
losses = DistributedMetric('val_loss')
metric_dict = self.get_metric_dict()
with torch.no_grad():
with tqdm(total=len(data_loader),
desc='Validate Epoch #{} {}'.format(epoch + 1, run_str),
disable=no_logs or not self.is_root) as t:
for i, (images, labels) in enumerate(data_loader):
images, labels = images.cuda(), labels.cuda()
# compute output
output = net(images)
loss = self.test_criterion(output, labels)
# measure accuracy and record loss
losses.update(loss, images.size(0))
self.update_metric(metric_dict, output, labels)
t.set_postfix({
'loss': losses.avg.item(),
**self.get_metric_vals(metric_dict, return_dict=True),
'img_size': images.size(2),
})
t.update(1)
return losses.avg.item(), self.get_metric_vals(metric_dict)
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
if net is None:
net = self.net
if isinstance(self.run_config.data_provider.image_size, list):
img_size_list, loss_list, top1_list, top5_list = [], [], [], []
for img_size in self.run_config.data_provider.image_size:
img_size_list.append(img_size)
self.run_config.data_provider.assign_active_img_size(img_size)
self.reset_running_statistics(net=net)
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
loss_list.append(loss)
top1_list.append(top1)
top5_list.append(top5)
return img_size_list, loss_list, top1_list, top5_list
else:
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
def train_one_epoch(self, args, epoch, warmup_epochs=5, warmup_lr=0):
self.net.train()
self.run_config.train_loader.sampler.set_epoch(epoch) # required by distributed sampler
MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution
nBatch = len(self.run_config.train_loader)
losses = DistributedMetric('train_loss')
metric_dict = self.get_metric_dict()
data_time = AverageMeter()
with tqdm(total=nBatch,
desc='Train Epoch #{}'.format(epoch + 1),
disable=not self.is_root) as t:
end = time.time()
for i, (images, labels) in enumerate(self.run_config.train_loader):
MyRandomResizedCrop.BATCH = i
data_time.update(time.time() - end)
if epoch < warmup_epochs:
new_lr = self.run_config.warmup_adjust_learning_rate(
self.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
)
else:
new_lr = self.run_config.adjust_learning_rate(self.optimizer, epoch - warmup_epochs, i, nBatch)
images, labels = images.cuda(), labels.cuda()
target = labels
if isinstance(self.run_config.mixup_alpha, float):
# transform data
random.seed(int('%d%.3d' % (i, epoch)))
lam = random.betavariate(self.run_config.mixup_alpha, self.run_config.mixup_alpha)
images = mix_images(images, lam)
labels = mix_labels(
labels, lam, self.run_config.data_provider.n_classes, self.run_config.label_smoothing
)
# soft target
if args.teacher_model is not None:
args.teacher_model.train()
with torch.no_grad():
soft_logits = args.teacher_model(images).detach()
soft_label = F.softmax(soft_logits, dim=1)
# compute output
output = self.net(images)
if args.teacher_model is None:
loss = self.train_criterion(output, labels)
loss_type = 'ce'
else:
if args.kd_type == 'ce':
kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
else:
kd_loss = F.mse_loss(output, soft_logits)
loss = args.kd_ratio * kd_loss + self.train_criterion(output, labels)
loss_type = '%.1fkd+ce' % args.kd_ratio
# update
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# measure accuracy and record loss
losses.update(loss, images.size(0))
self.update_metric(metric_dict, output, target)
t.set_postfix({
'loss': losses.avg.item(),
**self.get_metric_vals(metric_dict, return_dict=True),
'img_size': images.size(2),
'lr': new_lr,
'loss_type': loss_type,
'data_time': data_time.avg,
})
t.update(1)
end = time.time()
return losses.avg.item(), self.get_metric_vals(metric_dict)
def train(self, args, warmup_epochs=5, warmup_lr=0):
for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epochs):
train_loss, (train_top1, train_top5) = self.train_one_epoch(args, epoch, warmup_epochs, warmup_lr)
img_size, val_loss, val_top1, val_top5 = self.validate_all_resolution(epoch, is_test=False)
is_best = list_mean(val_top1) > self.best_acc
self.best_acc = max(self.best_acc, list_mean(val_top1))
if self.is_root:
val_log = '[{0}/{1}]\tloss {2:.3f}\t{6} acc {3:.3f} ({4:.3f})\t{7} acc {5:.3f}\t' \
'Train {6} {top1:.3f}\tloss {train_loss:.3f}\t'. \
format(epoch + 1 - warmup_epochs, self.run_config.n_epochs, list_mean(val_loss),
list_mean(val_top1), self.best_acc, list_mean(val_top5), *self.get_metric_names(),
top1=train_top1, train_loss=train_loss)
for i_s, v_a in zip(img_size, val_top1):
val_log += '(%d, %.3f), ' % (i_s, v_a)
self.write_log(val_log, prefix='valid', should_print=False)
self.save_model({
'epoch': epoch,
'best_acc': self.best_acc,
'optimizer': self.optimizer.state_dict(),
'state_dict': self.net.state_dict(),
}, is_best=is_best)
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
from ofa.imagenet_classification.elastic_nn.utils import set_running_statistics
if net is None:
net = self.net
if data_loader is None:
data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
set_running_statistics(net, data_loader)

View File

@@ -0,0 +1,161 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from ofa_local.utils import calc_learning_rate, build_optimizer
from ofa_local.imagenet_classification.data_providers import ImagenetDataProvider
__all__ = ['RunConfig', 'ImagenetRunConfig', 'DistributedImageNetRunConfig']
class RunConfig:
def __init__(self, n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha, model_init, validation_frequency, print_frequency):
self.n_epochs = n_epochs
self.init_lr = init_lr
self.lr_schedule_type = lr_schedule_type
self.lr_schedule_param = lr_schedule_param
self.dataset = dataset
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.valid_size = valid_size
self.opt_type = opt_type
self.opt_param = opt_param
self.weight_decay = weight_decay
self.label_smoothing = label_smoothing
self.no_decay_keys = no_decay_keys
self.mixup_alpha = mixup_alpha
self.model_init = model_init
self.validation_frequency = validation_frequency
self.print_frequency = print_frequency
@property
def config(self):
config = {}
for key in self.__dict__:
if not key.startswith('_'):
config[key] = self.__dict__[key]
return config
def copy(self):
return RunConfig(**self.config)
""" learning rate """
def adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None):
""" adjust learning of a given optimizer and return the new learning rate """
new_lr = calc_learning_rate(epoch, self.init_lr, self.n_epochs, batch, nBatch, self.lr_schedule_type)
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
return new_lr
def warmup_adjust_learning_rate(self, optimizer, T_total, nBatch, epoch, batch=0, warmup_lr=0):
T_cur = epoch * nBatch + batch + 1
new_lr = T_cur / T_total * (self.init_lr - warmup_lr) + warmup_lr
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
return new_lr
""" data provider """
@property
def data_provider(self):
raise NotImplementedError
@property
def train_loader(self):
return self.data_provider.train
@property
def valid_loader(self):
return self.data_provider.valid
@property
def test_loader(self):
return self.data_provider.test
def random_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
return self.data_provider.build_sub_train_loader(n_images, batch_size, num_worker, num_replicas, rank)
""" optimizer """
def build_optimizer(self, net_params):
return build_optimizer(net_params,
self.opt_type, self.opt_param, self.init_lr, self.weight_decay, self.no_decay_keys)
class ImagenetRunConfig(RunConfig):
def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='imagenet', train_batch_size=256, test_batch_size=500, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None,
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224, **kwargs):
super(ImagenetRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha,
model_init, validation_frequency, print_frequency
)
self.n_worker = n_worker
self.resize_scale = resize_scale
self.distort_color = distort_color
self.image_size = image_size
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == ImagenetDataProvider.name():
DataProviderClass = ImagenetDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
)
return self.__dict__['_data_provider']
class DistributedImageNetRunConfig(ImagenetRunConfig):
def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
dataset='imagenet', train_batch_size=64, test_batch_size=64, valid_size=None,
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None,
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
n_worker=8, resize_scale=0.08, distort_color='tf', image_size=224,
**kwargs):
super(DistributedImageNetRunConfig, self).__init__(
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
dataset, train_batch_size, test_batch_size, valid_size,
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
mixup_alpha, model_init, validation_frequency, print_frequency, n_worker, resize_scale, distort_color,
image_size, **kwargs
)
self._num_replicas = kwargs['num_replicas']
self._rank = kwargs['rank']
@property
def data_provider(self):
if self.__dict__.get('_data_provider', None) is None:
if self.dataset == ImagenetDataProvider.name():
DataProviderClass = ImagenetDataProvider
else:
raise NotImplementedError
self.__dict__['_data_provider'] = DataProviderClass(
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
distort_color=self.distort_color, image_size=self.image_size,
num_replicas=self._num_replicas, rank=self._rank,
)
return self.__dict__['_data_provider']

View File

@@ -0,0 +1,375 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import random
import time
import json
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from tqdm import tqdm
from ofa_local.utils import get_net_info, cross_entropy_loss_with_soft_target, cross_entropy_with_label_smoothing
from ofa_local.utils import AverageMeter, accuracy, write_log, mix_images, mix_labels, init_models
from ofa_local.utils import MyRandomResizedCrop
__all__ = ['RunManager']
class RunManager:
def __init__(self, path, net, run_config, init=True, measure_latency=None, no_gpu=False):
self.path = path
self.net = net
self.run_config = run_config
self.best_acc = 0
self.start_epoch = 0
os.makedirs(self.path, exist_ok=True)
# move network to GPU if available
if torch.cuda.is_available() and (not no_gpu):
self.device = torch.device('cuda:0')
self.net = self.net.to(self.device)
cudnn.benchmark = True
else:
self.device = torch.device('cpu')
# initialize model (default)
if init:
init_models(run_config.model_init)
# net info
net_info = get_net_info(self.net, self.run_config.data_provider.data_shape, measure_latency, True)
with open('%s/net_info.txt' % self.path, 'w') as fout:
fout.write(json.dumps(net_info, indent=4) + '\n')
# noinspection PyBroadException
try:
fout.write(self.network.module_str + '\n')
except Exception:
pass
fout.write('%s\n' % self.run_config.data_provider.train.dataset.transform)
fout.write('%s\n' % self.run_config.data_provider.test.dataset.transform)
fout.write('%s\n' % self.network)
# criterion
if isinstance(self.run_config.mixup_alpha, float):
self.train_criterion = cross_entropy_loss_with_soft_target
elif self.run_config.label_smoothing > 0:
self.train_criterion = \
lambda pred, target: cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
else:
self.train_criterion = nn.CrossEntropyLoss()
self.test_criterion = nn.CrossEntropyLoss()
# optimizer
if self.run_config.no_decay_keys:
keys = self.run_config.no_decay_keys.split('#')
net_params = [
self.network.get_parameters(keys, mode='exclude'), # parameters with weight decay
self.network.get_parameters(keys, mode='include'), # parameters without weight decay
]
else:
# noinspection PyBroadException
try:
net_params = self.network.weight_parameters()
except Exception:
net_params = []
for param in self.network.parameters():
if param.requires_grad:
net_params.append(param)
self.optimizer = self.run_config.build_optimizer(net_params)
self.net = torch.nn.DataParallel(self.net)
""" save path and log path """
@property
def save_path(self):
if self.__dict__.get('_save_path', None) is None:
save_path = os.path.join(self.path, 'checkpoint')
os.makedirs(save_path, exist_ok=True)
self.__dict__['_save_path'] = save_path
return self.__dict__['_save_path']
@property
def logs_path(self):
if self.__dict__.get('_logs_path', None) is None:
logs_path = os.path.join(self.path, 'logs')
os.makedirs(logs_path, exist_ok=True)
self.__dict__['_logs_path'] = logs_path
return self.__dict__['_logs_path']
@property
def network(self):
return self.net.module if isinstance(self.net, nn.DataParallel) else self.net
def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
write_log(self.logs_path, log_str, prefix, should_print, mode)
""" save and load models """
def save_model(self, checkpoint=None, is_best=False, model_name=None):
if checkpoint is None:
checkpoint = {'state_dict': self.network.state_dict()}
if model_name is None:
model_name = 'checkpoint.pth.tar'
checkpoint['dataset'] = self.run_config.dataset # add `dataset` info to the checkpoint
latest_fname = os.path.join(self.save_path, 'latest.txt')
model_path = os.path.join(self.save_path, model_name)
with open(latest_fname, 'w') as fout:
fout.write(model_path + '\n')
torch.save(checkpoint, model_path)
if is_best:
best_path = os.path.join(self.save_path, 'model_best.pth.tar')
torch.save({'state_dict': checkpoint['state_dict']}, best_path)
def load_model(self, model_fname=None):
latest_fname = os.path.join(self.save_path, 'latest.txt')
if model_fname is None and os.path.exists(latest_fname):
with open(latest_fname, 'r') as fin:
model_fname = fin.readline()
if model_fname[-1] == '\n':
model_fname = model_fname[:-1]
# noinspection PyBroadException
try:
if model_fname is None or not os.path.exists(model_fname):
model_fname = '%s/checkpoint.pth.tar' % self.save_path
with open(latest_fname, 'w') as fout:
fout.write(model_fname + '\n')
print("=> loading checkpoint '{}'".format(model_fname))
checkpoint = torch.load(model_fname, map_location='cpu')
except Exception:
print('fail to load checkpoint from %s' % self.save_path)
return {}
self.network.load_state_dict(checkpoint['state_dict'])
if 'epoch' in checkpoint:
self.start_epoch = checkpoint['epoch'] + 1
if 'best_acc' in checkpoint:
self.best_acc = checkpoint['best_acc']
if 'optimizer' in checkpoint:
self.optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}'".format(model_fname))
return checkpoint
def save_config(self, extra_run_config=None, extra_net_config=None):
""" dump run_config and net_config to the model_folder """
run_save_path = os.path.join(self.path, 'run.config')
if not os.path.isfile(run_save_path):
run_config = self.run_config.config
if extra_run_config is not None:
run_config.update(extra_run_config)
json.dump(run_config, open(run_save_path, 'w'), indent=4)
print('Run configs dump to %s' % run_save_path)
try:
net_save_path = os.path.join(self.path, 'net.config')
net_config = self.network.config
if extra_net_config is not None:
net_config.update(extra_net_config)
json.dump(net_config, open(net_save_path, 'w'), indent=4)
print('Network configs dump to %s' % net_save_path)
except Exception:
print('%s do not support net config' % type(self.network))
""" metric related """
def get_metric_dict(self):
return {
'top1': AverageMeter(),
'top5': AverageMeter(),
}
def update_metric(self, metric_dict, output, labels):
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
metric_dict['top1'].update(acc1[0].item(), output.size(0))
metric_dict['top5'].update(acc5[0].item(), output.size(0))
def get_metric_vals(self, metric_dict, return_dict=False):
if return_dict:
return {
key: metric_dict[key].avg for key in metric_dict
}
else:
return [metric_dict[key].avg for key in metric_dict]
def get_metric_names(self):
return 'top1', 'top5'
""" train and test """
def validate(self, epoch=0, is_test=False, run_str='', net=None, data_loader=None, no_logs=False, train_mode=False):
if net is None:
net = self.net
if not isinstance(net, nn.DataParallel):
net = nn.DataParallel(net)
if data_loader is None:
data_loader = self.run_config.test_loader if is_test else self.run_config.valid_loader
if train_mode:
net.train()
else:
net.eval()
losses = AverageMeter()
metric_dict = self.get_metric_dict()
with torch.no_grad():
with tqdm(total=len(data_loader),
desc='Validate Epoch #{} {}'.format(epoch + 1, run_str), disable=no_logs) as t:
for i, (images, labels) in enumerate(data_loader):
images, labels = images.to(self.device), labels.to(self.device)
# compute output
output = net(images)
loss = self.test_criterion(output, labels)
# measure accuracy and record loss
self.update_metric(metric_dict, output, labels)
losses.update(loss.item(), images.size(0))
t.set_postfix({
'loss': losses.avg,
**self.get_metric_vals(metric_dict, return_dict=True),
'img_size': images.size(2),
})
t.update(1)
return losses.avg, self.get_metric_vals(metric_dict)
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
if net is None:
net = self.network
if isinstance(self.run_config.data_provider.image_size, list):
img_size_list, loss_list, top1_list, top5_list = [], [], [], []
for img_size in self.run_config.data_provider.image_size:
img_size_list.append(img_size)
self.run_config.data_provider.assign_active_img_size(img_size)
self.reset_running_statistics(net=net)
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
loss_list.append(loss)
top1_list.append(top1)
top5_list.append(top5)
return img_size_list, loss_list, top1_list, top5_list
else:
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
def train_one_epoch(self, args, epoch, warmup_epochs=0, warmup_lr=0):
# switch to train mode
self.net.train()
MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution
nBatch = len(self.run_config.train_loader)
losses = AverageMeter()
metric_dict = self.get_metric_dict()
data_time = AverageMeter()
with tqdm(total=nBatch,
desc='{} Train Epoch #{}'.format(self.run_config.dataset, epoch + 1)) as t:
end = time.time()
for i, (images, labels) in enumerate(self.run_config.train_loader):
MyRandomResizedCrop.BATCH = i
data_time.update(time.time() - end)
if epoch < warmup_epochs:
new_lr = self.run_config.warmup_adjust_learning_rate(
self.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
)
else:
new_lr = self.run_config.adjust_learning_rate(self.optimizer, epoch - warmup_epochs, i, nBatch)
images, labels = images.to(self.device), labels.to(self.device)
target = labels
if isinstance(self.run_config.mixup_alpha, float):
# transform data
lam = random.betavariate(self.run_config.mixup_alpha, self.run_config.mixup_alpha)
images = mix_images(images, lam)
labels = mix_labels(
labels, lam, self.run_config.data_provider.n_classes, self.run_config.label_smoothing
)
# soft target
if args.teacher_model is not None:
args.teacher_model.train()
with torch.no_grad():
soft_logits = args.teacher_model(images).detach()
soft_label = F.softmax(soft_logits, dim=1)
# compute output
output = self.net(images)
loss = self.train_criterion(output, labels)
if args.teacher_model is None:
loss_type = 'ce'
else:
if args.kd_type == 'ce':
kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
else:
kd_loss = F.mse_loss(output, soft_logits)
loss = args.kd_ratio * kd_loss + loss
loss_type = '%.1fkd+ce' % args.kd_ratio
# compute gradient and do SGD step
self.net.zero_grad() # or self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# measure accuracy and record loss
losses.update(loss.item(), images.size(0))
self.update_metric(metric_dict, output, target)
t.set_postfix({
'loss': losses.avg,
**self.get_metric_vals(metric_dict, return_dict=True),
'img_size': images.size(2),
'lr': new_lr,
'loss_type': loss_type,
'data_time': data_time.avg,
})
t.update(1)
end = time.time()
return losses.avg, self.get_metric_vals(metric_dict)
def train(self, args, warmup_epoch=0, warmup_lr=0):
for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epoch):
train_loss, (train_top1, train_top5) = self.train_one_epoch(args, epoch, warmup_epoch, warmup_lr)
if (epoch + 1) % self.run_config.validation_frequency == 0:
img_size, val_loss, val_acc, val_acc5 = self.validate_all_resolution(epoch=epoch, is_test=False)
is_best = np.mean(val_acc) > self.best_acc
self.best_acc = max(self.best_acc, np.mean(val_acc))
val_log = 'Valid [{0}/{1}]\tloss {2:.3f}\t{5} {3:.3f} ({4:.3f})'. \
format(epoch + 1 - warmup_epoch, self.run_config.n_epochs,
np.mean(val_loss), np.mean(val_acc), self.best_acc, self.get_metric_names()[0])
val_log += '\t{2} {0:.3f}\tTrain {1} {top1:.3f}\tloss {train_loss:.3f}\t'. \
format(np.mean(val_acc5), *self.get_metric_names(), top1=train_top1, train_loss=train_loss)
for i_s, v_a in zip(img_size, val_acc):
val_log += '(%d, %.3f), ' % (i_s, v_a)
self.write_log(val_log, prefix='valid', should_print=False)
else:
is_best = False
self.save_model({
'epoch': epoch,
'best_acc': self.best_acc,
'optimizer': self.optimizer.state_dict(),
'state_dict': self.network.state_dict(),
}, is_best=is_best)
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
from ofa.imagenet_classification.elastic_nn.utils import set_running_statistics
if net is None:
net = self.network
if data_loader is None:
data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
set_running_statistics(net, data_loader)

View File

@@ -0,0 +1,87 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import json
import torch
from ofa_local.utils import download_url
from ofa_local.imagenet_classification.networks import get_net_by_name, proxyless_base
from ofa_local.imagenet_classification.elastic_nn.networks import OFAMobileNetV3, OFAProxylessNASNets, OFAResNets
__all__ = [
'ofa_specialized', 'ofa_net',
'proxylessnas_net', 'proxylessnas_mobile', 'proxylessnas_cpu', 'proxylessnas_gpu',
]
def ofa_specialized(net_id, pretrained=True):
url_base = 'https://hanlab.mit.edu/files/OnceForAll/ofa_specialized/'
net_config = json.load(open(
download_url(url_base + net_id + '/net.config', model_dir='.torch/ofa_specialized/%s/' % net_id)
))
net = get_net_by_name(net_config['name']).build_from_config(net_config)
image_size = json.load(open(
download_url(url_base + net_id + '/run.config', model_dir='.torch/ofa_specialized/%s/' % net_id)
))['image_size']
if pretrained:
init = torch.load(
download_url(url_base + net_id + '/init', model_dir='.torch/ofa_specialized/%s/' % net_id),
map_location='cpu'
)['state_dict']
net.load_state_dict(init)
return net, image_size
def ofa_net(net_id, pretrained=True):
if net_id == 'ofa_proxyless_d234_e346_k357_w1.3':
net = OFAProxylessNASNets(
dropout_rate=0, width_mult=1.3, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4],
)
elif net_id == 'ofa_mbv3_d234_e346_k357_w1.0':
net = OFAMobileNetV3(
dropout_rate=0, width_mult=1.0, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4],
)
elif net_id == 'ofa_mbv3_d234_e346_k357_w1.2':
net = OFAMobileNetV3(
dropout_rate=0, width_mult=1.2, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4],
)
elif net_id == 'ofa_resnet50':
net = OFAResNets(
dropout_rate=0, depth_list=[0, 1, 2], expand_ratio_list=[0.2, 0.25, 0.35], width_mult_list=[0.65, 0.8, 1.0]
)
net_id = 'ofa_resnet50_d=0+1+2_e=0.2+0.25+0.35_w=0.65+0.8+1.0'
else:
raise ValueError('Not supported: %s' % net_id)
if pretrained:
url_base = 'https://hanlab.mit.edu/files/OnceForAll/ofa_nets/'
init = torch.load(
download_url(url_base + net_id, model_dir='.torch/ofa_nets'),
map_location='cpu')['state_dict']
net.load_state_dict(init)
return net
def proxylessnas_net(net_id, pretrained=True):
net = proxyless_base(
net_config='https://hanlab.mit.edu/files/proxylessNAS/%s.config' % net_id,
)
if pretrained:
net.load_state_dict(torch.load(
download_url('https://hanlab.mit.edu/files/proxylessNAS/%s.pth' % net_id), map_location='cpu'
)['state_dict'])
def proxylessnas_mobile(pretrained=True):
return proxylessnas_net('proxyless_mobile', pretrained)
def proxylessnas_cpu(pretrained=True):
return proxylessnas_net('proxyless_cpu', pretrained)
def proxylessnas_gpu(pretrained=True):
return proxylessnas_net('proxyless_gpu', pretrained)

View File

@@ -0,0 +1,7 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .acc_dataset import *
from .acc_predictor import *
from .arch_encoder import *

View File

@@ -0,0 +1,181 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import json
import numpy as np
from tqdm import tqdm
import torch
import torch.utils.data
from ofa.utils import list_mean
__all__ = ['net_setting2id', 'net_id2setting', 'AccuracyDataset']
def net_setting2id(net_setting):
return json.dumps(net_setting)
def net_id2setting(net_id):
return json.loads(net_id)
class RegDataset(torch.utils.data.Dataset):
def __init__(self, inputs, targets):
super(RegDataset, self).__init__()
self.inputs = inputs
self.targets = targets
def __getitem__(self, index):
return self.inputs[index], self.targets[index]
def __len__(self):
return self.inputs.size(0)
class AccuracyDataset:
def __init__(self, path):
self.path = path
os.makedirs(self.path, exist_ok=True)
@property
def net_id_path(self):
return os.path.join(self.path, 'net_id.dict')
@property
def acc_src_folder(self):
return os.path.join(self.path, 'src')
@property
def acc_dict_path(self):
return os.path.join(self.path, 'acc.dict')
# TODO: support parallel building
def build_acc_dataset(self, run_manager, ofa_network, n_arch=1000, image_size_list=None):
# load net_id_list, random sample if not exist
if os.path.isfile(self.net_id_path):
net_id_list = json.load(open(self.net_id_path))
else:
net_id_list = set()
while len(net_id_list) < n_arch:
net_setting = ofa_network.sample_active_subnet()
net_id = net_setting2id(net_setting)
net_id_list.add(net_id)
net_id_list = list(net_id_list)
net_id_list.sort()
json.dump(net_id_list, open(self.net_id_path, 'w'), indent=4)
image_size_list = [128, 160, 192, 224] if image_size_list is None else image_size_list
with tqdm(total=len(net_id_list) * len(image_size_list), desc='Building Acc Dataset') as t:
for image_size in image_size_list:
# load val dataset into memory
val_dataset = []
run_manager.run_config.data_provider.assign_active_img_size(image_size)
for images, labels in run_manager.run_config.valid_loader:
val_dataset.append((images, labels))
# save path
os.makedirs(self.acc_src_folder, exist_ok=True)
acc_save_path = os.path.join(self.acc_src_folder, '%d.dict' % image_size)
acc_dict = {}
# load existing acc dict
if os.path.isfile(acc_save_path):
existing_acc_dict = json.load(open(acc_save_path, 'r'))
else:
existing_acc_dict = {}
for net_id in net_id_list:
net_setting = net_id2setting(net_id)
key = net_setting2id({**net_setting, 'image_size': image_size})
if key in existing_acc_dict:
acc_dict[key] = existing_acc_dict[key]
t.set_postfix({
'net_id': net_id,
'image_size': image_size,
'info_val': acc_dict[key],
'status': 'loading',
})
t.update()
continue
ofa_network.set_active_subnet(**net_setting)
run_manager.reset_running_statistics(ofa_network)
net_setting_str = ','.join(['%s_%s' % (
key, '%.1f' % list_mean(val) if isinstance(val, list) else val
) for key, val in net_setting.items()])
loss, (top1, top5) = run_manager.validate(
run_str=net_setting_str, net=ofa_network, data_loader=val_dataset, no_logs=True,
)
info_val = top1
t.set_postfix({
'net_id': net_id,
'image_size': image_size,
'info_val': info_val,
})
t.update()
acc_dict.update({
key: info_val
})
json.dump(acc_dict, open(acc_save_path, 'w'), indent=4)
def merge_acc_dataset(self, image_size_list=None):
# load existing data
merged_acc_dict = {}
for fname in os.listdir(self.acc_src_folder):
if '.dict' not in fname:
continue
image_size = int(fname.split('.dict')[0])
if image_size_list is not None and image_size not in image_size_list:
print('Skip ', fname)
continue
full_path = os.path.join(self.acc_src_folder, fname)
partial_acc_dict = json.load(open(full_path))
merged_acc_dict.update(partial_acc_dict)
print('loaded %s' % full_path)
json.dump(merged_acc_dict, open(self.acc_dict_path, 'w'), indent=4)
return merged_acc_dict
def build_acc_data_loader(self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16):
# load data
acc_dict = json.load(open(self.acc_dict_path))
X_all = []
Y_all = []
with tqdm(total=len(acc_dict), desc='Loading data') as t:
for k, v in acc_dict.items():
dic = json.loads(k)
X_all.append(arch_encoder.arch2feature(dic))
Y_all.append(v / 100.) # range: 0 - 1
t.update()
base_acc = np.mean(Y_all)
# convert to torch tensor
X_all = torch.tensor(X_all, dtype=torch.float)
Y_all = torch.tensor(Y_all)
# random shuffle
shuffle_idx = torch.randperm(len(X_all))
X_all = X_all[shuffle_idx]
Y_all = Y_all[shuffle_idx]
# split data
idx = X_all.size(0) // 5 * 4 if n_training_sample is None else n_training_sample
val_idx = X_all.size(0) // 5 * 4
X_train, Y_train = X_all[:idx], Y_all[:idx]
X_test, Y_test = X_all[val_idx:], Y_all[val_idx:]
print('Train Size: %d,' % len(X_train), 'Valid Size: %d' % len(X_test))
# build data loader
train_dataset = RegDataset(X_train, Y_train)
val_dataset = RegDataset(X_test, Y_test)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=n_workers
)
valid_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, pin_memory=False, num_workers=n_workers
)
return train_loader, valid_loader, base_acc

View File

@@ -0,0 +1,50 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import numpy as np
import torch
import torch.nn as nn
__all__ = ['AccuracyPredictor']
class AccuracyPredictor(nn.Module):
def __init__(self, arch_encoder, hidden_size=400, n_layers=3,
checkpoint_path=None, device='cuda:0'):
super(AccuracyPredictor, self).__init__()
self.arch_encoder = arch_encoder
self.hidden_size = hidden_size
self.n_layers = n_layers
self.device = device
# build layers
layers = []
for i in range(self.n_layers):
layers.append(nn.Sequential(
nn.Linear(self.arch_encoder.n_dim if i == 0 else self.hidden_size, self.hidden_size),
nn.ReLU(inplace=True),
))
layers.append(nn.Linear(self.hidden_size, 1, bias=False))
self.layers = nn.Sequential(*layers)
self.base_acc = nn.Parameter(torch.zeros(1, device=self.device), requires_grad=False)
if checkpoint_path is not None and os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
self.load_state_dict(checkpoint)
print('Loaded checkpoint from %s' % checkpoint_path)
self.layers = self.layers.to(self.device)
def forward(self, x):
y = self.layers(x).squeeze()
return y + self.base_acc
def predict_acc(self, arch_dict_list):
X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list]
X = torch.tensor(np.array(X)).float().to(self.device)
return self.forward(X)

View File

@@ -0,0 +1,315 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import random
import numpy as np
from ofa.imagenet_classification.networks import ResNets
__all__ = ['MobileNetArchEncoder', 'ResNetArchEncoder']
class MobileNetArchEncoder:
SPACE_TYPE = 'mbv3'
def __init__(self, image_size_list=None, ks_list=None, expand_list=None, depth_list=None, n_stage=None):
self.image_size_list = [224] if image_size_list is None else image_size_list
self.ks_list = [3, 5, 7] if ks_list is None else ks_list
self.expand_list = [3, 4, 6] if expand_list is None else [int(expand) for expand in expand_list]
self.depth_list = [2, 3, 4] if depth_list is None else depth_list
if n_stage is not None:
self.n_stage = n_stage
elif self.SPACE_TYPE == 'mbv2':
self.n_stage = 6
elif self.SPACE_TYPE == 'mbv3':
self.n_stage = 5
else:
raise NotImplementedError
# build info dict
self.n_dim = 0
self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
self._build_info_dict(target='r')
self.k_info = dict(id2val=[], val2id=[], L=[], R=[])
self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
self._build_info_dict(target='k')
self._build_info_dict(target='e')
@property
def max_n_blocks(self):
if self.SPACE_TYPE == 'mbv3':
return self.n_stage * max(self.depth_list)
elif self.SPACE_TYPE == 'mbv2':
return (self.n_stage - 1) * max(self.depth_list) + 1
else:
raise NotImplementedError
def _build_info_dict(self, target):
if target == 'r':
target_dict = self.r_info
target_dict['L'].append(self.n_dim)
for img_size in self.image_size_list:
target_dict['val2id'][img_size] = self.n_dim
target_dict['id2val'][self.n_dim] = img_size
self.n_dim += 1
target_dict['R'].append(self.n_dim)
else:
if target == 'k':
target_dict = self.k_info
choices = self.ks_list
elif target == 'e':
target_dict = self.e_info
choices = self.expand_list
else:
raise NotImplementedError
for i in range(self.max_n_blocks):
target_dict['val2id'].append({})
target_dict['id2val'].append({})
target_dict['L'].append(self.n_dim)
for k in choices:
target_dict['val2id'][i][k] = self.n_dim
target_dict['id2val'][i][self.n_dim] = k
self.n_dim += 1
target_dict['R'].append(self.n_dim)
def arch2feature(self, arch_dict):
ks, e, d, r = arch_dict['ks'], arch_dict['e'], arch_dict['d'], arch_dict['image_size']
feature = np.zeros(self.n_dim)
for i in range(self.max_n_blocks):
nowd = i % max(self.depth_list)
stg = i // max(self.depth_list)
if nowd < d[stg]:
feature[self.k_info['val2id'][i][ks[i]]] = 1
feature[self.e_info['val2id'][i][e[i]]] = 1
feature[self.r_info['val2id'][r]] = 1
return feature
def feature2arch(self, feature):
img_sz = self.r_info['id2val'][
int(np.argmax(feature[self.r_info['L'][0]:self.r_info['R'][0]])) + self.r_info['L'][0]
]
assert img_sz in self.image_size_list
arch_dict = {'ks': [], 'e': [], 'd': [], 'image_size': img_sz}
d = 0
for i in range(self.max_n_blocks):
skip = True
for j in range(self.k_info['L'][i], self.k_info['R'][i]):
if feature[j] == 1:
arch_dict['ks'].append(self.k_info['id2val'][i][j])
skip = False
break
for j in range(self.e_info['L'][i], self.e_info['R'][i]):
if feature[j] == 1:
arch_dict['e'].append(self.e_info['id2val'][i][j])
assert not skip
skip = False
break
if skip:
arch_dict['e'].append(0)
arch_dict['ks'].append(0)
else:
d += 1
if (i + 1) % max(self.depth_list) == 0 or (i + 1) == self.max_n_blocks:
arch_dict['d'].append(d)
d = 0
return arch_dict
def random_sample_arch(self):
return {
'ks': random.choices(self.ks_list, k=self.max_n_blocks),
'e': random.choices(self.expand_list, k=self.max_n_blocks),
'd': random.choices(self.depth_list, k=self.n_stage),
'image_size': random.choice(self.image_size_list)
}
def mutate_resolution(self, arch_dict, mutate_prob):
if random.random() < mutate_prob:
arch_dict['image_size'] = random.choice(self.image_size_list)
return arch_dict
def mutate_arch(self, arch_dict, mutate_prob):
for i in range(self.max_n_blocks):
if random.random() < mutate_prob:
arch_dict['ks'][i] = random.choice(self.ks_list)
arch_dict['e'][i] = random.choice(self.expand_list)
for i in range(self.n_stage):
if random.random() < mutate_prob:
arch_dict['d'][i] = random.choice(self.depth_list)
return arch_dict
class ResNetArchEncoder:
def __init__(self, image_size_list=None, depth_list=None, expand_list=None, width_mult_list=None,
base_depth_list=None):
self.image_size_list = [224] if image_size_list is None else image_size_list
self.expand_list = [0.2, 0.25, 0.35] if expand_list is None else expand_list
self.depth_list = [0, 1, 2] if depth_list is None else depth_list
self.width_mult_list = [0.65, 0.8, 1.0] if width_mult_list is None else width_mult_list
self.base_depth_list = ResNets.BASE_DEPTH_LIST if base_depth_list is None else base_depth_list
"""" build info dict """
self.n_dim = 0
# resolution
self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
self._build_info_dict(target='r')
# input stem skip
self.input_stem_d_info = dict(id2val={}, val2id={}, L=[], R=[])
self._build_info_dict(target='input_stem_d')
# width_mult
self.width_mult_info = dict(id2val=[], val2id=[], L=[], R=[])
self._build_info_dict(target='width_mult')
# expand ratio
self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
self._build_info_dict(target='e')
@property
def n_stage(self):
return len(self.base_depth_list)
@property
def max_n_blocks(self):
return sum(self.base_depth_list) + self.n_stage * max(self.depth_list)
def _build_info_dict(self, target):
if target == 'r':
target_dict = self.r_info
target_dict['L'].append(self.n_dim)
for img_size in self.image_size_list:
target_dict['val2id'][img_size] = self.n_dim
target_dict['id2val'][self.n_dim] = img_size
self.n_dim += 1
target_dict['R'].append(self.n_dim)
elif target == 'input_stem_d':
target_dict = self.input_stem_d_info
target_dict['L'].append(self.n_dim)
for skip in [0, 1]:
target_dict['val2id'][skip] = self.n_dim
target_dict['id2val'][self.n_dim] = skip
self.n_dim += 1
target_dict['R'].append(self.n_dim)
elif target == 'e':
target_dict = self.e_info
choices = self.expand_list
for i in range(self.max_n_blocks):
target_dict['val2id'].append({})
target_dict['id2val'].append({})
target_dict['L'].append(self.n_dim)
for e in choices:
target_dict['val2id'][i][e] = self.n_dim
target_dict['id2val'][i][self.n_dim] = e
self.n_dim += 1
target_dict['R'].append(self.n_dim)
elif target == 'width_mult':
target_dict = self.width_mult_info
choices = list(range(len(self.width_mult_list)))
for i in range(self.n_stage + 2):
target_dict['val2id'].append({})
target_dict['id2val'].append({})
target_dict['L'].append(self.n_dim)
for w in choices:
target_dict['val2id'][i][w] = self.n_dim
target_dict['id2val'][i][self.n_dim] = w
self.n_dim += 1
target_dict['R'].append(self.n_dim)
def arch2feature(self, arch_dict):
d, e, w, r = arch_dict['d'], arch_dict['e'], arch_dict['w'], arch_dict['image_size']
input_stem_skip = 1 if d[0] > 0 else 0
d = d[1:]
feature = np.zeros(self.n_dim)
feature[self.r_info['val2id'][r]] = 1
feature[self.input_stem_d_info['val2id'][input_stem_skip]] = 1
for i in range(self.n_stage + 2):
feature[self.width_mult_info['val2id'][i][w[i]]] = 1
start_pt = 0
for i, base_depth in enumerate(self.base_depth_list):
depth = base_depth + d[i]
for j in range(start_pt, start_pt + depth):
feature[self.e_info['val2id'][j][e[j]]] = 1
start_pt += max(self.depth_list) + base_depth
return feature
def feature2arch(self, feature):
img_sz = self.r_info['id2val'][
int(np.argmax(feature[self.r_info['L'][0]:self.r_info['R'][0]])) + self.r_info['L'][0]
]
input_stem_skip = self.input_stem_d_info['id2val'][
int(np.argmax(feature[self.input_stem_d_info['L'][0]:self.input_stem_d_info['R'][0]])) +
self.input_stem_d_info['L'][0]
] * 2
assert img_sz in self.image_size_list
arch_dict = {'d': [input_stem_skip], 'e': [], 'w': [], 'image_size': img_sz}
for i in range(self.n_stage + 2):
arch_dict['w'].append(
self.width_mult_info['id2val'][i][
int(np.argmax(feature[self.width_mult_info['L'][i]:self.width_mult_info['R'][i]])) +
self.width_mult_info['L'][i]
]
)
d = 0
skipped = 0
stage_id = 0
for i in range(self.max_n_blocks):
skip = True
for j in range(self.e_info['L'][i], self.e_info['R'][i]):
if feature[j] == 1:
arch_dict['e'].append(self.e_info['id2val'][i][j])
skip = False
break
if skip:
arch_dict['e'].append(0)
skipped += 1
else:
d += 1
if i + 1 == self.max_n_blocks or (skipped + d) % \
(max(self.depth_list) + self.base_depth_list[stage_id]) == 0:
arch_dict['d'].append(d - self.base_depth_list[stage_id])
d, skipped = 0, 0
stage_id += 1
return arch_dict
def random_sample_arch(self):
return {
'd': [random.choice([0, 2])] + random.choices(self.depth_list, k=self.n_stage),
'e': random.choices(self.expand_list, k=self.max_n_blocks),
'w': random.choices(list(range(len(self.width_mult_list))), k=self.n_stage + 2),
'image_size': random.choice(self.image_size_list)
}
def mutate_resolution(self, arch_dict, mutate_prob):
if random.random() < mutate_prob:
arch_dict['image_size'] = random.choice(self.image_size_list)
return arch_dict
def mutate_arch(self, arch_dict, mutate_prob):
# input stem skip
if random.random() < mutate_prob:
arch_dict['d'][0] = random.choice([0, 2])
# depth
for i in range(1, len(arch_dict['d'])):
if random.random() < mutate_prob:
arch_dict['d'][i] = random.choice(self.depth_list)
# width_mult
for i in range(len(arch_dict['w'])):
if random.random() < mutate_prob:
arch_dict['w'][i] = random.choice(list(range(len(self.width_mult_list))))
# expand ratio
for i in range(len(arch_dict['e'])):
if random.random() < mutate_prob:
arch_dict['e'][i] = random.choice(self.expand_list)

View File

@@ -0,0 +1,71 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import os
import copy
from .latency_lookup_table import *
class BaseEfficiencyModel:
def __init__(self, ofa_net):
self.ofa_net = ofa_net
def get_active_subnet_config(self, arch_dict):
arch_dict = copy.deepcopy(arch_dict)
image_size = arch_dict.pop('image_size')
self.ofa_net.set_active_subnet(**arch_dict)
active_net_config = self.ofa_net.get_active_net_config()
return active_net_config, image_size
def get_efficiency(self, arch_dict):
raise NotImplementedError
class ProxylessNASFLOPsModel(BaseEfficiencyModel):
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return ProxylessNASLatencyTable.count_flops_given_config(active_net_config, image_size)
class Mbv3FLOPsModel(BaseEfficiencyModel):
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return MBv3LatencyTable.count_flops_given_config(active_net_config, image_size)
class ResNet50FLOPsModel(BaseEfficiencyModel):
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return ResNet50LatencyTable.count_flops_given_config(active_net_config, image_size)
class ProxylessNASLatencyModel(BaseEfficiencyModel):
def __init__(self, ofa_net, lookup_table_path_dict):
super(ProxylessNASLatencyModel, self).__init__(ofa_net)
self.latency_tables = {}
for image_size, path in lookup_table_path_dict.items():
self.latency_tables[image_size] = ProxylessNASLatencyTable(
local_dir='/tmp/.ofa_latency_tools/', url=os.path.join(path, '%d_lookup_table.yaml' % image_size))
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return self.latency_tables[image_size].predict_network_latency_given_config(active_net_config, image_size)
class Mbv3LatencyModel(BaseEfficiencyModel):
def __init__(self, ofa_net, lookup_table_path_dict):
super(Mbv3LatencyModel, self).__init__(ofa_net)
self.latency_tables = {}
for image_size, path in lookup_table_path_dict.items():
self.latency_tables[image_size] = MBv3LatencyTable(
local_dir='/tmp/.ofa_latency_tools/', url=os.path.join(path, '%d_lookup_table.yaml' % image_size))
def get_efficiency(self, arch_dict):
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
return self.latency_tables[image_size].predict_network_latency_given_config(active_net_config, image_size)

View File

@@ -0,0 +1,387 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
import yaml
from ofa.utils import download_url, make_divisible, MyNetwork
__all__ = ['count_conv_flop', 'ProxylessNASLatencyTable', 'MBv3LatencyTable', 'ResNet50LatencyTable']
def count_conv_flop(out_size, in_channels, out_channels, kernel_size, groups):
out_h = out_w = out_size
delta_ops = in_channels * out_channels * kernel_size * kernel_size * out_h * out_w / groups
return delta_ops
class LatencyTable(object):
def __init__(self, local_dir='~/.ofa/latency_tools/',
url='https://hanlab.mit.edu/files/proxylessNAS/LatencyTools/mobile_trim.yaml'):
if url.startswith('http'):
fname = download_url(url, local_dir, overwrite=True)
else:
fname = url
with open(fname, 'r') as fp:
self.lut = yaml.load(fp)
@staticmethod
def repr_shape(shape):
if isinstance(shape, (list, tuple)):
return 'x'.join(str(_) for _ in shape)
elif isinstance(shape, str):
return shape
else:
return TypeError
def query(self, **kwargs):
raise NotImplementedError
def predict_network_latency(self, net, image_size):
raise NotImplementedError
def predict_network_latency_given_config(self, net_config, image_size):
raise NotImplementedError
@staticmethod
def count_flops_given_config(net_config, image_size=224):
raise NotImplementedError
class ProxylessNASLatencyTable(LatencyTable):
def query(self, l_type: str, input_shape, output_shape, expand=None, ks=None, stride=None, id_skip=None):
"""
:param l_type:
Layer type must be one of the followings
1. `Conv`: The initial 3x3 conv with stride 2.
2. `Conv_1`: feature_mix_layer
3. `Logits`: All operations after `Conv_1`.
4. `expanded_conv`: MobileInvertedResidual
:param input_shape: input shape (h, w, #channels)
:param output_shape: output shape (h, w, #channels)
:param expand: expansion ratio
:param ks: kernel size
:param stride:
:param id_skip: indicate whether has the residual connection
"""
infos = [l_type, 'input:%s' % self.repr_shape(input_shape), 'output:%s' % self.repr_shape(output_shape), ]
if l_type in ('expanded_conv',):
assert None not in (expand, ks, stride, id_skip)
infos += ['expand:%d' % expand, 'kernel:%d' % ks, 'stride:%d' % stride, 'idskip:%d' % id_skip]
key = '-'.join(infos)
return self.lut[key]['mean']
def predict_network_latency(self, net, image_size=224):
predicted_latency = 0
# first conv
predicted_latency += self.query(
'Conv', [image_size, image_size, 3],
[(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels]
)
# blocks
fsize = (image_size + 1) // 2
for block in net.blocks:
mb_conv = block.conv
shortcut = block.shortcut
if mb_conv is None:
continue
if shortcut is None:
idskip = 0
else:
idskip = 1
out_fz = int((fsize - 1) / mb_conv.stride + 1) # fsize // mb_conv.stride
block_latency = self.query(
'expanded_conv', [fsize, fsize, mb_conv.in_channels], [out_fz, out_fz, mb_conv.out_channels],
expand=mb_conv.expand_ratio, ks=mb_conv.kernel_size, stride=mb_conv.stride, id_skip=idskip
)
predicted_latency += block_latency
fsize = out_fz
# feature mix layer
predicted_latency += self.query(
'Conv_1', [fsize, fsize, net.feature_mix_layer.in_channels],
[fsize, fsize, net.feature_mix_layer.out_channels]
)
# classifier
predicted_latency += self.query(
'Logits', [fsize, fsize, net.classifier.in_features], [net.classifier.out_features] # 1000
)
return predicted_latency
def predict_network_latency_given_config(self, net_config, image_size=224):
predicted_latency = 0
# first conv
predicted_latency += self.query(
'Conv', [image_size, image_size, 3],
[(image_size + 1) // 2, (image_size + 1) // 2, net_config['first_conv']['out_channels']]
)
# blocks
fsize = (image_size + 1) // 2
for block in net_config['blocks']:
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
shortcut = block['shortcut']
if mb_conv is None:
continue
if shortcut is None:
idskip = 0
else:
idskip = 1
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
block_latency = self.query(
'expanded_conv', [fsize, fsize, mb_conv['in_channels']], [out_fz, out_fz, mb_conv['out_channels']],
expand=mb_conv['expand_ratio'], ks=mb_conv['kernel_size'], stride=mb_conv['stride'], id_skip=idskip
)
predicted_latency += block_latency
fsize = out_fz
# feature mix layer
predicted_latency += self.query(
'Conv_1', [fsize, fsize, net_config['feature_mix_layer']['in_channels']],
[fsize, fsize, net_config['feature_mix_layer']['out_channels']]
)
# classifier
predicted_latency += self.query(
'Logits', [fsize, fsize, net_config['classifier']['in_features']],
[net_config['classifier']['out_features']] # 1000
)
return predicted_latency
@staticmethod
def count_flops_given_config(net_config, image_size=224):
flops = 0
# first conv
flops += count_conv_flop((image_size + 1) // 2, 3, net_config['first_conv']['out_channels'], 3, 1)
# blocks
fsize = (image_size + 1) // 2
for block in net_config['blocks']:
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
if mb_conv is None:
continue
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
if mb_conv['mid_channels'] is None:
mb_conv['mid_channels'] = round(mb_conv['in_channels'] * mb_conv['expand_ratio'])
if mb_conv['expand_ratio'] != 1:
# inverted bottleneck
flops += count_conv_flop(fsize, mb_conv['in_channels'], mb_conv['mid_channels'], 1, 1)
# depth conv
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['mid_channels'],
mb_conv['kernel_size'], mb_conv['mid_channels'])
# point linear
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['out_channels'], 1, 1)
fsize = out_fz
# feature mix layer
flops += count_conv_flop(fsize, net_config['feature_mix_layer']['in_channels'],
net_config['feature_mix_layer']['out_channels'], 1, 1)
# classifier
flops += count_conv_flop(1, net_config['classifier']['in_features'],
net_config['classifier']['out_features'], 1, 1)
return flops / 1e6 # MFLOPs
class MBv3LatencyTable(LatencyTable):
def query(self, l_type: str, input_shape, output_shape, mid=None, ks=None, stride=None, id_skip=None,
se=None, h_swish=None):
infos = [l_type, 'input:%s' % self.repr_shape(input_shape), 'output:%s' % self.repr_shape(output_shape), ]
if l_type in ('expanded_conv',):
assert None not in (mid, ks, stride, id_skip, se, h_swish)
infos += ['expand:%d' % mid, 'kernel:%d' % ks, 'stride:%d' % stride, 'idskip:%d' % id_skip,
'se:%d' % se, 'hs:%d' % h_swish]
key = '-'.join(infos)
return self.lut[key]['mean']
def predict_network_latency(self, net, image_size=224):
predicted_latency = 0
# first conv
predicted_latency += self.query(
'Conv', [image_size, image_size, 3],
[(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels]
)
# blocks
fsize = (image_size + 1) // 2
for block in net.blocks:
mb_conv = block.conv
shortcut = block.shortcut
if mb_conv is None:
continue
if shortcut is None:
idskip = 0
else:
idskip = 1
out_fz = int((fsize - 1) / mb_conv.stride + 1)
block_latency = self.query(
'expanded_conv', [fsize, fsize, mb_conv.in_channels], [out_fz, out_fz, mb_conv.out_channels],
mid=mb_conv.depth_conv.conv.in_channels, ks=mb_conv.kernel_size, stride=mb_conv.stride, id_skip=idskip,
se=1 if mb_conv.use_se else 0, h_swish=1 if mb_conv.act_func == 'h_swish' else 0,
)
predicted_latency += block_latency
fsize = out_fz
# final expand layer
predicted_latency += self.query(
'Conv_1', [fsize, fsize, net.final_expand_layer.in_channels],
[fsize, fsize, net.final_expand_layer.out_channels],
)
# global average pooling
predicted_latency += self.query(
'AvgPool2D', [fsize, fsize, net.final_expand_layer.out_channels],
[1, 1, net.final_expand_layer.out_channels],
)
# feature mix layer
predicted_latency += self.query(
'Conv_2', [1, 1, net.feature_mix_layer.in_channels],
[1, 1, net.feature_mix_layer.out_channels]
)
# classifier
predicted_latency += self.query(
'Logits', [1, 1, net.classifier.in_features], [net.classifier.out_features]
)
return predicted_latency
def predict_network_latency_given_config(self, net_config, image_size=224):
predicted_latency = 0
# first conv
predicted_latency += self.query(
'Conv', [image_size, image_size, 3],
[(image_size + 1) // 2, (image_size + 1) // 2, net_config['first_conv']['out_channels']]
)
# blocks
fsize = (image_size + 1) // 2
for block in net_config['blocks']:
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
shortcut = block['shortcut']
if mb_conv is None:
continue
if shortcut is None:
idskip = 0
else:
idskip = 1
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
if mb_conv['mid_channels'] is None:
mb_conv['mid_channels'] = round(mb_conv['in_channels'] * mb_conv['expand_ratio'])
block_latency = self.query(
'expanded_conv', [fsize, fsize, mb_conv['in_channels']], [out_fz, out_fz, mb_conv['out_channels']],
mid=mb_conv['mid_channels'], ks=mb_conv['kernel_size'], stride=mb_conv['stride'], id_skip=idskip,
se=1 if mb_conv['use_se'] else 0, h_swish=1 if mb_conv['act_func'] == 'h_swish' else 0,
)
predicted_latency += block_latency
fsize = out_fz
# final expand layer
predicted_latency += self.query(
'Conv_1', [fsize, fsize, net_config['final_expand_layer']['in_channels']],
[fsize, fsize, net_config['final_expand_layer']['out_channels']],
)
# global average pooling
predicted_latency += self.query(
'AvgPool2D', [fsize, fsize, net_config['final_expand_layer']['out_channels']],
[1, 1, net_config['final_expand_layer']['out_channels']],
)
# feature mix layer
predicted_latency += self.query(
'Conv_2', [1, 1, net_config['feature_mix_layer']['in_channels']],
[1, 1, net_config['feature_mix_layer']['out_channels']]
)
# classifier
predicted_latency += self.query(
'Logits', [1, 1, net_config['classifier']['in_features']], [net_config['classifier']['out_features']]
)
return predicted_latency
@staticmethod
def count_flops_given_config(net_config, image_size=224):
flops = 0
# first conv
flops += count_conv_flop((image_size + 1) // 2, 3, net_config['first_conv']['out_channels'], 3, 1)
# blocks
fsize = (image_size + 1) // 2
for block in net_config['blocks']:
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
if mb_conv is None:
continue
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
if mb_conv['mid_channels'] is None:
mb_conv['mid_channels'] = round(mb_conv['in_channels'] * mb_conv['expand_ratio'])
if mb_conv['expand_ratio'] != 1:
# inverted bottleneck
flops += count_conv_flop(fsize, mb_conv['in_channels'], mb_conv['mid_channels'], 1, 1)
# depth conv
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['mid_channels'],
mb_conv['kernel_size'], mb_conv['mid_channels'])
if mb_conv['use_se']:
# SE layer
se_mid = make_divisible(mb_conv['mid_channels'] // 4, divisor=MyNetwork.CHANNEL_DIVISIBLE)
flops += count_conv_flop(1, mb_conv['mid_channels'], se_mid, 1, 1)
flops += count_conv_flop(1, se_mid, mb_conv['mid_channels'], 1, 1)
# point linear
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['out_channels'], 1, 1)
fsize = out_fz
# final expand layer
flops += count_conv_flop(fsize, net_config['final_expand_layer']['in_channels'],
net_config['final_expand_layer']['out_channels'], 1, 1)
# feature mix layer
flops += count_conv_flop(1, net_config['feature_mix_layer']['in_channels'],
net_config['feature_mix_layer']['out_channels'], 1, 1)
# classifier
flops += count_conv_flop(1, net_config['classifier']['in_features'],
net_config['classifier']['out_features'], 1, 1)
return flops / 1e6 # MFLOPs
class ResNet50LatencyTable(LatencyTable):
def query(self, **kwargs):
raise NotImplementedError
def predict_network_latency(self, net, image_size):
raise NotImplementedError
def predict_network_latency_given_config(self, net_config, image_size):
raise NotImplementedError
@staticmethod
def count_flops_given_config(net_config, image_size=224):
flops = 0
# input stem
for layer_config in net_config['input_stem']:
if layer_config['name'] != 'ConvLayer':
layer_config = layer_config['conv']
in_channel = layer_config['in_channels']
out_channel = layer_config['out_channels']
out_image_size = int((image_size - 1) / layer_config['stride'] + 1)
flops += count_conv_flop(out_image_size, in_channel, out_channel,
layer_config['kernel_size'], layer_config.get('groups', 1))
image_size = out_image_size
# max pooling
image_size = int((image_size - 1) / 2 + 1)
# ResNetBottleneckBlocks
for block_config in net_config['blocks']:
in_channel = block_config['in_channels']
out_channel = block_config['out_channels']
out_image_size = int((image_size - 1) / block_config['stride'] + 1)
mid_channel = block_config['mid_channels'] if block_config['mid_channels'] is not None \
else round(out_channel * block_config['expand_ratio'])
mid_channel = make_divisible(mid_channel, MyNetwork.CHANNEL_DIVISIBLE)
# conv1
flops += count_conv_flop(image_size, in_channel, mid_channel, 1, 1)
# conv2
flops += count_conv_flop(out_image_size, mid_channel, mid_channel,
block_config['kernel_size'], block_config['groups'])
# conv3
flops += count_conv_flop(out_image_size, mid_channel, out_channel, 1, 1)
# downsample
if block_config['stride'] == 1 and in_channel == out_channel:
pass
else:
flops += count_conv_flop(out_image_size, in_channel, out_channel, 1, 1)
image_size = out_image_size
# final classifier
flops += count_conv_flop(1, net_config['classifier']['in_features'],
net_config['classifier']['out_features'], 1, 1)
return flops / 1e6 # MFLOPs

View File

@@ -0,0 +1,5 @@
# Once for All: Train One Network and Specialize it for Efficient Deployment
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
# International Conference on Learning Representations (ICLR), 2020.
from .evolution import *

Some files were not shown because too many files have changed in this diff Show More