Upgrade NAS-Bench-201 to APIv1.3/FILEv1.1
This commit is contained in:
@@ -28,7 +28,10 @@ def main(xargs, nas_bench):
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10'
|
||||
if xargs.dataset == 'cifar10':
|
||||
dataname = 'cifar10-valid'
|
||||
else:
|
||||
dataname = xargs.dataset
|
||||
if xargs.data_path is not None:
|
||||
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
|
||||
split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
|
||||
@@ -62,7 +65,7 @@ def main(xargs, nas_bench):
|
||||
#for idx in range(xargs.random_num):
|
||||
while total_time_cost < xargs.time_budget:
|
||||
arch = random_arch()
|
||||
accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info)
|
||||
accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname)
|
||||
if total_time_cost + cost_time > xargs.time_budget: break
|
||||
else: total_time_cost += cost_time
|
||||
history.append(arch)
|
||||
|
Reference in New Issue
Block a user