upload
This commit is contained in:
169
sota/cnn/genotypes.py
Normal file
169
sota/cnn/genotypes.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from collections import namedtuple
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
|
||||
PRIMITIVES = [
|
||||
'none',
|
||||
'noise',
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_3x3',
|
||||
'dil_conv_5x5'
|
||||
]
|
||||
|
||||
|
||||
######## S1-S4 Space ########
|
||||
#### cifar10 s1 - s4
|
||||
|
||||
init_pt_s1_C10_0 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["avg_pool_3x3", 1], ["skip_connect", 2], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["dil_conv_5x5", 2], ["dil_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
init_pt_s1_C10_2 = Genotype(normal=[["skip_connect", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["max_pool_3x3", 0], ["dil_conv_3x3", 1], ["max_pool_3x3", 0], ["avg_pool_3x3", 1], ["sep_conv_3x3", 1], ["dil_conv_5x5", 3], ["dil_conv_5x5", 3], ["dil_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
init_pt_s2_C10_0 = Genotype(normal=[["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s2_C10_2 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["skip_connect", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 1], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s3_C10_0 = Genotype(normal=[["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s3_C10_2 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["skip_connect", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 1], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s4_C10_0 = Genotype(normal=[["sep_conv_3x3", 0], ["noise", 1], ["noise", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s4_C10_2 = Genotype(normal=[["sep_conv_3x3", 0], ["noise", 1], ["sep_conv_3x3", 1], ["noise", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["noise", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
|
||||
#### cifar100 s1 - s4
|
||||
init_pt_s1_C100_0 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["avg_pool_3x3", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["avg_pool_3x3", 1], ["dil_conv_5x5", 2]], reduce_concat=range(2, 6))
|
||||
init_pt_s1_C100_2 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["avg_pool_3x3", 1], ["dil_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_5x5", 3], ["dil_conv_5x5", 3], ["dil_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
init_pt_s2_C100_0 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s2_C100_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s3_C100_0 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s3_C100_2 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 3], ["sep_conv_3x3", 1], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s4_C100_0 = Genotype(normal=[["sep_conv_3x3", 0], ["noise", 1], ["sep_conv_3x3", 1], ["noise", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s4_C100_2 = Genotype(normal=[["noise", 0], ["sep_conv_3x3", 1], ["noise", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
#### svhn s1 - s4
|
||||
init_pt_s1_svhn_0 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["avg_pool_3x3", 1], ["dil_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_5x5", 3], ["dil_conv_5x5", 2], ["dil_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
init_pt_s1_svhn_2 = Genotype(normal=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["max_pool_3x3", 0], ["dil_conv_3x3", 1], ["max_pool_3x3", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_5x5", 3], ["avg_pool_3x3", 0], ["dil_conv_5x5", 3]], reduce_concat=range(2, 6))
|
||||
init_pt_s2_svhn_0 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s2_svhn_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s3_svhn_0 = Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s3_svhn_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s4_svhn_0 = Genotype(normal=[["noise", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["noise", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["noise", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s4_svhn_2 = Genotype(normal=[["sep_conv_3x3", 0], ["noise", 1], ["noise", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
|
||||
|
||||
######## DARTS Space ########
|
||||
|
||||
####init-100-N10
|
||||
init_pt_s5_C10_0_100_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
####global op gready
|
||||
global_pt_s5_C10_0_100_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
global_pt_s5_C10_1_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
global_pt_s5_C10_2_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
global_pt_s5_C10_3_100_N10 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
####2500_sample
|
||||
sample_2500_0 = Genotype(normal=[["dil_conv_5x5", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 3], ["sep_conv_5x5", 2], ["dil_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["dil_conv_3x3", 1], ["dil_conv_5x5", 1], ["sep_conv_5x5", 2], ["skip_connect", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["dil_conv_5x5", 3]], reduce_concat=range(2, 6))
|
||||
sample_2500_1 = Genotype(normal=[["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["dil_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["avg_pool_3x3", 2], ["dil_conv_5x5", 1], ["dil_conv_5x5", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
|
||||
sample_2500_2 = Genotype(normal=[["dil_conv_5x5", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 2], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["avg_pool_3x3", 1], ["sep_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["dil_conv_5x5", 2], ["dil_conv_5x5", 0], ["dil_conv_5x5", 2]], reduce_concat=range(2, 6))
|
||||
sample_2500_3 = Genotype(normal=[["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 1], ["dil_conv_3x3", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2], ["dil_conv_3x3", 0], ["dil_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["dil_conv_3x3", 1], ["dil_conv_5x5", 0], ["max_pool_3x3", 1], ["avg_pool_3x3", 0], ["max_pool_3x3", 1], ["avg_pool_3x3", 1], ["skip_connect", 3]], reduce_concat=range(2, 6))
|
||||
|
||||
|
||||
####20000_sample
|
||||
sample_20000_0 = Genotype(normal=[["skip_connect", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 1], ["skip_connect", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 2], ["dil_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_5x5", 1], ["dil_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["dil_conv_5x5", 0], ["sep_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
sample_20000_1 = Genotype(normal=[["skip_connect", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["dil_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["avg_pool_3x3", 2], ["dil_conv_5x5", 1], ["dil_conv_5x5", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
|
||||
sample_20000_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_5x5", 0], ["dil_conv_3x3", 1], ["skip_connect", 0], ["max_pool_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 3]], reduce_concat=range(2, 6))
|
||||
sample_20000_3 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["dil_conv_3x3", 2], ["dil_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 2], ["dil_conv_3x3", 1], ["sep_conv_3x3", 2]], reduce_concat=range(2, 6))
|
||||
|
||||
####50000_sample
|
||||
sample_50000_0 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["skip_connect", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["max_pool_3x3", 0], ["sep_conv_5x5", 1], ["avg_pool_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_3x3", 1], ["dil_conv_5x5", 0], ["max_pool_3x3", 1]], reduce_concat=range(2, 6))
|
||||
sample_50000_1 = Genotype(normal=[["dil_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 1], ["dil_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["max_pool_3x3", 1], ["dil_conv_3x3", 1], ["dil_conv_5x5", 2]], reduce_concat=range(2, 6))
|
||||
sample_50000_2 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["dil_conv_5x5", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_5x5", 0], ["dil_conv_3x3", 1], ["skip_connect", 0], ["max_pool_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 3]], reduce_concat=range(2, 6))
|
||||
sample_50000_3 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["skip_connect", 0], ["dil_conv_3x3", 2], ["dil_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 2], ["dil_conv_3x3", 1], ["sep_conv_3x3", 2]], reduce_concat=range(2, 6))
|
||||
|
||||
#### random
|
||||
random_max_0 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 3], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
random_max_1 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
random_max_2 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 3], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
random_max_3 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
#### ImageNet-1k
|
||||
init_pt_s5_in_0_100_N10=Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_in_1_100_N10=Genotype(normal=[["skip_connect", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 3], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["avg_pool_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_in_2_100_N10=Genotype(normal=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_3x3", 3], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["dil_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_in_3_100_N10=Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
|
||||
####N1
|
||||
init_pt_s5_C10_0_N1 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_N1 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_N1 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_N1 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
####N5
|
||||
|
||||
#####V1
|
||||
init_pt_s5_C10_0_1_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_1_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_1_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_1_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
#####V10
|
||||
init_pt_s5_C10_0_10_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_10_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_5x5", 1], ["sep_conv_5x5", 4]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_10_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["skip_connect", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_10_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
#####V100
|
||||
init_pt_s5_C10_0_100_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_100_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_100_N5 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 3], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_100_N5 = Genotype(normal=[["skip_connect", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
####N10
|
||||
|
||||
#####V1
|
||||
init_pt_s5_C10_0_1_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_1_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_1_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 0], ["dil_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_1_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3], ["sep_conv_5x5", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
#####V10
|
||||
init_pt_s5_C10_0_10_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 4]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_1_10_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_2_10_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], normal_concat=range(2, 6), reduce=[["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]], reduce_concat=range(2, 6))
|
||||
init_pt_s5_C10_3_10_N10 = Genotype(normal=[["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce=[["dil_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
|
||||
#fisher
|
||||
cf10_fisher = Genotype(normal=[["avg_pool_3x3", 0], ["avg_pool_3x3", 1], ["avg_pool_3x3", 0], ["dil_conv_3x3", 1],["avg_pool_3x3", 0], ["skip_connect", 2],["sep_conv_5x5", 0], ["dil_conv_3x3", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["max_pool_3x3", 0], ["max_pool_3x3", 2], ["sep_conv_3x3", 0], ["dil_conv_5x5", 3], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2]], reduce_concat=range(2, 6))
|
||||
#grasp
|
||||
cf10_grasp = Genotype(normal=[["avg_pool_3x3", 0], ["avg_pool_3x3", 1], ["skip_connect", 0], ["sep_conv_5x5", 1], ["dil_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_3x3", 0], ["skip_connect", 1], ["avg_pool_3x3", 0], ["skip_connect", 1], ["sep_conv_5x5", 0], ["skip_connect", 1], ["max_pool_3x3", 1], ["sep_conv_3x3", 3]], reduce_concat=range(2, 6))
|
||||
#jacob_cov
|
||||
cf10_jacob_cov = Genotype(normal=[["max_pool_3x3", 0], ["dil_conv_3x3", 1], ["dil_conv_3x3", 0], ["sep_conv_3x3", 2], ["dil_conv_3x3", 0], ["sep_conv_3x3", 3], ["sep_conv_5x5", 0], ["dil_conv_3x3", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_3x3", 0], ["max_pool_3x3", 1], ["max_pool_3x3", 0], ["avg_pool_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 3], ["dil_conv_3x3", 1], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
|
||||
#meco
|
||||
cf10_meco = Genotype(normal=[["dil_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["skip_connect", 1], ["dil_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 2], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce= [["dil_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["dil_conv_5x5", 0], ["dil_conv_5x5", 1], ["dil_conv_5x5", 1], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
|
||||
#synflow
|
||||
cf10_synflow = Genotype(normal= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]], reduce_concat=range(2, 6))
|
||||
#zico
|
||||
cf10_zico= Genotype(normal= [["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["skip_connect", 0], ["sep_conv_3x3", 2]], reduce_concat=range(2, 6))
|
||||
#snip
|
||||
cf10_snip = Genotype(normal= [["sep_conv_3x3", 0], ["avg_pool_3x3", 1], ["dil_conv_5x5", 0], ["sep_conv_5x5", 1], ["dil_conv_3x3", 1], ["sep_conv_3x3", 3], ["sep_conv_3x3", 2], ["sep_conv_5x5", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["avg_pool_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["skip_connect", 1], ["dil_conv_3x3", 0], ["sep_conv_3x3", 4]], reduce_concat=range(2, 6))
|
||||
|
||||
|
||||
#fisher
|
||||
cf100_fisher = Genotype(normal= [["sep_conv_3x3", 0], ["max_pool_3x3", 1], ["sep_conv_5x5", 0], ["max_pool_3x3", 1], ["dil_conv_3x3", 1], ["skip_connect", 3], ["dil_conv_5x5", 0], ["skip_connect", 1]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_3x3", 1], ["dil_conv_3x3", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["max_pool_3x3", 1], ["sep_conv_3x3", 4]] , reduce_concat=range(2, 6))
|
||||
#grasp
|
||||
cf100_grasp= Genotype(normal= [["max_pool_3x3", 0], ["avg_pool_3x3", 1], ["avg_pool_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["avg_pool_3x3", 0], ["sep_conv_3x3", 4]] , normal_concat=range(2, 6), reduce= [["max_pool_3x3", 0], ["sep_conv_3x3", 1], ["dil_conv_3x3", 0], ["dil_conv_3x3", 2], ["skip_connect", 0], ["dil_conv_3x3", 1], ["dil_conv_3x3", 1], ["sep_conv_3x3", 2]] , reduce_concat=range(2, 6))
|
||||
#jacob_cov
|
||||
cf100_jacob_cov = Genotype(normal= [["max_pool_3x3", 0], ["avg_pool_3x3", 1], ["dil_conv_3x3", 0], ["dil_conv_5x5", 1], ["avg_pool_3x3", 0], ["avg_pool_3x3", 3], ["dil_conv_5x5", 1], ["dil_conv_5x5", 4]], normal_concat=range(2, 6), reduce= [["skip_connect", 0], ["sep_conv_5x5", 1], ["avg_pool_3x3", 0], ["skip_connect", 2], ["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["dil_conv_3x3", 0], ["dil_conv_5x5", 1]] , reduce_concat=range(2, 6))
|
||||
#meco
|
||||
cf100_meco = Genotype(normal= [["dil_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 2], ["sep_conv_5x5", 2], ["sep_conv_3x3", 3], ["dil_conv_5x5", 0], ["sep_conv_3x3", 2]], normal_concat=range(2, 6), reduce= [["avg_pool_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3], ["dil_conv_3x3", 0], ["sep_conv_3x3", 1]] , reduce_concat=range(2, 6))
|
||||
#snip
|
||||
cf100_snip = Genotype(normal= [["sep_conv_5x5", 0], ["skip_connect", 1], ["sep_conv_3x3", 1], ["sep_conv_5x5", 2], ["skip_connect", 0], ["sep_conv_3x3", 2], ["dil_conv_3x3", 0], ["max_pool_3x3", 3]], normal_concat=range(2, 6), reduce= [["dil_conv_3x3", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["skip_connect", 2], ["skip_connect", 0], ["skip_connect", 2], ["dil_conv_5x5", 1], ["sep_conv_5x5", 2]] , reduce_concat=range(2, 6))
|
||||
#synflow
|
||||
cf100_synflow = Genotype(normal= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 1], ["sep_conv_5x5", 2], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]] , normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1]] , reduce_concat=range(2, 6))
|
||||
#zico
|
||||
cf100_zico = Genotype(normal= [["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["sep_conv_5x5", 1], ["sep_conv_3x3", 0], ["sep_conv_3x3", 3]], normal_concat=range(2, 6), reduce= [["sep_conv_5x5", 0], ["sep_conv_3x3", 1], ["sep_conv_5x5", 0], ["dil_conv_5x5", 1], ["sep_conv_5x5", 0], ["sep_conv_3x3", 2], ["sep_conv_3x3", 0], ["sep_conv_3x3", 1]] , reduce_concat=range(2, 6))
|
||||
|
40
sota/cnn/hdf5.py
Normal file
40
sota/cnn/hdf5.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import h5py
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class H5Dataset(Dataset):
|
||||
def __init__(self, h5_path, transform=None):
|
||||
self.h5_path = h5_path
|
||||
self.h5_file = None
|
||||
self.length = len(h5py.File(h5_path, 'r'))
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
#loading in getitem allows us to use multiple processes for data loading
|
||||
#because hdf5 files aren't pickelable so can't transfer them across processes
|
||||
# https://discuss.pytorch.org/t/hdf5-a-data-format-for-pytorch/40379
|
||||
# https://discuss.pytorch.org/t/dataloader-when-num-worker-0-there-is-bug/25643/16
|
||||
# TODO possible look at __getstate__ and __setstate__ as a more elegant solution
|
||||
if self.h5_file is None:
|
||||
self.h5_file = h5py.File(self.h5_path, 'r', libver="latest", swmr=True)
|
||||
|
||||
record = self.h5_file[str(index)]
|
||||
|
||||
if self.transform:
|
||||
x = Image.fromarray(record['data'][()])
|
||||
x = self.transform(x)
|
||||
else:
|
||||
x = torch.from_numpy(record['data'][()])
|
||||
|
||||
y = record['target'][()]
|
||||
y = torch.from_numpy(np.asarray(y))
|
||||
|
||||
return (x,y)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
336
sota/cnn/init_projection.py
Normal file
336
sota/cnn/init_projection.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import torch.utils
|
||||
from copy import deepcopy
|
||||
from foresight.pruners import *
|
||||
|
||||
torch.set_printoptions(precision=4, sci_mode=False)
|
||||
|
||||
def sample_op(model, input, target, args, cell_type, selected_eid=None):
|
||||
''' operation '''
|
||||
#### macros
|
||||
num_edges, num_ops = model.num_edges, model.num_ops
|
||||
candidate_flags = model.candidate_flags[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an edge
|
||||
if selected_eid is None:
|
||||
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
selected_eid = np.random.choice(remain_eids, size=1)[0]
|
||||
logging.info('selected edge: %d %s', selected_eid, cell_type)
|
||||
|
||||
select_opid = np.random.choice(np.array(range(num_ops)), size=1)[0]
|
||||
return selected_eid, select_opid
|
||||
|
||||
def project_op(model, input, target, args, cell_type, proj_queue=None, selected_eid=None):
|
||||
''' operation '''
|
||||
#### macros
|
||||
num_edges, num_ops = model.num_edges, model.num_ops
|
||||
candidate_flags = model.candidate_flags[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an edge
|
||||
if selected_eid is None:
|
||||
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
# print(num_edges, num_ops, remain_eids)
|
||||
if args.edge_decision == "random":
|
||||
selected_eid = np.random.choice(remain_eids, size=1)[0]
|
||||
logging.info('selected edge: %d %s', selected_eid, cell_type)
|
||||
elif args.edge_decision == 'reverse':
|
||||
selected_eid = remain_eids[-1]
|
||||
logging.info('selected edge: %d %s', selected_eid, cell_type)
|
||||
else:
|
||||
selected_eid = remain_eids[0]
|
||||
logging.info('selected node: %d %s', selected_eid, cell_type)
|
||||
|
||||
#### select the best operation
|
||||
if proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
else:
|
||||
crit_idx = 0
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
if args.dataset == 'cifar100':
|
||||
n_classes = 100
|
||||
elif args.dataset == 'imagenet16-120':
|
||||
n_classes = 120
|
||||
else:
|
||||
n_classes = 10
|
||||
|
||||
best_opid = 0
|
||||
crit_extrema = None
|
||||
crit_list = []
|
||||
op_ids = []
|
||||
for opid in range(num_ops):
|
||||
## projection
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
proj_mask = torch.ones_like(weights[selected_eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||
|
||||
# ## proj evaluation
|
||||
# with torch.no_grad():
|
||||
# valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
|
||||
# crit = valid_stats
|
||||
# crit_list.append(crit)
|
||||
# if crit_extrema is None or compare(crit, crit_extrema):
|
||||
# crit_extrema = crit
|
||||
# best_opid = opid
|
||||
|
||||
## proj evaluation
|
||||
if proj_crit == 'jacob':
|
||||
crit = Jocab_Score(model,cell_type, input, target, weights=weights)
|
||||
else:
|
||||
cache_weight = model.proj_weights[cell_type][selected_eid]
|
||||
cache_flag = model.candidate_flags[cell_type][selected_eid]
|
||||
|
||||
for idx in range(num_ops):
|
||||
if idx == opid:
|
||||
model.proj_weights[cell_type][selected_eid][opid] = 0
|
||||
else:
|
||||
model.proj_weights[cell_type][selected_eid][idx] = 1.0 / num_ops
|
||||
|
||||
model.candidate_flags[cell_type][selected_eid] = False
|
||||
# print(model.get_projected_weights())
|
||||
if proj_crit == 'comb':
|
||||
synflow = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=['synflow'])
|
||||
var = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=['var'])
|
||||
# print(synflow, var)
|
||||
comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1)
|
||||
measures = {'comb': comb}
|
||||
else:
|
||||
measures = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=[proj_crit])
|
||||
|
||||
# print(measures)
|
||||
for idx in range(num_ops):
|
||||
model.proj_weights[cell_type][selected_eid][idx] = 0
|
||||
model.candidate_flags[cell_type][selected_eid] = cache_flag
|
||||
crit = measures[proj_crit]
|
||||
|
||||
crit_list.append(crit)
|
||||
op_ids.append(opid)
|
||||
|
||||
best_opid = op_ids[np.nanargmin(crit_list)]
|
||||
|
||||
|
||||
|
||||
#### project
|
||||
logging.info('best opid: %d', best_opid)
|
||||
logging.info(crit_list)
|
||||
return selected_eid, best_opid
|
||||
|
||||
def project_global_op(model, input, target, args, infer, cell_type, selected_eid=None):
|
||||
''' operation '''
|
||||
#### macros
|
||||
num_edges, num_ops = model.num_edges, model.num_ops
|
||||
candidate_flags = model.candidate_flags[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
|
||||
#### select the best operation
|
||||
if proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
best_opid = 0
|
||||
crit_extrema = None
|
||||
best_eid = None
|
||||
for eid in remain_eids:
|
||||
for opid in range(num_ops):
|
||||
## projection
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
proj_mask = torch.ones_like(weights[eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[eid] = weights[eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
|
||||
#weights_dict = {cell_type:weights}
|
||||
with torch.no_grad():
|
||||
valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
|
||||
crit = valid_stats
|
||||
if crit_extrema is None or compare(crit, crit_extrema):
|
||||
crit_extrema = crit
|
||||
best_opid = opid
|
||||
best_eid = eid
|
||||
|
||||
#### project
|
||||
logging.info('best opid: %d', best_opid)
|
||||
#logging.info(crit_list)
|
||||
return best_eid, best_opid
|
||||
|
||||
def sample_edge(model, input, target, args, cell_type, selected_eid=None):
|
||||
''' topology '''
|
||||
#### macros
|
||||
candidate_flags = model.candidate_flags_edge[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an node
|
||||
remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
selected_nid = np.random.choice(remain_nids, size=1)[0]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
|
||||
eids = deepcopy(model.nid2eids[selected_nid])
|
||||
|
||||
while len(eids) > 2:
|
||||
elected_eid = np.random.choice(eids, size=1)[0]
|
||||
eids.remove(elected_eid)
|
||||
|
||||
return selected_nid, eids
|
||||
|
||||
def project_edge(model, input, target, args, cell_type):
|
||||
''' topology '''
|
||||
#### macros
|
||||
candidate_flags = model.candidate_flags_edge[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an node
|
||||
remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
if args.edge_decision == "random":
|
||||
selected_nid = np.random.choice(remain_nids, size=1)[0]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
elif args.edge_decision == 'reverse':
|
||||
selected_nid = remain_nids[-1]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
else:
|
||||
selected_nid = np.random.choice(remain_nids, size=1)[0]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
|
||||
#### select top2 edges
|
||||
if proj_crit == 'jacob':
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
else:
|
||||
crit_idx = 3
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
eids = deepcopy(model.nid2eids[selected_nid])
|
||||
crit_list = []
|
||||
while len(eids) > 2:
|
||||
eid_todel = None
|
||||
crit_extrema = None
|
||||
for eid in eids:
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
weights[eid].data.fill_(0)
|
||||
|
||||
## proj evaluation
|
||||
with torch.no_grad():
|
||||
valid_stats = Jocab_Score(model, cell_type, input, target, weights=weights)
|
||||
crit = valid_stats
|
||||
|
||||
crit_list.append(crit)
|
||||
if crit_extrema is None or not compare(crit, crit_extrema): # find out bad edges
|
||||
crit_extrema = crit
|
||||
eid_todel = eid
|
||||
|
||||
eids.remove(eid_todel)
|
||||
|
||||
#### project
|
||||
logging.info('top2 edges: (%d, %d)', eids[0], eids[1])
|
||||
#logging.info(crit_list)
|
||||
return selected_nid, eids
|
||||
|
||||
|
||||
def pt_project(train_queue, model, args):
|
||||
model.eval()
|
||||
|
||||
#### macros
|
||||
num_projs = model.num_edges + len(model.nid2eids.keys())
|
||||
args.proj_crit = {'normal':args.proj_crit_normal, 'reduce':args.proj_crit_reduce}
|
||||
proj_queue = train_queue
|
||||
|
||||
epoch = 0
|
||||
for step, (input, target) in enumerate(proj_queue):
|
||||
if epoch < model.num_edges:
|
||||
logging.info('project op')
|
||||
|
||||
if args.edge_decision == 'global_op_greedy':
|
||||
selected_eid_normal, best_opid_normal = project_global_op(model, input, target, args, cell_type='normal')
|
||||
elif args.edge_decision == 'sample':
|
||||
selected_eid_normal, best_opid_normal = sample_op(model, input, target, args, cell_type='normal')
|
||||
else:
|
||||
selected_eid_normal, best_opid_normal = project_op(model, input, target, args, proj_queue=proj_queue, cell_type='normal')
|
||||
model.project_op(selected_eid_normal, best_opid_normal, cell_type='normal')
|
||||
if args.edge_decision == 'global_op_greedy':
|
||||
selected_eid_reduce, best_opid_reduce = project_global_op(model, input, target, args, cell_type='reduce')
|
||||
elif args.edge_decision == 'sample':
|
||||
selected_eid_reduce, best_opid_reduce = sample_op(model, input, target, args, cell_type='reduce')
|
||||
else:
|
||||
selected_eid_reduce, best_opid_reduce = project_op(model, input, target, args, proj_queue=proj_queue, cell_type='reduce')
|
||||
model.project_op(selected_eid_reduce, best_opid_reduce, cell_type='reduce')
|
||||
|
||||
else:
|
||||
logging.info('project edge')
|
||||
if args.edge_decision == 'sample':
|
||||
selected_nid_normal, eids_normal = sample_edge(model, input, target, args, cell_type='normal')
|
||||
model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
|
||||
selected_nid_reduce, eids_reduce = sample_edge(model, input, target, args, cell_type='reduce')
|
||||
model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
|
||||
else:
|
||||
selected_nid_normal, eids_normal = project_edge(model, input, target, args, cell_type='normal')
|
||||
model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
|
||||
selected_nid_reduce, eids_reduce = project_edge(model, input, target, args, cell_type='reduce')
|
||||
model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
|
||||
epoch+=1
|
||||
|
||||
if epoch == num_projs:
|
||||
break
|
||||
|
||||
return
|
||||
|
||||
def Jocab_Score(ori_model, cell_type, input, target, weights=None):
|
||||
model = deepcopy(ori_model)
|
||||
model.eval()
|
||||
if cell_type == 'reduce':
|
||||
model.proj_weights['reduce'] = weights
|
||||
model.proj_weights['normal'] = model.get_projected_weights('normal')
|
||||
else:
|
||||
model.proj_weights['normal'] = weights
|
||||
model.proj_weights['reduce'] = model.get_projected_weights('reduce')
|
||||
|
||||
batch_size = input.shape[0]
|
||||
model.K = torch.zeros(batch_size, batch_size).cuda()
|
||||
def counting_forward_hook(module, inp, out):
|
||||
try:
|
||||
if isinstance(inp, tuple):
|
||||
inp = inp[0]
|
||||
inp = inp.view(inp.size(0), -1)
|
||||
x = (inp > 0).float()
|
||||
K = x @ x.t()
|
||||
K2 = (1.-x) @ (1.-x.t())
|
||||
model.K = model.K + K + K2
|
||||
except:
|
||||
pass
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if 'ReLU' in str(type(module)):
|
||||
module.register_forward_hook(counting_forward_hook)
|
||||
|
||||
input = input.cuda()
|
||||
|
||||
model(input, using_proj=True)
|
||||
score = hooklogdet(model.K.cpu().numpy())
|
||||
|
||||
del model
|
||||
return score
|
||||
|
||||
def hooklogdet(K, labels=None):
|
||||
s, ld = np.linalg.slogdet(K)
|
||||
return ld
|
133
sota/cnn/model.py
Normal file
133
sota/cnn/model.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sota.cnn.operations import *
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
from nasbench201.utils import drop_path
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
|
||||
|
||||
if reduction:
|
||||
op_names, indices = zip(*genotype.reduce)
|
||||
concat = genotype.reduce_concat
|
||||
else:
|
||||
op_names, indices = zip(*genotype.normal)
|
||||
concat = genotype.normal_concat
|
||||
self._compile(C, op_names, indices, concat, reduction)
|
||||
|
||||
def _compile(self, C, op_names, indices, concat, reduction):
|
||||
assert len(op_names) == len(indices)
|
||||
self._steps = len(op_names) // 2
|
||||
self._concat = concat
|
||||
self.multiplier = len(concat)
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for name, index in zip(op_names, indices):
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
op = OPS[name](C, stride, True)
|
||||
self._ops += [op]
|
||||
self._indices = indices
|
||||
|
||||
def forward(self, s0, s1, drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2*i]]
|
||||
h2 = states[self._indices[2*i+1]]
|
||||
op1 = self._ops[2*i]
|
||||
op2 = self._ops[2*i+1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
s = h1 + h2
|
||||
states += [s]
|
||||
return torch.cat([states[i] for i in self._concat], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHead(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHead, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
# image size = 2 x 2
|
||||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(Network, self).__init__()
|
||||
self._layers = layers
|
||||
self._auxiliary = auxiliary
|
||||
|
||||
stem_multiplier = 3
|
||||
C_curr = stem_multiplier*C
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = False
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
|
||||
if i == 2*layers//3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, num_classes)
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
logits_aux = None
|
||||
s0 = s1 = self.stem(input)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == 2*self._layers//3:
|
||||
if self._auxiliary and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0), -1))
|
||||
return logits, logits_aux
|
150
sota/cnn/model_imagenet.py
Normal file
150
sota/cnn/model_imagenet.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
# from optimizers.darts.operations import *
|
||||
from sota.cnn.operations import *
|
||||
#from optimizers.darts.utils import drop_path
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1.-drop_prob
|
||||
mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
|
||||
x.div_(keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
print(C_prev_prev, C_prev, C)
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
|
||||
if reduction:
|
||||
op_names, indices = zip(*genotype.reduce)
|
||||
concat = genotype.reduce_concat
|
||||
else:
|
||||
op_names, indices = zip(*genotype.normal)
|
||||
concat = genotype.normal_concat
|
||||
self._compile(C, op_names, indices, concat, reduction)
|
||||
|
||||
def _compile(self, C, op_names, indices, concat, reduction):
|
||||
assert len(op_names) == len(indices)
|
||||
self._steps = len(op_names) // 2
|
||||
self._concat = concat
|
||||
self.multiplier = len(concat)
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for name, index in zip(op_names, indices):
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
op = OPS[name](C, stride, True)
|
||||
self._ops += [op]
|
||||
self._indices = indices
|
||||
|
||||
def forward(self, s0, s1, drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2 * i]]
|
||||
h2 = states[self._indices[2 * i + 1]]
|
||||
op1 = self._ops[2 * i]
|
||||
op2 = self._ops[2 * i + 1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
s = h1 + h2
|
||||
states += [s]
|
||||
return torch.cat([states[i] for i in self._concat], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHeadImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 14x14"""
|
||||
super(AuxiliaryHeadImageNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
|
||||
# Commenting it out for consistency with the experiments in the paper.
|
||||
# nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
|
||||
|
||||
class NetworkImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, num_classes, layers, auxiliary, genotype):
|
||||
super(NetworkImageNet, self).__init__()
|
||||
self._layers = layers
|
||||
self._auxiliary = auxiliary
|
||||
self.drop_path_prob = 0.0
|
||||
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C, C, C
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
reduction_prev = True
|
||||
for i in range(layers):
|
||||
if i in [layers // 3, 2 * layers // 3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
|
||||
if i == 2 * layers // 3:
|
||||
C_to_auxiliary = C_prev
|
||||
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
|
||||
self.global_pooling = nn.AvgPool2d(7)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
def forward(self, input):
|
||||
logits_aux = None
|
||||
s0 = self.stem0(input)
|
||||
s1 = self.stem1(s0)
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == 2 * self._layers // 3:
|
||||
if self._auxiliary and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0), -1))
|
||||
return logits, logits_aux
|
288
sota/cnn/model_search.py
Normal file
288
sota/cnn/model_search.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
from sota.cnn.operations import *
|
||||
from sota.cnn.genotypes import Genotype
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
from nasbench201.utils import drop_path
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
def __init__(self, C, stride, PRIMITIVES):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in PRIMITIVES:
|
||||
op = OPS[primitive](C, stride, False)
|
||||
if 'pool' in primitive:
|
||||
op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
|
||||
self._ops.append(op)
|
||||
|
||||
def forward(self, x, weights):
|
||||
ret = sum(w * op(x, block_input=True) if w == 0 else w * op(x) for w, op in zip(weights, self._ops) if w != 0)
|
||||
return ret
|
||||
|
||||
|
||||
class Cell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(Cell, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.primitives = self.PRIMITIVES['primitives_reduct' if reduction else 'primitives_normal']
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
self._bns = nn.ModuleList()
|
||||
|
||||
edge_index = 0
|
||||
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride, self.primitives[edge_index])
|
||||
self._ops.append(op)
|
||||
edge_index += 1
|
||||
|
||||
def forward(self, s0, s1, weights, drop_prob=0.):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
offset = 0
|
||||
for i in range(self._steps):
|
||||
if drop_prob > 0. and self.training:
|
||||
s = sum(drop_path(self._ops[offset+j](h, weights[offset+j]), drop_prob) for j, h in enumerate(states))
|
||||
else:
|
||||
s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
|
||||
offset += len(states)
|
||||
states.append(s)
|
||||
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
class Network(nn.Module):
|
||||
def __init__(self, C, num_classes, layers, criterion, primitives, args,
|
||||
steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0, nettype='cifar'):
|
||||
super(Network, self).__init__()
|
||||
#### original code
|
||||
self._C = C
|
||||
self._num_classes = num_classes
|
||||
self._layers = layers
|
||||
self._criterion = criterion
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.drop_path_prob = drop_path_prob
|
||||
self.nettype = nettype
|
||||
|
||||
nn.Module.PRIMITIVES = primitives; self.op_names = primitives
|
||||
|
||||
C_curr = stem_multiplier*C
|
||||
if self.nettype == 'cifar':
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr)
|
||||
)
|
||||
else:
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(3, C_curr // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C_curr // 2, C_curr, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr),
|
||||
)
|
||||
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C_curr, C_curr, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C_curr),
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
self.cells = nn.ModuleList()
|
||||
if self.nettype == 'cifar':
|
||||
reduction_prev = False
|
||||
else:
|
||||
reduction_prev = True
|
||||
for i in range(layers):
|
||||
if i in [layers//3, 2*layers//3]:
|
||||
C_curr *= 2
|
||||
reduction = True
|
||||
else:
|
||||
reduction = False
|
||||
cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, multiplier*C_curr
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
self._initialize_alphas()
|
||||
|
||||
#### optimizer
|
||||
self._args = args
|
||||
self.optimizer = torch.optim.SGD(
|
||||
self.get_weights(),
|
||||
args.learning_rate,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
nesterov= args.nesterov)
|
||||
|
||||
|
||||
def reset_optimizer(self, lr, momentum, weight_decay):
|
||||
del self.optimizer
|
||||
self.optimizer = torch.optim.SGD(
|
||||
self.get_weights(),
|
||||
lr,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay)
|
||||
|
||||
def _loss(self, input, target, return_logits=False):
|
||||
logits = self(input)
|
||||
loss = self._criterion(logits, target)
|
||||
return (loss, logits) if return_logits else loss
|
||||
|
||||
def _initialize_alphas(self):
|
||||
k = sum(1 for i in range(self._steps) for n in range(2+i))
|
||||
num_ops = len(self.PRIMITIVES['primitives_normal'][0])
|
||||
self.num_edges = k
|
||||
self.num_ops = num_ops
|
||||
|
||||
self.alphas_normal = self._initialize_alphas_numpy(k, num_ops)
|
||||
self.alphas_reduce = self._initialize_alphas_numpy(k, num_ops)
|
||||
self._arch_parameters = [ # must be in this order!
|
||||
self.alphas_normal,
|
||||
self.alphas_reduce,
|
||||
]
|
||||
|
||||
def _initialize_alphas_numpy(self, k, num_ops):
|
||||
''' init from specified arch '''
|
||||
return Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
|
||||
|
||||
def forward(self, input):
|
||||
weights = self.get_softmax()
|
||||
weights_normal = weights['normal']
|
||||
weights_reduce = weights['reduce']
|
||||
|
||||
if self.nettype == 'cifar':
|
||||
s0 = s1 = self.stem(input)
|
||||
else:
|
||||
print('imagetnet')
|
||||
s0 = self.stem0(input)
|
||||
s1 = self.stem1(s0)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
weights = weights_reduce
|
||||
else:
|
||||
weights = weights_normal
|
||||
|
||||
s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
|
||||
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0),-1))
|
||||
|
||||
return logits
|
||||
|
||||
def step(self, input, target, args, shared=None):
|
||||
assert shared is None, 'gradient sharing disabled'
|
||||
|
||||
Lt, logit_t = self._loss(input, target, return_logits=True)
|
||||
Lt.backward()
|
||||
|
||||
nn.utils.clip_grad_norm_(self.get_weights(), args.grad_clip)
|
||||
self.optimizer.step()
|
||||
|
||||
return logit_t, Lt
|
||||
|
||||
#### utils
|
||||
def set_arch_parameters(self, new_alphas):
|
||||
for alpha, new_alpha in zip(self.arch_parameters(), new_alphas):
|
||||
alpha.data.copy_(new_alpha.data)
|
||||
|
||||
def get_softmax(self):
|
||||
weights_normal = F.softmax(self.alphas_normal, dim=-1)
|
||||
weights_reduce = F.softmax(self.alphas_reduce, dim=-1)
|
||||
return {'normal':weights_normal, 'reduce':weights_reduce}
|
||||
|
||||
def printing(self, logging, option='all'):
|
||||
weights = self.get_softmax()
|
||||
if option in ['all', 'normal']:
|
||||
weights_normal = weights['normal']
|
||||
logging.info(weights_normal)
|
||||
if option in ['all', 'reduce']:
|
||||
weights_reduce = weights['reduce']
|
||||
logging.info(weights_reduce)
|
||||
|
||||
def arch_parameters(self):
|
||||
return self._arch_parameters
|
||||
|
||||
def get_weights(self):
|
||||
return self.parameters()
|
||||
|
||||
def new(self):
|
||||
model_new = Network(self._C, self._num_classes, self._layers, self._criterion, self.PRIMITIVES, self._args,\
|
||||
drop_path_prob=self.drop_path_prob).cuda()
|
||||
for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
|
||||
x.data.copy_(y.data)
|
||||
return model_new
|
||||
|
||||
def clip(self):
|
||||
for p in self.arch_parameters():
|
||||
for line in p:
|
||||
max_index = line.argmax()
|
||||
line.data.clamp_(0, 1)
|
||||
if line.sum() == 0.0:
|
||||
line.data[max_index] = 1.0
|
||||
line.data.div_(line.sum())
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights, normal=True):
|
||||
PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct'] ## two are equal for Darts space
|
||||
|
||||
gene = []
|
||||
n = 2
|
||||
start = 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
|
||||
try:
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2]
|
||||
except ValueError: # This error happens when the 'none' op is not present in the ops
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]
|
||||
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if 'none' in PRIMITIVES[j]:
|
||||
if k != PRIMITIVES[j].index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
else:
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[start+j][k_best], j))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy(), True)
|
||||
gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy(), False)
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=gene_reduce, reduce_concat=concat
|
||||
)
|
||||
return genotype
|
213
sota/cnn/model_search_darts_proj.py
Normal file
213
sota/cnn/model_search_darts_proj.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
|
||||
from sota.cnn.operations import *
|
||||
from sota.cnn.genotypes import Genotype
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
from sota.cnn.model_search import Network
|
||||
|
||||
class DartsNetworkProj(Network):
|
||||
def __init__(self, C, num_classes, layers, criterion, primitives, args,
|
||||
steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0.0):
|
||||
super(DartsNetworkProj, self).__init__(C, num_classes, layers, criterion, primitives, args,
|
||||
steps=steps, multiplier=multiplier, stem_multiplier=stem_multiplier, drop_path_prob=drop_path_prob)
|
||||
|
||||
self._initialize_flags()
|
||||
self._initialize_proj_weights()
|
||||
self._initialize_topology_dicts()
|
||||
|
||||
#### proj flags
|
||||
def _initialize_topology_dicts(self):
|
||||
self.nid2eids = {0:[2,3,4], 1:[5,6,7,8], 2:[9,10,11,12,13]}
|
||||
self.nid2selected_eids = {
|
||||
'normal': {0:[],1:[],2:[]},
|
||||
'reduce': {0:[],1:[],2:[]},
|
||||
}
|
||||
|
||||
def _initialize_flags(self):
|
||||
self.candidate_flags = {
|
||||
'normal':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
'reduce':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
} # must be in this order
|
||||
self.candidate_flags_edge = {
|
||||
'normal': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
'reduce': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
}
|
||||
|
||||
def _initialize_proj_weights(self):
|
||||
''' data structures used for proj '''
|
||||
if isinstance(self.alphas_normal, list):
|
||||
alphas_normal = torch.stack(self.alphas_normal, dim=0)
|
||||
alphas_reduce = torch.stack(self.alphas_reduce, dim=0)
|
||||
else:
|
||||
alphas_normal = self.alphas_normal
|
||||
alphas_reduce = self.alphas_reduce
|
||||
|
||||
self.proj_weights = { # for hard/soft assignment after project
|
||||
'normal': torch.zeros_like(alphas_normal),
|
||||
'reduce': torch.zeros_like(alphas_reduce),
|
||||
}
|
||||
|
||||
#### proj function
|
||||
def project_op(self, eid, opid, cell_type):
|
||||
self.proj_weights[cell_type][eid][opid] = 1 ## hard by default
|
||||
self.candidate_flags[cell_type][eid] = False
|
||||
|
||||
def project_edge(self, nid, eids, cell_type):
|
||||
for eid in self.nid2eids[nid]:
|
||||
if eid not in eids: # not top2
|
||||
self.proj_weights[cell_type][eid].data.fill_(0)
|
||||
self.nid2selected_eids[cell_type][nid] = deepcopy(eids)
|
||||
self.candidate_flags_edge[cell_type][nid] = False
|
||||
|
||||
#### critical function
|
||||
def get_projected_weights(self, cell_type):
|
||||
''' used in forward and genotype '''
|
||||
weights = self.get_softmax()[cell_type]
|
||||
|
||||
## proj op
|
||||
for eid in range(self.num_edges):
|
||||
if not self.candidate_flags[cell_type][eid]:
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
## proj edge
|
||||
for nid in self.nid2eids:
|
||||
if not self.candidate_flags_edge[cell_type][nid]: ## projected node
|
||||
for eid in self.nid2eids[nid]:
|
||||
if eid not in self.nid2selected_eids[cell_type][nid]:
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
return weights
|
||||
|
||||
def get_all_projected_weights(self, cell_type):
|
||||
weights = self.get_softmax()[cell_type]
|
||||
|
||||
for eid in range(self.num_edges):
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
for nid in self.nid2eids:
|
||||
for eid in self.nid2eids[nid]:
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
return weights
|
||||
|
||||
def forward(self, input, weights_dict=None, using_proj=False):
|
||||
if using_proj:
|
||||
weights_normal = self.get_all_projected_weights('normal')
|
||||
weights_reduce = self.get_all_projected_weights('reduce')
|
||||
else:
|
||||
if weights_dict is None or 'normal' not in weights_dict:
|
||||
weights_normal = self.get_projected_weights('normal')
|
||||
else:
|
||||
weights_normal = weights_dict['normal']
|
||||
if weights_dict is None or 'reduce' not in weights_dict:
|
||||
weights_reduce = self.get_projected_weights('reduce')
|
||||
else:
|
||||
weights_reduce = weights_dict['reduce']
|
||||
|
||||
|
||||
|
||||
s0 = s1 = self.stem(input)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
weights = weights_reduce
|
||||
else:
|
||||
weights = weights_normal
|
||||
|
||||
s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
|
||||
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0),-1))
|
||||
|
||||
return logits
|
||||
|
||||
def reset_arch_parameters(self):
|
||||
self._initialize_flags()
|
||||
self._initialize_proj_weights()
|
||||
self._initialize_topology_dicts()
|
||||
|
||||
#### utils
|
||||
def printing(self, logging, option='all'):
|
||||
weights_normal = self.get_projected_weights('normal')
|
||||
weights_reduce = self.get_projected_weights('reduce')
|
||||
|
||||
if option in ['all', 'normal']:
|
||||
logging.info('\n%s', weights_normal)
|
||||
if option in ['all', 'reduce']:
|
||||
logging.info('\n%s', weights_reduce)
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights, normal=True):
|
||||
PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct']
|
||||
|
||||
gene = []
|
||||
n = 2
|
||||
start = 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
|
||||
try:
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2]
|
||||
except ValueError:
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]
|
||||
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if 'none' in PRIMITIVES[j]:
|
||||
if k != PRIMITIVES[j].index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
else:
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[start+j][k_best], j))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
weights_normal = self.get_projected_weights('normal')
|
||||
weights_reduce = self.get_projected_weights('reduce')
|
||||
gene_normal = _parse(weights_normal.data.cpu().numpy(), True)
|
||||
gene_reduce = _parse(weights_reduce.data.cpu().numpy(), False)
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=gene_reduce, reduce_concat=concat
|
||||
)
|
||||
return genotype
|
||||
|
||||
def get_state_dict(self, epoch, architect, scheduler):
|
||||
model_state_dict = {
|
||||
'epoch': epoch, ## no +1 because we are saving before projection / at the beginning of an epoch
|
||||
'state_dict': self.state_dict(),
|
||||
'alpha': self.arch_parameters(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'arch_optimizer': architect.optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
#### projection
|
||||
'nid2eids': self.nid2eids,
|
||||
'nid2selected_eids': self.nid2selected_eids,
|
||||
'candidate_flags': self.candidate_flags,
|
||||
'candidate_flags_edge': self.candidate_flags_edge,
|
||||
'proj_weights': self.proj_weights,
|
||||
}
|
||||
return model_state_dict
|
||||
|
||||
def set_state_dict(self, architect, scheduler, checkpoint):
|
||||
#### common
|
||||
self.load_state_dict(checkpoint['state_dict'])
|
||||
self.set_arch_parameters(checkpoint['alpha'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
architect.optimizer.load_state_dict(checkpoint['arch_optimizer'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
|
||||
#### projection
|
||||
self.nid2eids = checkpoint['nid2eids']
|
||||
self.nid2selected_eids = checkpoint['nid2selected_eids']
|
||||
self.candidate_flags = checkpoint['candidate_flags']
|
||||
self.candidate_flags_edge = checkpoint['candidate_flags_edge']
|
||||
self.proj_weights = checkpoint['proj_weights']
|
214
sota/cnn/model_search_imagenet_proj.py
Normal file
214
sota/cnn/model_search_imagenet_proj.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
import torch.nn as nn
|
||||
from sota.cnn.operations import *
|
||||
from sota.cnn.genotypes import Genotype
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
from sota.cnn.model_search import Network
|
||||
|
||||
class ImageNetNetworkProj(Network):
|
||||
def __init__(self, C, num_classes, layers, criterion, primitives, args,
|
||||
steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0.0, nettype='imagenet'):
|
||||
super(ImageNetNetworkProj, self).__init__(C, num_classes, layers, criterion, primitives, args,
|
||||
steps=steps, multiplier=multiplier, stem_multiplier=stem_multiplier, drop_path_prob=drop_path_prob, nettype=nettype)
|
||||
|
||||
self._initialize_flags()
|
||||
self._initialize_proj_weights()
|
||||
self._initialize_topology_dicts()
|
||||
|
||||
#### proj flags
|
||||
def _initialize_topology_dicts(self):
|
||||
self.nid2eids = {0:[2,3,4], 1:[5,6,7,8], 2:[9,10,11,12,13]}
|
||||
self.nid2selected_eids = {
|
||||
'normal': {0:[],1:[],2:[]},
|
||||
'reduce': {0:[],1:[],2:[]},
|
||||
}
|
||||
|
||||
def _initialize_flags(self):
|
||||
self.candidate_flags = {
|
||||
'normal':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
'reduce':torch.tensor(self.num_edges * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
} # must be in this order
|
||||
self.candidate_flags_edge = {
|
||||
'normal': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
'reduce': torch.tensor(3 * [True], requires_grad=False, dtype=torch.bool).cuda(),
|
||||
}
|
||||
|
||||
def _initialize_proj_weights(self):
|
||||
''' data structures used for proj '''
|
||||
if isinstance(self.alphas_normal, list):
|
||||
alphas_normal = torch.stack(self.alphas_normal, dim=0)
|
||||
alphas_reduce = torch.stack(self.alphas_reduce, dim=0)
|
||||
else:
|
||||
alphas_normal = self.alphas_normal
|
||||
alphas_reduce = self.alphas_reduce
|
||||
|
||||
self.proj_weights = { # for hard/soft assignment after project
|
||||
'normal': torch.zeros_like(alphas_normal),
|
||||
'reduce': torch.zeros_like(alphas_reduce),
|
||||
}
|
||||
|
||||
#### proj function
|
||||
def project_op(self, eid, opid, cell_type):
|
||||
self.proj_weights[cell_type][eid][opid] = 1 ## hard by default
|
||||
self.candidate_flags[cell_type][eid] = False
|
||||
|
||||
def project_edge(self, nid, eids, cell_type):
|
||||
for eid in self.nid2eids[nid]:
|
||||
if eid not in eids: # not top2
|
||||
self.proj_weights[cell_type][eid].data.fill_(0)
|
||||
self.nid2selected_eids[cell_type][nid] = deepcopy(eids)
|
||||
self.candidate_flags_edge[cell_type][nid] = False
|
||||
|
||||
#### critical function
|
||||
def get_projected_weights(self, cell_type):
|
||||
''' used in forward and genotype '''
|
||||
weights = self.get_softmax()[cell_type]
|
||||
|
||||
## proj op
|
||||
for eid in range(self.num_edges):
|
||||
if not self.candidate_flags[cell_type][eid]:
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
## proj edge
|
||||
for nid in self.nid2eids:
|
||||
if not self.candidate_flags_edge[cell_type][nid]: ## projected node
|
||||
for eid in self.nid2eids[nid]:
|
||||
if eid not in self.nid2selected_eids[cell_type][nid]:
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
return weights
|
||||
|
||||
def get_all_projected_weights(self, cell_type):
|
||||
weights = self.get_softmax()[cell_type]
|
||||
|
||||
for eid in range(self.num_edges):
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
for nid in self.nid2eids:
|
||||
for eid in self.nid2eids[nid]:
|
||||
weights[eid].data.copy_(self.proj_weights[cell_type][eid])
|
||||
|
||||
return weights
|
||||
|
||||
def forward(self, input, weights_dict=None, using_proj=False):
|
||||
if using_proj:
|
||||
weights_normal = self.get_all_projected_weights('normal')
|
||||
weights_reduce = self.get_all_projected_weights('reduce')
|
||||
else:
|
||||
if weights_dict is None or 'normal' not in weights_dict:
|
||||
weights_normal = self.get_projected_weights('normal')
|
||||
else:
|
||||
weights_normal = weights_dict['normal']
|
||||
if weights_dict is None or 'reduce' not in weights_dict:
|
||||
weights_reduce = self.get_projected_weights('reduce')
|
||||
else:
|
||||
weights_reduce = weights_dict['reduce']
|
||||
|
||||
|
||||
|
||||
s0 = self.stem0(input)
|
||||
s1 = self.stem1(s0)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
weights = weights_reduce
|
||||
else:
|
||||
weights = weights_normal
|
||||
|
||||
s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
|
||||
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0),-1))
|
||||
|
||||
return logits
|
||||
|
||||
def reset_arch_parameters(self):
|
||||
self._initialize_flags()
|
||||
self._initialize_proj_weights()
|
||||
self._initialize_topology_dicts()
|
||||
|
||||
#### utils
|
||||
def printing(self, logging, option='all'):
|
||||
weights_normal = self.get_projected_weights('normal')
|
||||
weights_reduce = self.get_projected_weights('reduce')
|
||||
|
||||
if option in ['all', 'normal']:
|
||||
logging.info('\n%s', weights_normal)
|
||||
if option in ['all', 'reduce']:
|
||||
logging.info('\n%s', weights_reduce)
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights, normal=True):
|
||||
PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct']
|
||||
|
||||
gene = []
|
||||
n = 2
|
||||
start = 0
|
||||
for i in range(self._steps):
|
||||
end = start + n
|
||||
W = weights[start:end].copy()
|
||||
|
||||
try:
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2]
|
||||
except ValueError:
|
||||
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]
|
||||
|
||||
for j in edges:
|
||||
k_best = None
|
||||
for k in range(len(W[j])):
|
||||
if 'none' in PRIMITIVES[j]:
|
||||
if k != PRIMITIVES[j].index('none'):
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
else:
|
||||
if k_best is None or W[j][k] > W[j][k_best]:
|
||||
k_best = k
|
||||
gene.append((PRIMITIVES[start+j][k_best], j))
|
||||
start = end
|
||||
n += 1
|
||||
return gene
|
||||
|
||||
weights_normal = self.get_projected_weights('normal')
|
||||
weights_reduce = self.get_projected_weights('reduce')
|
||||
gene_normal = _parse(weights_normal.data.cpu().numpy(), True)
|
||||
gene_reduce = _parse(weights_reduce.data.cpu().numpy(), False)
|
||||
|
||||
concat = range(2+self._steps-self._multiplier, self._steps+2)
|
||||
genotype = Genotype(
|
||||
normal=gene_normal, normal_concat=concat,
|
||||
reduce=gene_reduce, reduce_concat=concat
|
||||
)
|
||||
return genotype
|
||||
|
||||
def get_state_dict(self, epoch, architect, scheduler):
|
||||
model_state_dict = {
|
||||
'epoch': epoch, ## no +1 because we are saving before projection / at the beginning of an epoch
|
||||
'state_dict': self.state_dict(),
|
||||
'alpha': self.arch_parameters(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'arch_optimizer': architect.optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
#### projection
|
||||
'nid2eids': self.nid2eids,
|
||||
'nid2selected_eids': self.nid2selected_eids,
|
||||
'candidate_flags': self.candidate_flags,
|
||||
'candidate_flags_edge': self.candidate_flags_edge,
|
||||
'proj_weights': self.proj_weights,
|
||||
}
|
||||
return model_state_dict
|
||||
|
||||
def set_state_dict(self, architect, scheduler, checkpoint):
|
||||
#### common
|
||||
self.load_state_dict(checkpoint['state_dict'])
|
||||
self.set_arch_parameters(checkpoint['alpha'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
architect.optimizer.load_state_dict(checkpoint['arch_optimizer'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
|
||||
#### projection
|
||||
self.nid2eids = checkpoint['nid2eids']
|
||||
self.nid2selected_eids = checkpoint['nid2selected_eids']
|
||||
self.candidate_flags = checkpoint['candidate_flags']
|
||||
self.candidate_flags_edge = checkpoint['candidate_flags_edge']
|
||||
self.proj_weights = checkpoint['proj_weights']
|
236
sota/cnn/networks_proposal.py
Normal file
236
sota/cnn/networks_proposal.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
import time
|
||||
import glob
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import shutil
|
||||
import nasbench201.utils as ig_utils
|
||||
import logging
|
||||
import argparse
|
||||
import torch.nn as nn
|
||||
import torch.utils
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
import json
|
||||
import copy
|
||||
|
||||
from sota.cnn.model_search import Network as DartsNetwork
|
||||
from sota.cnn.model_search_darts_proj import DartsNetworkProj
|
||||
from sota.cnn.model_search_imagenet_proj import ImageNetNetworkProj
|
||||
# from optimizers.darts.architect import Architect as DartsArchitect
|
||||
from nasbench201.architect_ig import Architect
|
||||
from sota.cnn.spaces import spaces_dict
|
||||
from foresight.pruners import *
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from sota.cnn.init_projection import pt_project
|
||||
from hdf5 import H5Dataset
|
||||
|
||||
torch.set_printoptions(precision=4, sci_mode=False)
|
||||
|
||||
parser = argparse.ArgumentParser("sota")
|
||||
parser.add_argument('--data', type=str, default='../../data',help='location of the data corpus')
|
||||
parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset')
|
||||
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
|
||||
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
|
||||
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
|
||||
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
|
||||
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
|
||||
parser.add_argument('--seed', type=int, default=666, help='random seed')
|
||||
|
||||
#model opt related config
|
||||
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
|
||||
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||
parser.add_argument('--nesterov', action='store_true', default=True, help='using nestrov momentum for SGD')
|
||||
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
|
||||
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
|
||||
|
||||
#system config
|
||||
parser.add_argument('--gpu', type=str, default='0', help='gpu device id')
|
||||
parser.add_argument('--save', type=str, default='exp', help='experiment name')
|
||||
parser.add_argument('--save_path', type=str, default='../../experiments/sota', help='experiment name')
|
||||
#search sapce config
|
||||
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
|
||||
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
|
||||
parser.add_argument('--search_space', type=str, default='s5', help='searching space to choose from')
|
||||
parser.add_argument('--pool_size', type=int, default=10, help='number of model to proposed')
|
||||
|
||||
## projection
|
||||
parser.add_argument('--edge_decision', type=str, default='random', choices=['random','reverse', 'order', 'global_op_greedy', 'global_op_once', 'global_edge_greedy', 'global_edge_once', 'sample'], help='used for both proj_op and proj_edge')
|
||||
parser.add_argument('--proj_crit_normal', type=str, default='meco', choices=['loss', 'acc', 'jacob', 'comb', 'synflow', 'snip', 'fisher', 'var', 'cor', 'norm', 'grad_norm', 'grasp', 'jacob_cov', 'meco', 'zico'])
|
||||
parser.add_argument('--proj_crit_reduce', type=str, default='meco', choices=['loss', 'acc', 'jacob', 'comb', 'synflow', 'snip', 'fisher', 'var', 'cor', 'norm', 'grad_norm', 'grasp', 'jacob_cov', 'meco', 'zico'])
|
||||
parser.add_argument('--proj_crit_edge', type=str, default='meco', choices=['loss', 'acc', 'jacob', 'comb', 'synflow', 'snip', 'fisher', 'var', 'cor', 'norm', 'grad_norm', 'grasp', 'jacob_cov', 'meco', 'zico'])
|
||||
parser.add_argument('--proj_mode_edge', type=str, default='reg', choices=['reg'],
|
||||
help='edge projection evaluation mode, reg: one edge at a time')
|
||||
args = parser.parse_args()
|
||||
|
||||
#### args augment
|
||||
|
||||
expid = args.save
|
||||
args.save = '{}/{}-search-{}-{}-{}-{}-{}'.format(args.save_path,
|
||||
args.dataset, args.save, args.search_space, args.seed, args.pool_size, args.proj_crit_normal)
|
||||
|
||||
if not args.edge_decision == 'random':
|
||||
args.save += '-' + args.edge_decision
|
||||
|
||||
scripts_to_save = glob.glob('*.py') + glob.glob('../../nasbench201/architect*.py') + glob.glob('../../optimizers/darts/architect.py')
|
||||
if os.path.exists(args.save):
|
||||
if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y':
|
||||
print('proceed to override saving directory')
|
||||
shutil.rmtree(args.save)
|
||||
else:
|
||||
exit(0)
|
||||
ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save)
|
||||
|
||||
#### logging
|
||||
log_format = '%(asctime)s %(message)s'
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
||||
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||
log_file = 'log.txt'
|
||||
log_path = os.path.join(args.save, log_file)
|
||||
logging.info('======> log filename: %s', log_file)
|
||||
|
||||
if os.path.exists(log_path):
|
||||
if input("WARNING: {} exists, override?[y/n]".format(log_file)) == 'y':
|
||||
print('proceed to override log file directory')
|
||||
else:
|
||||
exit(0)
|
||||
|
||||
fh = logging.FileHandler(log_path, mode='w')
|
||||
fh.setFormatter(logging.Formatter(log_format))
|
||||
logging.getLogger().addHandler(fh)
|
||||
writer = SummaryWriter(args.save + '/runs')
|
||||
|
||||
if args.dataset == 'cifar100':
|
||||
n_classes = 100
|
||||
elif args.dataset == 'imagenet':
|
||||
n_classes = 1000
|
||||
else:
|
||||
n_classes = 10
|
||||
|
||||
def main():
|
||||
torch.set_num_threads(3)
|
||||
if not torch.cuda.is_available():
|
||||
logging.info('no gpu device available')
|
||||
sys.exit(1)
|
||||
|
||||
np.random.seed(args.seed)
|
||||
gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
|
||||
torch.cuda.set_device(gpu)
|
||||
cudnn.benchmark = True
|
||||
torch.manual_seed(args.seed)
|
||||
cudnn.enabled = True
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
logging.info('gpu device = %d' % gpu)
|
||||
logging.info("args = %s", args)
|
||||
|
||||
#### model
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
criterion = criterion.cuda()
|
||||
|
||||
## darts
|
||||
if args.dataset == 'imagenet':
|
||||
model = ImageNetNetworkProj(args.init_channels, n_classes, args.layers, criterion, spaces_dict[args.search_space], args)
|
||||
else:
|
||||
model = DartsNetworkProj(args.init_channels, n_classes, args.layers, criterion, spaces_dict[args.search_space], args)
|
||||
model = model.cuda()
|
||||
logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model))
|
||||
|
||||
#### data
|
||||
if args.dataset == 'imagenet':
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.2),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
#for test
|
||||
#from nasbench201.DownsampledImageNet import ImageNet16
|
||||
# train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
|
||||
# n_classes = 10
|
||||
train_data = H5Dataset(os.path.join(args.data, 'imagenet-train-256.h5'), transform=train_transform)
|
||||
#valid_data = H5Dataset(os.path.join(args.data, 'imagenet-val-256.h5'), transform=test_transform)
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
|
||||
|
||||
else:
|
||||
if args.dataset == 'cifar10':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)
|
||||
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
|
||||
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
|
||||
elif args.dataset == 'cifar100':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_cifar100(args)
|
||||
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
|
||||
valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
|
||||
elif args.dataset == 'svhn':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_svhn(args)
|
||||
train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform)
|
||||
valid_data = dset.SVHN(root=args.data, split='test', download=True, transform=valid_transform)
|
||||
|
||||
num_train = len(train_data)
|
||||
indices = list(range(num_train))
|
||||
split = int(np.floor(args.train_portion * num_train))
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=args.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
|
||||
pin_memory=True)
|
||||
|
||||
valid_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=args.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
|
||||
pin_memory=True)
|
||||
# for x, y in train_queue:
|
||||
# from torchvision import transforms
|
||||
# unloader = transforms.ToPILImage()
|
||||
# image = x.cpu().clone() # clone the tensor
|
||||
# image = image.squeeze(0) # remove the fake batch dimension
|
||||
# image = unloader(image)
|
||||
# image.save('example.jpg')
|
||||
|
||||
# print(x.size())
|
||||
# exit()
|
||||
|
||||
|
||||
#### projection
|
||||
networks_pool={}
|
||||
networks_pool['search_space'] = args.search_space
|
||||
networks_pool['dataset'] = args.dataset
|
||||
networks_pool['networks'] = []
|
||||
for i in range(args.pool_size):
|
||||
network_info={}
|
||||
logging.info('{} MODEL HAS SEARCHED'.format(i+1))
|
||||
pt_project(train_queue, model, args)
|
||||
|
||||
## logging
|
||||
num_params = ig_utils.count_parameters_in_Compact(model)
|
||||
genotype = model.genotype()
|
||||
json_data = {}
|
||||
json_data['normal'] = genotype.normal
|
||||
json_data['normal_concat'] = [x for x in genotype.normal_concat]
|
||||
json_data['reduce'] = genotype.reduce
|
||||
json_data['reduce_concat'] = [x for x in genotype.reduce_concat]
|
||||
json_string = json.dumps(json_data)
|
||||
logging.info(json_string)
|
||||
network_info['id'] = str(i)
|
||||
network_info['genotype'] = json_string
|
||||
networks_pool['networks'].append(network_info)
|
||||
model.reset_arch_parameters()
|
||||
|
||||
with open(os.path.join(args.save,'networks_pool.json'), 'w') as save_file:
|
||||
json.dump(networks_pool, save_file)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
181
sota/cnn/operations.py
Normal file
181
sota/cnn/operations.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
OPS = {
|
||||
'noise': lambda C, stride, affine: NoiseOp(stride, 0., 1.),
|
||||
'none': lambda C, stride, affine: Zero(stride),
|
||||
'avg_pool_3x3': lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
|
||||
'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
|
||||
'skip_connect': lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
|
||||
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
|
||||
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
|
||||
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
|
||||
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
|
||||
'conv_7x1_1x7': lambda C, stride, affine: nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, (1, 7), stride=(1, stride), padding=(0, 3), bias=False),
|
||||
nn.Conv2d(C, C, (7, 1), stride=(stride, 1), padding=(3, 0), bias=False),
|
||||
nn.BatchNorm2d(C, affine=affine)
|
||||
),
|
||||
'sep_conv_3x3_skip': lambda C, stride, affine: SepConvSkip(C, C, 3, stride, 1, affine=affine),
|
||||
'sep_conv_5x5_skip': lambda C, stride, affine: SepConvSkip(C, C, 5, stride, 2, affine=affine),
|
||||
'dil_conv_3x3_skip': lambda C, stride, affine: DilConvSkip(C, C, 3, stride, 2, 2, affine=affine),
|
||||
'dil_conv_5x5_skip': lambda C, stride, affine: DilConvSkip(C, C, 5, stride, 4, 2, affine=affine),
|
||||
}
|
||||
|
||||
|
||||
class NoiseOp(nn.Module):
|
||||
def __init__(self, stride, mean, std):
|
||||
super(NoiseOp, self).__init__()
|
||||
self.stride = stride
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
if self.stride != 1:
|
||||
x_new = x[:,:,::self.stride,::self.stride]
|
||||
else:
|
||||
x_new = x
|
||||
noise = Variable(x_new.data.new(x_new.size()).normal_(self.mean, self.std))
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
|
||||
super(DilConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
|
||||
groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
if self.stride == 1:
|
||||
return x.mul(0.)
|
||||
return x[:, :, ::self.stride, ::self.stride].mul(0.)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, affine=True):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
assert C_out % 2 == 0
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
x = self.relu(x)
|
||||
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
|
||||
#### operations with skip
|
||||
class DilConvSkip(nn.Module):
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
|
||||
super(DilConvSkip, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
|
||||
groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
return self.op(x) + x
|
||||
|
||||
|
||||
class SepConvSkip(nn.Module):
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(SepConvSkip, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x, block_input=False):
|
||||
if block_input:
|
||||
x = x*0
|
||||
return self.op(x) + x
|
248
sota/cnn/projection.py
Normal file
248
sota/cnn/projection.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
import numpy as np
|
||||
import torch
|
||||
import nasbench201.utils as ig_utils
|
||||
import logging
|
||||
import torch.utils
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
torch.set_printoptions(precision=4, sci_mode=False)
|
||||
|
||||
|
||||
def project_op(model, proj_queue, args, infer, cell_type, selected_eid=None):
|
||||
''' operation '''
|
||||
#### macros
|
||||
num_edges, num_ops = model.num_edges, model.num_ops
|
||||
candidate_flags = model.candidate_flags[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an edge
|
||||
if selected_eid is None:
|
||||
remain_eids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
if args.edge_decision == "random":
|
||||
selected_eid = np.random.choice(remain_eids, size=1)[0]
|
||||
logging.info('selected edge: %d %s', selected_eid, cell_type)
|
||||
|
||||
#### select the best operation
|
||||
if proj_crit == 'loss':
|
||||
crit_idx = 1
|
||||
compare = lambda x, y: x > y
|
||||
elif proj_crit == 'acc':
|
||||
crit_idx = 0
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
best_opid = 0
|
||||
crit_extrema = None
|
||||
for opid in range(num_ops):
|
||||
## projection
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
proj_mask = torch.ones_like(weights[selected_eid])
|
||||
proj_mask[opid] = 0
|
||||
weights[selected_eid] = weights[selected_eid] * proj_mask
|
||||
|
||||
## proj evaluation
|
||||
weights_dict = {cell_type:weights}
|
||||
valid_stats = infer(proj_queue, model, log=False, _eval=False, weights_dict=weights_dict)
|
||||
crit = valid_stats[crit_idx]
|
||||
|
||||
if crit_extrema is None or compare(crit, crit_extrema):
|
||||
crit_extrema = crit
|
||||
best_opid = opid
|
||||
logging.info('valid_acc %f', valid_stats[0])
|
||||
logging.info('valid_loss %f', valid_stats[1])
|
||||
|
||||
#### project
|
||||
logging.info('best opid: %d', best_opid)
|
||||
return selected_eid, best_opid
|
||||
|
||||
|
||||
def project_edge(model, proj_queue, args, infer, cell_type):
|
||||
''' topology '''
|
||||
#### macros
|
||||
candidate_flags = model.candidate_flags_edge[cell_type]
|
||||
proj_crit = args.proj_crit[cell_type]
|
||||
|
||||
#### select an edge
|
||||
remain_nids = torch.nonzero(candidate_flags).cpu().numpy().T[0]
|
||||
if args.edge_decision == "random":
|
||||
selected_nid = np.random.choice(remain_nids, size=1)[0]
|
||||
logging.info('selected node: %d %s', selected_nid, cell_type)
|
||||
|
||||
#### select top2 edges
|
||||
if proj_crit == 'loss':
|
||||
crit_idx = 1
|
||||
compare = lambda x, y: x > y
|
||||
elif proj_crit == 'acc':
|
||||
crit_idx = 0
|
||||
compare = lambda x, y: x < y
|
||||
|
||||
eids = deepcopy(model.nid2eids[selected_nid])
|
||||
while len(eids) > 2:
|
||||
eid_todel = None
|
||||
crit_extrema = None
|
||||
for eid in eids:
|
||||
weights = model.get_projected_weights(cell_type)
|
||||
weights[eid].data.fill_(0)
|
||||
weights_dict = {cell_type:weights}
|
||||
|
||||
## proj evaluation
|
||||
valid_stats = infer(proj_queue, model, log=False, _eval=False, weights_dict=weights_dict)
|
||||
crit = valid_stats[crit_idx]
|
||||
|
||||
if crit_extrema is None or not compare(crit, crit_extrema): # find out bad edges
|
||||
crit_extrema = crit
|
||||
eid_todel = eid
|
||||
logging.info('valid_acc %f', valid_stats[0])
|
||||
logging.info('valid_loss %f', valid_stats[1])
|
||||
eids.remove(eid_todel)
|
||||
|
||||
#### project
|
||||
logging.info('top2 edges: (%d, %d)', eids[0], eids[1])
|
||||
return selected_nid, eids
|
||||
|
||||
|
||||
def pt_project(train_queue, valid_queue, model, architect, optimizer,
|
||||
epoch, args, infer, perturb_alpha, epsilon_alpha):
|
||||
model.train()
|
||||
model.printing(logging)
|
||||
|
||||
train_acc, train_obj = infer(train_queue, model, log=False)
|
||||
logging.info('train_acc %f', train_acc)
|
||||
logging.info('train_loss %f', train_obj)
|
||||
|
||||
valid_acc, valid_obj = infer(valid_queue, model, log=False)
|
||||
logging.info('valid_acc %f', valid_acc)
|
||||
logging.info('valid_loss %f', valid_obj)
|
||||
|
||||
objs = ig_utils.AvgrageMeter()
|
||||
top1 = ig_utils.AvgrageMeter()
|
||||
top5 = ig_utils.AvgrageMeter()
|
||||
|
||||
|
||||
#### macros
|
||||
num_projs = model.num_edges + len(model.nid2eids.keys()) - 1 ## -1 because we project at both epoch 0 and -1
|
||||
tune_epochs = args.proj_intv * num_projs + 1
|
||||
proj_intv = args.proj_intv
|
||||
args.proj_crit = {'normal':args.proj_crit_normal, 'reduce':args.proj_crit_reduce}
|
||||
proj_queue = valid_queue
|
||||
|
||||
|
||||
#### reset optimizer
|
||||
model.reset_optimizer(args.learning_rate / 10, args.momentum, args.weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
model.optimizer, float(tune_epochs), eta_min=args.learning_rate_min)
|
||||
|
||||
|
||||
#### load proj checkpoints
|
||||
start_epoch = 0
|
||||
if args.dev_resume_epoch >= 0:
|
||||
filename = os.path.join(args.dev_resume_checkpoint_dir, 'checkpoint_{}.pth.tar'.format(args.dev_resume_epoch))
|
||||
if os.path.isfile(filename):
|
||||
logging.info("=> loading projection checkpoint '{}'".format(filename))
|
||||
checkpoint = torch.load(filename, map_location='cpu')
|
||||
start_epoch = checkpoint['epoch']
|
||||
model.set_state_dict(architect, scheduler, checkpoint)
|
||||
model.set_arch_parameters(checkpoint['alpha'])
|
||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer']) # optimizer
|
||||
else:
|
||||
logging.info("=> no checkpoint found at '{}'".format(filename))
|
||||
exit(0)
|
||||
|
||||
|
||||
#### projecting and tuning
|
||||
for epoch in range(start_epoch, tune_epochs):
|
||||
logging.info('epoch %d', epoch)
|
||||
|
||||
## project
|
||||
if epoch % proj_intv == 0 or epoch == tune_epochs - 1:
|
||||
## saving every projection
|
||||
save_state_dict = model.get_state_dict(epoch, architect, scheduler)
|
||||
ig_utils.save_checkpoint(save_state_dict, False, args.dev_save_checkpoint_dir, per_epoch=True)
|
||||
|
||||
if epoch < proj_intv * model.num_edges:
|
||||
logging.info('project op')
|
||||
|
||||
selected_eid_normal, best_opid_normal = project_op(model, proj_queue, args, infer, cell_type='normal')
|
||||
model.project_op(selected_eid_normal, best_opid_normal, cell_type='normal')
|
||||
selected_eid_reduce, best_opid_reduce = project_op(model, proj_queue, args, infer, cell_type='reduce')
|
||||
model.project_op(selected_eid_reduce, best_opid_reduce, cell_type='reduce')
|
||||
|
||||
model.printing(logging)
|
||||
else:
|
||||
logging.info('project edge')
|
||||
|
||||
selected_nid_normal, eids_normal = project_edge(model, proj_queue, args, infer, cell_type='normal')
|
||||
model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
|
||||
selected_nid_reduce, eids_reduce = project_edge(model, proj_queue, args, infer, cell_type='reduce')
|
||||
model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')
|
||||
|
||||
model.printing(logging)
|
||||
|
||||
## tune
|
||||
for step, (input, target) in enumerate(train_queue):
|
||||
model.train()
|
||||
n = input.size(0)
|
||||
|
||||
## fetch data
|
||||
input = input.cuda()
|
||||
target = target.cuda(non_blocking=True)
|
||||
input_search, target_search = next(iter(valid_queue))
|
||||
input_search = input_search.cuda()
|
||||
target_search = target_search.cuda(non_blocking=True)
|
||||
|
||||
## train alpha
|
||||
optimizer.zero_grad(); architect.optimizer.zero_grad()
|
||||
architect.step(input, target, input_search, target_search,
|
||||
return_logits=True)
|
||||
|
||||
## sdarts
|
||||
if perturb_alpha:
|
||||
# transform arch_parameters to prob (for perturbation)
|
||||
model.softmax_arch_parameters()
|
||||
optimizer.zero_grad(); architect.optimizer.zero_grad()
|
||||
perturb_alpha(model, input, target, epsilon_alpha)
|
||||
|
||||
## train weight
|
||||
optimizer.zero_grad(); architect.optimizer.zero_grad()
|
||||
logits, loss = model.step(input, target, args)
|
||||
|
||||
## sdarts
|
||||
if perturb_alpha:
|
||||
## restore alpha to unperturbed arch_parameters
|
||||
model.restore_arch_parameters()
|
||||
|
||||
## logging
|
||||
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
|
||||
objs.update(loss.data, n)
|
||||
top1.update(prec1.data, n)
|
||||
top5.update(prec5.data, n)
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
|
||||
|
||||
if args.fast:
|
||||
break
|
||||
|
||||
## one epoch end
|
||||
model.printing(logging)
|
||||
|
||||
train_acc, train_obj = infer(train_queue, model, log=False)
|
||||
logging.info('train_acc %f', train_acc)
|
||||
logging.info('train_loss %f', train_obj)
|
||||
|
||||
valid_acc, valid_obj = infer(valid_queue, model, log=False)
|
||||
logging.info('valid_acc %f', valid_acc)
|
||||
logging.info('valid_loss %f', valid_obj)
|
||||
|
||||
|
||||
logging.info('projection finished')
|
||||
model.printing(logging)
|
||||
num_params = ig_utils.count_parameters_in_Compact(model)
|
||||
genotype = model.genotype()
|
||||
logging.info('param size = %f', num_params)
|
||||
logging.info('genotype = %s', genotype)
|
||||
|
||||
return
|
103
sota/cnn/spaces.py
Normal file
103
sota/cnn/spaces.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
|
||||
|
||||
primitives_1 = OrderedDict([('primitives_normal', [['skip_connect',
|
||||
'dil_conv_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5'],
|
||||
['skip_connect',
|
||||
'sep_conv_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_3x3'],
|
||||
['max_pool_3x3',
|
||||
'skip_connect'],
|
||||
['skip_connect',
|
||||
'sep_conv_3x3'],
|
||||
['skip_connect',
|
||||
'sep_conv_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_3x3'],
|
||||
['skip_connect',
|
||||
'sep_conv_3x3'],
|
||||
['max_pool_3x3',
|
||||
'skip_connect'],
|
||||
['skip_connect',
|
||||
'dil_conv_3x3'],
|
||||
['dil_conv_3x3',
|
||||
'dil_conv_5x5'],
|
||||
['dil_conv_3x3',
|
||||
'dil_conv_5x5']]),
|
||||
('primitives_reduct', [['max_pool_3x3',
|
||||
'avg_pool_3x3'],
|
||||
['max_pool_3x3',
|
||||
'dil_conv_3x3'],
|
||||
['max_pool_3x3',
|
||||
'avg_pool_3x3'],
|
||||
['max_pool_3x3',
|
||||
'avg_pool_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5'],
|
||||
['max_pool_3x3',
|
||||
'avg_pool_3x3'],
|
||||
['max_pool_3x3',
|
||||
'sep_conv_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5'],
|
||||
['max_pool_3x3',
|
||||
'avg_pool_3x3'],
|
||||
['max_pool_3x3',
|
||||
'avg_pool_3x3'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5'],
|
||||
['skip_connect',
|
||||
'dil_conv_5x5']])])
|
||||
|
||||
primitives_2 = OrderedDict([('primitives_normal', 14 * [['skip_connect',
|
||||
'sep_conv_3x3']]),
|
||||
('primitives_reduct', 14 * [['skip_connect',
|
||||
'sep_conv_3x3']])])
|
||||
|
||||
primitives_3 = OrderedDict([('primitives_normal', 14 * [['none',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3']]),
|
||||
('primitives_reduct', 14 * [['none',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3']])])
|
||||
|
||||
primitives_4 = OrderedDict([('primitives_normal', 14 * [['noise',
|
||||
'sep_conv_3x3']]),
|
||||
('primitives_reduct', 14 * [['noise',
|
||||
'sep_conv_3x3']])])
|
||||
|
||||
PRIMITIVES = [
|
||||
#'none', #0
|
||||
'max_pool_3x3', # 0
|
||||
'avg_pool_3x3', # 1
|
||||
'skip_connect', # 2
|
||||
'sep_conv_3x3', # 3
|
||||
'sep_conv_5x5', # 4
|
||||
'dil_conv_3x3', # 5
|
||||
'dil_conv_5x5' # 6
|
||||
]
|
||||
|
||||
primitives_5 = OrderedDict([('primitives_normal', 14 * [PRIMITIVES]),
|
||||
('primitives_reduct', 14 * [PRIMITIVES])])
|
||||
|
||||
primitives_6 = OrderedDict([('primitives_normal', 14 * [['sep_conv_5x5']]),
|
||||
('primitives_reduct', 14 * [['sep_conv_5x5']])])
|
||||
spaces_dict = {
|
||||
's1': primitives_1,
|
||||
's2': primitives_2,
|
||||
's3': primitives_3,
|
||||
's4': primitives_4,
|
||||
's5': primitives_5, # DARTS Space
|
||||
's6': primitives_6,
|
||||
}
|
309
sota/cnn/train.py
Normal file
309
sota/cnn/train.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, '../../')
|
||||
import glob
|
||||
import numpy as np
|
||||
import torch
|
||||
import nasbench201.utils as ig_utils
|
||||
import logging
|
||||
import argparse
|
||||
import shutil
|
||||
import torch.nn as nn
|
||||
import torch.utils
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
import json
|
||||
from sota.cnn.model import Network
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from collections import namedtuple
|
||||
|
||||
parser = argparse.ArgumentParser("cifar")
|
||||
parser.add_argument('--data', type=str, default='../../data',
|
||||
help='location of the data corpus')
|
||||
parser.add_argument('--dataset', type=str, default='cifar10', help='choose dataset')
|
||||
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
|
||||
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
|
||||
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
|
||||
parser.add_argument('--gpu', type=str, default='auto', help='gpu device id')
|
||||
parser.add_argument('--epochs', type=int, default=600, help='num of training epochs')
|
||||
parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
|
||||
parser.add_argument('--layers', type=int, default=20, help='total number of layers')
|
||||
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
|
||||
parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower')
|
||||
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
|
||||
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
|
||||
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
|
||||
parser.add_argument('--cutout_prob', type=float, default=1.0, help='cutout probability')
|
||||
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
|
||||
parser.add_argument('--save', type=str, default='exp', help='experiment name')
|
||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||
parser.add_argument('--arch', type=str, default='c100_s4_pgd', help='which architecture to use')
|
||||
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
|
||||
#### common
|
||||
parser.add_argument('--resume_epoch', type=int, default=0, help="load ckpt, start training at resume_epoch")
|
||||
parser.add_argument('--ckpt_interval', type=int, default=50, help="interval (epoch) for saving checkpoints")
|
||||
parser.add_argument('--resume_expid', type=str, default='', help="full expid to resume from, name == ckpt folder name")
|
||||
parser.add_argument('--fast', action='store_true', default=False, help="fast mode for debugging")
|
||||
parser.add_argument('--queue', action='store_true', default=False, help="queueing for gpu")
|
||||
|
||||
parser.add_argument('--from_dir', action='store_true', default=True, help="arch load form dir")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def load_network_pool(ckpt_path):
|
||||
with open(os.path.join(ckpt_path, 'best_networks.json'), 'r') as save_file:
|
||||
networks_pool = json.load(save_file)
|
||||
return networks_pool['networks']
|
||||
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
#### args augment
|
||||
expid = args.save
|
||||
|
||||
print(args.from_dir)
|
||||
if args.from_dir:
|
||||
id_name = os.path.split(args.arch)[1]
|
||||
# print('aaaaaaa', args.arch)
|
||||
args.arch = load_network_pool(args.arch)
|
||||
args.save = '../../experiments/sota/{}/eval/{}-{}-{}'.format(
|
||||
args.dataset, args.save, id_name, args.seed)
|
||||
else:
|
||||
args.save = '../../experiments/sota/{}/eval/{}-{}-{}'.format(
|
||||
args.dataset, args.save, args.arch, args.seed)
|
||||
if args.cutout:
|
||||
args.save += '-cutout-' + str(args.cutout_length) + '-' + str(args.cutout_prob)
|
||||
if args.auxiliary:
|
||||
args.save += '-auxiliary-' + str(args.auxiliary_weight)
|
||||
|
||||
#### logging
|
||||
if args.resume_epoch > 0: # do not delete dir if resume:
|
||||
args.save = '../../experiments/sota/{}/{}'.format(args.dataset, args.resume_expid)
|
||||
assert (os.path.exists(args.save), 'resume but {} does not exist!'.format(args.save))
|
||||
else:
|
||||
scripts_to_save = glob.glob('*.py')
|
||||
if os.path.exists(args.save):
|
||||
if input("WARNING: {} exists, override?[y/n]".format(args.save)) == 'y':
|
||||
print('proceed to override saving directory')
|
||||
shutil.rmtree(args.save)
|
||||
else:
|
||||
exit(0)
|
||||
ig_utils.create_exp_dir(args.save, scripts_to_save=scripts_to_save)
|
||||
|
||||
log_format = '%(asctime)s %(message)s'
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
||||
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||
log_file = 'log_resume_{}.txt'.format(args.resume_epoch) if args.resume_epoch > 0 else 'log.txt'
|
||||
fh = logging.FileHandler(os.path.join(args.save, log_file), mode='w')
|
||||
fh.setFormatter(logging.Formatter(log_format))
|
||||
logging.getLogger().addHandler(fh)
|
||||
writer = SummaryWriter(args.save + '/runs')
|
||||
|
||||
if args.dataset == 'cifar100':
|
||||
n_classes = 100
|
||||
else:
|
||||
n_classes = 10
|
||||
|
||||
|
||||
def seed_torch(seed=0):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
cudnn.deterministic = True
|
||||
cudnn.benchmark = False
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_num_threads(3)
|
||||
if not torch.cuda.is_available():
|
||||
logging.info('no gpu device available')
|
||||
sys.exit(1)
|
||||
|
||||
#### gpu queueing
|
||||
if args.queue:
|
||||
ig_utils.queue_gpu()
|
||||
|
||||
gpu = ig_utils.pick_gpu_lowest_memory() if args.gpu == 'auto' else int(args.gpu)
|
||||
torch.cuda.set_device(gpu)
|
||||
cudnn.enabled = True
|
||||
seed_torch(args.seed)
|
||||
|
||||
logging.info('gpu device = %d' % gpu)
|
||||
logging.info("args = %s", args)
|
||||
|
||||
if args.from_dir:
|
||||
genotype_config = json.loads(args.arch)
|
||||
genotype = Genotype(normal=genotype_config['normal'], normal_concat=genotype_config['normal_concat'],
|
||||
reduce=genotype_config['reduce'], reduce_concat=genotype_config['reduce_concat'])
|
||||
else:
|
||||
genotype = eval("genotypes.%s" % args.arch)
|
||||
|
||||
model = Network(args.init_channels, n_classes, args.layers, args.auxiliary, genotype)
|
||||
model = model.cuda()
|
||||
|
||||
logging.info("param size = %fMB", ig_utils.count_parameters_in_MB(model))
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
criterion = criterion.cuda()
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
args.learning_rate,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
if args.dataset == 'cifar10':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_cifar10(args)
|
||||
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
|
||||
valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
|
||||
elif args.dataset == 'cifar100':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_cifar100(args)
|
||||
train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
|
||||
valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
|
||||
elif args.dataset == 'svhn':
|
||||
train_transform, valid_transform = ig_utils._data_transforms_svhn(args)
|
||||
train_data = dset.SVHN(root=args.data, split='train', download=True, transform=train_transform)
|
||||
valid_data = dset.SVHN(root=args.data, split='test', download=True, transform=valid_transform)
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=0)
|
||||
|
||||
valid_queue = torch.utils.data.DataLoader(
|
||||
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0)
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, float(args.epochs),
|
||||
# eta_min=1e-4
|
||||
)
|
||||
|
||||
#### resume
|
||||
start_epoch = 0
|
||||
if args.resume_epoch > 0:
|
||||
logging.info('loading checkpoint from {}'.format(expid))
|
||||
filename = os.path.join(args.save, 'checkpoint_{}.pth.tar'.format(args.resume_epoch))
|
||||
|
||||
if os.path.isfile(filename):
|
||||
print("=> loading checkpoint '{}'".format(filename))
|
||||
checkpoint = torch.load(filename, map_location='cpu')
|
||||
resume_epoch = checkpoint['epoch'] # epoch
|
||||
model.load_state_dict(checkpoint['state_dict']) # model
|
||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer']) # optimizer
|
||||
start_epoch = args.resume_epoch
|
||||
print("=> loaded checkpoint '{}' (epoch {})".format(filename, resume_epoch))
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(filename))
|
||||
|
||||
#### main training
|
||||
best_valid_acc = 0
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
lr = scheduler.get_lr()[0]
|
||||
if args.cutout:
|
||||
train_transform.transforms[-1].cutout_prob = args.cutout_prob
|
||||
logging.info('epoch %d lr %e cutout_prob %e', epoch, lr,
|
||||
train_transform.transforms[-1].cutout_prob)
|
||||
else:
|
||||
logging.info('epoch %d lr %e', epoch, lr)
|
||||
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
|
||||
|
||||
train_acc, train_obj = train(train_queue, model, criterion, optimizer)
|
||||
logging.info('train_acc %f', train_acc)
|
||||
writer.add_scalar('Acc/train', train_acc, epoch)
|
||||
writer.add_scalar('Obj/train', train_obj, epoch)
|
||||
|
||||
## scheduler
|
||||
scheduler.step()
|
||||
|
||||
valid_acc, valid_obj = infer(valid_queue, model, criterion)
|
||||
logging.info('valid_acc %f', valid_acc)
|
||||
writer.add_scalar('Acc/valid', valid_acc, epoch)
|
||||
writer.add_scalar('Obj/valid', valid_obj, epoch)
|
||||
|
||||
## checkpoint
|
||||
if (epoch + 1) % args.ckpt_interval == 0:
|
||||
save_state_dict = {
|
||||
'epoch': epoch + 1,
|
||||
'state_dict': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
}
|
||||
ig_utils.save_checkpoint(save_state_dict, False, args.save, per_epoch=True)
|
||||
|
||||
best_valid_acc = max(best_valid_acc, valid_acc)
|
||||
logging.info('best valid_acc %f', best_valid_acc)
|
||||
writer.close()
|
||||
|
||||
|
||||
def train(train_queue, model, criterion, optimizer):
|
||||
objs = ig_utils.AvgrageMeter()
|
||||
top1 = ig_utils.AvgrageMeter()
|
||||
top5 = ig_utils.AvgrageMeter()
|
||||
model.train()
|
||||
|
||||
for step, (input, target) in enumerate(train_queue):
|
||||
input = input.cuda()
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
optimizer.zero_grad()
|
||||
logits, logits_aux = model(input)
|
||||
loss = criterion(logits, target)
|
||||
if args.auxiliary:
|
||||
loss_aux = criterion(logits_aux, target)
|
||||
loss += args.auxiliary_weight * loss_aux
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
|
||||
n = input.size(0)
|
||||
objs.update(loss.data, n)
|
||||
top1.update(prec1.data, n)
|
||||
top5.update(prec5.data, n)
|
||||
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
|
||||
|
||||
if args.fast:
|
||||
logging.info('//// WARNING: FAST MODE')
|
||||
break
|
||||
|
||||
return top1.avg, objs.avg
|
||||
|
||||
|
||||
def infer(valid_queue, model, criterion):
|
||||
objs = ig_utils.AvgrageMeter()
|
||||
top1 = ig_utils.AvgrageMeter()
|
||||
top5 = ig_utils.AvgrageMeter()
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for step, (input, target) in enumerate(valid_queue):
|
||||
input = input.cuda()
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
logits, _ = model(input)
|
||||
loss = criterion(logits, target)
|
||||
|
||||
prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
|
||||
n = input.size(0)
|
||||
objs.update(loss.data, n)
|
||||
top1.update(prec1.data, n)
|
||||
top5.update(prec5.data, n)
|
||||
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
|
||||
|
||||
if args.fast:
|
||||
logging.info('//// WARNING: FAST MODE')
|
||||
break
|
||||
|
||||
return top1.avg, objs.avg
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
254
sota/cnn/train_imagenet.py
Normal file
254
sota/cnn/train_imagenet.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import sys
|
||||
sys.path.insert(0, '../../')
|
||||
import time
|
||||
import random
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
import torch.utils
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
from torch.autograd import Variable
|
||||
|
||||
import nasbench201.utils as utils
|
||||
from sota.cnn.model_imagenet import NetworkImageNet as Network
|
||||
import sota.cnn.genotypes as genotypes
|
||||
from sota.cnn.hdf5 import H5Dataset
|
||||
|
||||
parser = argparse.ArgumentParser("imagenet")
|
||||
parser.add_argument('--data', type=str, default='../../data', help='location of the data corpus')
|
||||
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
|
||||
parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||
parser.add_argument('--weight_decay', type=float, default=3e-5, help='weight decay')
|
||||
parser.add_argument('--report_freq', type=float, default=100, help='report frequency')
|
||||
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
|
||||
parser.add_argument('--epochs', type=int, default=250, help='num of training epochs')
|
||||
parser.add_argument('--init_channels', type=int, default=48, help='num of init channels')
|
||||
parser.add_argument('--layers', type=int, default=14, help='total number of layers')
|
||||
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
|
||||
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
|
||||
parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability')
|
||||
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
|
||||
parser.add_argument('--seed', type=int, default=0, help='random_ws seed')
|
||||
parser.add_argument('--arch', type=str, default='c10_s3_pgd', help='which architecture to use')
|
||||
parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping')
|
||||
parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
|
||||
parser.add_argument('--gamma', type=float, default=0.97, help='learning rate decay')
|
||||
parser.add_argument('--decay_period', type=int, default=1, help='epochs between two learning rate decays')
|
||||
parser.add_argument('--parallel', action='store_true', default=False, help='darts parallelism')
|
||||
parser.add_argument('--load', action='store_true', default=False, help='whether load checkpoint for continue training')
|
||||
args = parser.parse_args()
|
||||
|
||||
args.save = '../../experiments/sota/imagenet/eval/{}-{}-{}-{}'.format(
|
||||
args.save, time.strftime("%Y%m%d-%H%M%S"), args.arch, args.seed)
|
||||
if args.auxiliary:
|
||||
args.save += '-auxiliary-' + str(args.auxiliary_weight)
|
||||
args.save += '-' + str(np.random.randint(10000))
|
||||
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
|
||||
|
||||
log_format = '%(asctime)s %(message)s'
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
||||
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
|
||||
fh.setFormatter(logging.Formatter(log_format))
|
||||
logging.getLogger().addHandler(fh)
|
||||
writer = SummaryWriter(args.save + '/runs')
|
||||
|
||||
|
||||
CLASSES = 1000
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, epsilon):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
def seed_torch(seed=0):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
cudnn.deterministic = True
|
||||
cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
if not torch.cuda.is_available():
|
||||
logging.info('no gpu device available')
|
||||
sys.exit(1)
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
cudnn.enabled = True
|
||||
seed_torch(args.seed)
|
||||
|
||||
logging.info('gpu device = %d' % args.gpu)
|
||||
logging.info("args = %s", args)
|
||||
|
||||
genotype = eval("genotypes.%s" % args.arch)
|
||||
model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
|
||||
|
||||
if args.parallel:
|
||||
model = nn.DataParallel(model).cuda()
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
|
||||
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
criterion = criterion.cuda()
|
||||
criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
|
||||
criterion_smooth = criterion_smooth.cuda()
|
||||
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
args.learning_rate,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_transform = transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.2),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
test_transform = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
train_data = H5Dataset(os.path.join(args.data, 'imagenet-train-256.h5'), transform=train_transform)
|
||||
valid_data = H5Dataset(os.path.join(args.data, 'imagenet-val-256.h5'), transform=test_transform)
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
|
||||
|
||||
valid_queue = torch.utils.data.DataLoader(
|
||||
valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)
|
||||
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma)
|
||||
|
||||
if args.load:
|
||||
model, optimizer, start_epoch, best_acc_top1 = utils.load_checkpoint(
|
||||
model, optimizer, '../../experiments/sota/imagenet/eval/EXP-20200210-143540-c10_s3_pgd-0-auxiliary-0.4-2753')
|
||||
else:
|
||||
best_acc_top1 = 0
|
||||
start_epoch = 0
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
|
||||
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
|
||||
|
||||
train_acc, train_obj = train(train_queue, model, criterion_smooth, optimizer)
|
||||
logging.info('train_acc %f', train_acc)
|
||||
writer.add_scalar('Acc/train', train_acc, epoch)
|
||||
writer.add_scalar('Obj/train', train_obj, epoch)
|
||||
scheduler.step()
|
||||
|
||||
valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion)
|
||||
logging.info('valid_acc_top1 %f', valid_acc_top1)
|
||||
logging.info('valid_acc_top5 %f', valid_acc_top5)
|
||||
writer.add_scalar('Acc/valid_top1', valid_acc_top1, epoch)
|
||||
writer.add_scalar('Acc/valid_top5', valid_acc_top5, epoch)
|
||||
|
||||
is_best = False
|
||||
if valid_acc_top1 > best_acc_top1:
|
||||
best_acc_top1 = valid_acc_top1
|
||||
is_best = True
|
||||
|
||||
utils.save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'state_dict': model.state_dict(),
|
||||
'best_acc_top1': best_acc_top1,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
}, is_best, args.save)
|
||||
|
||||
|
||||
def train(train_queue, model, criterion, optimizer):
|
||||
objs = utils.AvgrageMeter()
|
||||
top1 = utils.AvgrageMeter()
|
||||
top5 = utils.AvgrageMeter()
|
||||
model.train()
|
||||
|
||||
for step, (input, target) in enumerate(train_queue):
|
||||
input = input.cuda()
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
optimizer.zero_grad()
|
||||
logits, logits_aux = model(input)
|
||||
loss = criterion(logits, target)
|
||||
if args.auxiliary:
|
||||
loss_aux = criterion(logits_aux, target)
|
||||
loss += args.auxiliary_weight * loss_aux
|
||||
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
|
||||
n = input.size(0)
|
||||
objs.update(loss.data, n)
|
||||
top1.update(prec1.data, n)
|
||||
top5.update(prec5.data, n)
|
||||
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
|
||||
|
||||
return top1.avg, objs.avg
|
||||
|
||||
|
||||
def infer(valid_queue, model, criterion):
|
||||
objs = utils.AvgrageMeter()
|
||||
top1 = utils.AvgrageMeter()
|
||||
top5 = utils.AvgrageMeter()
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for step, (input, target) in enumerate(valid_queue):
|
||||
input = input.cuda()
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
logits, _ = model(input)
|
||||
loss = criterion(logits, target)
|
||||
|
||||
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
|
||||
n = input.size(0)
|
||||
objs.update(loss.data, n)
|
||||
top1.update(prec1.data, n)
|
||||
top5.update(prec5.data, n)
|
||||
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
67
sota/cnn/visualize.py
Normal file
67
sota/cnn/visualize.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import sys
|
||||
import genotypes
|
||||
from graphviz import Digraph
|
||||
|
||||
|
||||
def plot(genotype, filename, mode=''):
|
||||
g = Digraph(
|
||||
format='pdf',
|
||||
edge_attr=dict(fontsize='40', fontname="times"),
|
||||
node_attr=dict(style='filled', shape='rect', align='center', fontsize='40', height='0.5', width='0.5',
|
||||
penwidth='2', fontname="times"),
|
||||
engine='dot')
|
||||
|
||||
g.body.extend(['rankdir=LR'])
|
||||
|
||||
# g.body.extend(['ratio=0.15'])
|
||||
# g.view()
|
||||
|
||||
g.node("c_{k-2}", fillcolor='darkseagreen2')
|
||||
g.node("c_{k-1}", fillcolor='darkseagreen2')
|
||||
assert len(genotype) % 2 == 0
|
||||
steps = len(genotype) // 2
|
||||
|
||||
for i in range(steps):
|
||||
g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
for i in range(steps):
|
||||
for k in [2 * i, 2 * i + 1]:
|
||||
op, j = genotype[k]
|
||||
if j == 0:
|
||||
u = "c_{k-2}"
|
||||
elif j == 1:
|
||||
u = "c_{k-1}"
|
||||
else:
|
||||
u = str(j - 2)
|
||||
v = str(i)
|
||||
|
||||
if mode == 'cue' and op != 'skip_connect' and op != 'noise':
|
||||
g.edge(u, v, label=op, fillcolor='gray', color='red', fontcolor='red')
|
||||
else:
|
||||
g.edge(u, v, label=op, fillcolor="gray")
|
||||
|
||||
g.node("c_{k}", fillcolor='palegoldenrod')
|
||||
for i in range(steps):
|
||||
g.edge(str(i), "c_{k}", fillcolor="gray")
|
||||
|
||||
g.render(filename, view=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) != 2:
|
||||
print("usage:\n python {} ARCH_NAME".format(sys.argv[0]))
|
||||
sys.exit(1)
|
||||
|
||||
genotype_name = sys.argv[1]
|
||||
try:
|
||||
genotype = eval('genotypes.{}'.format(genotype_name))
|
||||
# print(genotype)
|
||||
except AttributeError:
|
||||
print("{} is not specified in genotypes.py".format(genotype_name))
|
||||
sys.exit(1)
|
||||
|
||||
mode = 'cue'
|
||||
path = '../../figs/genotypes/cnn_{}/'.format(mode)
|
||||
# print(genotype.normal)
|
||||
plot(genotype.normal, path + genotype_name + "_normal", mode=mode)
|
||||
plot(genotype.reduce, path + genotype_name + "_reduce", mode=mode)
|
144
sota/cnn/visualize_full.py
Normal file
144
sota/cnn/visualize_full.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import sys
|
||||
import genotypes
|
||||
import numpy as np
|
||||
from graphviz import Digraph
|
||||
|
||||
|
||||
supernet_dict = {
|
||||
0: ('c_{k-2}', '0'),
|
||||
1: ('c_{k-1}', '0'),
|
||||
2: ('c_{k-2}', '1'),
|
||||
3: ('c_{k-1}', '1'),
|
||||
4: ('0', '1'),
|
||||
5: ('c_{k-2}', '2'),
|
||||
6: ('c_{k-1}', '2'),
|
||||
7: ('0', '2'),
|
||||
8: ('1', '2'),
|
||||
9: ('c_{k-2}', '3'),
|
||||
10: ('c_{k-1}', '3'),
|
||||
11: ('0', '3'),
|
||||
12: ('1', '3'),
|
||||
13: ('2', '3'),
|
||||
}
|
||||
steps = 4
|
||||
|
||||
def plot_space(primitives, filename):
|
||||
g = Digraph(
|
||||
format='pdf',
|
||||
edge_attr=dict(fontsize='20', fontname="times"),
|
||||
node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
|
||||
engine='dot')
|
||||
g.body.extend(['rankdir=LR'])
|
||||
g.body.extend(['ratio=50.0'])
|
||||
|
||||
g.node("c_{k-2}", fillcolor='darkseagreen2')
|
||||
g.node("c_{k-1}", fillcolor='darkseagreen2')
|
||||
|
||||
steps = 4
|
||||
|
||||
for i in range(steps):
|
||||
g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
n = 2
|
||||
start = 0
|
||||
nodes_indx = ["c_{k-2}", "c_{k-1}"]
|
||||
for i in range(steps):
|
||||
end = start + n
|
||||
p = primitives[start:end]
|
||||
v = str(i)
|
||||
for node, prim in zip(nodes_indx, p):
|
||||
u = node
|
||||
for op in prim:
|
||||
g.edge(u, v, label=op, fillcolor="gray")
|
||||
|
||||
start = end
|
||||
n += 1
|
||||
nodes_indx.append(v)
|
||||
|
||||
g.node("c_{k}", fillcolor='palegoldenrod')
|
||||
for i in range(steps):
|
||||
g.edge(str(i), "c_{k}", fillcolor="gray")
|
||||
|
||||
g.render(filename, view=False)
|
||||
|
||||
|
||||
def plot(genotype, filename):
|
||||
g = Digraph(
|
||||
format='pdf',
|
||||
edge_attr=dict(fontsize='100', fontname="times"),
|
||||
node_attr=dict(style='filled', shape='rect', align='center', fontsize='100', height='0.5', width='0.5', penwidth='2', fontname="times"),
|
||||
engine='dot')
|
||||
g.body.extend(['rankdir=LR'])
|
||||
g.body.extend(['ratio=0.3'])
|
||||
|
||||
g.node("c_{k-2}", fillcolor='darkseagreen2')
|
||||
g.node("c_{k-1}", fillcolor='darkseagreen2')
|
||||
num_edges = len(genotype)
|
||||
|
||||
for i in range(steps):
|
||||
g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
for eid in range(num_edges):
|
||||
op = genotype[eid]
|
||||
u, v = supernet_dict[eid]
|
||||
if op != 'skip_connect':
|
||||
g.edge(u, v, label=op, fillcolor="gray", color='red', fontcolor='red')
|
||||
else:
|
||||
g.edge(u, v, label=op, fillcolor="gray")
|
||||
|
||||
g.node("c_{k}", fillcolor='palegoldenrod')
|
||||
for i in range(steps):
|
||||
g.edge(str(i), "c_{k}", fillcolor="gray")
|
||||
|
||||
g.render(filename, view=False)
|
||||
|
||||
|
||||
|
||||
# def plot(genotype, filename):
|
||||
# g = Digraph(
|
||||
# format='pdf',
|
||||
# edge_attr=dict(fontsize='100', fontname="times", penwidth='3'),
|
||||
# node_attr=dict(style='filled', shape='rect', align='center', fontsize='100', height='0.5', width='0.5',
|
||||
# penwidth='2', fontname="times"),
|
||||
# engine='dot')
|
||||
# g.body.extend(['rankdir=LR'])
|
||||
|
||||
# g.node("c_{k-2}", fillcolor='darkseagreen2')
|
||||
# g.node("c_{k-1}", fillcolor='darkseagreen2')
|
||||
# num_edges = len(genotype)
|
||||
|
||||
# for i in range(steps):
|
||||
# g.node(str(i), fillcolor='lightblue')
|
||||
|
||||
# for eid in range(num_edges):
|
||||
# op = genotype[eid]
|
||||
# u, v = supernet_dict[eid]
|
||||
# if op != 'skip_connect':
|
||||
# g.edge(u, v, label=op, fillcolor="gray", color='red', fontcolor='red')
|
||||
# else:
|
||||
# g.edge(u, v, label=op, fillcolor="gray")
|
||||
|
||||
# g.node("c_{k}", fillcolor='palegoldenrod')
|
||||
# for i in range(steps):
|
||||
# g.edge(str(i), "c_{k}", fillcolor="gray")
|
||||
|
||||
# g.render(filename, view=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#### visualize the supernet ####
|
||||
if len(sys.argv) != 2:
|
||||
print("usage:\n python {} ARCH_NAME".format(sys.argv[0]))
|
||||
sys.exit(1)
|
||||
|
||||
genotype_name = sys.argv[1]
|
||||
assert 'supernet' in genotype_name, 'this script only supports supernet visualization'
|
||||
try:
|
||||
genotype = eval('genotypes.{}'.format(genotype_name))
|
||||
except AttributeError:
|
||||
print("{} is not specified in genotypes.py".format(genotype_name))
|
||||
sys.exit(1)
|
||||
|
||||
path = '../../figs/genotypes/cnn_supernet_cue/'
|
||||
plot(genotype.normal, path + genotype_name + "_normal")
|
||||
plot(genotype.reduce, path + genotype_name + "_reduce")
|
Reference in New Issue
Block a user