update NAS-Bench
This commit is contained in:
@@ -6,14 +6,15 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from typing import List, Text, Dict
|
||||
from .search_cells import NASNetSearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkSETN(nn.Module):
|
||||
|
||||
def __init__(self, C, N, steps, multiplier, stem_multiplier, num_classes, search_space, affine, track_running_stats):
|
||||
def __init__(self, C: int, N: int, steps: int, multiplier: int, stem_multiplier: int,
|
||||
num_classes: int, search_space: List[Text], affine: bool, track_running_stats: bool):
|
||||
super(NASNetworkSETN, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
@@ -45,6 +46,16 @@ class NASNetworkSETN(nn.Module):
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_normal_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||
self.arch_reduce_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||
self.mode = 'urs'
|
||||
self.dynamic_cell = None
|
||||
|
||||
def set_cal_mode(self, mode, dynamic_cell=None):
|
||||
assert mode in ['urs', 'joint', 'select', 'dynamic']
|
||||
self.mode = mode
|
||||
if mode == 'dynamic':
|
||||
self.dynamic_cell = deepcopy(dynamic_cell)
|
||||
else:
|
||||
self.dynamic_cell = None
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list( self.stem.parameters() ) + list( self.cells.parameters() )
|
||||
@@ -70,6 +81,24 @@ class NASNetworkSETN(nn.Module):
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def dync_genotype(self, use_random=False):
|
||||
genotypes = []
|
||||
with torch.no_grad():
|
||||
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = '{:}<-{:}'.format(i, j)
|
||||
if use_random:
|
||||
op_name = random.choice(self.op_names)
|
||||
else:
|
||||
weights = alphas_cpu[ self.edge2index[node_str] ]
|
||||
op_index = torch.multinomial(weights, 1).item()
|
||||
op_name = self.op_names[ op_index ]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append( tuple(xlist) )
|
||||
return Structure( genotypes )
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
@@ -94,9 +123,6 @@ class NASNetworkSETN(nn.Module):
|
||||
def forward(self, inputs):
|
||||
normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1)
|
||||
reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1)
|
||||
with torch.no_grad():
|
||||
normal_hardwts_cpu = normal_hardwts.detach().cpu()
|
||||
reduce_hardwts_cpu = reduce_hardwts.detach().cpu()
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
|
Reference in New Issue
Block a user