Upgrade NAS-Bench-201 to APIv1.3/FILEv1.1
This commit is contained in:
@@ -98,7 +98,10 @@ def main(xargs, nas_bench):
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||
if xargs.dataset == 'cifar10':
|
||||
dataname = 'cifar10-valid'
|
||||
else:
|
||||
dataname = xargs.dataset
|
||||
if xargs.data_path is not None:
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
@@ -148,7 +151,7 @@ def main(xargs, nas_bench):
|
||||
start_time = time.time()
|
||||
log_prob, action = select_action( policy )
|
||||
arch = policy.generate_arch( action )
|
||||
reward, cost_time = train_and_eval(arch, nas_bench, extra_info)
|
||||
reward, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
|
||||
trace.append( (reward, arch) )
|
||||
# accumulate time
|
||||
if total_costs + cost_time < xargs.time_budget:
|
||||
|
Reference in New Issue
Block a user