Update TAS abd FBV2 for NAS-Bench
This commit is contained in:
@@ -5,13 +5,14 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import OPS
|
||||
|
||||
from models.cell_operations import OPS
|
||||
|
||||
|
||||
# Cell for NAS-Bench-201
|
||||
class InferCell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_in, C_out, stride):
|
||||
def __init__(self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True):
|
||||
super(InferCell, self).__init__()
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
@@ -24,9 +25,9 @@ class InferCell(nn.Module):
|
||||
cur_innod = []
|
||||
for (op_name, op_in) in node_info:
|
||||
if op_in == 0:
|
||||
layer = OPS[op_name](C_in , C_out, stride, True, True)
|
||||
layer = OPS[op_name](C_in , C_out, stride, affine, track_running_stats)
|
||||
else:
|
||||
layer = OPS[op_name](C_out, C_out, 1, True, True)
|
||||
layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats)
|
||||
cur_index.append( len(self.layers) )
|
||||
cur_innod.append( op_in )
|
||||
self.layers.append( layer )
|
||||
|
Reference in New Issue
Block a user