Add more algorithms
This commit is contained in:
173
lib/nas_infer_model/DXYs/base_cells.py
Normal file
173
lib/nas_infer_model/DXYs/base_cells.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import math
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .construct_utils import drop_path
|
||||
from ..operations import OPS, Identity, FactorizedReduce, ReLUConvBN
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
|
||||
def __init__(self, C, stride, PRIMITIVES):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
self.name2idx = {}
|
||||
for idx, primitive in enumerate(PRIMITIVES):
|
||||
op = OPS[primitive](C, C, stride, False)
|
||||
self._ops.append(op)
|
||||
assert primitive not in self.name2idx, '{:} has already in'.format(primitive)
|
||||
self.name2idx[primitive] = idx
|
||||
|
||||
def forward(self, x, weights, op_name):
|
||||
if op_name is None:
|
||||
if weights is None:
|
||||
return [op(x) for op in self._ops]
|
||||
else:
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
else:
|
||||
op_index = self.name2idx[op_name]
|
||||
return self._ops[op_index](x)
|
||||
|
||||
|
||||
|
||||
class SearchCell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, PRIMITIVES, use_residual):
|
||||
super(SearchCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.PRIMITIVES = deepcopy(PRIMITIVES)
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self._use_residual = use_residual
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride, self.PRIMITIVES)
|
||||
self._ops.append(op)
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(residual={_use_residual}, steps={_steps}, multiplier={_multiplier})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, S0, S1, weights, connect, adjacency, drop_prob, modes):
|
||||
if modes[0] is None:
|
||||
if modes[1] == 'normal':
|
||||
output = self.__forwardBoth(S0, S1, weights, connect, adjacency, drop_prob)
|
||||
elif modes[1] == 'only_W':
|
||||
output = self.__forwardOnlyW(S0, S1, drop_prob)
|
||||
else:
|
||||
test_genotype = modes[0]
|
||||
if self.reduction: operations, concats = test_genotype.reduce, test_genotype.reduce_concat
|
||||
else : operations, concats = test_genotype.normal, test_genotype.normal_concat
|
||||
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||||
states, offset = [s0, s1], 0
|
||||
assert self._steps == len(operations), '{:} vs. {:}'.format(self._steps, len(operations))
|
||||
for i, (opA, opB) in enumerate(operations):
|
||||
A = self._ops[offset + opA[1]](states[opA[1]], None, opA[0])
|
||||
B = self._ops[offset + opB[1]](states[opB[1]], None, opB[0])
|
||||
state = A + B
|
||||
offset += len(states)
|
||||
states.append(state)
|
||||
output = torch.cat([states[i] for i in concats], dim=1)
|
||||
if self._use_residual and S1.size() == output.size():
|
||||
return S1 + output
|
||||
else: return output
|
||||
|
||||
def __forwardBoth(self, S0, S1, weights, connect, adjacency, drop_prob):
|
||||
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||||
states, offset = [s0, s1], 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
x = self._ops[offset+j](h, weights[offset+j], None)
|
||||
if self.training and drop_prob > 0.:
|
||||
x = drop_path(x, math.pow(drop_prob, 1./len(states)))
|
||||
clist.append( x )
|
||||
connection = torch.mm(connect['{:}'.format(i)], adjacency[i]).squeeze(0)
|
||||
state = sum(w * node for w, node in zip(connection, clist))
|
||||
offset += len(states)
|
||||
states.append(state)
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
def __forwardOnlyW(self, S0, S1, drop_prob):
|
||||
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||||
states, offset = [s0, s1], 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
xs = self._ops[offset+j](h, None, None)
|
||||
clist += xs
|
||||
if self.training and drop_prob > 0.:
|
||||
xlist = [drop_path(x, math.pow(drop_prob, 1./len(states))) for x in clist]
|
||||
else: xlist = clist
|
||||
state = sum(xlist) * 2 / len(xlist)
|
||||
offset += len(states)
|
||||
states.append(state)
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
|
||||
class InferCell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(InferCell, self).__init__()
|
||||
print(C_prev_prev, C_prev, C)
|
||||
|
||||
if reduction_prev is None:
|
||||
self.preprocess0 = Identity()
|
||||
elif reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
|
||||
if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat
|
||||
else : step_ops, concat = genotype.normal, genotype.normal_concat
|
||||
self._steps = len(step_ops)
|
||||
self._concat = concat
|
||||
self._multiplier = len(concat)
|
||||
self._ops = nn.ModuleList()
|
||||
self._indices = []
|
||||
for operations in step_ops:
|
||||
for name, index in operations:
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
if reduction_prev is None and index == 0:
|
||||
op = OPS[name](C_prev_prev, C, stride, True)
|
||||
else:
|
||||
op = OPS[name](C , C, stride, True)
|
||||
self._ops.append( op )
|
||||
self._indices.append( index )
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(steps={_steps}, concat={_concat})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, S0, S1, drop_prob):
|
||||
s0 = self.preprocess0(S0)
|
||||
s1 = self.preprocess1(S1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2*i]]
|
||||
h2 = states[self._indices[2*i+1]]
|
||||
op1 = self._ops[2*i]
|
||||
op2 = self._ops[2*i+1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
|
||||
state = h1 + h2
|
||||
states += [state]
|
||||
output = torch.cat([states[i] for i in self._concat], dim=1)
|
||||
return output
|
Reference in New Issue
Block a user