To answer issue #119

This commit is contained in:
D-X-Y
2022-03-20 23:12:12 -07:00
parent d2cef525f3
commit 8d0799dfb1
3 changed files with 35 additions and 4 deletions

View File

@@ -347,6 +347,10 @@ class GenericNAS201Model(nn.Module):
feature = cell.forward_gdas(feature, alphas, index)
if self.verbose:
verbose_str += "-forward_gdas"
elif self.mode == "gdas_v1":
feature = cell.forward_gdas_v1(feature, alphas, index)
if self.verbose:
verbose_str += "-forward_gdas_v1"
else:
raise ValueError("invalid mode={:}".format(self.mode))
else:

View File

@@ -85,6 +85,20 @@ class NAS201SearchCell(nn.Module):
nodes.append(sum(inter_nodes))
return nodes[-1]
# GDAS Variant: https://github.com/D-X-Y/AutoDL-Projects/issues/119
def forward_gdas_v1(self, inputs, hardwts, index):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = hardwts[self.edge2index[node_str]]
argmaxs = index[self.edge2index[node_str]].item()
weigsum = weights[argmaxs] * self.edges[node_str](nodes[j])
inter_nodes.append(weigsum)
nodes.append(sum(inter_nodes))
return nodes[-1]
# joint
def forward_joint(self, inputs, weightss):
nodes = [inputs]
@@ -152,6 +166,9 @@ class NAS201SearchCell(nn.Module):
return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class MixedOp(nn.Module):
def __init__(self, space, C, stride, affine, track_running_stats):
super(MixedOp, self).__init__()
@@ -167,7 +184,6 @@ class MixedOp(nn.Module):
return sum(w * op(x) for w, op in zip(weights, self._ops))
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetSearchCell(nn.Module):
def __init__(
self,