Sync NATS-Bench's v1.0 and update algorithm names
This commit is contained in:
@@ -43,7 +43,7 @@ from models import get_cell_based_tiny_net, get_search_spaces
|
||||
from nats_bench import create
|
||||
|
||||
|
||||
# Ad-hoc for TuNAS
|
||||
# Ad-hoc for RL algorithms.
|
||||
class ExponentialMovingAverage(object):
|
||||
"""Class that maintains an exponential moving average."""
|
||||
|
||||
|
@@ -44,8 +44,8 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suf
|
||||
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
|
||||
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
|
||||
alg2name['channel-wise interpolation'] = 'tas-affine0_BN0-AWD0.001{:}'.format(suffix)
|
||||
alg2name['masking + Gumbel-Softmax'] = 'fbv2-affine0_BN0-AWD0.001{:}'.format(suffix)
|
||||
alg2name['masking + sampling'] = 'tunas-affine0_BN0-AWD0.0{:}'.format(suffix)
|
||||
alg2name['masking + Gumbel-Softmax'] = 'mask_gumbel-affine0_BN0-AWD0.001{:}'.format(suffix)
|
||||
alg2name['masking + sampling'] = 'mask_rl-affine0_BN0-AWD0.0{:}'.format(suffix)
|
||||
for alg, name in alg2name.items():
|
||||
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
|
||||
alg2data = OrderedDict()
|
||||
|
Reference in New Issue
Block a user