Update TAS abd FBV2 for NAS-Bench

This commit is contained in:
D-X-Y
2020-07-24 12:56:34 +00:00
parent b9fbe5577c
commit 4a2292a863
8 changed files with 491 additions and 12 deletions

View File

@@ -5,13 +5,14 @@
import torch
import torch.nn as nn
from copy import deepcopy
from ..cell_operations import OPS
from models.cell_operations import OPS
# Cell for NAS-Bench-201
class InferCell(nn.Module):
def __init__(self, genotype, C_in, C_out, stride):
def __init__(self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True):
super(InferCell, self).__init__()
self.layers = nn.ModuleList()
@@ -24,9 +25,9 @@ class InferCell(nn.Module):
cur_innod = []
for (op_name, op_in) in node_info:
if op_in == 0:
layer = OPS[op_name](C_in , C_out, stride, True, True)
layer = OPS[op_name](C_in , C_out, stride, affine, track_running_stats)
else:
layer = OPS[op_name](C_out, C_out, 1, True, True)
layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats)
cur_index.append( len(self.layers) )
cur_innod.append( op_in )
self.layers.append( layer )