Add more algorithms
This commit is contained in:
60
lib/nas_infer_model/DXYs/construct_utils.py
Normal file
60
lib/nas_infer_model/DXYs/construct_utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1. - drop_prob
|
||||
mask = x.new_zeros(x.size(0), 1, 1, 1)
|
||||
mask = mask.bernoulli_(keep_prob)
|
||||
x = torch.div(x, keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
def return_alphas_str(basemodel):
|
||||
if hasattr(basemodel, 'alphas_normal'):
|
||||
string = 'normal [{:}] : \n-->>{:}'.format(basemodel.alphas_normal.size(), F.softmax(basemodel.alphas_normal, dim=-1) )
|
||||
else: string = ''
|
||||
if hasattr(basemodel, 'alphas_reduce'):
|
||||
string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) )
|
||||
|
||||
if hasattr(basemodel, 'get_adjacency'):
|
||||
adjacency = basemodel.get_adjacency()
|
||||
for i in range( len(adjacency) ):
|
||||
weight = F.softmax( basemodel.connect_normal[str(i)], dim=-1 )
|
||||
adj = torch.mm(weight, adjacency[i]).view(-1)
|
||||
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
|
||||
string = string + '\nnormal--{:}-->{:}'.format(i, ', '.join(adj))
|
||||
for i in range( len(adjacency) ):
|
||||
weight = F.softmax( basemodel.connect_reduce[str(i)], dim=-1 )
|
||||
adj = torch.mm(weight, adjacency[i]).view(-1)
|
||||
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
|
||||
string = string + '\nreduce--{:}-->{:}'.format(i, ', '.join(adj))
|
||||
|
||||
if hasattr(basemodel, 'alphas_connect'):
|
||||
weight = F.softmax(basemodel.alphas_connect, dim=-1).cpu()
|
||||
ZERO = ['{:.3f}'.format(x) for x in weight[:,0].tolist()]
|
||||
IDEN = ['{:.3f}'.format(x) for x in weight[:,1].tolist()]
|
||||
string = string + '\nconnect [{:}] : \n ->{:}\n ->{:}'.format( list(basemodel.alphas_connect.size()), ZERO, IDEN )
|
||||
else:
|
||||
string = string + '\nconnect = None'
|
||||
|
||||
if hasattr(basemodel, 'get_gcn_out'):
|
||||
outputs = basemodel.get_gcn_out(True)
|
||||
for i, output in enumerate(outputs):
|
||||
string = string + '\nnormal:[{:}] : {:}'.format(i, F.softmax(output, dim=-1) )
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def remove_duplicate_archs(all_archs):
|
||||
archs = []
|
||||
str_archs = ['{:}'.format(x) for x in all_archs]
|
||||
for i, arch_x in enumerate(str_archs):
|
||||
choose = True
|
||||
for j in range(i):
|
||||
if arch_x == str_archs[j]:
|
||||
choose = False; break
|
||||
if choose: archs.append(all_archs[i])
|
||||
return archs
|
Reference in New Issue
Block a user