To answer issue #119
This commit is contained in:
@@ -347,6 +347,10 @@ class GenericNAS201Model(nn.Module):
|
||||
feature = cell.forward_gdas(feature, alphas, index)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_gdas"
|
||||
elif self.mode == "gdas_v1":
|
||||
feature = cell.forward_gdas_v1(feature, alphas, index)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_gdas_v1"
|
||||
else:
|
||||
raise ValueError("invalid mode={:}".format(self.mode))
|
||||
else:
|
||||
|
@@ -85,6 +85,20 @@ class NAS201SearchCell(nn.Module):
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# GDAS Variant: https://github.com/D-X-Y/AutoDL-Projects/issues/119
|
||||
def forward_gdas_v1(self, inputs, hardwts, index):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = hardwts[self.edge2index[node_str]]
|
||||
argmaxs = index[self.edge2index[node_str]].item()
|
||||
weigsum = weights[argmaxs] * self.edges[node_str](nodes[j])
|
||||
inter_nodes.append(weigsum)
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# joint
|
||||
def forward_joint(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
@@ -152,6 +166,9 @@ class NAS201SearchCell(nn.Module):
|
||||
return nodes[-1]
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
def __init__(self, space, C, stride, affine, track_running_stats):
|
||||
super(MixedOp, self).__init__()
|
||||
@@ -167,7 +184,6 @@ class MixedOp(nn.Module):
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
class NASNetSearchCell(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user